274 lines
6.7 KiB
Go
274 lines
6.7 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"os/signal"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"gemserve/lib/apperrors"
|
|
"gemserve/lib/logging"
|
|
|
|
"gemserve/config"
|
|
"gemserve/gemini"
|
|
"gemserve/server"
|
|
|
|
"git.antanst.com/antanst/uid"
|
|
)
|
|
|
|
func main() {
|
|
config.CONFIG = *config.GetConfig()
|
|
logging.SetupLogging()
|
|
logger := logging.Logger
|
|
ctx := logging.WithLogger(context.Background(), logger)
|
|
err := runApp(ctx)
|
|
if err != nil {
|
|
logger.Error(fmt.Sprintf("Fatal Error: %v", err))
|
|
panic(fmt.Sprintf("Fatal Error: %v", err))
|
|
}
|
|
os.Exit(0)
|
|
}
|
|
|
|
func runApp(ctx context.Context) error {
|
|
logger := logging.FromContext(ctx)
|
|
logger.Info("Starting up. Press Ctrl+C to exit")
|
|
|
|
listenAddr := config.CONFIG.ListenAddr
|
|
|
|
signals := make(chan os.Signal, 1)
|
|
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
fatalErrors := make(chan error)
|
|
|
|
// Root server context, used to cancel
|
|
// connections and graceful shutdown.
|
|
serverCtx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
// WaitGroup to track active connections
|
|
var wg sync.WaitGroup
|
|
|
|
// Spawn server on the background.
|
|
// Returned errors are considered fatal.
|
|
go func() {
|
|
err := startServer(serverCtx, listenAddr, &wg, fatalErrors)
|
|
if err != nil {
|
|
fatalErrors <- apperrors.NewFatalError(fmt.Errorf("server startup failed: %w", err))
|
|
}
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case <-signals:
|
|
logger.Warn("Received SIGINT or SIGTERM signal, shutting down gracefully")
|
|
cancel()
|
|
wg.Wait()
|
|
return nil
|
|
case fatalError := <-fatalErrors:
|
|
cancel()
|
|
wg.Wait()
|
|
return fatalError
|
|
}
|
|
}
|
|
}
|
|
|
|
func startServer(ctx context.Context, listenAddr string, wg *sync.WaitGroup, fatalErrors chan<- error) (err error) {
|
|
logger := logging.FromContext(ctx)
|
|
|
|
cert, err := tls.LoadX509KeyPair(config.CONFIG.TLSCert, config.CONFIG.TLSKey)
|
|
if err != nil {
|
|
return apperrors.NewFatalError(fmt.Errorf("failed to load TLS certificate/key: %w", err))
|
|
}
|
|
|
|
logger.Debug("Using TLS cert", "path", config.CONFIG.TLSCert)
|
|
logger.Debug("Using TLS key", "path", config.CONFIG.TLSKey)
|
|
|
|
tlsConfig := &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
MinVersion: tls.VersionTLS12,
|
|
}
|
|
|
|
listener, err := tls.Listen("tcp", listenAddr, tlsConfig)
|
|
if err != nil {
|
|
return apperrors.NewFatalError(err)
|
|
}
|
|
|
|
defer func(listener net.Listener) {
|
|
_ = listener.Close()
|
|
}(listener)
|
|
|
|
// If context is cancelled, close listener
|
|
// to unblock Accept() inside main loop.
|
|
go func() {
|
|
<-ctx.Done()
|
|
_ = listener.Close()
|
|
}()
|
|
|
|
logger.Info("Server listening", "address", listenAddr)
|
|
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return nil
|
|
} // ctx cancellation
|
|
logger.Info("Failed to accept connection: %v", "error", err)
|
|
continue
|
|
}
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
// Type assert the connection to TLS connection
|
|
tlsConn, ok := conn.(*tls.Conn)
|
|
if !ok {
|
|
logger.Error("Connection is not a TLS connection")
|
|
_ = conn.Close()
|
|
return
|
|
}
|
|
|
|
remoteAddr := conn.RemoteAddr().String()
|
|
connId := uid.UID()
|
|
|
|
// Create a timeout context for this connection
|
|
connCtx, cancel := context.WithTimeout(ctx, time.Duration(config.CONFIG.ResponseTimeout)*time.Second)
|
|
defer cancel()
|
|
|
|
err := handleConnection(connCtx, tlsConn, connId, remoteAddr)
|
|
if err != nil {
|
|
if apperrors.IsFatal(err) {
|
|
fatalErrors <- err
|
|
return
|
|
}
|
|
logger.Info("Connection failed", "id", connId, "remoteAddr", remoteAddr, "error", err)
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
|
|
func closeConnection(conn *tls.Conn) error {
|
|
err := conn.CloseWrite()
|
|
if err != nil {
|
|
return apperrors.NewNetworkError(fmt.Errorf("failed to close TLS connection: %w", err))
|
|
}
|
|
err = conn.Close()
|
|
if err != nil {
|
|
return apperrors.NewNetworkError(fmt.Errorf("failed to close connection: %w", err))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func handleConnection(ctx context.Context, conn *tls.Conn, connId string, remoteAddr string) (err error) {
|
|
logger := logging.FromContext(ctx)
|
|
start := time.Now()
|
|
var outputBytes []byte
|
|
|
|
// Set connection deadline based on context
|
|
if deadline, ok := ctx.Deadline(); ok {
|
|
_ = conn.SetDeadline(deadline)
|
|
}
|
|
|
|
defer func(conn *tls.Conn) {
|
|
end := time.Now()
|
|
tookMs := end.Sub(start).Milliseconds()
|
|
var responseHeader string
|
|
|
|
// On non-errors, just log response and close connection.
|
|
if err == nil {
|
|
// Log non-erroneous responses
|
|
if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 {
|
|
responseHeader = string(outputBytes[:i])
|
|
}
|
|
logger.Info("Response", "connId", connId, "remoteAddr", remoteAddr, "responseHeader", responseHeader, "ms", tookMs)
|
|
_ = closeConnection(conn)
|
|
return
|
|
}
|
|
|
|
// Handle context cancellation/timeout
|
|
if errors.Is(err, context.DeadlineExceeded) {
|
|
logger.Info("Connection timeout", "id", connId, "remoteAddr", remoteAddr, "ms", tookMs)
|
|
responseHeader = fmt.Sprintf("%d Request timeout", gemini.StatusCGIError)
|
|
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
|
|
_ = closeConnection(conn)
|
|
return
|
|
}
|
|
if errors.Is(err, context.Canceled) {
|
|
logger.Info("Connection cancelled", "id", connId, "remoteAddr", remoteAddr, "ms", tookMs)
|
|
_ = closeConnection(conn)
|
|
return
|
|
}
|
|
|
|
var code int
|
|
var responseMsg string
|
|
if apperrors.IsFatal(err) {
|
|
_ = closeConnection(conn)
|
|
return
|
|
}
|
|
if apperrors.IsGeminiError(err) {
|
|
code = apperrors.GetStatusCode(err)
|
|
responseMsg = "server error"
|
|
} else {
|
|
code = gemini.StatusPermanentFailure
|
|
responseMsg = "server error"
|
|
}
|
|
responseHeader = fmt.Sprintf("%d %s", code, responseMsg)
|
|
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
|
|
_ = closeConnection(conn)
|
|
}(conn)
|
|
|
|
// Check context before starting
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Gemini is supposed to have a 1kb limit
|
|
// on input requests.
|
|
buffer := make([]byte, 1025)
|
|
|
|
n, err := conn.Read(buffer)
|
|
if err != nil && err != io.EOF {
|
|
return apperrors.NewGeminiError(fmt.Errorf("failed to read connection data: %w", err), gemini.StatusBadRequest)
|
|
}
|
|
if n == 0 {
|
|
return apperrors.NewGeminiError(fmt.Errorf("client did not send data"), gemini.StatusBadRequest)
|
|
}
|
|
if n > 1024 {
|
|
return apperrors.NewGeminiError(fmt.Errorf("client request size %d > 1024 bytes", n), gemini.StatusBadRequest)
|
|
}
|
|
|
|
// Check context after read
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
dataBytes := buffer[:n]
|
|
dataString := string(dataBytes)
|
|
|
|
logger.Info("Request", "id", connId, "remoteAddr", remoteAddr, "data", strings.TrimSpace(dataString), "size", len(dataBytes))
|
|
outputBytes, err = server.GenerateResponse(ctx, conn, connId, dataString)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Check context before write
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = conn.Write(outputBytes)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|