.
This commit is contained in:
153
server/server.go
153
server/server.go
@@ -1,9 +1,12 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -11,6 +14,7 @@ import (
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"gemserve/lib/apperrors"
|
||||
@@ -22,11 +26,27 @@ import (
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const CtxConnIdKey contextKey = "connId"
|
||||
|
||||
type ServerConfig interface {
|
||||
DirIndexingEnabled() bool
|
||||
RootPath() string
|
||||
}
|
||||
|
||||
func CloseConnection(conn *tls.Conn) error {
|
||||
err := conn.CloseWrite()
|
||||
if err != nil {
|
||||
return apperrors.NewNetworkError(fmt.Errorf("failed to close TLS connection: %w", err))
|
||||
}
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
return apperrors.NewNetworkError(fmt.Errorf("failed to close connection: %w", err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkRequestURL(url *gemini.URL) error {
|
||||
if !utf8.ValidString(url.String()) {
|
||||
return apperrors.NewGeminiError(fmt.Errorf("invalid URL"), gemini.StatusBadRequest)
|
||||
@@ -45,13 +65,19 @@ func checkRequestURL(url *gemini.URL) error {
|
||||
return apperrors.NewGeminiError(fmt.Errorf("invalid server listen port: %w", err), gemini.StatusBadRequest)
|
||||
}
|
||||
if url.Port != listenPort {
|
||||
return apperrors.NewGeminiError(fmt.Errorf("failed to parse URL: %w", err), gemini.StatusProxyRequestRefused)
|
||||
return apperrors.NewGeminiError(fmt.Errorf("port mismatch"), gemini.StatusProxyRequestRefused)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GenerateResponse(ctx context.Context, conn *tls.Conn, connId string, input string) ([]byte, error) {
|
||||
func GenerateResponse(ctx context.Context, conn *tls.Conn, input string) ([]byte, error) {
|
||||
logger := logging.FromContext(ctx)
|
||||
connId := ctx.Value(CtxConnIdKey).(string)
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
trimmedInput := strings.TrimSpace(input)
|
||||
// url will have a cleaned and normalized path after this
|
||||
url, err := gemini.ParseURL(trimmedInput, "", true)
|
||||
@@ -80,12 +106,15 @@ func GenerateResponse(ctx context.Context, conn *tls.Conn, connId string, input
|
||||
|
||||
// Handle directory.
|
||||
if info.IsDir() {
|
||||
return generateResponseDir(localPath)
|
||||
return generateResponseDir(ctx, localPath)
|
||||
}
|
||||
return generateResponseFile(localPath)
|
||||
return generateResponseFile(ctx, localPath)
|
||||
}
|
||||
|
||||
func generateResponseFile(localPath string) ([]byte, error) {
|
||||
func generateResponseFile(ctx context.Context, localPath string) ([]byte, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data, err := os.ReadFile(localPath)
|
||||
if err != nil {
|
||||
return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound)
|
||||
@@ -102,7 +131,10 @@ func generateResponseFile(localPath string) ([]byte, error) {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func generateResponseDir(localPath string) (output []byte, err error) {
|
||||
func generateResponseDir(ctx context.Context, localPath string) (output []byte, err error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entries, err := os.ReadDir(localPath)
|
||||
if err != nil {
|
||||
return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound)
|
||||
@@ -127,7 +159,7 @@ func generateResponseDir(localPath string) (output []byte, err error) {
|
||||
return response, nil
|
||||
}
|
||||
filePath := filepath.Join(localPath, "index.gmi")
|
||||
return generateResponseFile(filePath)
|
||||
return generateResponseFile(ctx, filePath)
|
||||
}
|
||||
|
||||
func calculateLocalPath(input string, basePath string) (string, error) {
|
||||
@@ -155,3 +187,110 @@ func calculateLocalPath(input string, basePath string) (string, error) {
|
||||
filePath = path.Join(basePath, localPath)
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
func HandleConnection(ctx context.Context, conn *tls.Conn) (err error) {
|
||||
logger := logging.FromContext(ctx)
|
||||
start := time.Now()
|
||||
var outputBytes []byte
|
||||
|
||||
// Set connection deadline based on context
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
_ = conn.SetDeadline(deadline)
|
||||
}
|
||||
|
||||
defer func(conn *tls.Conn) {
|
||||
end := time.Now()
|
||||
tookMs := end.Sub(start).Milliseconds()
|
||||
var responseHeader string
|
||||
|
||||
// On non-errors, just log response and close connection.
|
||||
if err == nil {
|
||||
// Log non-erroneous responses
|
||||
if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 {
|
||||
responseHeader = string(outputBytes[:i])
|
||||
}
|
||||
logger.Debug("Response", "responseHeader", responseHeader, "ms", tookMs)
|
||||
_ = CloseConnection(conn)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle context cancellation/timeout
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.Info("Connection timeout", "ms", tookMs)
|
||||
responseHeader = fmt.Sprintf("%d Request timeout", gemini.StatusCGIError)
|
||||
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
|
||||
_ = CloseConnection(conn)
|
||||
return
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
logger.Info("Connection cancelled", "ms", tookMs)
|
||||
_ = CloseConnection(conn)
|
||||
return
|
||||
}
|
||||
|
||||
var code int
|
||||
var responseMsg string
|
||||
if apperrors.IsFatal(err) {
|
||||
_ = CloseConnection(conn)
|
||||
return
|
||||
}
|
||||
if apperrors.IsGeminiError(err) {
|
||||
code = apperrors.GetStatusCode(err)
|
||||
responseMsg = "server error"
|
||||
} else {
|
||||
code = gemini.StatusPermanentFailure
|
||||
responseMsg = "server error"
|
||||
}
|
||||
responseHeader = fmt.Sprintf("%d %s", code, responseMsg)
|
||||
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
|
||||
_ = CloseConnection(conn)
|
||||
}(conn)
|
||||
|
||||
// Check context before starting
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Gemini is supposed to have a 1kb limit
|
||||
// on input requests.
|
||||
buffer := make([]byte, 1025)
|
||||
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil && err != io.EOF {
|
||||
return apperrors.NewGeminiError(fmt.Errorf("failed to read connection data: %w", err), gemini.StatusBadRequest)
|
||||
}
|
||||
if n == 0 {
|
||||
return apperrors.NewGeminiError(fmt.Errorf("client did not send data"), gemini.StatusBadRequest)
|
||||
}
|
||||
if n > 1024 {
|
||||
return apperrors.NewGeminiError(fmt.Errorf("client request size %d > 1024 bytes", n), gemini.StatusBadRequest)
|
||||
}
|
||||
|
||||
// Check context after read
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dataBytes := buffer[:n]
|
||||
dataString := string(dataBytes)
|
||||
|
||||
logger.Info("Request", "data", strings.TrimSpace(dataString), "size", len(dataBytes))
|
||||
outputBytes, err = GenerateResponse(ctx, conn, dataString)
|
||||
if len(outputBytes) > config.CONFIG.MaxResponseSize {
|
||||
return apperrors.NewGeminiError(fmt.Errorf("max response size reached"), gemini.StatusTemporaryFailure)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check context before write
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = conn.Write(outputBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user