This commit is contained in:
antanst
2025-10-10 15:20:45 +03:00
parent 3a5835fc42
commit d336bdffba
10 changed files with 494 additions and 374 deletions

87
AGENTS.md Normal file
View File

@@ -0,0 +1,87 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
Gemserve is a simple Gemini protocol server written in Go that serves static files over TLS-encrypted connections. The Gemini protocol is a lightweight, privacy-focused alternative to HTTP designed for serving text-based content.
## Development Commands
```bash
# Build, test, and format everything
make
# Run tests only
make test
# Build binary to ./dist/gemserve
make build
# Format code with gofumpt and gci
make fmt
# Run golangci-lint
make lint
# Run linter with auto-fix
make lintfix
# Clean build artifacts
make clean
# Run the server (after building)
./dist/gemserve
# Generate TLS certificates for development
certs/generate.sh
```
## Architecture
### Core Components
- **main.go**: Entry point with TLS server setup, signal handling, and graceful shutdown
- **server/**: Request processing, file serving, and Gemini protocol response handling
- **gemini/**: Gemini protocol implementation (URL parsing, status codes, path normalization)
- **config/**: CLI-based configuration system
- **uid/**: Connection ID generation for logging
### Key Patterns
- **Security First**: All file operations use `filepath.IsLocal()` and path cleaning to prevent directory traversal
- **Error Handling**: Uses structured errors with `xerrors` package for consistent error propagation
- **Logging**: Structured logging with configurable levels via internal logging package
- **Testing**: Table-driven tests with parallel execution, heavy focus on security edge cases
### Request Flow
1. TLS connection established on port 1965
2. Read up to 1KB request (Gemini spec limit)
3. Parse and normalize Gemini URL
4. Validate path security (prevent traversal)
5. Serve file or directory index with appropriate MIME type
6. Send response with proper Gemini status codes
## Configuration
Server configured via CLI flags:
- `--listen`: Server address (default: localhost:1965)
- `--root-path`: Directory to serve files from
- `--dir-indexing`: Enable directory browsing
- `--log-level`: Logging verbosity
- `--response-timeout`: Response timeout in seconds
## Testing Strategy
- **server/server_test.go**: Path security and file serving tests
- **gemini/url_test.go**: URL parsing and normalization tests
- Focus on security edge cases (Unicode, traversal attempts, malformed URLs)
- Use parallel test execution for performance
## Security Considerations
- All connections require TLS certificates (stored in certs/)
- Path traversal protection is critical - test thoroughly when modifying file serving logic
- Request size limited to 1KB per Gemini specification
- Input validation on all URLs and paths

View File

@@ -1,87 +1 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
Gemserve is a simple Gemini protocol server written in Go that serves static files over TLS-encrypted connections. The Gemini protocol is a lightweight, privacy-focused alternative to HTTP designed for serving text-based content.
## Development Commands
```bash
# Build, test, and format everything
make
# Run tests only
make test
# Build binary to ./dist/gemserve
make build
# Format code with gofumpt and gci
make fmt
# Run golangci-lint
make lint
# Run linter with auto-fix
make lintfix
# Clean build artifacts
make clean
# Run the server (after building)
./dist/gemserve
# Generate TLS certificates for development
certs/generate.sh
```
## Architecture
### Core Components
- **main.go**: Entry point with TLS server setup, signal handling, and graceful shutdown
- **server/**: Request processing, file serving, and Gemini protocol response handling
- **gemini/**: Gemini protocol implementation (URL parsing, status codes, path normalization)
- **config/**: CLI-based configuration system
- **uid/**: Connection ID generation for logging
### Key Patterns
- **Security First**: All file operations use `filepath.IsLocal()` and path cleaning to prevent directory traversal
- **Error Handling**: Uses structured errors with `xerrors` package for consistent error propagation
- **Logging**: Structured logging with configurable levels via internal logging package
- **Testing**: Table-driven tests with parallel execution, heavy focus on security edge cases
### Request Flow
1. TLS connection established on port 1965
2. Read up to 1KB request (Gemini spec limit)
3. Parse and normalize Gemini URL
4. Validate path security (prevent traversal)
5. Serve file or directory index with appropriate MIME type
6. Send response with proper Gemini status codes
## Configuration
Server configured via CLI flags:
- `--listen`: Server address (default: localhost:1965)
- `--root-path`: Directory to serve files from
- `--dir-indexing`: Enable directory browsing
- `--log-level`: Logging verbosity
- `--response-timeout`: Response timeout in seconds
## Testing Strategy
- **server/server_test.go**: Path security and file serving tests
- **gemini/url_test.go**: URL parsing and normalization tests
- Focus on security edge cases (Unicode, traversal attempts, malformed URLs)
- Use parallel test execution for performance
## Security Considerations
- All connections require TLS certificates (stored in certs/)
- Path traversal protection is critical - test thoroughly when modifying file serving logic
- Request size limited to 1KB per Gemini specification
- Input validation on all URLs and paths
@AGENTS.md

View File

@@ -34,7 +34,8 @@ lintfix: fmt
build: clean
mkdir -p ./dist
go build -mod=vendor -o ./dist/gemserve ./main.go
go build -mod=vendor -o ./dist/gemserve ./cmd/gemserve/gemserve.go
go build -mod=vendor -o ./dist/gemget ./cmd/gemget/gemget.go
build-docker: build
docker build -t gemserve .

86
cmd/gemget/gemget.go Normal file
View File

@@ -0,0 +1,86 @@
package main
import (
"crypto/tls"
"flag"
"fmt"
"io"
"net/url"
"os"
"strings"
)
func main() {
// Parse command-line flags
insecure := flag.Bool("insecure", false, "Skip TLS certificate verification")
flag.Parse()
// Get the URL from arguments
args := flag.Args()
if len(args) != 1 {
fmt.Fprintf(os.Stderr, "Usage: gemget [--insecure] <gemini-url>\n")
os.Exit(1)
}
geminiURL := args[0]
host := validateUrl(geminiURL)
connect(geminiURL, host, *insecure)
}
func validateUrl(u string) string {
// Parse the URL
parsedURL, err := url.Parse(u)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing URL: %v\n", err)
os.Exit(1)
}
// Ensure it's a gemini URL
if parsedURL.Scheme != "gemini" {
fmt.Fprintf(os.Stderr, "Error: URL must use gemini:// scheme\n")
os.Exit(1)
}
// Get host and port
host := parsedURL.Host
if !strings.Contains(host, ":") {
host = host + ":1965" // Default Gemini port
}
return host
}
func connect(url string, host string, insecure bool) {
// Configure TLS
tlsConfig := &tls.Config{
InsecureSkipVerify: insecure,
MinVersion: tls.VersionTLS12,
}
// Connect to the server
conn, err := tls.Dial("tcp", host, tlsConfig)
if err != nil {
fmt.Fprintf(os.Stderr, "Error connecting to server: %v\n", err)
os.Exit(1)
}
defer func() {
_ = conn.Close()
}()
// Send the request (URL + CRLF)
request := url + "\r\n"
_, err = conn.Write([]byte(request))
if err != nil {
fmt.Fprintf(os.Stderr, "Error sending request: %v\n", err)
os.Exit(1)
}
// Read and print the response to stdout
_, err = io.Copy(os.Stdout, conn)
if err != nil {
fmt.Fprintf(os.Stderr, "Error reading response: %v\n", err)
os.Exit(1)
}
}

160
cmd/gemserve/gemserve.go Normal file
View File

@@ -0,0 +1,160 @@
package main
import (
"context"
"crypto/tls"
"fmt"
"net"
"os"
"os/signal"
"sync"
"syscall"
"time"
"gemserve/lib/apperrors"
"gemserve/lib/logging"
"gemserve/config"
"gemserve/server"
"git.antanst.com/antanst/uid"
)
func main() {
config.CONFIG = *config.GetConfig()
logging.SetupLogging()
logger := logging.Logger
ctx := logging.WithLogger(context.Background(), logger)
err := runApp(ctx)
if err != nil {
logger.Error("Fatal Error", "err", err)
panic(fmt.Sprintf("Fatal Error: %v", err))
}
os.Exit(0)
}
func runApp(ctx context.Context) error {
logger := logging.FromContext(ctx)
logger.Info("Starting up. Press Ctrl+C to exit")
listenAddr := config.CONFIG.ListenAddr
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
fatalErrors := make(chan error)
// Root server context, used to cancel
// connections and graceful shutdown.
serverCtx, cancel := context.WithCancel(ctx)
defer cancel()
// WaitGroup to track active connections
// in order to be able to wait until
// they are properly dropped
var wg sync.WaitGroup
// Spawn server on the background.
// Returned errors are considered fatal.
go func() {
err := startServer(serverCtx, listenAddr, &wg, fatalErrors)
if err != nil {
fatalErrors <- apperrors.NewFatalError(fmt.Errorf("server startup failed: %w", err))
}
}()
for {
select {
case <-signals:
logger.Warn("Received SIGINT or SIGTERM signal, shutting down gracefully")
cancel()
wg.Wait()
return nil
case fatalError := <-fatalErrors:
cancel()
wg.Wait()
return fatalError
}
}
}
func startServer(ctx context.Context, listenAddr string, wg *sync.WaitGroup, fatalErrors chan<- error) (err error) {
logger := logging.FromContext(ctx)
cert, err := tls.LoadX509KeyPair(config.CONFIG.TLSCert, config.CONFIG.TLSKey)
if err != nil {
return apperrors.NewFatalError(fmt.Errorf("failed to load TLS certificate/key: %w", err))
}
logger.Debug("Using TLS cert", "path", config.CONFIG.TLSCert)
logger.Debug("Using TLS key", "path", config.CONFIG.TLSKey)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
}
listener, err := tls.Listen("tcp", listenAddr, tlsConfig)
if err != nil {
return apperrors.NewFatalError(err)
}
defer func(listener net.Listener) {
_ = listener.Close()
}(listener)
// If context is cancelled, close listener
// to unblock Accept() inside main loop.
go func() {
<-ctx.Done()
_ = listener.Close()
}()
logger.Info("Server listening", "address", listenAddr)
for {
conn, err := listener.Accept()
if err != nil {
if ctx.Err() != nil {
return nil
} // ctx cancellation
logger.Info("Failed to accept connection: %v", "error", err)
continue
}
// At this point we have a new connection.
wg.Add(1)
go func() {
defer wg.Done()
// Type assert the connection to TLS connection
tlsConn, ok := conn.(*tls.Conn)
if !ok {
logger.Error("Connection is not a TLS connection")
_ = conn.Close()
return
}
remoteAddr := conn.RemoteAddr().String()
connId := uid.UID()
// Create cancellable connection context
// with connection ID.
connLogger := logging.WithAttr(logger, "id", connId)
connLogger = logging.WithAttr(connLogger, "remoteAddr", remoteAddr)
connCtx := context.WithValue(ctx, server.CtxConnIdKey, connId)
connCtx = context.WithValue(connCtx, logging.CtxLoggerKey, connLogger)
connCtx, cancel := context.WithTimeout(connCtx, time.Duration(config.CONFIG.ResponseTimeout)*time.Second)
defer cancel()
err := server.HandleConnection(connCtx, tlsConn)
if err != nil {
if apperrors.IsFatal(err) {
fatalErrors <- err
return
}
connLogger.Info("Connection failed", "error", err)
}
}()
}
}

View File

@@ -17,6 +17,7 @@ type Config struct {
ListenAddr string // Address to listen on
TLSCert string // TLS certificate file
TLSKey string // TLS key file
MaxResponseSize int // Max response size in bytes
}
var CONFIG Config //nolint:gochecknoglobals
@@ -47,6 +48,7 @@ func GetConfig() *Config {
listen := flag.String("listen", "localhost:1965", "Address to listen on")
tlsCert := flag.String("tls-cert", "certs/server.crt", "TLS certificate file")
tlsKey := flag.String("tls-key", "certs/server.key", "TLS key file")
maxResponseSize := flag.Int("max-response-size", 5_242_880, "Max response size in bytes")
flag.Parse()
@@ -71,5 +73,6 @@ func GetConfig() *Config {
ListenAddr: *listen,
TLSCert: *tlsCert,
TLSKey: *tlsKey,
MaxResponseSize: *maxResponseSize,
}
}

View File

@@ -6,8 +6,7 @@ const (
// Input group
StatusInputExpected = 10
StatusInputExpectedSensitive = 11
StatusSuccess = 20
StatusSuccess = 20
// Redirect group
StatusRedirectTemporary = 30

View File

@@ -11,9 +11,9 @@ import (
"github.com/lmittmann/tint"
)
type contextKey int
type contextKey string
const loggerKey contextKey = 0
const CtxLoggerKey contextKey = "logger"
var (
programLevel *slog.LevelVar = new(slog.LevelVar) // Info by default
@@ -21,11 +21,15 @@ var (
)
func WithLogger(ctx context.Context, logger *slog.Logger) context.Context {
return context.WithValue(ctx, loggerKey, logger)
return context.WithValue(ctx, CtxLoggerKey, logger)
}
func WithAttr(logger *slog.Logger, attrName string, attrValue interface{}) *slog.Logger {
return logger.With(attrName, attrValue)
}
func FromContext(ctx context.Context) *slog.Logger {
if logger, ok := ctx.Value(loggerKey).(*slog.Logger); ok {
if logger, ok := ctx.Value(CtxLoggerKey).(*slog.Logger); ok {
return logger
}
// Return default logger instead of panicking

273
main.go
View File

@@ -1,273 +0,0 @@
package main
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"gemserve/lib/apperrors"
"gemserve/lib/logging"
"gemserve/config"
"gemserve/gemini"
"gemserve/server"
"git.antanst.com/antanst/uid"
)
func main() {
config.CONFIG = *config.GetConfig()
logging.SetupLogging()
logger := logging.Logger
ctx := logging.WithLogger(context.Background(), logger)
err := runApp(ctx)
if err != nil {
logger.Error(fmt.Sprintf("Fatal Error: %v", err))
panic(fmt.Sprintf("Fatal Error: %v", err))
}
os.Exit(0)
}
func runApp(ctx context.Context) error {
logger := logging.FromContext(ctx)
logger.Info("Starting up. Press Ctrl+C to exit")
listenAddr := config.CONFIG.ListenAddr
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
fatalErrors := make(chan error)
// Root server context, used to cancel
// connections and graceful shutdown.
serverCtx, cancel := context.WithCancel(ctx)
defer cancel()
// WaitGroup to track active connections
var wg sync.WaitGroup
// Spawn server on the background.
// Returned errors are considered fatal.
go func() {
err := startServer(serverCtx, listenAddr, &wg, fatalErrors)
if err != nil {
fatalErrors <- apperrors.NewFatalError(fmt.Errorf("server startup failed: %w", err))
}
}()
for {
select {
case <-signals:
logger.Warn("Received SIGINT or SIGTERM signal, shutting down gracefully")
cancel()
wg.Wait()
return nil
case fatalError := <-fatalErrors:
cancel()
wg.Wait()
return fatalError
}
}
}
func startServer(ctx context.Context, listenAddr string, wg *sync.WaitGroup, fatalErrors chan<- error) (err error) {
logger := logging.FromContext(ctx)
cert, err := tls.LoadX509KeyPair(config.CONFIG.TLSCert, config.CONFIG.TLSKey)
if err != nil {
return apperrors.NewFatalError(fmt.Errorf("failed to load TLS certificate/key: %w", err))
}
logger.Debug("Using TLS cert", "path", config.CONFIG.TLSCert)
logger.Debug("Using TLS key", "path", config.CONFIG.TLSKey)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
}
listener, err := tls.Listen("tcp", listenAddr, tlsConfig)
if err != nil {
return apperrors.NewFatalError(err)
}
defer func(listener net.Listener) {
_ = listener.Close()
}(listener)
// If context is cancelled, close listener
// to unblock Accept() inside main loop.
go func() {
<-ctx.Done()
_ = listener.Close()
}()
logger.Info("Server listening", "address", listenAddr)
for {
conn, err := listener.Accept()
if err != nil {
if ctx.Err() != nil {
return nil
} // ctx cancellation
logger.Info("Failed to accept connection: %v", "error", err)
continue
}
wg.Add(1)
go func() {
defer wg.Done()
// Type assert the connection to TLS connection
tlsConn, ok := conn.(*tls.Conn)
if !ok {
logger.Error("Connection is not a TLS connection")
_ = conn.Close()
return
}
remoteAddr := conn.RemoteAddr().String()
connId := uid.UID()
// Create a timeout context for this connection
connCtx, cancel := context.WithTimeout(ctx, time.Duration(config.CONFIG.ResponseTimeout)*time.Second)
defer cancel()
err := handleConnection(connCtx, tlsConn, connId, remoteAddr)
if err != nil {
if apperrors.IsFatal(err) {
fatalErrors <- err
return
}
logger.Info("Connection failed", "id", connId, "remoteAddr", remoteAddr, "error", err)
}
}()
}
}
func closeConnection(conn *tls.Conn) error {
err := conn.CloseWrite()
if err != nil {
return apperrors.NewNetworkError(fmt.Errorf("failed to close TLS connection: %w", err))
}
err = conn.Close()
if err != nil {
return apperrors.NewNetworkError(fmt.Errorf("failed to close connection: %w", err))
}
return nil
}
func handleConnection(ctx context.Context, conn *tls.Conn, connId string, remoteAddr string) (err error) {
logger := logging.FromContext(ctx)
start := time.Now()
var outputBytes []byte
// Set connection deadline based on context
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetDeadline(deadline)
}
defer func(conn *tls.Conn) {
end := time.Now()
tookMs := end.Sub(start).Milliseconds()
var responseHeader string
// On non-errors, just log response and close connection.
if err == nil {
// Log non-erroneous responses
if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 {
responseHeader = string(outputBytes[:i])
}
logger.Info("Response", "connId", connId, "remoteAddr", remoteAddr, "responseHeader", responseHeader, "ms", tookMs)
_ = closeConnection(conn)
return
}
// Handle context cancellation/timeout
if errors.Is(err, context.DeadlineExceeded) {
logger.Info("Connection timeout", "id", connId, "remoteAddr", remoteAddr, "ms", tookMs)
responseHeader = fmt.Sprintf("%d Request timeout", gemini.StatusCGIError)
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
_ = closeConnection(conn)
return
}
if errors.Is(err, context.Canceled) {
logger.Info("Connection cancelled", "id", connId, "remoteAddr", remoteAddr, "ms", tookMs)
_ = closeConnection(conn)
return
}
var code int
var responseMsg string
if apperrors.IsFatal(err) {
_ = closeConnection(conn)
return
}
if apperrors.IsGeminiError(err) {
code = apperrors.GetStatusCode(err)
responseMsg = "server error"
} else {
code = gemini.StatusPermanentFailure
responseMsg = "server error"
}
responseHeader = fmt.Sprintf("%d %s", code, responseMsg)
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
_ = closeConnection(conn)
}(conn)
// Check context before starting
if err := ctx.Err(); err != nil {
return err
}
// Gemini is supposed to have a 1kb limit
// on input requests.
buffer := make([]byte, 1025)
n, err := conn.Read(buffer)
if err != nil && err != io.EOF {
return apperrors.NewGeminiError(fmt.Errorf("failed to read connection data: %w", err), gemini.StatusBadRequest)
}
if n == 0 {
return apperrors.NewGeminiError(fmt.Errorf("client did not send data"), gemini.StatusBadRequest)
}
if n > 1024 {
return apperrors.NewGeminiError(fmt.Errorf("client request size %d > 1024 bytes", n), gemini.StatusBadRequest)
}
// Check context after read
if err := ctx.Err(); err != nil {
return err
}
dataBytes := buffer[:n]
dataString := string(dataBytes)
logger.Info("Request", "id", connId, "remoteAddr", remoteAddr, "data", strings.TrimSpace(dataString), "size", len(dataBytes))
outputBytes, err = server.GenerateResponse(ctx, conn, connId, dataString)
if err != nil {
return err
}
// Check context before write
if err := ctx.Err(); err != nil {
return err
}
_, err = conn.Write(outputBytes)
if err != nil {
return err
}
return nil
}

View File

@@ -1,9 +1,12 @@
package server
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/url"
"os"
@@ -11,6 +14,7 @@ import (
"path/filepath"
"strconv"
"strings"
"time"
"unicode/utf8"
"gemserve/lib/apperrors"
@@ -22,11 +26,27 @@ import (
"github.com/gabriel-vasile/mimetype"
)
type contextKey string
const CtxConnIdKey contextKey = "connId"
type ServerConfig interface {
DirIndexingEnabled() bool
RootPath() string
}
func CloseConnection(conn *tls.Conn) error {
err := conn.CloseWrite()
if err != nil {
return apperrors.NewNetworkError(fmt.Errorf("failed to close TLS connection: %w", err))
}
err = conn.Close()
if err != nil {
return apperrors.NewNetworkError(fmt.Errorf("failed to close connection: %w", err))
}
return nil
}
func checkRequestURL(url *gemini.URL) error {
if !utf8.ValidString(url.String()) {
return apperrors.NewGeminiError(fmt.Errorf("invalid URL"), gemini.StatusBadRequest)
@@ -45,13 +65,19 @@ func checkRequestURL(url *gemini.URL) error {
return apperrors.NewGeminiError(fmt.Errorf("invalid server listen port: %w", err), gemini.StatusBadRequest)
}
if url.Port != listenPort {
return apperrors.NewGeminiError(fmt.Errorf("failed to parse URL: %w", err), gemini.StatusProxyRequestRefused)
return apperrors.NewGeminiError(fmt.Errorf("port mismatch"), gemini.StatusProxyRequestRefused)
}
return nil
}
func GenerateResponse(ctx context.Context, conn *tls.Conn, connId string, input string) ([]byte, error) {
func GenerateResponse(ctx context.Context, conn *tls.Conn, input string) ([]byte, error) {
logger := logging.FromContext(ctx)
connId := ctx.Value(CtxConnIdKey).(string)
if err := ctx.Err(); err != nil {
return nil, err
}
trimmedInput := strings.TrimSpace(input)
// url will have a cleaned and normalized path after this
url, err := gemini.ParseURL(trimmedInput, "", true)
@@ -80,12 +106,15 @@ func GenerateResponse(ctx context.Context, conn *tls.Conn, connId string, input
// Handle directory.
if info.IsDir() {
return generateResponseDir(localPath)
return generateResponseDir(ctx, localPath)
}
return generateResponseFile(localPath)
return generateResponseFile(ctx, localPath)
}
func generateResponseFile(localPath string) ([]byte, error) {
func generateResponseFile(ctx context.Context, localPath string) ([]byte, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
data, err := os.ReadFile(localPath)
if err != nil {
return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound)
@@ -102,7 +131,10 @@ func generateResponseFile(localPath string) ([]byte, error) {
return response, nil
}
func generateResponseDir(localPath string) (output []byte, err error) {
func generateResponseDir(ctx context.Context, localPath string) (output []byte, err error) {
if err := ctx.Err(); err != nil {
return nil, err
}
entries, err := os.ReadDir(localPath)
if err != nil {
return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound)
@@ -127,7 +159,7 @@ func generateResponseDir(localPath string) (output []byte, err error) {
return response, nil
}
filePath := filepath.Join(localPath, "index.gmi")
return generateResponseFile(filePath)
return generateResponseFile(ctx, filePath)
}
func calculateLocalPath(input string, basePath string) (string, error) {
@@ -155,3 +187,110 @@ func calculateLocalPath(input string, basePath string) (string, error) {
filePath = path.Join(basePath, localPath)
return filePath, nil
}
func HandleConnection(ctx context.Context, conn *tls.Conn) (err error) {
logger := logging.FromContext(ctx)
start := time.Now()
var outputBytes []byte
// Set connection deadline based on context
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetDeadline(deadline)
}
defer func(conn *tls.Conn) {
end := time.Now()
tookMs := end.Sub(start).Milliseconds()
var responseHeader string
// On non-errors, just log response and close connection.
if err == nil {
// Log non-erroneous responses
if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 {
responseHeader = string(outputBytes[:i])
}
logger.Debug("Response", "responseHeader", responseHeader, "ms", tookMs)
_ = CloseConnection(conn)
return
}
// Handle context cancellation/timeout
if errors.Is(err, context.DeadlineExceeded) {
logger.Info("Connection timeout", "ms", tookMs)
responseHeader = fmt.Sprintf("%d Request timeout", gemini.StatusCGIError)
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
_ = CloseConnection(conn)
return
}
if errors.Is(err, context.Canceled) {
logger.Info("Connection cancelled", "ms", tookMs)
_ = CloseConnection(conn)
return
}
var code int
var responseMsg string
if apperrors.IsFatal(err) {
_ = CloseConnection(conn)
return
}
if apperrors.IsGeminiError(err) {
code = apperrors.GetStatusCode(err)
responseMsg = "server error"
} else {
code = gemini.StatusPermanentFailure
responseMsg = "server error"
}
responseHeader = fmt.Sprintf("%d %s", code, responseMsg)
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
_ = CloseConnection(conn)
}(conn)
// Check context before starting
if err := ctx.Err(); err != nil {
return err
}
// Gemini is supposed to have a 1kb limit
// on input requests.
buffer := make([]byte, 1025)
n, err := conn.Read(buffer)
if err != nil && err != io.EOF {
return apperrors.NewGeminiError(fmt.Errorf("failed to read connection data: %w", err), gemini.StatusBadRequest)
}
if n == 0 {
return apperrors.NewGeminiError(fmt.Errorf("client did not send data"), gemini.StatusBadRequest)
}
if n > 1024 {
return apperrors.NewGeminiError(fmt.Errorf("client request size %d > 1024 bytes", n), gemini.StatusBadRequest)
}
// Check context after read
if err := ctx.Err(); err != nil {
return err
}
dataBytes := buffer[:n]
dataString := string(dataBytes)
logger.Info("Request", "data", strings.TrimSpace(dataString), "size", len(dataBytes))
outputBytes, err = GenerateResponse(ctx, conn, dataString)
if len(outputBytes) > config.CONFIG.MaxResponseSize {
return apperrors.NewGeminiError(fmt.Errorf("max response size reached"), gemini.StatusTemporaryFailure)
}
if err != nil {
return err
}
// Check context before write
if err := ctx.Err(); err != nil {
return err
}
_, err = conn.Write(outputBytes)
if err != nil {
return err
}
return nil
}