package server import ( "bytes" "context" "crypto/tls" "errors" "fmt" "io" "net" "net/url" "os" "path" "path/filepath" "strconv" "strings" "time" "unicode/utf8" "gemserve/lib/apperrors" "gemserve/lib/logging" "gemserve/config" "gemserve/gemini" "github.com/gabriel-vasile/mimetype" ) type contextKey string const CtxConnIdKey contextKey = "connId" type ServerConfig interface { DirIndexingEnabled() bool RootPath() string } func CloseConnection(conn *tls.Conn) error { err := conn.CloseWrite() if err != nil { return apperrors.NewNetworkError(fmt.Errorf("failed to close TLS connection: %w", err)) } err = conn.Close() if err != nil { return apperrors.NewNetworkError(fmt.Errorf("failed to close connection: %w", err)) } return nil } func checkRequestURL(url *gemini.URL) error { if !utf8.ValidString(url.String()) { return apperrors.NewGeminiError(fmt.Errorf("invalid URL"), gemini.StatusBadRequest) } if url.Protocol != "gemini" { return apperrors.NewGeminiError(fmt.Errorf("invalid URL"), gemini.StatusProxyRequestRefused) } _, portStr, err := net.SplitHostPort(config.CONFIG.ListenAddr) if err != nil { return apperrors.NewGeminiError(fmt.Errorf("failed to parse server listen address: %w", err), gemini.StatusBadRequest) } listenPort, err := strconv.Atoi(portStr) if err != nil { return apperrors.NewGeminiError(fmt.Errorf("invalid server listen port: %w", err), gemini.StatusBadRequest) } if url.Port != listenPort { return apperrors.NewGeminiError(fmt.Errorf("port mismatch"), gemini.StatusProxyRequestRefused) } return nil } func GenerateResponse(ctx context.Context, conn *tls.Conn, input string) ([]byte, error) { logger := logging.FromContext(ctx) connId := ctx.Value(CtxConnIdKey).(string) if err := ctx.Err(); err != nil { return nil, err } trimmedInput := strings.TrimSpace(input) // url will have a cleaned and normalized path after this url, err := gemini.ParseURL(trimmedInput, "", true) if err != nil { return nil, apperrors.NewGeminiError(fmt.Errorf("failed to parse URL: %w", err), gemini.StatusBadRequest) } logger.Debug("normalized URL path", "id", connId, "remoteAddr", conn.RemoteAddr(), "path", url.Path) err = checkRequestURL(url) if err != nil { return nil, err } serverRootPath := config.CONFIG.RootPath localPath, err := calculateLocalPath(url.Path, serverRootPath) if err != nil { return nil, apperrors.NewGeminiError(err, gemini.StatusBadRequest) } logger.Debug("request path", "id", connId, "remoteAddr", conn.RemoteAddr(), "local path", localPath) // Get file/directory information info, err := os.Stat(localPath) if err != nil { return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound) } // Handle directory. if info.IsDir() { return generateResponseDir(ctx, localPath) } return generateResponseFile(ctx, localPath) } func generateResponseFile(ctx context.Context, localPath string) ([]byte, error) { if err := ctx.Err(); err != nil { return nil, err } data, err := os.ReadFile(localPath) if err != nil { return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound) } var mimeType string if path.Ext(localPath) == ".gmi" { mimeType = "text/gemini" } else { mimeType = mimetype.Detect(data).String() } headerBytes := []byte(fmt.Sprintf("%d %s; lang=en\r\n", gemini.StatusSuccess, mimeType)) response := append(headerBytes, data...) return response, nil } func generateResponseDir(ctx context.Context, localPath string) (output []byte, err error) { if err := ctx.Err(); err != nil { return nil, err } entries, err := os.ReadDir(localPath) if err != nil { return nil, apperrors.NewGeminiError(fmt.Errorf("failed to access path: %w", err), gemini.StatusNotFound) } if config.CONFIG.DirIndexingEnabled { var contents []string contents = append(contents, "Directory index:\n\n") contents = append(contents, "=> ../\n") for _, entry := range entries { // URL-encode entry names for safety safeName := url.PathEscape(entry.Name()) if entry.IsDir() { contents = append(contents, fmt.Sprintf("=> %s/\n", safeName)) } else { contents = append(contents, fmt.Sprintf("=> %s\n", safeName)) } } data := []byte(strings.Join(contents, "")) headerBytes := []byte(fmt.Sprintf("%d text/gemini; lang=en\r\n", gemini.StatusSuccess)) response := append(headerBytes, data...) return response, nil } filePath := filepath.Join(localPath, "index.gmi") return generateResponseFile(ctx, filePath) } func calculateLocalPath(input string, basePath string) (string, error) { // Check for invalid characters early if strings.ContainsAny(input, "\\") { return "", apperrors.NewGeminiError(fmt.Errorf("invalid characters in path: %s", input), gemini.StatusBadRequest) } // If IsLocal(path) returns true, then Join(base, path) // will always produce a path contained within base and // Clean(path) will always produce an unrooted path with // no ".." path elements. filePath := input filePath = strings.TrimPrefix(filePath, "/") if filePath == "" { filePath = "." } filePath = strings.TrimSuffix(filePath, "/") localPath, err := filepath.Localize(filePath) if err != nil || !filepath.IsLocal(localPath) { return "", apperrors.NewGeminiError(fmt.Errorf("could not construct local path from %s: %s", input, err), gemini.StatusBadRequest) } filePath = path.Join(basePath, localPath) return filePath, nil } func HandleConnection(ctx context.Context, conn *tls.Conn) (err error) { logger := logging.FromContext(ctx) start := time.Now() var outputBytes []byte // Set connection deadline based on context if deadline, ok := ctx.Deadline(); ok { _ = conn.SetDeadline(deadline) } defer func(conn *tls.Conn) { end := time.Now() tookMs := end.Sub(start).Milliseconds() var responseHeader string // On non-errors, just log response and close connection. if err == nil { // Log non-erroneous responses if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 { responseHeader = string(outputBytes[:i]) } logger.Debug("Response", "responseHeader", responseHeader, "ms", tookMs) _ = CloseConnection(conn) return } // Handle context cancellation/timeout if errors.Is(err, context.DeadlineExceeded) { logger.Info("Connection timeout", "ms", tookMs) responseHeader = fmt.Sprintf("%d Request timeout", gemini.StatusCGIError) _, _ = conn.Write([]byte(responseHeader + "\r\n")) _ = CloseConnection(conn) return } if errors.Is(err, context.Canceled) { logger.Info("Connection cancelled", "ms", tookMs) _ = CloseConnection(conn) return } var code int var responseMsg string if apperrors.IsFatal(err) { _ = CloseConnection(conn) return } if apperrors.IsGeminiError(err) { code = apperrors.GetStatusCode(err) responseMsg = "server error" } else { code = gemini.StatusPermanentFailure responseMsg = "server error" } responseHeader = fmt.Sprintf("%d %s", code, responseMsg) _, _ = conn.Write([]byte(responseHeader + "\r\n")) _ = CloseConnection(conn) }(conn) // Check context before starting if err := ctx.Err(); err != nil { return err } // Gemini is supposed to have a 1kb limit // on input requests. buffer := make([]byte, 1025) n, err := conn.Read(buffer) if err != nil && err != io.EOF { return apperrors.NewGeminiError(fmt.Errorf("failed to read connection data: %w", err), gemini.StatusBadRequest) } if n == 0 { return apperrors.NewGeminiError(fmt.Errorf("client did not send data"), gemini.StatusBadRequest) } if n > 1024 { return apperrors.NewGeminiError(fmt.Errorf("client request size %d > 1024 bytes", n), gemini.StatusBadRequest) } // Check context after read if err := ctx.Err(); err != nil { return err } dataBytes := buffer[:n] dataString := string(dataBytes) logger.Info("Request", "data", strings.TrimSpace(dataString), "size", len(dataBytes)) outputBytes, err = GenerateResponse(ctx, conn, dataString) if len(outputBytes) > config.CONFIG.MaxResponseSize { return apperrors.NewGeminiError(fmt.Errorf("max response size reached"), gemini.StatusTemporaryFailure) } if err != nil { return err } // Check context before write if err := ctx.Err(); err != nil { return err } _, err = conn.Write(outputBytes) if err != nil { return err } return nil }