297 lines
8.2 KiB
Go
297 lines
8.2 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/url"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
"unicode/utf8"
|
|
|
|
"gemserve/lib/apperrors"
|
|
"gemserve/lib/logging"
|
|
|
|
"gemserve/config"
|
|
"gemserve/gemini"
|
|
|
|
"github.com/gabriel-vasile/mimetype"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const CtxConnIdKey contextKey = "connId"
|
|
|
|
type ServerConfig interface {
|
|
DirIndexingEnabled() bool
|
|
RootPath() string
|
|
}
|
|
|
|
func CloseConnection(conn *tls.Conn) error {
|
|
err := conn.CloseWrite()
|
|
if err != nil {
|
|
return apperrors.NewNetworkError(fmt.Errorf("failed to close TLS connection: %w", err))
|
|
}
|
|
err = conn.Close()
|
|
if err != nil {
|
|
return apperrors.NewNetworkError(fmt.Errorf("failed to close connection: %w", err))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func checkRequestURL(url *gemini.URL) error {
|
|
if !utf8.ValidString(url.String()) {
|
|
return apperrors.NewGeminiError(fmt.Errorf("invalid URL"), gemini.StatusBadRequest)
|
|
}
|
|
|
|
if url.Protocol != "gemini" {
|
|
return apperrors.NewGeminiError(fmt.Errorf("invalid URL"), gemini.StatusProxyRequestRefused)
|
|
}
|
|
|
|
_, portStr, err := net.SplitHostPort(config.CONFIG.ListenAddr)
|
|
if err != nil {
|
|
return apperrors.NewGeminiError(fmt.Errorf("failed to parse server listen address: %w", err), gemini.StatusBadRequest)
|
|
}
|
|
listenPort, err := strconv.Atoi(portStr)
|
|
if err != nil {
|
|
return apperrors.NewGeminiError(fmt.Errorf("invalid server listen port: %w", err), gemini.StatusBadRequest)
|
|
}
|
|
if url.Port != listenPort {
|
|
return apperrors.NewGeminiError(fmt.Errorf("port mismatch"), gemini.StatusProxyRequestRefused)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func GenerateResponse(ctx context.Context, conn *tls.Conn, input string) ([]byte, error) {
|
|
logger := logging.FromContext(ctx)
|
|
connId := ctx.Value(CtxConnIdKey).(string)
|
|
|
|
if err := ctx.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
trimmedInput := strings.TrimSpace(input)
|
|
// url will have a cleaned and normalized path after this
|
|
url, err := gemini.ParseURL(trimmedInput, "", true)
|
|
if err != nil {
|
|
return nil, apperrors.NewGeminiError(fmt.Errorf("failed to parse URL: %w", err), gemini.StatusBadRequest)
|
|
}
|
|
logger.Debug("normalized URL path", "id", connId, "remoteAddr", conn.RemoteAddr(), "path", url.Path)
|
|
|
|
err = checkRequestURL(url)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
serverRootPath := config.CONFIG.RootPath
|
|
localPath, err := calculateLocalPath(url.Path, serverRootPath)
|
|
if err != nil {
|
|
return nil, apperrors.NewGeminiError(err, gemini.StatusBadRequest)
|
|
}
|
|
logger.Debug("request path", "id", connId, "remoteAddr", conn.RemoteAddr(), "local path", localPath)
|
|
|
|
// Get file/directory information
|
|
info, err := os.Stat(localPath)
|
|
if err != nil {
|
|
return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound)
|
|
}
|
|
|
|
// Handle directory.
|
|
if info.IsDir() {
|
|
return generateResponseDir(ctx, localPath)
|
|
}
|
|
return generateResponseFile(ctx, localPath)
|
|
}
|
|
|
|
func generateResponseFile(ctx context.Context, localPath string) ([]byte, error) {
|
|
if err := ctx.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
data, err := os.ReadFile(localPath)
|
|
if err != nil {
|
|
return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound)
|
|
}
|
|
|
|
var mimeType string
|
|
if path.Ext(localPath) == ".gmi" {
|
|
mimeType = "text/gemini"
|
|
} else {
|
|
mimeType = mimetype.Detect(data).String()
|
|
}
|
|
headerBytes := []byte(fmt.Sprintf("%d %s; lang=en\r\n", gemini.StatusSuccess, mimeType))
|
|
response := append(headerBytes, data...)
|
|
return response, nil
|
|
}
|
|
|
|
func generateResponseDir(ctx context.Context, localPath string) (output []byte, err error) {
|
|
if err := ctx.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
entries, err := os.ReadDir(localPath)
|
|
if err != nil {
|
|
return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound)
|
|
}
|
|
|
|
if config.CONFIG.DirIndexingEnabled {
|
|
var contents []string
|
|
contents = append(contents, "Directory index:\n\n")
|
|
contents = append(contents, "=> ../\n")
|
|
for _, entry := range entries {
|
|
// URL-encode entry names for safety
|
|
safeName := url.PathEscape(entry.Name())
|
|
if entry.IsDir() {
|
|
contents = append(contents, fmt.Sprintf("=> %s/\n", safeName))
|
|
} else {
|
|
contents = append(contents, fmt.Sprintf("=> %s\n", safeName))
|
|
}
|
|
}
|
|
data := []byte(strings.Join(contents, ""))
|
|
headerBytes := []byte(fmt.Sprintf("%d text/gemini; lang=en\r\n", gemini.StatusSuccess))
|
|
response := append(headerBytes, data...)
|
|
return response, nil
|
|
}
|
|
filePath := filepath.Join(localPath, "index.gmi")
|
|
return generateResponseFile(ctx, filePath)
|
|
}
|
|
|
|
func calculateLocalPath(input string, basePath string) (string, error) {
|
|
// Check for invalid characters early
|
|
if strings.ContainsAny(input, "\\") {
|
|
return "", apperrors.NewGeminiError(fmt.Errorf("invalid characters in path: %s", input), gemini.StatusBadRequest)
|
|
}
|
|
|
|
// If IsLocal(path) returns true, then Join(base, path)
|
|
// will always produce a path contained within base and
|
|
// Clean(path) will always produce an unrooted path with
|
|
// no ".." path elements.
|
|
filePath := input
|
|
filePath = strings.TrimPrefix(filePath, "/")
|
|
if filePath == "" {
|
|
filePath = "."
|
|
}
|
|
filePath = strings.TrimSuffix(filePath, "/")
|
|
|
|
localPath, err := filepath.Localize(filePath)
|
|
if err != nil || !filepath.IsLocal(localPath) {
|
|
return "", apperrors.NewGeminiError(fmt.Errorf("could not construct local path from %s: %s", input, err), gemini.StatusBadRequest)
|
|
}
|
|
|
|
filePath = path.Join(basePath, localPath)
|
|
return filePath, nil
|
|
}
|
|
|
|
func HandleConnection(ctx context.Context, conn *tls.Conn) (err error) {
|
|
logger := logging.FromContext(ctx)
|
|
start := time.Now()
|
|
var outputBytes []byte
|
|
|
|
// Set connection deadline based on context
|
|
if deadline, ok := ctx.Deadline(); ok {
|
|
_ = conn.SetDeadline(deadline)
|
|
}
|
|
|
|
defer func(conn *tls.Conn) {
|
|
end := time.Now()
|
|
tookMs := end.Sub(start).Milliseconds()
|
|
var responseHeader string
|
|
|
|
// On non-errors, just log response and close connection.
|
|
if err == nil {
|
|
// Log non-erroneous responses
|
|
if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 {
|
|
responseHeader = string(outputBytes[:i])
|
|
}
|
|
logger.Debug("Response", "responseHeader", responseHeader, "ms", tookMs)
|
|
_ = CloseConnection(conn)
|
|
return
|
|
}
|
|
|
|
// Handle context cancellation/timeout
|
|
if errors.Is(err, context.DeadlineExceeded) {
|
|
logger.Info("Connection timeout", "ms", tookMs)
|
|
responseHeader = fmt.Sprintf("%d Request timeout", gemini.StatusCGIError)
|
|
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
|
|
_ = CloseConnection(conn)
|
|
return
|
|
}
|
|
if errors.Is(err, context.Canceled) {
|
|
logger.Info("Connection cancelled", "ms", tookMs)
|
|
_ = CloseConnection(conn)
|
|
return
|
|
}
|
|
|
|
var code int
|
|
var responseMsg string
|
|
if apperrors.IsFatal(err) {
|
|
_ = CloseConnection(conn)
|
|
return
|
|
}
|
|
if apperrors.IsGeminiError(err) {
|
|
code = apperrors.GetStatusCode(err)
|
|
responseMsg = "server error"
|
|
} else {
|
|
code = gemini.StatusPermanentFailure
|
|
responseMsg = "server error"
|
|
}
|
|
responseHeader = fmt.Sprintf("%d %s", code, responseMsg)
|
|
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
|
|
_ = CloseConnection(conn)
|
|
}(conn)
|
|
|
|
// Check context before starting
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Gemini is supposed to have a 1kb limit
|
|
// on input requests.
|
|
buffer := make([]byte, 1025)
|
|
|
|
n, err := conn.Read(buffer)
|
|
if err != nil && err != io.EOF {
|
|
return apperrors.NewGeminiError(fmt.Errorf("failed to read connection data: %w", err), gemini.StatusBadRequest)
|
|
}
|
|
if n == 0 {
|
|
return apperrors.NewGeminiError(fmt.Errorf("client did not send data"), gemini.StatusBadRequest)
|
|
}
|
|
if n > 1024 {
|
|
return apperrors.NewGeminiError(fmt.Errorf("client request size %d > 1024 bytes", n), gemini.StatusBadRequest)
|
|
}
|
|
|
|
// Check context after read
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
dataBytes := buffer[:n]
|
|
dataString := string(dataBytes)
|
|
|
|
logger.Info("Request", "data", strings.TrimSpace(dataString), "size", len(dataBytes))
|
|
outputBytes, err = GenerateResponse(ctx, conn, dataString)
|
|
if len(outputBytes) > config.CONFIG.MaxResponseSize {
|
|
return apperrors.NewGeminiError(fmt.Errorf("max response size reached"), gemini.StatusTemporaryFailure)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Check context before write
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = conn.Write(outputBytes)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|