diff --git a/common/gemini_url.go b/common/gemini_url.go index 64e8bad..dc064f6 100644 --- a/common/gemini_url.go +++ b/common/gemini_url.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "gemserve/errors" + "git.antanst.com/antanst/xerrors" ) type URL struct { @@ -28,7 +28,7 @@ func (u *URL) Scan(value interface{}) error { } b, ok := value.(string) if !ok { - return errors.NewFatalError(fmt.Errorf("database scan error: expected string, got %T", value)) + return xerrors.NewError(fmt.Errorf("database scan error: expected string, got %T", value), 0, "Database scan error", true) } parsedURL, err := ParseURL(b, "", false) if err != nil { @@ -67,12 +67,10 @@ func ParseURL(input string, descr string, normalize bool) (*URL, error) { } else { u, err = url.Parse(input) if err != nil { - return nil, errors.NewError(fmt.Errorf("error parsing URL: %w: %s", err, input)) + return nil, xerrors.NewError(fmt.Errorf("error parsing URL: %w: %s", err, input), 0, "URL parse error", false) } } - if u.Scheme != "gemini" { - return nil, errors.NewError(fmt.Errorf("error parsing URL: not a gemini URL: %s", input)) - } + protocol := u.Scheme hostname := u.Hostname() strPort := u.Port() @@ -82,7 +80,7 @@ func ParseURL(input string, descr string, normalize bool) (*URL, error) { } port, err := strconv.Atoi(strPort) if err != nil { - return nil, errors.NewError(fmt.Errorf("error parsing URL: %w: %s", err, input)) + return nil, xerrors.NewError(fmt.Errorf("error parsing URL: %w: %s", err, input), 0, "URL parse error", false) } full := fmt.Sprintf("%s://%s:%d%s", protocol, hostname, port, urlPath) // full field should also contain query params and url fragments @@ -128,13 +126,13 @@ func NormalizeURL(rawURL string) (*url.URL, error) { // Parse the URL u, err := url.Parse(rawURL) if err != nil { - return nil, errors.NewError(fmt.Errorf("error normalizing URL: %w: %s", err, rawURL)) + return nil, xerrors.NewError(fmt.Errorf("error normalizing URL: %w: %s", err, rawURL), 0, "URL normalization error", false) } if u.Scheme == "" { - return nil, errors.NewError(fmt.Errorf("error normalizing URL: No scheme: %s", rawURL)) + return nil, xerrors.NewError(fmt.Errorf("error normalizing URL: No scheme: %s", rawURL), 0, "Missing URL scheme", false) } if u.Host == "" { - return nil, errors.NewError(fmt.Errorf("error normalizing URL: No host: %s", rawURL)) + return nil, xerrors.NewError(fmt.Errorf("error normalizing URL: No host: %s", rawURL), 0, "Missing URL host", false) } // Convert scheme to lowercase diff --git a/errors/errors.go b/errors/errors.go deleted file mode 100644 index 6bd39ea..0000000 --- a/errors/errors.go +++ /dev/null @@ -1,114 +0,0 @@ -package errors - -import ( - "errors" - "fmt" - "runtime" - "strings" -) - -type fatal interface { - Fatal() bool -} - -func IsFatal(err error) bool { - te, ok := errors.Unwrap(err).(fatal) - return ok && te.Fatal() -} - -func As(err error, target any) bool { - return errors.As(err, target) -} - -func Is(err, target error) bool { - return errors.Is(err, target) -} - -func Unwrap(err error) error { - return errors.Unwrap(err) -} - -type Error struct { - Err error - Stack string - fatal bool -} - -func (e *Error) Error() string { - var sb strings.Builder - sb.WriteString(fmt.Sprintf("%v\n", e.Err)) - return sb.String() -} - -func (e *Error) ErrorWithStack() string { - var sb strings.Builder - sb.WriteString(fmt.Sprintf("%v\n", e.Err)) - sb.WriteString(fmt.Sprintf("Stack Trace:\n%s", e.Stack)) - return sb.String() -} - -func (e *Error) Fatal() bool { - return e.fatal -} - -func (e *Error) Unwrap() error { - return e.Err -} - -func NewError(err error) error { - if err == nil { - return nil - } - - // Check if it's already of our own - // Error type, so we don't add stack twice. - var asError *Error - if errors.As(err, &asError) { - return err - } - - // Get the stack trace - var stack strings.Builder - buf := make([]uintptr, 50) - n := runtime.Callers(2, buf) - frames := runtime.CallersFrames(buf[:n]) - - // Format the stack trace - for { - frame, more := frames.Next() - // Skip runtime and standard library frames - if !strings.Contains(frame.File, "runtime/") { - stack.WriteString(fmt.Sprintf("\t%s:%d - %s\n", frame.File, frame.Line, frame.Function)) - } - if !more { - break - } - } - - return &Error{ - Err: err, - Stack: stack.String(), - } -} - -func NewFatalError(err error) error { - if err == nil { - return nil - } - - // Check if it's already of our own - // Error type. - var asError *Error - if errors.As(err, &asError) { - return err - } - err2 := NewError(err) - err2.(*Error).fatal = true - return err2 -} - -var ConnectionError error = fmt.Errorf("connection error") - -func NewConnectionError(err error) error { - return fmt.Errorf("%w: %w", ConnectionError, err) -} diff --git a/errors/errors_test.go b/errors/errors_test.go deleted file mode 100644 index 30bde0e..0000000 --- a/errors/errors_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package errors - -import ( - "errors" - "fmt" - "testing" -) - -type CustomError struct { - Err error -} - -func (e *CustomError) Error() string { return e.Err.Error() } - -func IsCustomError(err error) bool { - var asError *CustomError - return errors.As(err, &asError) -} - -func TestWrapping(t *testing.T) { - t.Parallel() - originalErr := errors.New("original error") - err1 := NewError(originalErr) - if !errors.Is(err1, originalErr) { - t.Errorf("original error is not wrapped") - } - if !Is(err1, originalErr) { - t.Errorf("original error is not wrapped") - } - unwrappedErr := errors.Unwrap(err1) - if !errors.Is(unwrappedErr, originalErr) { - t.Errorf("original error is not wrapped") - } - if !Is(unwrappedErr, originalErr) { - t.Errorf("original error is not wrapped") - } - unwrappedErr = Unwrap(err1) - if !errors.Is(unwrappedErr, originalErr) { - t.Errorf("original error is not wrapped") - } - if !Is(unwrappedErr, originalErr) { - t.Errorf("original error is not wrapped") - } - wrappedErr := fmt.Errorf("wrapped: %w", originalErr) - if !errors.Is(wrappedErr, originalErr) { - t.Errorf("original error is not wrapped") - } - if !Is(wrappedErr, originalErr) { - t.Errorf("original error is not wrapped") - } -} - -func TestNewError(t *testing.T) { - t.Parallel() - originalErr := &CustomError{errors.New("err1")} - if !IsCustomError(originalErr) { - t.Errorf("TestNewError fail #1") - } - err1 := NewError(originalErr) - if !IsCustomError(err1) { - t.Errorf("TestNewError fail #2") - } - wrappedErr1 := fmt.Errorf("wrapped %w", err1) - if !IsCustomError(wrappedErr1) { - t.Errorf("TestNewError fail #3") - } - unwrappedErr1 := Unwrap(wrappedErr1) - if !IsCustomError(unwrappedErr1) { - t.Errorf("TestNewError fail #4") - } -} diff --git a/logging/logging.go b/logging/logging.go deleted file mode 100644 index 3b8ec62..0000000 --- a/logging/logging.go +++ /dev/null @@ -1,23 +0,0 @@ -package logging - -import ( - "fmt" - - zlog "github.com/rs/zerolog/log" -) - -func LogDebug(format string, args ...interface{}) { - zlog.Debug().Msg(fmt.Sprintf(format, args...)) -} - -func LogInfo(format string, args ...interface{}) { - zlog.Info().Msg(fmt.Sprintf(format, args...)) -} - -func LogWarn(format string, args ...interface{}) { - zlog.Warn().Msg(fmt.Sprintf(format, args...)) -} - -func LogError(format string, args ...interface{}) { - zlog.Error().Err(fmt.Errorf(format, args...)).Msg("") -} diff --git a/main.go b/main.go index 3f29e2a..d320b80 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ package main import ( "bytes" "crypto/tls" + "errors" "fmt" "io" "net" @@ -13,46 +14,40 @@ import ( "time" "gemserve/config" - "gemserve/errors" - "gemserve/logging" "gemserve/server" "gemserve/uid" - "github.com/rs/zerolog" - zlog "github.com/rs/zerolog/log" + logging "git.antanst.com/antanst/logging" + "git.antanst.com/antanst/xerrors" ) +var fatalErrors chan error + func main() { config.CONFIG = *config.GetConfig() - zerolog.TimeFieldFormat = zerolog.TimeFormatUnix - zerolog.SetGlobalLevel(config.CONFIG.LogLevel) - zlog.Logger = zlog.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: "[2006-01-02 15:04:05]"}) + + logging.InitSlogger(config.CONFIG.LogLevel) + err := runApp() if err != nil { - fmt.Printf("%v\n", err) - logging.LogError("%v", err) - os.Exit(1) + panic(fmt.Sprintf("Fatal Error: %v", err)) } + os.Exit(0) } func runApp() error { logging.LogInfo("Starting up. Press Ctrl+C to exit") - var listenHost string - if len(os.Args) != 2 { - listenHost = "0.0.0.0:1965" - } else { - listenHost = os.Args[1] - } + listenHost := config.CONFIG.Listen signals := make(chan os.Signal, 1) signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) - serverErrors := make(chan error) + fatalErrors = make(chan error) go func() { err := startServer(listenHost) if err != nil { - serverErrors <- errors.NewFatalError(err) + fatalErrors <- xerrors.NewError(err, 0, "Server startup failed", true) } }() @@ -61,16 +56,16 @@ func runApp() error { case <-signals: logging.LogWarn("Received SIGINT or SIGTERM signal, exiting") return nil - case serverError := <-serverErrors: - return errors.NewFatalError(serverError) + case fatalError := <-fatalErrors: + return xerrors.NewError(fatalError, 0, "Server error", true) } } } func startServer(listenHost string) (err error) { - cert, err := tls.LoadX509KeyPair("/certs/cert", "/certs/key") + cert, err := tls.LoadX509KeyPair("certs/server.crt", "certs/server.key") if err != nil { - return errors.NewFatalError(fmt.Errorf("failed to load certificate: %w", err)) + return xerrors.NewError(fmt.Errorf("failed to load certificate: %w", err), 0, "Certificate loading failed", true) } tlsConfig := &tls.Config{ @@ -80,7 +75,7 @@ func startServer(listenHost string) (err error) { listener, err := tls.Listen("tcp", listenHost, tlsConfig) if err != nil { - return errors.NewFatalError(fmt.Errorf("failed to create listener: %w", err)) + return xerrors.NewError(fmt.Errorf("failed to create listener: %w", err), 0, "Listener creation failed", true) } defer func(listener net.Listener) { // If we've got an error closing the @@ -88,7 +83,7 @@ func startServer(listenHost string) (err error) { // the original error (if not nil) errClose := listener.Close() if errClose != nil && err == nil { - err = errors.NewFatalError(err) + err = xerrors.NewError(err, 0, "Listener close failed", true) } }(listener) @@ -102,16 +97,16 @@ func startServer(listenHost string) (err error) { } go func() { - err := handleConnection(conn.(*tls.Conn)) + remoteAddr := conn.RemoteAddr().String() + connId := uid.UID() + err := handleConnection(conn.(*tls.Conn), connId, remoteAddr) if err != nil { - var asErr *errors.Error - if errors.As(err, &asErr) { - logging.LogError("Unexpected error: %v", err.(*errors.Error).ErrorWithStack()) + var asErr *xerrors.XError + if errors.As(err, &asErr) && asErr.IsFatal { + fatalErrors <- asErr + return } else { - logging.LogError("Unexpected error: %v", err) - } - if config.CONFIG.PanicOnUnexpectedError { - panic("Encountered unexpected error") + logging.LogWarn("%s %s Connection failed: %d %s (%v)", connId, remoteAddr, asErr.Code, asErr.UserMsg, err) } } }() @@ -121,56 +116,68 @@ func startServer(listenHost string) (err error) { func closeConnection(conn *tls.Conn) error { err := conn.CloseWrite() if err != nil { - return errors.NewConnectionError(fmt.Errorf("failed to close TLS connection: %w", err)) + return xerrors.NewError(fmt.Errorf("failed to close TLS connection: %w", err), 50, "Connection close failed", false) } err = conn.Close() if err != nil { - return errors.NewConnectionError(fmt.Errorf("failed to close connection: %w", err)) + return xerrors.NewError(fmt.Errorf("failed to close connection: %w", err), 50, "Connection close failed", false) } return nil } -func handleConnection(conn *tls.Conn) (err error) { - remoteAddr := conn.RemoteAddr().String() - connId := uid.UID() +func handleConnection(conn *tls.Conn, connId string, remoteAddr string) (err error) { start := time.Now() var outputBytes []byte defer func(conn *tls.Conn) { - // Three possible cases here: - // - We don't have an error - // - We have a ConnectionError, which we don't propagate up - // - We have an unexpected error. end := time.Now() tookMs := end.Sub(start).Milliseconds() var responseHeader string - if err != nil { - _, _ = conn.Write([]byte("50 server error")) - responseHeader = "50 server error" - // We don't propagate connection errors up. - if errors.Is(err, errors.ConnectionError) { - logging.LogInfo("%s %s %v", connId, remoteAddr, err) - err = nil - } - } else { + + // On non-errors, just log response and close connection. + if err == nil { + // Log non-erroneous responses if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 { responseHeader = string(outputBytes[:i]) } + logging.LogInfo("%s %s response %s (%dms)", connId, remoteAddr, responseHeader, tookMs) + _ = closeConnection(conn) + return } - logging.LogInfo("%s %s response %s (%dms)", connId, remoteAddr, responseHeader, tookMs) + + var code int + var responseMsg string + var xErr *xerrors.XError + if errors.As(err, &xErr) { + // On fatal errors, immediatelly return the error. + if xErr.IsFatal { + _ = closeConnection(conn) + return + } + code = xErr.Code + responseMsg = xErr.UserMsg + } else { + code = 50 + responseMsg = "server error" + } + responseHeader = fmt.Sprintf("%d %s", code, responseMsg) + _, _ = conn.Write([]byte(responseHeader)) _ = closeConnection(conn) }(conn) // Gemini is supposed to have a 1kb limit // on input requests. - buffer := make([]byte, 1024) + buffer := make([]byte, 1025) n, err := conn.Read(buffer) if err != nil && err != io.EOF { - return errors.NewConnectionError(fmt.Errorf("failed to read connection data: %w", err)) + return xerrors.NewError(fmt.Errorf("failed to read connection data: %w", err), 59, "Connection read failed", false) } if n == 0 { - return errors.NewConnectionError(fmt.Errorf("client did not send data")) + return xerrors.NewError(fmt.Errorf("client did not send data"), 59, "No data received", false) + } + if n > 1024 { + return xerrors.NewError(fmt.Errorf("client request size %d > 1024 bytes", n), 59, "Request too large", false) } dataBytes := buffer[:n] diff --git a/server/server.go b/server/server.go index 4b51368..e457a67 100644 --- a/server/server.go +++ b/server/server.go @@ -2,16 +2,19 @@ package server import ( "crypto/tls" + "errors" "fmt" + "net" "os" "path" "path/filepath" + "strconv" "strings" "gemserve/common" "gemserve/config" - "gemserve/errors" - "gemserve/logging" + logging "git.antanst.com/antanst/logging" + "git.antanst.com/antanst/xerrors" "github.com/gabriel-vasile/mimetype" ) @@ -20,18 +23,43 @@ type ServerConfig interface { RootPath() string } +func checkRequestURL(url *common.URL) error { + if url.Protocol != "gemini" { + return xerrors.NewError(fmt.Errorf("invalid URL"), 53, "URL Protocol not Gemini, proxying refused", false) + } + + _, portStr, err := net.SplitHostPort(config.CONFIG.Listen) + if err != nil { + return xerrors.NewError(fmt.Errorf("failed to parse listen address: %w", err), 50, "Server configuration error", false) + } + listenPort, err := strconv.Atoi(portStr) + if err != nil { + return xerrors.NewError(fmt.Errorf("failed to parse listen port: %w", err), 50, "Server configuration error", false) + } + if url.Port != listenPort { + return xerrors.NewError(fmt.Errorf("failed to parse URL: %w", err), 53, "invalid URL port, proxying refused", false) + } + return nil +} + func GenerateResponse(conn *tls.Conn, connId string, input string) ([]byte, error) { trimmedInput := strings.TrimSpace(input) // url will have a cleaned and normalized path after this url, err := common.ParseURL(trimmedInput, "", true) if err != nil { - return nil, errors.NewConnectionError(fmt.Errorf("failed to parse URL: %w", err)) + return nil, xerrors.NewError(fmt.Errorf("failed to parse URL: %w", err), 59, "Invalid URL", false) } logging.LogDebug("%s %s normalized URL path: %s", connId, conn.RemoteAddr(), url.Path) + + err = checkRequestURL(url) + if err != nil { + return nil, err + } + serverRootPath := config.CONFIG.RootPath localPath, err := calculateLocalPath(url.Path, serverRootPath) if err != nil { - return nil, errors.NewConnectionError(err) + return nil, xerrors.NewError(err, 59, "Invalid path", false) } logging.LogDebug("%s %s request file path: %s", connId, conn.RemoteAddr(), localPath) @@ -40,7 +68,7 @@ func GenerateResponse(conn *tls.Conn, connId string, input string) ([]byte, erro if errors.Is(err, os.ErrNotExist) || errors.Is(err, os.ErrPermission) { return []byte("51 not found\r\n"), nil } else if err != nil { - return nil, errors.NewConnectionError(fmt.Errorf("%s %s failed to access path: %w", connId, conn.RemoteAddr(), err)) + return nil, xerrors.NewError(fmt.Errorf("%s %s failed to access path: %w", connId, conn.RemoteAddr(), err), 0, "Path access failed", false) } // Handle directory. @@ -55,7 +83,7 @@ func generateResponseFile(conn *tls.Conn, connId string, url *common.URL, localP if errors.Is(err, os.ErrNotExist) || errors.Is(err, os.ErrPermission) { return []byte("51 not found\r\n"), nil } else if err != nil { - return nil, errors.NewConnectionError(fmt.Errorf("%s %s failed to read file: %w", connId, conn.RemoteAddr(), err)) + return nil, xerrors.NewError(fmt.Errorf("%s %s failed to read file: %w", connId, conn.RemoteAddr(), err), 0, "File read failed", false) } var mimeType string @@ -64,7 +92,7 @@ func generateResponseFile(conn *tls.Conn, connId string, url *common.URL, localP } else { mimeType = mimetype.Detect(data).String() } - headerBytes := []byte(fmt.Sprintf("20 %s\r\n", mimeType)) + headerBytes := []byte(fmt.Sprintf("20 %s; lang=en\r\n", mimeType)) response := append(headerBytes, data...) return response, nil } @@ -72,7 +100,7 @@ func generateResponseFile(conn *tls.Conn, connId string, url *common.URL, localP func generateResponseDir(conn *tls.Conn, connId string, url *common.URL, localPath string) (output []byte, err error) { entries, err := os.ReadDir(localPath) if err != nil { - return nil, errors.NewConnectionError(fmt.Errorf("%s %s failed to read directory: %w", connId, conn.RemoteAddr(), err)) + return nil, xerrors.NewError(fmt.Errorf("%s %s failed to read directory: %w", connId, conn.RemoteAddr(), err), 0, "Directory read failed", false) } if config.CONFIG.DirIndexingEnabled { @@ -87,7 +115,7 @@ func generateResponseDir(conn *tls.Conn, connId string, url *common.URL, localPa } } data := []byte(strings.Join(contents, "")) - headerBytes := []byte("20 text/gemini;\r\n") + headerBytes := []byte("20 text/gemini; lang=en\r\n") response := append(headerBytes, data...) return response, nil } else { @@ -100,7 +128,7 @@ func generateResponseDir(conn *tls.Conn, connId string, url *common.URL, localPa func calculateLocalPath(input string, basePath string) (string, error) { // Check for invalid characters early if strings.ContainsAny(input, "\\") { - return "", errors.NewError(fmt.Errorf("invalid characters in path: %s", input)) + return "", xerrors.NewError(fmt.Errorf("invalid characters in path: %s", input), 0, "Invalid path characters", false) } // If IsLocal(path) returns true, then Join(base, path) @@ -116,7 +144,7 @@ func calculateLocalPath(input string, basePath string) (string, error) { localPath, err := filepath.Localize(filePath) if err != nil || !filepath.IsLocal(localPath) { - return "", errors.NewError(fmt.Errorf("could not construct local path from %s: %s", input, err)) + return "", xerrors.NewError(fmt.Errorf("could not construct local path from %s: %s", input, err), 0, "Invalid local path", false) } filePath = path.Join(basePath, localPath)