package main import ( "bytes" "context" "crypto/tls" "errors" "fmt" "io" "net" "os" "os/signal" "strings" "sync" "syscall" "time" "gemserve/lib/apperrors" "gemserve/lib/logging" "gemserve/config" "gemserve/gemini" "gemserve/server" "git.antanst.com/antanst/uid" ) func main() { config.CONFIG = *config.GetConfig() logging.SetupLogging() logger := logging.Logger ctx := logging.WithLogger(context.Background(), logger) err := runApp(ctx) if err != nil { logger.Error(fmt.Sprintf("Fatal Error: %v", err)) panic(fmt.Sprintf("Fatal Error: %v", err)) } os.Exit(0) } func runApp(ctx context.Context) error { logger := logging.FromContext(ctx) logger.Info("Starting up. Press Ctrl+C to exit") listenAddr := config.CONFIG.ListenAddr signals := make(chan os.Signal, 1) signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) fatalErrors := make(chan error) // Root server context, used to cancel // connections and graceful shutdown. serverCtx, cancel := context.WithCancel(ctx) defer cancel() // WaitGroup to track active connections var wg sync.WaitGroup // Spawn server on the background. // Returned errors are considered fatal. go func() { err := startServer(serverCtx, listenAddr, &wg, fatalErrors) if err != nil { fatalErrors <- apperrors.NewFatalError(fmt.Errorf("server startup failed: %w", err)) } }() for { select { case <-signals: logger.Warn("Received SIGINT or SIGTERM signal, shutting down gracefully") cancel() wg.Wait() return nil case fatalError := <-fatalErrors: cancel() wg.Wait() return fatalError } } } func startServer(ctx context.Context, listenAddr string, wg *sync.WaitGroup, fatalErrors chan<- error) (err error) { logger := logging.FromContext(ctx) cert, err := tls.LoadX509KeyPair(config.CONFIG.TLSCert, config.CONFIG.TLSKey) if err != nil { return apperrors.NewFatalError(fmt.Errorf("failed to load TLS certificate/key: %w", err)) } logger.Debug("Using TLS cert", "path", config.CONFIG.TLSCert) logger.Debug("Using TLS key", "path", config.CONFIG.TLSKey) tlsConfig := &tls.Config{ Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12, } listener, err := tls.Listen("tcp", listenAddr, tlsConfig) if err != nil { return apperrors.NewFatalError(err) } defer func(listener net.Listener) { _ = listener.Close() }(listener) // If context is cancelled, close listener // to unblock Accept() inside main loop. go func() { <-ctx.Done() _ = listener.Close() }() logger.Info("Server listening", "address", listenAddr) for { conn, err := listener.Accept() if err != nil { if ctx.Err() != nil { return nil } // ctx cancellation logger.Info("Failed to accept connection: %v", "error", err) continue } wg.Add(1) go func() { defer wg.Done() // Type assert the connection to TLS connection tlsConn, ok := conn.(*tls.Conn) if !ok { logger.Error("Connection is not a TLS connection") _ = conn.Close() return } remoteAddr := conn.RemoteAddr().String() connId := uid.UID() // Create a timeout context for this connection connCtx, cancel := context.WithTimeout(ctx, time.Duration(config.CONFIG.ResponseTimeout)*time.Second) defer cancel() err := handleConnection(connCtx, tlsConn, connId, remoteAddr) if err != nil { if apperrors.IsFatal(err) { fatalErrors <- err return } logger.Info("Connection failed", "id", connId, "remoteAddr", remoteAddr, "error", err) } }() } } func closeConnection(conn *tls.Conn) error { err := conn.CloseWrite() if err != nil { return apperrors.NewNetworkError(fmt.Errorf("failed to close TLS connection: %w", err)) } err = conn.Close() if err != nil { return apperrors.NewNetworkError(fmt.Errorf("failed to close connection: %w", err)) } return nil } func handleConnection(ctx context.Context, conn *tls.Conn, connId string, remoteAddr string) (err error) { logger := logging.FromContext(ctx) start := time.Now() var outputBytes []byte // Set connection deadline based on context if deadline, ok := ctx.Deadline(); ok { _ = conn.SetDeadline(deadline) } defer func(conn *tls.Conn) { end := time.Now() tookMs := end.Sub(start).Milliseconds() var responseHeader string // On non-errors, just log response and close connection. if err == nil { // Log non-erroneous responses if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 { responseHeader = string(outputBytes[:i]) } logger.Info("Response", "connId", connId, "remoteAddr", remoteAddr, "responseHeader", responseHeader, "ms", tookMs) _ = closeConnection(conn) return } // Handle context cancellation/timeout if errors.Is(err, context.DeadlineExceeded) { logger.Info("Connection timeout", "id", connId, "remoteAddr", remoteAddr, "ms", tookMs) responseHeader = fmt.Sprintf("%d Request timeout", gemini.StatusCGIError) _, _ = conn.Write([]byte(responseHeader + "\r\n")) _ = closeConnection(conn) return } if errors.Is(err, context.Canceled) { logger.Info("Connection cancelled", "id", connId, "remoteAddr", remoteAddr, "ms", tookMs) _ = closeConnection(conn) return } var code int var responseMsg string if apperrors.IsFatal(err) { _ = closeConnection(conn) return } if apperrors.IsGeminiError(err) { code = apperrors.GetStatusCode(err) responseMsg = "server error" } else { code = gemini.StatusPermanentFailure responseMsg = "server error" } responseHeader = fmt.Sprintf("%d %s", code, responseMsg) _, _ = conn.Write([]byte(responseHeader + "\r\n")) _ = closeConnection(conn) }(conn) // Check context before starting if err := ctx.Err(); err != nil { return err } // Gemini is supposed to have a 1kb limit // on input requests. buffer := make([]byte, 1025) n, err := conn.Read(buffer) if err != nil && err != io.EOF { return apperrors.NewGeminiError(fmt.Errorf("failed to read connection data: %w", err), gemini.StatusBadRequest) } if n == 0 { return apperrors.NewGeminiError(fmt.Errorf("client did not send data"), gemini.StatusBadRequest) } if n > 1024 { return apperrors.NewGeminiError(fmt.Errorf("client request size %d > 1024 bytes", n), gemini.StatusBadRequest) } // Check context after read if err := ctx.Err(); err != nil { return err } dataBytes := buffer[:n] dataString := string(dataBytes) logger.Info("Request", "id", connId, "remoteAddr", remoteAddr, "data", strings.TrimSpace(dataString), "size", len(dataBytes)) outputBytes, err = server.GenerateResponse(ctx, conn, connId, dataString) if err != nil { return err } // Check context before write if err := ctx.Err(); err != nil { return err } _, err = conn.Write(outputBytes) if err != nil { return err } return nil }