Files
gemserve/main.go
antanst 3a5835fc42 .
2025-10-09 17:43:23 +03:00

274 lines
6.7 KiB
Go

package main
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"gemserve/lib/apperrors"
"gemserve/lib/logging"
"gemserve/config"
"gemserve/gemini"
"gemserve/server"
"git.antanst.com/antanst/uid"
)
func main() {
config.CONFIG = *config.GetConfig()
logging.SetupLogging()
logger := logging.Logger
ctx := logging.WithLogger(context.Background(), logger)
err := runApp(ctx)
if err != nil {
logger.Error(fmt.Sprintf("Fatal Error: %v", err))
panic(fmt.Sprintf("Fatal Error: %v", err))
}
os.Exit(0)
}
func runApp(ctx context.Context) error {
logger := logging.FromContext(ctx)
logger.Info("Starting up. Press Ctrl+C to exit")
listenAddr := config.CONFIG.ListenAddr
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
fatalErrors := make(chan error)
// Root server context, used to cancel
// connections and graceful shutdown.
serverCtx, cancel := context.WithCancel(ctx)
defer cancel()
// WaitGroup to track active connections
var wg sync.WaitGroup
// Spawn server on the background.
// Returned errors are considered fatal.
go func() {
err := startServer(serverCtx, listenAddr, &wg, fatalErrors)
if err != nil {
fatalErrors <- apperrors.NewFatalError(fmt.Errorf("server startup failed: %w", err))
}
}()
for {
select {
case <-signals:
logger.Warn("Received SIGINT or SIGTERM signal, shutting down gracefully")
cancel()
wg.Wait()
return nil
case fatalError := <-fatalErrors:
cancel()
wg.Wait()
return fatalError
}
}
}
func startServer(ctx context.Context, listenAddr string, wg *sync.WaitGroup, fatalErrors chan<- error) (err error) {
logger := logging.FromContext(ctx)
cert, err := tls.LoadX509KeyPair(config.CONFIG.TLSCert, config.CONFIG.TLSKey)
if err != nil {
return apperrors.NewFatalError(fmt.Errorf("failed to load TLS certificate/key: %w", err))
}
logger.Debug("Using TLS cert", "path", config.CONFIG.TLSCert)
logger.Debug("Using TLS key", "path", config.CONFIG.TLSKey)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
}
listener, err := tls.Listen("tcp", listenAddr, tlsConfig)
if err != nil {
return apperrors.NewFatalError(err)
}
defer func(listener net.Listener) {
_ = listener.Close()
}(listener)
// If context is cancelled, close listener
// to unblock Accept() inside main loop.
go func() {
<-ctx.Done()
_ = listener.Close()
}()
logger.Info("Server listening", "address", listenAddr)
for {
conn, err := listener.Accept()
if err != nil {
if ctx.Err() != nil {
return nil
} // ctx cancellation
logger.Info("Failed to accept connection: %v", "error", err)
continue
}
wg.Add(1)
go func() {
defer wg.Done()
// Type assert the connection to TLS connection
tlsConn, ok := conn.(*tls.Conn)
if !ok {
logger.Error("Connection is not a TLS connection")
_ = conn.Close()
return
}
remoteAddr := conn.RemoteAddr().String()
connId := uid.UID()
// Create a timeout context for this connection
connCtx, cancel := context.WithTimeout(ctx, time.Duration(config.CONFIG.ResponseTimeout)*time.Second)
defer cancel()
err := handleConnection(connCtx, tlsConn, connId, remoteAddr)
if err != nil {
if apperrors.IsFatal(err) {
fatalErrors <- err
return
}
logger.Info("Connection failed", "id", connId, "remoteAddr", remoteAddr, "error", err)
}
}()
}
}
func closeConnection(conn *tls.Conn) error {
err := conn.CloseWrite()
if err != nil {
return apperrors.NewNetworkError(fmt.Errorf("failed to close TLS connection: %w", err))
}
err = conn.Close()
if err != nil {
return apperrors.NewNetworkError(fmt.Errorf("failed to close connection: %w", err))
}
return nil
}
func handleConnection(ctx context.Context, conn *tls.Conn, connId string, remoteAddr string) (err error) {
logger := logging.FromContext(ctx)
start := time.Now()
var outputBytes []byte
// Set connection deadline based on context
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetDeadline(deadline)
}
defer func(conn *tls.Conn) {
end := time.Now()
tookMs := end.Sub(start).Milliseconds()
var responseHeader string
// On non-errors, just log response and close connection.
if err == nil {
// Log non-erroneous responses
if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 {
responseHeader = string(outputBytes[:i])
}
logger.Info("Response", "connId", connId, "remoteAddr", remoteAddr, "responseHeader", responseHeader, "ms", tookMs)
_ = closeConnection(conn)
return
}
// Handle context cancellation/timeout
if errors.Is(err, context.DeadlineExceeded) {
logger.Info("Connection timeout", "id", connId, "remoteAddr", remoteAddr, "ms", tookMs)
responseHeader = fmt.Sprintf("%d Request timeout", gemini.StatusCGIError)
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
_ = closeConnection(conn)
return
}
if errors.Is(err, context.Canceled) {
logger.Info("Connection cancelled", "id", connId, "remoteAddr", remoteAddr, "ms", tookMs)
_ = closeConnection(conn)
return
}
var code int
var responseMsg string
if apperrors.IsFatal(err) {
_ = closeConnection(conn)
return
}
if apperrors.IsGeminiError(err) {
code = apperrors.GetStatusCode(err)
responseMsg = "server error"
} else {
code = gemini.StatusPermanentFailure
responseMsg = "server error"
}
responseHeader = fmt.Sprintf("%d %s", code, responseMsg)
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
_ = closeConnection(conn)
}(conn)
// Check context before starting
if err := ctx.Err(); err != nil {
return err
}
// Gemini is supposed to have a 1kb limit
// on input requests.
buffer := make([]byte, 1025)
n, err := conn.Read(buffer)
if err != nil && err != io.EOF {
return apperrors.NewGeminiError(fmt.Errorf("failed to read connection data: %w", err), gemini.StatusBadRequest)
}
if n == 0 {
return apperrors.NewGeminiError(fmt.Errorf("client did not send data"), gemini.StatusBadRequest)
}
if n > 1024 {
return apperrors.NewGeminiError(fmt.Errorf("client request size %d > 1024 bytes", n), gemini.StatusBadRequest)
}
// Check context after read
if err := ctx.Err(); err != nil {
return err
}
dataBytes := buffer[:n]
dataString := string(dataBytes)
logger.Info("Request", "id", connId, "remoteAddr", remoteAddr, "data", strings.TrimSpace(dataString), "size", len(dataBytes))
outputBytes, err = server.GenerateResponse(ctx, conn, connId, dataString)
if err != nil {
return err
}
// Check context before write
if err := ctx.Err(); err != nil {
return err
}
_, err = conn.Write(outputBytes)
if err != nil {
return err
}
return nil
}