diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..9fea5c5 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,87 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Gemserve is a simple Gemini protocol server written in Go that serves static files over TLS-encrypted connections. The Gemini protocol is a lightweight, privacy-focused alternative to HTTP designed for serving text-based content. + +## Development Commands + +```bash +# Build, test, and format everything +make + +# Run tests only +make test + +# Build binary to ./dist/gemserve +make build + +# Format code with gofumpt and gci +make fmt + +# Run golangci-lint +make lint + +# Run linter with auto-fix +make lintfix + +# Clean build artifacts +make clean + +# Run the server (after building) +./dist/gemserve + +# Generate TLS certificates for development +certs/generate.sh +``` + +## Architecture + +### Core Components + +- **main.go**: Entry point with TLS server setup, signal handling, and graceful shutdown +- **server/**: Request processing, file serving, and Gemini protocol response handling +- **gemini/**: Gemini protocol implementation (URL parsing, status codes, path normalization) +- **config/**: CLI-based configuration system +- **uid/**: Connection ID generation for logging + +### Key Patterns + +- **Security First**: All file operations use `filepath.IsLocal()` and path cleaning to prevent directory traversal +- **Error Handling**: Uses structured errors with `xerrors` package for consistent error propagation +- **Logging**: Structured logging with configurable levels via internal logging package +- **Testing**: Table-driven tests with parallel execution, heavy focus on security edge cases + +### Request Flow + +1. TLS connection established on port 1965 +2. Read up to 1KB request (Gemini spec limit) +3. Parse and normalize Gemini URL +4. Validate path security (prevent traversal) +5. Serve file or directory index with appropriate MIME type +6. Send response with proper Gemini status codes + +## Configuration + +Server configured via CLI flags: +- `--listen`: Server address (default: localhost:1965) +- `--root-path`: Directory to serve files from +- `--dir-indexing`: Enable directory browsing +- `--log-level`: Logging verbosity +- `--response-timeout`: Response timeout in seconds + +## Testing Strategy + +- **server/server_test.go**: Path security and file serving tests +- **gemini/url_test.go**: URL parsing and normalization tests +- Focus on security edge cases (Unicode, traversal attempts, malformed URLs) +- Use parallel test execution for performance + +## Security Considerations + +- All connections require TLS certificates (stored in certs/) +- Path traversal protection is critical - test thoroughly when modifying file serving logic +- Request size limited to 1KB per Gemini specification +- Input validation on all URLs and paths \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index 9fea5c5..43c994c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,87 +1 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Project Overview - -Gemserve is a simple Gemini protocol server written in Go that serves static files over TLS-encrypted connections. The Gemini protocol is a lightweight, privacy-focused alternative to HTTP designed for serving text-based content. - -## Development Commands - -```bash -# Build, test, and format everything -make - -# Run tests only -make test - -# Build binary to ./dist/gemserve -make build - -# Format code with gofumpt and gci -make fmt - -# Run golangci-lint -make lint - -# Run linter with auto-fix -make lintfix - -# Clean build artifacts -make clean - -# Run the server (after building) -./dist/gemserve - -# Generate TLS certificates for development -certs/generate.sh -``` - -## Architecture - -### Core Components - -- **main.go**: Entry point with TLS server setup, signal handling, and graceful shutdown -- **server/**: Request processing, file serving, and Gemini protocol response handling -- **gemini/**: Gemini protocol implementation (URL parsing, status codes, path normalization) -- **config/**: CLI-based configuration system -- **uid/**: Connection ID generation for logging - -### Key Patterns - -- **Security First**: All file operations use `filepath.IsLocal()` and path cleaning to prevent directory traversal -- **Error Handling**: Uses structured errors with `xerrors` package for consistent error propagation -- **Logging**: Structured logging with configurable levels via internal logging package -- **Testing**: Table-driven tests with parallel execution, heavy focus on security edge cases - -### Request Flow - -1. TLS connection established on port 1965 -2. Read up to 1KB request (Gemini spec limit) -3. Parse and normalize Gemini URL -4. Validate path security (prevent traversal) -5. Serve file or directory index with appropriate MIME type -6. Send response with proper Gemini status codes - -## Configuration - -Server configured via CLI flags: -- `--listen`: Server address (default: localhost:1965) -- `--root-path`: Directory to serve files from -- `--dir-indexing`: Enable directory browsing -- `--log-level`: Logging verbosity -- `--response-timeout`: Response timeout in seconds - -## Testing Strategy - -- **server/server_test.go**: Path security and file serving tests -- **gemini/url_test.go**: URL parsing and normalization tests -- Focus on security edge cases (Unicode, traversal attempts, malformed URLs) -- Use parallel test execution for performance - -## Security Considerations - -- All connections require TLS certificates (stored in certs/) -- Path traversal protection is critical - test thoroughly when modifying file serving logic -- Request size limited to 1KB per Gemini specification -- Input validation on all URLs and paths \ No newline at end of file +@AGENTS.md diff --git a/Makefile b/Makefile index 9ff6a24..f7aebe4 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,8 @@ lintfix: fmt build: clean mkdir -p ./dist - go build -mod=vendor -o ./dist/gemserve ./main.go + go build -mod=vendor -o ./dist/gemserve ./cmd/gemserve/gemserve.go + go build -mod=vendor -o ./dist/gemget ./cmd/gemget/gemget.go build-docker: build docker build -t gemserve . diff --git a/cmd/gemget/gemget.go b/cmd/gemget/gemget.go new file mode 100644 index 0000000..ac7f75a --- /dev/null +++ b/cmd/gemget/gemget.go @@ -0,0 +1,86 @@ +package main + +import ( + "crypto/tls" + "flag" + "fmt" + "io" + "net/url" + "os" + "strings" +) + +func main() { + // Parse command-line flags + insecure := flag.Bool("insecure", false, "Skip TLS certificate verification") + flag.Parse() + + // Get the URL from arguments + args := flag.Args() + if len(args) != 1 { + fmt.Fprintf(os.Stderr, "Usage: gemget [--insecure] \n") + os.Exit(1) + } + + geminiURL := args[0] + + host := validateUrl(geminiURL) + + connect(geminiURL, host, *insecure) +} + +func validateUrl(u string) string { + // Parse the URL + parsedURL, err := url.Parse(u) + if err != nil { + fmt.Fprintf(os.Stderr, "Error parsing URL: %v\n", err) + os.Exit(1) + } + + // Ensure it's a gemini URL + if parsedURL.Scheme != "gemini" { + fmt.Fprintf(os.Stderr, "Error: URL must use gemini:// scheme\n") + os.Exit(1) + } + + // Get host and port + host := parsedURL.Host + if !strings.Contains(host, ":") { + host = host + ":1965" // Default Gemini port + } + + return host +} + +func connect(url string, host string, insecure bool) { + // Configure TLS + tlsConfig := &tls.Config{ + InsecureSkipVerify: insecure, + MinVersion: tls.VersionTLS12, + } + + // Connect to the server + conn, err := tls.Dial("tcp", host, tlsConfig) + if err != nil { + fmt.Fprintf(os.Stderr, "Error connecting to server: %v\n", err) + os.Exit(1) + } + defer func() { + _ = conn.Close() + }() + + // Send the request (URL + CRLF) + request := url + "\r\n" + _, err = conn.Write([]byte(request)) + if err != nil { + fmt.Fprintf(os.Stderr, "Error sending request: %v\n", err) + os.Exit(1) + } + + // Read and print the response to stdout + _, err = io.Copy(os.Stdout, conn) + if err != nil { + fmt.Fprintf(os.Stderr, "Error reading response: %v\n", err) + os.Exit(1) + } +} diff --git a/cmd/gemserve/gemserve.go b/cmd/gemserve/gemserve.go new file mode 100644 index 0000000..96c319f --- /dev/null +++ b/cmd/gemserve/gemserve.go @@ -0,0 +1,160 @@ +package main + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "gemserve/lib/apperrors" + "gemserve/lib/logging" + + "gemserve/config" + "gemserve/server" + + "git.antanst.com/antanst/uid" +) + +func main() { + config.CONFIG = *config.GetConfig() + logging.SetupLogging() + logger := logging.Logger + ctx := logging.WithLogger(context.Background(), logger) + err := runApp(ctx) + if err != nil { + logger.Error("Fatal Error", "err", err) + panic(fmt.Sprintf("Fatal Error: %v", err)) + } + os.Exit(0) +} + +func runApp(ctx context.Context) error { + logger := logging.FromContext(ctx) + logger.Info("Starting up. Press Ctrl+C to exit") + + listenAddr := config.CONFIG.ListenAddr + + signals := make(chan os.Signal, 1) + signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) + + fatalErrors := make(chan error) + + // Root server context, used to cancel + // connections and graceful shutdown. + serverCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // WaitGroup to track active connections + // in order to be able to wait until + // they are properly dropped + var wg sync.WaitGroup + + // Spawn server on the background. + // Returned errors are considered fatal. + go func() { + err := startServer(serverCtx, listenAddr, &wg, fatalErrors) + if err != nil { + fatalErrors <- apperrors.NewFatalError(fmt.Errorf("server startup failed: %w", err)) + } + }() + + for { + select { + case <-signals: + logger.Warn("Received SIGINT or SIGTERM signal, shutting down gracefully") + cancel() + wg.Wait() + return nil + case fatalError := <-fatalErrors: + cancel() + wg.Wait() + return fatalError + } + } +} + +func startServer(ctx context.Context, listenAddr string, wg *sync.WaitGroup, fatalErrors chan<- error) (err error) { + logger := logging.FromContext(ctx) + + cert, err := tls.LoadX509KeyPair(config.CONFIG.TLSCert, config.CONFIG.TLSKey) + if err != nil { + return apperrors.NewFatalError(fmt.Errorf("failed to load TLS certificate/key: %w", err)) + } + + logger.Debug("Using TLS cert", "path", config.CONFIG.TLSCert) + logger.Debug("Using TLS key", "path", config.CONFIG.TLSKey) + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + } + + listener, err := tls.Listen("tcp", listenAddr, tlsConfig) + if err != nil { + return apperrors.NewFatalError(err) + } + + defer func(listener net.Listener) { + _ = listener.Close() + }(listener) + + // If context is cancelled, close listener + // to unblock Accept() inside main loop. + go func() { + <-ctx.Done() + _ = listener.Close() + }() + + logger.Info("Server listening", "address", listenAddr) + + for { + conn, err := listener.Accept() + if err != nil { + if ctx.Err() != nil { + return nil + } // ctx cancellation + logger.Info("Failed to accept connection: %v", "error", err) + continue + } + + // At this point we have a new connection. + wg.Add(1) + go func() { + defer wg.Done() + + // Type assert the connection to TLS connection + tlsConn, ok := conn.(*tls.Conn) + if !ok { + logger.Error("Connection is not a TLS connection") + _ = conn.Close() + return + } + + remoteAddr := conn.RemoteAddr().String() + connId := uid.UID() + + // Create cancellable connection context + // with connection ID. + connLogger := logging.WithAttr(logger, "id", connId) + connLogger = logging.WithAttr(connLogger, "remoteAddr", remoteAddr) + connCtx := context.WithValue(ctx, server.CtxConnIdKey, connId) + connCtx = context.WithValue(connCtx, logging.CtxLoggerKey, connLogger) + connCtx, cancel := context.WithTimeout(connCtx, time.Duration(config.CONFIG.ResponseTimeout)*time.Second) + defer cancel() + + err := server.HandleConnection(connCtx, tlsConn) + if err != nil { + if apperrors.IsFatal(err) { + fatalErrors <- err + return + } + connLogger.Info("Connection failed", "error", err) + } + }() + } +} diff --git a/config/config.go b/config/config.go index 628b10c..4c972ab 100644 --- a/config/config.go +++ b/config/config.go @@ -17,6 +17,7 @@ type Config struct { ListenAddr string // Address to listen on TLSCert string // TLS certificate file TLSKey string // TLS key file + MaxResponseSize int // Max response size in bytes } var CONFIG Config //nolint:gochecknoglobals @@ -47,6 +48,7 @@ func GetConfig() *Config { listen := flag.String("listen", "localhost:1965", "Address to listen on") tlsCert := flag.String("tls-cert", "certs/server.crt", "TLS certificate file") tlsKey := flag.String("tls-key", "certs/server.key", "TLS key file") + maxResponseSize := flag.Int("max-response-size", 5_242_880, "Max response size in bytes") flag.Parse() @@ -71,5 +73,6 @@ func GetConfig() *Config { ListenAddr: *listen, TLSCert: *tlsCert, TLSKey: *tlsKey, + MaxResponseSize: *maxResponseSize, } } diff --git a/gemini/statusCodes.go b/gemini/statusCodes.go index 2f62b62..e5e706f 100644 --- a/gemini/statusCodes.go +++ b/gemini/statusCodes.go @@ -6,8 +6,7 @@ const ( // Input group StatusInputExpected = 10 StatusInputExpectedSensitive = 11 - - StatusSuccess = 20 + StatusSuccess = 20 // Redirect group StatusRedirectTemporary = 30 diff --git a/lib/logging/logging.go b/lib/logging/logging.go index 66d8a36..3b5aa18 100644 --- a/lib/logging/logging.go +++ b/lib/logging/logging.go @@ -11,9 +11,9 @@ import ( "github.com/lmittmann/tint" ) -type contextKey int +type contextKey string -const loggerKey contextKey = 0 +const CtxLoggerKey contextKey = "logger" var ( programLevel *slog.LevelVar = new(slog.LevelVar) // Info by default @@ -21,11 +21,15 @@ var ( ) func WithLogger(ctx context.Context, logger *slog.Logger) context.Context { - return context.WithValue(ctx, loggerKey, logger) + return context.WithValue(ctx, CtxLoggerKey, logger) +} + +func WithAttr(logger *slog.Logger, attrName string, attrValue interface{}) *slog.Logger { + return logger.With(attrName, attrValue) } func FromContext(ctx context.Context) *slog.Logger { - if logger, ok := ctx.Value(loggerKey).(*slog.Logger); ok { + if logger, ok := ctx.Value(CtxLoggerKey).(*slog.Logger); ok { return logger } // Return default logger instead of panicking diff --git a/main.go b/main.go deleted file mode 100644 index 70f59a9..0000000 --- a/main.go +++ /dev/null @@ -1,273 +0,0 @@ -package main - -import ( - "bytes" - "context" - "crypto/tls" - "errors" - "fmt" - "io" - "net" - "os" - "os/signal" - "strings" - "sync" - "syscall" - "time" - - "gemserve/lib/apperrors" - "gemserve/lib/logging" - - "gemserve/config" - "gemserve/gemini" - "gemserve/server" - - "git.antanst.com/antanst/uid" -) - -func main() { - config.CONFIG = *config.GetConfig() - logging.SetupLogging() - logger := logging.Logger - ctx := logging.WithLogger(context.Background(), logger) - err := runApp(ctx) - if err != nil { - logger.Error(fmt.Sprintf("Fatal Error: %v", err)) - panic(fmt.Sprintf("Fatal Error: %v", err)) - } - os.Exit(0) -} - -func runApp(ctx context.Context) error { - logger := logging.FromContext(ctx) - logger.Info("Starting up. Press Ctrl+C to exit") - - listenAddr := config.CONFIG.ListenAddr - - signals := make(chan os.Signal, 1) - signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) - - fatalErrors := make(chan error) - - // Root server context, used to cancel - // connections and graceful shutdown. - serverCtx, cancel := context.WithCancel(ctx) - defer cancel() - - // WaitGroup to track active connections - var wg sync.WaitGroup - - // Spawn server on the background. - // Returned errors are considered fatal. - go func() { - err := startServer(serverCtx, listenAddr, &wg, fatalErrors) - if err != nil { - fatalErrors <- apperrors.NewFatalError(fmt.Errorf("server startup failed: %w", err)) - } - }() - - for { - select { - case <-signals: - logger.Warn("Received SIGINT or SIGTERM signal, shutting down gracefully") - cancel() - wg.Wait() - return nil - case fatalError := <-fatalErrors: - cancel() - wg.Wait() - return fatalError - } - } -} - -func startServer(ctx context.Context, listenAddr string, wg *sync.WaitGroup, fatalErrors chan<- error) (err error) { - logger := logging.FromContext(ctx) - - cert, err := tls.LoadX509KeyPair(config.CONFIG.TLSCert, config.CONFIG.TLSKey) - if err != nil { - return apperrors.NewFatalError(fmt.Errorf("failed to load TLS certificate/key: %w", err)) - } - - logger.Debug("Using TLS cert", "path", config.CONFIG.TLSCert) - logger.Debug("Using TLS key", "path", config.CONFIG.TLSKey) - - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS12, - } - - listener, err := tls.Listen("tcp", listenAddr, tlsConfig) - if err != nil { - return apperrors.NewFatalError(err) - } - - defer func(listener net.Listener) { - _ = listener.Close() - }(listener) - - // If context is cancelled, close listener - // to unblock Accept() inside main loop. - go func() { - <-ctx.Done() - _ = listener.Close() - }() - - logger.Info("Server listening", "address", listenAddr) - - for { - conn, err := listener.Accept() - if err != nil { - if ctx.Err() != nil { - return nil - } // ctx cancellation - logger.Info("Failed to accept connection: %v", "error", err) - continue - } - - wg.Add(1) - go func() { - defer wg.Done() - - // Type assert the connection to TLS connection - tlsConn, ok := conn.(*tls.Conn) - if !ok { - logger.Error("Connection is not a TLS connection") - _ = conn.Close() - return - } - - remoteAddr := conn.RemoteAddr().String() - connId := uid.UID() - - // Create a timeout context for this connection - connCtx, cancel := context.WithTimeout(ctx, time.Duration(config.CONFIG.ResponseTimeout)*time.Second) - defer cancel() - - err := handleConnection(connCtx, tlsConn, connId, remoteAddr) - if err != nil { - if apperrors.IsFatal(err) { - fatalErrors <- err - return - } - logger.Info("Connection failed", "id", connId, "remoteAddr", remoteAddr, "error", err) - } - }() - } -} - -func closeConnection(conn *tls.Conn) error { - err := conn.CloseWrite() - if err != nil { - return apperrors.NewNetworkError(fmt.Errorf("failed to close TLS connection: %w", err)) - } - err = conn.Close() - if err != nil { - return apperrors.NewNetworkError(fmt.Errorf("failed to close connection: %w", err)) - } - return nil -} - -func handleConnection(ctx context.Context, conn *tls.Conn, connId string, remoteAddr string) (err error) { - logger := logging.FromContext(ctx) - start := time.Now() - var outputBytes []byte - - // Set connection deadline based on context - if deadline, ok := ctx.Deadline(); ok { - _ = conn.SetDeadline(deadline) - } - - defer func(conn *tls.Conn) { - end := time.Now() - tookMs := end.Sub(start).Milliseconds() - var responseHeader string - - // On non-errors, just log response and close connection. - if err == nil { - // Log non-erroneous responses - if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 { - responseHeader = string(outputBytes[:i]) - } - logger.Info("Response", "connId", connId, "remoteAddr", remoteAddr, "responseHeader", responseHeader, "ms", tookMs) - _ = closeConnection(conn) - return - } - - // Handle context cancellation/timeout - if errors.Is(err, context.DeadlineExceeded) { - logger.Info("Connection timeout", "id", connId, "remoteAddr", remoteAddr, "ms", tookMs) - responseHeader = fmt.Sprintf("%d Request timeout", gemini.StatusCGIError) - _, _ = conn.Write([]byte(responseHeader + "\r\n")) - _ = closeConnection(conn) - return - } - if errors.Is(err, context.Canceled) { - logger.Info("Connection cancelled", "id", connId, "remoteAddr", remoteAddr, "ms", tookMs) - _ = closeConnection(conn) - return - } - - var code int - var responseMsg string - if apperrors.IsFatal(err) { - _ = closeConnection(conn) - return - } - if apperrors.IsGeminiError(err) { - code = apperrors.GetStatusCode(err) - responseMsg = "server error" - } else { - code = gemini.StatusPermanentFailure - responseMsg = "server error" - } - responseHeader = fmt.Sprintf("%d %s", code, responseMsg) - _, _ = conn.Write([]byte(responseHeader + "\r\n")) - _ = closeConnection(conn) - }(conn) - - // Check context before starting - if err := ctx.Err(); err != nil { - return err - } - - // Gemini is supposed to have a 1kb limit - // on input requests. - buffer := make([]byte, 1025) - - n, err := conn.Read(buffer) - if err != nil && err != io.EOF { - return apperrors.NewGeminiError(fmt.Errorf("failed to read connection data: %w", err), gemini.StatusBadRequest) - } - if n == 0 { - return apperrors.NewGeminiError(fmt.Errorf("client did not send data"), gemini.StatusBadRequest) - } - if n > 1024 { - return apperrors.NewGeminiError(fmt.Errorf("client request size %d > 1024 bytes", n), gemini.StatusBadRequest) - } - - // Check context after read - if err := ctx.Err(); err != nil { - return err - } - - dataBytes := buffer[:n] - dataString := string(dataBytes) - - logger.Info("Request", "id", connId, "remoteAddr", remoteAddr, "data", strings.TrimSpace(dataString), "size", len(dataBytes)) - outputBytes, err = server.GenerateResponse(ctx, conn, connId, dataString) - if err != nil { - return err - } - - // Check context before write - if err := ctx.Err(); err != nil { - return err - } - - _, err = conn.Write(outputBytes) - if err != nil { - return err - } - return nil -} diff --git a/server/server.go b/server/server.go index 0742982..cb355df 100644 --- a/server/server.go +++ b/server/server.go @@ -1,9 +1,12 @@ package server import ( + "bytes" "context" "crypto/tls" + "errors" "fmt" + "io" "net" "net/url" "os" @@ -11,6 +14,7 @@ import ( "path/filepath" "strconv" "strings" + "time" "unicode/utf8" "gemserve/lib/apperrors" @@ -22,11 +26,27 @@ import ( "github.com/gabriel-vasile/mimetype" ) +type contextKey string + +const CtxConnIdKey contextKey = "connId" + type ServerConfig interface { DirIndexingEnabled() bool RootPath() string } +func CloseConnection(conn *tls.Conn) error { + err := conn.CloseWrite() + if err != nil { + return apperrors.NewNetworkError(fmt.Errorf("failed to close TLS connection: %w", err)) + } + err = conn.Close() + if err != nil { + return apperrors.NewNetworkError(fmt.Errorf("failed to close connection: %w", err)) + } + return nil +} + func checkRequestURL(url *gemini.URL) error { if !utf8.ValidString(url.String()) { return apperrors.NewGeminiError(fmt.Errorf("invalid URL"), gemini.StatusBadRequest) @@ -45,13 +65,19 @@ func checkRequestURL(url *gemini.URL) error { return apperrors.NewGeminiError(fmt.Errorf("invalid server listen port: %w", err), gemini.StatusBadRequest) } if url.Port != listenPort { - return apperrors.NewGeminiError(fmt.Errorf("failed to parse URL: %w", err), gemini.StatusProxyRequestRefused) + return apperrors.NewGeminiError(fmt.Errorf("port mismatch"), gemini.StatusProxyRequestRefused) } return nil } -func GenerateResponse(ctx context.Context, conn *tls.Conn, connId string, input string) ([]byte, error) { +func GenerateResponse(ctx context.Context, conn *tls.Conn, input string) ([]byte, error) { logger := logging.FromContext(ctx) + connId := ctx.Value(CtxConnIdKey).(string) + + if err := ctx.Err(); err != nil { + return nil, err + } + trimmedInput := strings.TrimSpace(input) // url will have a cleaned and normalized path after this url, err := gemini.ParseURL(trimmedInput, "", true) @@ -80,12 +106,15 @@ func GenerateResponse(ctx context.Context, conn *tls.Conn, connId string, input // Handle directory. if info.IsDir() { - return generateResponseDir(localPath) + return generateResponseDir(ctx, localPath) } - return generateResponseFile(localPath) + return generateResponseFile(ctx, localPath) } -func generateResponseFile(localPath string) ([]byte, error) { +func generateResponseFile(ctx context.Context, localPath string) ([]byte, error) { + if err := ctx.Err(); err != nil { + return nil, err + } data, err := os.ReadFile(localPath) if err != nil { return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound) @@ -102,7 +131,10 @@ func generateResponseFile(localPath string) ([]byte, error) { return response, nil } -func generateResponseDir(localPath string) (output []byte, err error) { +func generateResponseDir(ctx context.Context, localPath string) (output []byte, err error) { + if err := ctx.Err(); err != nil { + return nil, err + } entries, err := os.ReadDir(localPath) if err != nil { return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound) @@ -127,7 +159,7 @@ func generateResponseDir(localPath string) (output []byte, err error) { return response, nil } filePath := filepath.Join(localPath, "index.gmi") - return generateResponseFile(filePath) + return generateResponseFile(ctx, filePath) } func calculateLocalPath(input string, basePath string) (string, error) { @@ -155,3 +187,110 @@ func calculateLocalPath(input string, basePath string) (string, error) { filePath = path.Join(basePath, localPath) return filePath, nil } + +func HandleConnection(ctx context.Context, conn *tls.Conn) (err error) { + logger := logging.FromContext(ctx) + start := time.Now() + var outputBytes []byte + + // Set connection deadline based on context + if deadline, ok := ctx.Deadline(); ok { + _ = conn.SetDeadline(deadline) + } + + defer func(conn *tls.Conn) { + end := time.Now() + tookMs := end.Sub(start).Milliseconds() + var responseHeader string + + // On non-errors, just log response and close connection. + if err == nil { + // Log non-erroneous responses + if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 { + responseHeader = string(outputBytes[:i]) + } + logger.Debug("Response", "responseHeader", responseHeader, "ms", tookMs) + _ = CloseConnection(conn) + return + } + + // Handle context cancellation/timeout + if errors.Is(err, context.DeadlineExceeded) { + logger.Info("Connection timeout", "ms", tookMs) + responseHeader = fmt.Sprintf("%d Request timeout", gemini.StatusCGIError) + _, _ = conn.Write([]byte(responseHeader + "\r\n")) + _ = CloseConnection(conn) + return + } + if errors.Is(err, context.Canceled) { + logger.Info("Connection cancelled", "ms", tookMs) + _ = CloseConnection(conn) + return + } + + var code int + var responseMsg string + if apperrors.IsFatal(err) { + _ = CloseConnection(conn) + return + } + if apperrors.IsGeminiError(err) { + code = apperrors.GetStatusCode(err) + responseMsg = "server error" + } else { + code = gemini.StatusPermanentFailure + responseMsg = "server error" + } + responseHeader = fmt.Sprintf("%d %s", code, responseMsg) + _, _ = conn.Write([]byte(responseHeader + "\r\n")) + _ = CloseConnection(conn) + }(conn) + + // Check context before starting + if err := ctx.Err(); err != nil { + return err + } + + // Gemini is supposed to have a 1kb limit + // on input requests. + buffer := make([]byte, 1025) + + n, err := conn.Read(buffer) + if err != nil && err != io.EOF { + return apperrors.NewGeminiError(fmt.Errorf("failed to read connection data: %w", err), gemini.StatusBadRequest) + } + if n == 0 { + return apperrors.NewGeminiError(fmt.Errorf("client did not send data"), gemini.StatusBadRequest) + } + if n > 1024 { + return apperrors.NewGeminiError(fmt.Errorf("client request size %d > 1024 bytes", n), gemini.StatusBadRequest) + } + + // Check context after read + if err := ctx.Err(); err != nil { + return err + } + + dataBytes := buffer[:n] + dataString := string(dataBytes) + + logger.Info("Request", "data", strings.TrimSpace(dataString), "size", len(dataBytes)) + outputBytes, err = GenerateResponse(ctx, conn, dataString) + if len(outputBytes) > config.CONFIG.MaxResponseSize { + return apperrors.NewGeminiError(fmt.Errorf("max response size reached"), gemini.StatusTemporaryFailure) + } + if err != nil { + return err + } + + // Check context before write + if err := ctx.Err(); err != nil { + return err + } + + _, err = conn.Write(outputBytes) + if err != nil { + return err + } + return nil +}