This commit is contained in:
antanst
2025-10-10 15:20:45 +03:00
parent 3a5835fc42
commit d336bdffba
10 changed files with 494 additions and 374 deletions

View File

@@ -1,9 +1,12 @@
package server
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/url"
"os"
@@ -11,6 +14,7 @@ import (
"path/filepath"
"strconv"
"strings"
"time"
"unicode/utf8"
"gemserve/lib/apperrors"
@@ -22,11 +26,27 @@ import (
"github.com/gabriel-vasile/mimetype"
)
type contextKey string
const CtxConnIdKey contextKey = "connId"
type ServerConfig interface {
DirIndexingEnabled() bool
RootPath() string
}
func CloseConnection(conn *tls.Conn) error {
err := conn.CloseWrite()
if err != nil {
return apperrors.NewNetworkError(fmt.Errorf("failed to close TLS connection: %w", err))
}
err = conn.Close()
if err != nil {
return apperrors.NewNetworkError(fmt.Errorf("failed to close connection: %w", err))
}
return nil
}
func checkRequestURL(url *gemini.URL) error {
if !utf8.ValidString(url.String()) {
return apperrors.NewGeminiError(fmt.Errorf("invalid URL"), gemini.StatusBadRequest)
@@ -45,13 +65,19 @@ func checkRequestURL(url *gemini.URL) error {
return apperrors.NewGeminiError(fmt.Errorf("invalid server listen port: %w", err), gemini.StatusBadRequest)
}
if url.Port != listenPort {
return apperrors.NewGeminiError(fmt.Errorf("failed to parse URL: %w", err), gemini.StatusProxyRequestRefused)
return apperrors.NewGeminiError(fmt.Errorf("port mismatch"), gemini.StatusProxyRequestRefused)
}
return nil
}
func GenerateResponse(ctx context.Context, conn *tls.Conn, connId string, input string) ([]byte, error) {
func GenerateResponse(ctx context.Context, conn *tls.Conn, input string) ([]byte, error) {
logger := logging.FromContext(ctx)
connId := ctx.Value(CtxConnIdKey).(string)
if err := ctx.Err(); err != nil {
return nil, err
}
trimmedInput := strings.TrimSpace(input)
// url will have a cleaned and normalized path after this
url, err := gemini.ParseURL(trimmedInput, "", true)
@@ -80,12 +106,15 @@ func GenerateResponse(ctx context.Context, conn *tls.Conn, connId string, input
// Handle directory.
if info.IsDir() {
return generateResponseDir(localPath)
return generateResponseDir(ctx, localPath)
}
return generateResponseFile(localPath)
return generateResponseFile(ctx, localPath)
}
func generateResponseFile(localPath string) ([]byte, error) {
func generateResponseFile(ctx context.Context, localPath string) ([]byte, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
data, err := os.ReadFile(localPath)
if err != nil {
return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound)
@@ -102,7 +131,10 @@ func generateResponseFile(localPath string) ([]byte, error) {
return response, nil
}
func generateResponseDir(localPath string) (output []byte, err error) {
func generateResponseDir(ctx context.Context, localPath string) (output []byte, err error) {
if err := ctx.Err(); err != nil {
return nil, err
}
entries, err := os.ReadDir(localPath)
if err != nil {
return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound)
@@ -127,7 +159,7 @@ func generateResponseDir(localPath string) (output []byte, err error) {
return response, nil
}
filePath := filepath.Join(localPath, "index.gmi")
return generateResponseFile(filePath)
return generateResponseFile(ctx, filePath)
}
func calculateLocalPath(input string, basePath string) (string, error) {
@@ -155,3 +187,110 @@ func calculateLocalPath(input string, basePath string) (string, error) {
filePath = path.Join(basePath, localPath)
return filePath, nil
}
func HandleConnection(ctx context.Context, conn *tls.Conn) (err error) {
logger := logging.FromContext(ctx)
start := time.Now()
var outputBytes []byte
// Set connection deadline based on context
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetDeadline(deadline)
}
defer func(conn *tls.Conn) {
end := time.Now()
tookMs := end.Sub(start).Milliseconds()
var responseHeader string
// On non-errors, just log response and close connection.
if err == nil {
// Log non-erroneous responses
if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 {
responseHeader = string(outputBytes[:i])
}
logger.Debug("Response", "responseHeader", responseHeader, "ms", tookMs)
_ = CloseConnection(conn)
return
}
// Handle context cancellation/timeout
if errors.Is(err, context.DeadlineExceeded) {
logger.Info("Connection timeout", "ms", tookMs)
responseHeader = fmt.Sprintf("%d Request timeout", gemini.StatusCGIError)
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
_ = CloseConnection(conn)
return
}
if errors.Is(err, context.Canceled) {
logger.Info("Connection cancelled", "ms", tookMs)
_ = CloseConnection(conn)
return
}
var code int
var responseMsg string
if apperrors.IsFatal(err) {
_ = CloseConnection(conn)
return
}
if apperrors.IsGeminiError(err) {
code = apperrors.GetStatusCode(err)
responseMsg = "server error"
} else {
code = gemini.StatusPermanentFailure
responseMsg = "server error"
}
responseHeader = fmt.Sprintf("%d %s", code, responseMsg)
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
_ = CloseConnection(conn)
}(conn)
// Check context before starting
if err := ctx.Err(); err != nil {
return err
}
// Gemini is supposed to have a 1kb limit
// on input requests.
buffer := make([]byte, 1025)
n, err := conn.Read(buffer)
if err != nil && err != io.EOF {
return apperrors.NewGeminiError(fmt.Errorf("failed to read connection data: %w", err), gemini.StatusBadRequest)
}
if n == 0 {
return apperrors.NewGeminiError(fmt.Errorf("client did not send data"), gemini.StatusBadRequest)
}
if n > 1024 {
return apperrors.NewGeminiError(fmt.Errorf("client request size %d > 1024 bytes", n), gemini.StatusBadRequest)
}
// Check context after read
if err := ctx.Err(); err != nil {
return err
}
dataBytes := buffer[:n]
dataString := string(dataBytes)
logger.Info("Request", "data", strings.TrimSpace(dataString), "size", len(dataBytes))
outputBytes, err = GenerateResponse(ctx, conn, dataString)
if len(outputBytes) > config.CONFIG.MaxResponseSize {
return apperrors.NewGeminiError(fmt.Errorf("max response size reached"), gemini.StatusTemporaryFailure)
}
if err != nil {
return err
}
// Check context before write
if err := ctx.Err(); err != nil {
return err
}
_, err = conn.Write(outputBytes)
if err != nil {
return err
}
return nil
}