Compare commits

..

2 Commits

Author SHA1 Message Date
antanst
0e218cb57b . 2025-10-14 16:58:38 +03:00
antanst
d336bdffba . 2025-10-10 15:20:45 +03:00
11 changed files with 690 additions and 378 deletions

98
AGENTS.md Normal file
View File

@@ -0,0 +1,98 @@
# CLAUDE.md
This file provides guidance to AI Agents such as Claude Code or ChatGPT Codex when working with code in this repository.
## General guidelines
Use idiomatic Go as possible. Prefer simple code than complex.
## Project Overview
Gemserve is a simple Gemini protocol server written in Go that serves static files over TLS-encrypted connections. The Gemini protocol is a lightweight, privacy-focused alternative to HTTP designed for serving text-based content.
### Development Commands
```bash
# Build, test, and format everything
make
# Run tests only
make test
# Build binaries to ./dist/ (gemserve, gemget, gembench)
make build
# Format code with gofumpt and gci
make fmt
# Run golangci-lint
make lint
# Run linter with auto-fix
make lintfix
# Clean build artifacts
make clean
# Run the server (after building)
./dist/gemserve
# Generate TLS certificates for development
certs/generate.sh
```
### Architecture
Core Components
- **cmd/gemserve/gemserve.go**: Entry point with TLS server setup, signal handling, and graceful shutdown
- **cmd/gemget/**: Gemini protocol client for fetching content
- **cmd/gembench/**: Benchmarking tool for Gemini servers
- **server/**: Request processing, file serving, and Gemini protocol response handling
- **gemini/**: Gemini protocol implementation (URL parsing, status codes, path normalization)
- **config/**: CLI-based configuration system
- **lib/logging/**: Structured logging package with context-aware loggers
- **lib/apperrors/**: Application error handling (fatal vs non-fatal errors)
- **uid/**: Connection ID generation for logging (uses external vendor package)
Key Patterns
- **Security First**: All file operations use `filepath.IsLocal()` and path cleaning to prevent directory traversal
- **Error Handling**: Uses structured errors via `lib/apperrors` package distinguishing fatal from non-fatal errors
- **Logging**: Structured logging with configurable levels via internal logging package
- **Testing**: Table-driven tests with parallel execution, heavy focus on security edge cases
Request Flow
1. TLS connection established on port 1965
2. Read up to 1KB request (Gemini spec limit)
3. Parse and normalize Gemini URL
4. Validate path security (prevent traversal)
5. Serve file or directory index with appropriate MIME type
6. Send response with proper Gemini status codes
Configuration
Server configured via CLI flags:
- `--listen`: Server address (default: localhost:1965)
- `--root-path`: Directory to serve files from
- `--dir-indexing`: Enable directory browsing (default: false)
- `--log-level`: Logging verbosity (debug, info, warn, error; default: info)
- `--response-timeout`: Response timeout in seconds (default: 30)
- `--tls-cert`: TLS certificate file path (default: certs/server.crt)
- `--tls-key`: TLS key file path (default: certs/server.key)
- `--max-response-size`: Maximum response size in bytes (default: 5242880)
Testing Strategy
- **server/server_test.go**: Path security and file serving tests
- **gemini/url_test.go**: URL parsing and normalization tests
- Focus on security edge cases (Unicode, traversal attempts, malformed URLs)
- Use parallel test execution for performance
Security Considerations
- All connections require TLS certificates (stored in certs/)
- Path traversal protection is critical - test thoroughly when modifying file serving logic
- Request size limited to 1KB per Gemini specification
- Input validation on all URLs and paths

View File

@@ -1,87 +0,0 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
Gemserve is a simple Gemini protocol server written in Go that serves static files over TLS-encrypted connections. The Gemini protocol is a lightweight, privacy-focused alternative to HTTP designed for serving text-based content.
## Development Commands
```bash
# Build, test, and format everything
make
# Run tests only
make test
# Build binary to ./dist/gemserve
make build
# Format code with gofumpt and gci
make fmt
# Run golangci-lint
make lint
# Run linter with auto-fix
make lintfix
# Clean build artifacts
make clean
# Run the server (after building)
./dist/gemserve
# Generate TLS certificates for development
certs/generate.sh
```
## Architecture
### Core Components
- **main.go**: Entry point with TLS server setup, signal handling, and graceful shutdown
- **server/**: Request processing, file serving, and Gemini protocol response handling
- **gemini/**: Gemini protocol implementation (URL parsing, status codes, path normalization)
- **config/**: CLI-based configuration system
- **uid/**: Connection ID generation for logging
### Key Patterns
- **Security First**: All file operations use `filepath.IsLocal()` and path cleaning to prevent directory traversal
- **Error Handling**: Uses structured errors with `xerrors` package for consistent error propagation
- **Logging**: Structured logging with configurable levels via internal logging package
- **Testing**: Table-driven tests with parallel execution, heavy focus on security edge cases
### Request Flow
1. TLS connection established on port 1965
2. Read up to 1KB request (Gemini spec limit)
3. Parse and normalize Gemini URL
4. Validate path security (prevent traversal)
5. Serve file or directory index with appropriate MIME type
6. Send response with proper Gemini status codes
## Configuration
Server configured via CLI flags:
- `--listen`: Server address (default: localhost:1965)
- `--root-path`: Directory to serve files from
- `--dir-indexing`: Enable directory browsing
- `--log-level`: Logging verbosity
- `--response-timeout`: Response timeout in seconds
## Testing Strategy
- **server/server_test.go**: Path security and file serving tests
- **gemini/url_test.go**: URL parsing and normalization tests
- Focus on security edge cases (Unicode, traversal attempts, malformed URLs)
- Use parallel test execution for performance
## Security Considerations
- All connections require TLS certificates (stored in certs/)
- Path traversal protection is critical - test thoroughly when modifying file serving logic
- Request size limited to 1KB per Gemini specification
- Input validation on all URLs and paths

1
CLAUDE.md Symbolic link
View File

@@ -0,0 +1 @@
./AGENTS.md

View File

@@ -34,7 +34,9 @@ lintfix: fmt
build: clean
mkdir -p ./dist
go build -mod=vendor -o ./dist/gemserve ./main.go
go build -mod=vendor -o ./dist/gemserve ./cmd/gemserve/gemserve.go
go build -mod=vendor -o ./dist/gemget ./cmd/gemget/gemget.go
go build -mod=vendor -o ./dist/gembench ./cmd/gembench/gembench.go
build-docker: build
docker build -t gemserve .

180
cmd/gembench/gembench.go Normal file
View File

@@ -0,0 +1,180 @@
package main
import (
"context"
"crypto/tls"
"flag"
"fmt"
"io"
"log/slog"
"net/url"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"gemserve/lib/logging"
)
func main() {
// Parse command-line flags
insecure := flag.Bool("insecure", false, "Skip TLS certificate verification")
totalConnections := flag.Int("total-connections", 250, "Total connections to make")
parallelism := flag.Int("parallelism", 10, "How many connections to run in parallel")
flag.Parse()
// Get the URL from arguments
args := flag.Args()
if len(args) != 1 {
fmt.Fprintf(os.Stderr, "Usage: gemget [--insecure] <gemini-url>\n")
os.Exit(1)
}
logging.SetupLogging(slog.LevelInfo)
logger := logging.Logger
ctx := logging.WithLogger(context.Background(), logger)
geminiURL := args[0]
host := validateUrl(geminiURL)
if host == "" {
logger.Error("Invalid URL.")
os.Exit(1)
}
start := time.Now()
err := benchmark(ctx, geminiURL, host, *insecure, *totalConnections, *parallelism)
if err != nil {
logger.Error(err.Error())
os.Exit(1)
}
end := time.Now()
tookMs := end.Sub(start).Milliseconds()
logger.Info("End.", "ms", tookMs)
}
var wg sync.WaitGroup
type ctxKey int
const ctxKeyJobIndex ctxKey = 1
func benchmark(ctx context.Context, u string, h string, insecure bool, totalConnections int, parallelism int) error {
logger := logging.FromContext(ctx)
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
// Root context, used to cancel
// connections and graceful shutdown.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Semaphore to limit concurrency.
// Goroutines put value to channel (acquire slot)
// and consume value from channel (release slot).
semaphore := make(chan struct{}, parallelism)
loop:
for i := 0; i < totalConnections; i++ {
select {
case <-signals:
logger.Warn("Received SIGINT or SIGTERM signal, shutting down gracefully")
cancel()
break loop
case semaphore <- struct{}{}: // Acquire slot
wg.Add(1)
go func(jobIndex int) {
defer func() {
<-semaphore // Release slot
wg.Done()
}()
ctxWithValue := context.WithValue(ctx, ctxKeyJobIndex, jobIndex)
ctxWithTimeout, cancel := context.WithTimeout(ctxWithValue, 60*time.Second)
defer cancel()
err := connect(ctxWithTimeout, u, h, insecure)
if err != nil {
logger.Warn(fmt.Sprintf("%d error: %v", jobIndex, err))
}
}(i)
}
}
wg.Wait()
return nil
}
func validateUrl(u string) string {
// Parse the URL
parsedURL, err := url.Parse(u)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing URL: %v\n", err)
os.Exit(1)
}
// Ensure it's a gemini URL
if parsedURL.Scheme != "gemini" {
fmt.Fprintf(os.Stderr, "Error: URL must use gemini:// scheme\n")
os.Exit(1)
}
// Get host and port
host := parsedURL.Host
if !strings.Contains(host, ":") {
host = host + ":1965" // Default Gemini port
}
return host
}
func connect(ctx context.Context, url string, host string, insecure bool) error {
logger := logging.FromContext(ctx)
tlsConfig := &tls.Config{
InsecureSkipVerify: insecure,
MinVersion: tls.VersionTLS12,
}
// Context checkpoint
if ctx.Err() != nil {
return nil
}
// Connect to the server
conn, err := tls.Dial("tcp", host, tlsConfig)
if err != nil {
return err
}
// Set connection deadline based on context
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetDeadline(deadline)
}
defer func() {
_ = conn.Close()
}()
// Context checkpoint
if ctx.Err() != nil {
return nil
}
// Send the request (URL + CRLF)
request := url + "\r\n"
_, err = conn.Write([]byte(request))
if err != nil {
return err
}
// Context checkpoint
if ctx.Err() != nil {
return nil
}
// Read and dump response
_, err = io.Copy(io.Discard, conn)
if err != nil {
return err
}
jobIndex := ctx.Value(ctxKeyJobIndex)
logger.Debug(fmt.Sprintf("%d done", jobIndex))
return nil
}

86
cmd/gemget/gemget.go Normal file
View File

@@ -0,0 +1,86 @@
package main
import (
"crypto/tls"
"flag"
"fmt"
"io"
"net/url"
"os"
"strings"
)
func main() {
// Parse command-line flags
insecure := flag.Bool("insecure", false, "Skip TLS certificate verification")
flag.Parse()
// Get the URL from arguments
args := flag.Args()
if len(args) != 1 {
fmt.Fprintf(os.Stderr, "Usage: gemget [--insecure] <gemini-url>\n")
os.Exit(1)
}
geminiURL := args[0]
host := validateUrl(geminiURL)
connect(geminiURL, host, *insecure)
}
func validateUrl(u string) string {
// Parse the URL
parsedURL, err := url.Parse(u)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing URL: %v\n", err)
os.Exit(1)
}
// Ensure it's a gemini URL
if parsedURL.Scheme != "gemini" {
fmt.Fprintf(os.Stderr, "Error: URL must use gemini:// scheme\n")
os.Exit(1)
}
// Get host and port
host := parsedURL.Host
if !strings.Contains(host, ":") {
host = host + ":1965" // Default Gemini port
}
return host
}
func connect(url string, host string, insecure bool) {
// Configure TLS
tlsConfig := &tls.Config{
InsecureSkipVerify: insecure,
MinVersion: tls.VersionTLS12,
}
// Connect to the server
conn, err := tls.Dial("tcp", host, tlsConfig)
if err != nil {
fmt.Fprintf(os.Stderr, "Error connecting to server: %v\n", err)
os.Exit(1)
}
defer func() {
_ = conn.Close()
}()
// Send the request (URL + CRLF)
request := url + "\r\n"
_, err = conn.Write([]byte(request))
if err != nil {
fmt.Fprintf(os.Stderr, "Error sending request: %v\n", err)
os.Exit(1)
}
// Read and print the response to stdout
_, err = io.Copy(os.Stdout, conn)
if err != nil {
fmt.Fprintf(os.Stderr, "Error reading response: %v\n", err)
os.Exit(1)
}
}

162
cmd/gemserve/gemserve.go Normal file
View File

@@ -0,0 +1,162 @@
package main
import (
"context"
"crypto/tls"
"fmt"
"net"
"os"
"os/signal"
"sync"
"syscall"
"time"
"gemserve/lib/apperrors"
"gemserve/lib/logging"
"gemserve/config"
"gemserve/server"
"git.antanst.com/antanst/uid"
)
func main() {
config.CONFIG = *config.GetConfig()
logging.SetupLogging(config.CONFIG.LogLevel)
logger := logging.Logger
ctx := logging.WithLogger(context.Background(), logger)
err := runApp(ctx)
if err != nil {
logger.Error("Fatal Error", "err", err)
panic(fmt.Sprintf("Fatal Error: %v", err))
}
os.Exit(0)
}
func runApp(ctx context.Context) error {
logger := logging.FromContext(ctx)
logger.Info("Starting up. Press Ctrl+C to exit")
listenAddr := config.CONFIG.ListenAddr
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
// Only this file should send to this channel.
// All external functions should return errors.
fatalErrors := make(chan error)
// Root context, used to cancel
// connections and graceful shutdown.
serverCtx, cancel := context.WithCancel(ctx)
defer cancel()
// WaitGroup to track active connections
// in order to be able to wait until
// they are properly dropped
var wg sync.WaitGroup
// Spawn server on the background.
// Returned errors are considered fatal.
go func() {
err := startServer(serverCtx, listenAddr, &wg, fatalErrors)
if err != nil {
fatalErrors <- apperrors.NewFatalError(fmt.Errorf("server startup failed: %w", err))
}
}()
for {
select {
case <-signals:
logger.Warn("Received SIGINT or SIGTERM signal, shutting down gracefully")
cancel()
wg.Wait()
return nil
case fatalError := <-fatalErrors:
cancel()
wg.Wait()
return fatalError
}
}
}
func startServer(ctx context.Context, listenAddr string, wg *sync.WaitGroup, fatalErrors chan<- error) (err error) {
logger := logging.FromContext(ctx)
cert, err := tls.LoadX509KeyPair(config.CONFIG.TLSCert, config.CONFIG.TLSKey)
if err != nil {
return apperrors.NewFatalError(fmt.Errorf("failed to load TLS certificate/key: %w", err))
}
logger.Debug("Using TLS cert", "path", config.CONFIG.TLSCert)
logger.Debug("Using TLS key", "path", config.CONFIG.TLSKey)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
}
listener, err := tls.Listen("tcp", listenAddr, tlsConfig)
if err != nil {
return apperrors.NewFatalError(err)
}
defer func(listener net.Listener) {
_ = listener.Close()
}(listener)
// If context is cancelled, close listener
// to unblock Accept() inside main loop.
go func() {
<-ctx.Done()
_ = listener.Close()
}()
logger.Info("Server listening", "address", listenAddr)
for {
conn, err := listener.Accept()
if err != nil {
if ctx.Err() != nil {
return nil
} // ctx cancellation
logger.Info("Failed to accept connection: %v", "error", err)
continue
}
// At this point we have a new connection.
wg.Add(1)
go func() {
defer wg.Done()
// Type assert the connection to TLS connection
tlsConn, ok := conn.(*tls.Conn)
if !ok {
logger.Error("Connection is not a TLS connection")
_ = conn.Close()
return
}
remoteAddr := conn.RemoteAddr().String()
connId := uid.UID()
// Create cancellable connection context
// with connection ID.
connLogger := logging.WithAttr(logger, "id", connId)
connLogger = logging.WithAttr(connLogger, "remoteAddr", remoteAddr)
connCtx := context.WithValue(ctx, server.CtxConnIdKey, connId)
connCtx = context.WithValue(connCtx, logging.CtxLoggerKey, connLogger)
connCtx, cancel := context.WithTimeout(connCtx, time.Duration(config.CONFIG.ResponseTimeout)*time.Second)
defer cancel()
err := server.HandleConnection(connCtx, tlsConn)
if err != nil {
if apperrors.IsFatal(err) {
fatalErrors <- err
return
}
connLogger.Info("Connection failed", "error", err)
}
}()
}
}

View File

@@ -17,6 +17,7 @@ type Config struct {
ListenAddr string // Address to listen on
TLSCert string // TLS certificate file
TLSKey string // TLS key file
MaxResponseSize int // Max response size in bytes
}
var CONFIG Config //nolint:gochecknoglobals
@@ -47,6 +48,7 @@ func GetConfig() *Config {
listen := flag.String("listen", "localhost:1965", "Address to listen on")
tlsCert := flag.String("tls-cert", "certs/server.crt", "TLS certificate file")
tlsKey := flag.String("tls-key", "certs/server.key", "TLS key file")
maxResponseSize := flag.Int("max-response-size", 5_242_880, "Max response size in bytes")
flag.Parse()
@@ -71,5 +73,6 @@ func GetConfig() *Config {
ListenAddr: *listen,
TLSCert: *tlsCert,
TLSKey: *tlsKey,
MaxResponseSize: *maxResponseSize,
}
}

View File

@@ -6,7 +6,6 @@ const (
// Input group
StatusInputExpected = 10
StatusInputExpectedSensitive = 11
StatusSuccess = 20
// Redirect group

View File

@@ -6,14 +6,12 @@ import (
"os"
"path/filepath"
"gemserve/config"
"github.com/lmittmann/tint"
)
type contextKey int
type contextKey string
const loggerKey contextKey = 0
const CtxLoggerKey contextKey = "logger"
var (
programLevel *slog.LevelVar = new(slog.LevelVar) // Info by default
@@ -21,19 +19,23 @@ var (
)
func WithLogger(ctx context.Context, logger *slog.Logger) context.Context {
return context.WithValue(ctx, loggerKey, logger)
return context.WithValue(ctx, CtxLoggerKey, logger)
}
func WithAttr(logger *slog.Logger, attrName string, attrValue interface{}) *slog.Logger {
return logger.With(attrName, attrValue)
}
func FromContext(ctx context.Context) *slog.Logger {
if logger, ok := ctx.Value(loggerKey).(*slog.Logger); ok {
if logger, ok := ctx.Value(CtxLoggerKey).(*slog.Logger); ok {
return logger
}
// Return default logger instead of panicking
return slog.Default()
}
func SetupLogging() {
programLevel.Set(config.CONFIG.LogLevel)
func SetupLogging(logLevel slog.Level) {
programLevel.Set(logLevel)
// With coloring (uses external package)
opts := &tint.Options{
AddSource: true,

273
main.go
View File

@@ -1,273 +0,0 @@
package main
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"gemserve/lib/apperrors"
"gemserve/lib/logging"
"gemserve/config"
"gemserve/gemini"
"gemserve/server"
"git.antanst.com/antanst/uid"
)
func main() {
config.CONFIG = *config.GetConfig()
logging.SetupLogging()
logger := logging.Logger
ctx := logging.WithLogger(context.Background(), logger)
err := runApp(ctx)
if err != nil {
logger.Error(fmt.Sprintf("Fatal Error: %v", err))
panic(fmt.Sprintf("Fatal Error: %v", err))
}
os.Exit(0)
}
func runApp(ctx context.Context) error {
logger := logging.FromContext(ctx)
logger.Info("Starting up. Press Ctrl+C to exit")
listenAddr := config.CONFIG.ListenAddr
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
fatalErrors := make(chan error)
// Root server context, used to cancel
// connections and graceful shutdown.
serverCtx, cancel := context.WithCancel(ctx)
defer cancel()
// WaitGroup to track active connections
var wg sync.WaitGroup
// Spawn server on the background.
// Returned errors are considered fatal.
go func() {
err := startServer(serverCtx, listenAddr, &wg, fatalErrors)
if err != nil {
fatalErrors <- apperrors.NewFatalError(fmt.Errorf("server startup failed: %w", err))
}
}()
for {
select {
case <-signals:
logger.Warn("Received SIGINT or SIGTERM signal, shutting down gracefully")
cancel()
wg.Wait()
return nil
case fatalError := <-fatalErrors:
cancel()
wg.Wait()
return fatalError
}
}
}
func startServer(ctx context.Context, listenAddr string, wg *sync.WaitGroup, fatalErrors chan<- error) (err error) {
logger := logging.FromContext(ctx)
cert, err := tls.LoadX509KeyPair(config.CONFIG.TLSCert, config.CONFIG.TLSKey)
if err != nil {
return apperrors.NewFatalError(fmt.Errorf("failed to load TLS certificate/key: %w", err))
}
logger.Debug("Using TLS cert", "path", config.CONFIG.TLSCert)
logger.Debug("Using TLS key", "path", config.CONFIG.TLSKey)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
}
listener, err := tls.Listen("tcp", listenAddr, tlsConfig)
if err != nil {
return apperrors.NewFatalError(err)
}
defer func(listener net.Listener) {
_ = listener.Close()
}(listener)
// If context is cancelled, close listener
// to unblock Accept() inside main loop.
go func() {
<-ctx.Done()
_ = listener.Close()
}()
logger.Info("Server listening", "address", listenAddr)
for {
conn, err := listener.Accept()
if err != nil {
if ctx.Err() != nil {
return nil
} // ctx cancellation
logger.Info("Failed to accept connection: %v", "error", err)
continue
}
wg.Add(1)
go func() {
defer wg.Done()
// Type assert the connection to TLS connection
tlsConn, ok := conn.(*tls.Conn)
if !ok {
logger.Error("Connection is not a TLS connection")
_ = conn.Close()
return
}
remoteAddr := conn.RemoteAddr().String()
connId := uid.UID()
// Create a timeout context for this connection
connCtx, cancel := context.WithTimeout(ctx, time.Duration(config.CONFIG.ResponseTimeout)*time.Second)
defer cancel()
err := handleConnection(connCtx, tlsConn, connId, remoteAddr)
if err != nil {
if apperrors.IsFatal(err) {
fatalErrors <- err
return
}
logger.Info("Connection failed", "id", connId, "remoteAddr", remoteAddr, "error", err)
}
}()
}
}
func closeConnection(conn *tls.Conn) error {
err := conn.CloseWrite()
if err != nil {
return apperrors.NewNetworkError(fmt.Errorf("failed to close TLS connection: %w", err))
}
err = conn.Close()
if err != nil {
return apperrors.NewNetworkError(fmt.Errorf("failed to close connection: %w", err))
}
return nil
}
func handleConnection(ctx context.Context, conn *tls.Conn, connId string, remoteAddr string) (err error) {
logger := logging.FromContext(ctx)
start := time.Now()
var outputBytes []byte
// Set connection deadline based on context
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetDeadline(deadline)
}
defer func(conn *tls.Conn) {
end := time.Now()
tookMs := end.Sub(start).Milliseconds()
var responseHeader string
// On non-errors, just log response and close connection.
if err == nil {
// Log non-erroneous responses
if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 {
responseHeader = string(outputBytes[:i])
}
logger.Info("Response", "connId", connId, "remoteAddr", remoteAddr, "responseHeader", responseHeader, "ms", tookMs)
_ = closeConnection(conn)
return
}
// Handle context cancellation/timeout
if errors.Is(err, context.DeadlineExceeded) {
logger.Info("Connection timeout", "id", connId, "remoteAddr", remoteAddr, "ms", tookMs)
responseHeader = fmt.Sprintf("%d Request timeout", gemini.StatusCGIError)
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
_ = closeConnection(conn)
return
}
if errors.Is(err, context.Canceled) {
logger.Info("Connection cancelled", "id", connId, "remoteAddr", remoteAddr, "ms", tookMs)
_ = closeConnection(conn)
return
}
var code int
var responseMsg string
if apperrors.IsFatal(err) {
_ = closeConnection(conn)
return
}
if apperrors.IsGeminiError(err) {
code = apperrors.GetStatusCode(err)
responseMsg = "server error"
} else {
code = gemini.StatusPermanentFailure
responseMsg = "server error"
}
responseHeader = fmt.Sprintf("%d %s", code, responseMsg)
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
_ = closeConnection(conn)
}(conn)
// Check context before starting
if err := ctx.Err(); err != nil {
return err
}
// Gemini is supposed to have a 1kb limit
// on input requests.
buffer := make([]byte, 1025)
n, err := conn.Read(buffer)
if err != nil && err != io.EOF {
return apperrors.NewGeminiError(fmt.Errorf("failed to read connection data: %w", err), gemini.StatusBadRequest)
}
if n == 0 {
return apperrors.NewGeminiError(fmt.Errorf("client did not send data"), gemini.StatusBadRequest)
}
if n > 1024 {
return apperrors.NewGeminiError(fmt.Errorf("client request size %d > 1024 bytes", n), gemini.StatusBadRequest)
}
// Check context after read
if err := ctx.Err(); err != nil {
return err
}
dataBytes := buffer[:n]
dataString := string(dataBytes)
logger.Info("Request", "id", connId, "remoteAddr", remoteAddr, "data", strings.TrimSpace(dataString), "size", len(dataBytes))
outputBytes, err = server.GenerateResponse(ctx, conn, connId, dataString)
if err != nil {
return err
}
// Check context before write
if err := ctx.Err(); err != nil {
return err
}
_, err = conn.Write(outputBytes)
if err != nil {
return err
}
return nil
}

View File

@@ -1,9 +1,12 @@
package server
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/url"
"os"
@@ -11,6 +14,7 @@ import (
"path/filepath"
"strconv"
"strings"
"time"
"unicode/utf8"
"gemserve/lib/apperrors"
@@ -22,11 +26,27 @@ import (
"github.com/gabriel-vasile/mimetype"
)
type contextKey string
const CtxConnIdKey contextKey = "connId"
type ServerConfig interface {
DirIndexingEnabled() bool
RootPath() string
}
func CloseConnection(conn *tls.Conn) error {
err := conn.CloseWrite()
if err != nil {
return apperrors.NewNetworkError(fmt.Errorf("failed to close TLS connection: %w", err))
}
err = conn.Close()
if err != nil {
return apperrors.NewNetworkError(fmt.Errorf("failed to close connection: %w", err))
}
return nil
}
func checkRequestURL(url *gemini.URL) error {
if !utf8.ValidString(url.String()) {
return apperrors.NewGeminiError(fmt.Errorf("invalid URL"), gemini.StatusBadRequest)
@@ -45,13 +65,19 @@ func checkRequestURL(url *gemini.URL) error {
return apperrors.NewGeminiError(fmt.Errorf("invalid server listen port: %w", err), gemini.StatusBadRequest)
}
if url.Port != listenPort {
return apperrors.NewGeminiError(fmt.Errorf("failed to parse URL: %w", err), gemini.StatusProxyRequestRefused)
return apperrors.NewGeminiError(fmt.Errorf("port mismatch"), gemini.StatusProxyRequestRefused)
}
return nil
}
func GenerateResponse(ctx context.Context, conn *tls.Conn, connId string, input string) ([]byte, error) {
func GenerateResponse(ctx context.Context, conn *tls.Conn, input string) ([]byte, error) {
logger := logging.FromContext(ctx)
connId := ctx.Value(CtxConnIdKey).(string)
if err := ctx.Err(); err != nil {
return nil, err
}
trimmedInput := strings.TrimSpace(input)
// url will have a cleaned and normalized path after this
url, err := gemini.ParseURL(trimmedInput, "", true)
@@ -80,12 +106,15 @@ func GenerateResponse(ctx context.Context, conn *tls.Conn, connId string, input
// Handle directory.
if info.IsDir() {
return generateResponseDir(localPath)
return generateResponseDir(ctx, localPath)
}
return generateResponseFile(localPath)
return generateResponseFile(ctx, localPath)
}
func generateResponseFile(localPath string) ([]byte, error) {
func generateResponseFile(ctx context.Context, localPath string) ([]byte, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
data, err := os.ReadFile(localPath)
if err != nil {
return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound)
@@ -102,7 +131,10 @@ func generateResponseFile(localPath string) ([]byte, error) {
return response, nil
}
func generateResponseDir(localPath string) (output []byte, err error) {
func generateResponseDir(ctx context.Context, localPath string) (output []byte, err error) {
if err := ctx.Err(); err != nil {
return nil, err
}
entries, err := os.ReadDir(localPath)
if err != nil {
return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound)
@@ -127,7 +159,7 @@ func generateResponseDir(localPath string) (output []byte, err error) {
return response, nil
}
filePath := filepath.Join(localPath, "index.gmi")
return generateResponseFile(filePath)
return generateResponseFile(ctx, filePath)
}
func calculateLocalPath(input string, basePath string) (string, error) {
@@ -155,3 +187,110 @@ func calculateLocalPath(input string, basePath string) (string, error) {
filePath = path.Join(basePath, localPath)
return filePath, nil
}
func HandleConnection(ctx context.Context, conn *tls.Conn) (err error) {
logger := logging.FromContext(ctx)
start := time.Now()
var outputBytes []byte
// Set connection deadline based on context
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetDeadline(deadline)
}
defer func(conn *tls.Conn) {
end := time.Now()
tookMs := end.Sub(start).Milliseconds()
var responseHeader string
// On non-errors, just log response and close connection.
if err == nil {
// Log non-erroneous responses
if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 {
responseHeader = string(outputBytes[:i])
}
logger.Debug("Response", "responseHeader", responseHeader, "ms", tookMs)
_ = CloseConnection(conn)
return
}
// Handle context cancellation/timeout
if errors.Is(err, context.DeadlineExceeded) {
logger.Info("Connection timeout", "ms", tookMs)
responseHeader = fmt.Sprintf("%d Request timeout", gemini.StatusCGIError)
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
_ = CloseConnection(conn)
return
}
if errors.Is(err, context.Canceled) {
logger.Info("Connection cancelled", "ms", tookMs)
_ = CloseConnection(conn)
return
}
var code int
var responseMsg string
if apperrors.IsFatal(err) {
_ = CloseConnection(conn)
return
}
if apperrors.IsGeminiError(err) {
code = apperrors.GetStatusCode(err)
responseMsg = "server error"
} else {
code = gemini.StatusPermanentFailure
responseMsg = "server error"
}
responseHeader = fmt.Sprintf("%d %s", code, responseMsg)
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
_ = CloseConnection(conn)
}(conn)
// Check context before starting
if err := ctx.Err(); err != nil {
return err
}
// Gemini is supposed to have a 1kb limit
// on input requests.
buffer := make([]byte, 1025)
n, err := conn.Read(buffer)
if err != nil && err != io.EOF {
return apperrors.NewGeminiError(fmt.Errorf("failed to read connection data: %w", err), gemini.StatusBadRequest)
}
if n == 0 {
return apperrors.NewGeminiError(fmt.Errorf("client did not send data"), gemini.StatusBadRequest)
}
if n > 1024 {
return apperrors.NewGeminiError(fmt.Errorf("client request size %d > 1024 bytes", n), gemini.StatusBadRequest)
}
// Check context after read
if err := ctx.Err(); err != nil {
return err
}
dataBytes := buffer[:n]
dataString := string(dataBytes)
logger.Info("Request", "data", strings.TrimSpace(dataString), "size", len(dataBytes))
outputBytes, err = GenerateResponse(ctx, conn, dataString)
if len(outputBytes) > config.CONFIG.MaxResponseSize {
return apperrors.NewGeminiError(fmt.Errorf("max response size reached"), gemini.StatusTemporaryFailure)
}
if err != nil {
return err
}
// Check context before write
if err := ctx.Err(); err != nil {
return err
}
_, err = conn.Write(outputBytes)
if err != nil {
return err
}
return nil
}