Files
gemserve/server/server.go
antanst d336bdffba .
2025-10-10 15:20:45 +03:00

297 lines
8.2 KiB
Go

package server
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/url"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"unicode/utf8"
"gemserve/lib/apperrors"
"gemserve/lib/logging"
"gemserve/config"
"gemserve/gemini"
"github.com/gabriel-vasile/mimetype"
)
type contextKey string
const CtxConnIdKey contextKey = "connId"
type ServerConfig interface {
DirIndexingEnabled() bool
RootPath() string
}
func CloseConnection(conn *tls.Conn) error {
err := conn.CloseWrite()
if err != nil {
return apperrors.NewNetworkError(fmt.Errorf("failed to close TLS connection: %w", err))
}
err = conn.Close()
if err != nil {
return apperrors.NewNetworkError(fmt.Errorf("failed to close connection: %w", err))
}
return nil
}
func checkRequestURL(url *gemini.URL) error {
if !utf8.ValidString(url.String()) {
return apperrors.NewGeminiError(fmt.Errorf("invalid URL"), gemini.StatusBadRequest)
}
if url.Protocol != "gemini" {
return apperrors.NewGeminiError(fmt.Errorf("invalid URL"), gemini.StatusProxyRequestRefused)
}
_, portStr, err := net.SplitHostPort(config.CONFIG.ListenAddr)
if err != nil {
return apperrors.NewGeminiError(fmt.Errorf("failed to parse server listen address: %w", err), gemini.StatusBadRequest)
}
listenPort, err := strconv.Atoi(portStr)
if err != nil {
return apperrors.NewGeminiError(fmt.Errorf("invalid server listen port: %w", err), gemini.StatusBadRequest)
}
if url.Port != listenPort {
return apperrors.NewGeminiError(fmt.Errorf("port mismatch"), gemini.StatusProxyRequestRefused)
}
return nil
}
func GenerateResponse(ctx context.Context, conn *tls.Conn, input string) ([]byte, error) {
logger := logging.FromContext(ctx)
connId := ctx.Value(CtxConnIdKey).(string)
if err := ctx.Err(); err != nil {
return nil, err
}
trimmedInput := strings.TrimSpace(input)
// url will have a cleaned and normalized path after this
url, err := gemini.ParseURL(trimmedInput, "", true)
if err != nil {
return nil, apperrors.NewGeminiError(fmt.Errorf("failed to parse URL: %w", err), gemini.StatusBadRequest)
}
logger.Debug("normalized URL path", "id", connId, "remoteAddr", conn.RemoteAddr(), "path", url.Path)
err = checkRequestURL(url)
if err != nil {
return nil, err
}
serverRootPath := config.CONFIG.RootPath
localPath, err := calculateLocalPath(url.Path, serverRootPath)
if err != nil {
return nil, apperrors.NewGeminiError(err, gemini.StatusBadRequest)
}
logger.Debug("request path", "id", connId, "remoteAddr", conn.RemoteAddr(), "local path", localPath)
// Get file/directory information
info, err := os.Stat(localPath)
if err != nil {
return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound)
}
// Handle directory.
if info.IsDir() {
return generateResponseDir(ctx, localPath)
}
return generateResponseFile(ctx, localPath)
}
func generateResponseFile(ctx context.Context, localPath string) ([]byte, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
data, err := os.ReadFile(localPath)
if err != nil {
return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound)
}
var mimeType string
if path.Ext(localPath) == ".gmi" {
mimeType = "text/gemini"
} else {
mimeType = mimetype.Detect(data).String()
}
headerBytes := []byte(fmt.Sprintf("%d %s; lang=en\r\n", gemini.StatusSuccess, mimeType))
response := append(headerBytes, data...)
return response, nil
}
func generateResponseDir(ctx context.Context, localPath string) (output []byte, err error) {
if err := ctx.Err(); err != nil {
return nil, err
}
entries, err := os.ReadDir(localPath)
if err != nil {
return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound)
}
if config.CONFIG.DirIndexingEnabled {
var contents []string
contents = append(contents, "Directory index:\n\n")
contents = append(contents, "=> ../\n")
for _, entry := range entries {
// URL-encode entry names for safety
safeName := url.PathEscape(entry.Name())
if entry.IsDir() {
contents = append(contents, fmt.Sprintf("=> %s/\n", safeName))
} else {
contents = append(contents, fmt.Sprintf("=> %s\n", safeName))
}
}
data := []byte(strings.Join(contents, ""))
headerBytes := []byte(fmt.Sprintf("%d text/gemini; lang=en\r\n", gemini.StatusSuccess))
response := append(headerBytes, data...)
return response, nil
}
filePath := filepath.Join(localPath, "index.gmi")
return generateResponseFile(ctx, filePath)
}
func calculateLocalPath(input string, basePath string) (string, error) {
// Check for invalid characters early
if strings.ContainsAny(input, "\\") {
return "", apperrors.NewGeminiError(fmt.Errorf("invalid characters in path: %s", input), gemini.StatusBadRequest)
}
// If IsLocal(path) returns true, then Join(base, path)
// will always produce a path contained within base and
// Clean(path) will always produce an unrooted path with
// no ".." path elements.
filePath := input
filePath = strings.TrimPrefix(filePath, "/")
if filePath == "" {
filePath = "."
}
filePath = strings.TrimSuffix(filePath, "/")
localPath, err := filepath.Localize(filePath)
if err != nil || !filepath.IsLocal(localPath) {
return "", apperrors.NewGeminiError(fmt.Errorf("could not construct local path from %s: %s", input, err), gemini.StatusBadRequest)
}
filePath = path.Join(basePath, localPath)
return filePath, nil
}
func HandleConnection(ctx context.Context, conn *tls.Conn) (err error) {
logger := logging.FromContext(ctx)
start := time.Now()
var outputBytes []byte
// Set connection deadline based on context
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetDeadline(deadline)
}
defer func(conn *tls.Conn) {
end := time.Now()
tookMs := end.Sub(start).Milliseconds()
var responseHeader string
// On non-errors, just log response and close connection.
if err == nil {
// Log non-erroneous responses
if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 {
responseHeader = string(outputBytes[:i])
}
logger.Debug("Response", "responseHeader", responseHeader, "ms", tookMs)
_ = CloseConnection(conn)
return
}
// Handle context cancellation/timeout
if errors.Is(err, context.DeadlineExceeded) {
logger.Info("Connection timeout", "ms", tookMs)
responseHeader = fmt.Sprintf("%d Request timeout", gemini.StatusCGIError)
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
_ = CloseConnection(conn)
return
}
if errors.Is(err, context.Canceled) {
logger.Info("Connection cancelled", "ms", tookMs)
_ = CloseConnection(conn)
return
}
var code int
var responseMsg string
if apperrors.IsFatal(err) {
_ = CloseConnection(conn)
return
}
if apperrors.IsGeminiError(err) {
code = apperrors.GetStatusCode(err)
responseMsg = "server error"
} else {
code = gemini.StatusPermanentFailure
responseMsg = "server error"
}
responseHeader = fmt.Sprintf("%d %s", code, responseMsg)
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
_ = CloseConnection(conn)
}(conn)
// Check context before starting
if err := ctx.Err(); err != nil {
return err
}
// Gemini is supposed to have a 1kb limit
// on input requests.
buffer := make([]byte, 1025)
n, err := conn.Read(buffer)
if err != nil && err != io.EOF {
return apperrors.NewGeminiError(fmt.Errorf("failed to read connection data: %w", err), gemini.StatusBadRequest)
}
if n == 0 {
return apperrors.NewGeminiError(fmt.Errorf("client did not send data"), gemini.StatusBadRequest)
}
if n > 1024 {
return apperrors.NewGeminiError(fmt.Errorf("client request size %d > 1024 bytes", n), gemini.StatusBadRequest)
}
// Check context after read
if err := ctx.Err(); err != nil {
return err
}
dataBytes := buffer[:n]
dataString := string(dataBytes)
logger.Info("Request", "data", strings.TrimSpace(dataString), "size", len(dataBytes))
outputBytes, err = GenerateResponse(ctx, conn, dataString)
if len(outputBytes) > config.CONFIG.MaxResponseSize {
return apperrors.NewGeminiError(fmt.Errorf("max response size reached"), gemini.StatusTemporaryFailure)
}
if err != nil {
return err
}
// Check context before write
if err := ctx.Err(); err != nil {
return err
}
_, err = conn.Write(outputBytes)
if err != nil {
return err
}
return nil
}