Refactor error handling and logging system
- Replace custom errors package with xerrors for structured error handling - Remove local logging wrapper and use git.antanst.com/antanst/logging - Add proper error codes and user messages in server responses - Improve connection handling with better error categorization - Update certificate path to use local certs/ directory - Add request size validation (1024 byte limit) - Remove panic-on-error configuration option - Enhance error logging with connection IDs and remote addresses
This commit is contained in:
@@ -8,7 +8,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gemserve/errors"
|
"git.antanst.com/antanst/xerrors"
|
||||||
)
|
)
|
||||||
|
|
||||||
type URL struct {
|
type URL struct {
|
||||||
@@ -28,7 +28,7 @@ func (u *URL) Scan(value interface{}) error {
|
|||||||
}
|
}
|
||||||
b, ok := value.(string)
|
b, ok := value.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return errors.NewFatalError(fmt.Errorf("database scan error: expected string, got %T", value))
|
return xerrors.NewError(fmt.Errorf("database scan error: expected string, got %T", value), 0, "Database scan error", true)
|
||||||
}
|
}
|
||||||
parsedURL, err := ParseURL(b, "", false)
|
parsedURL, err := ParseURL(b, "", false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -67,12 +67,10 @@ func ParseURL(input string, descr string, normalize bool) (*URL, error) {
|
|||||||
} else {
|
} else {
|
||||||
u, err = url.Parse(input)
|
u, err = url.Parse(input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.NewError(fmt.Errorf("error parsing URL: %w: %s", err, input))
|
return nil, xerrors.NewError(fmt.Errorf("error parsing URL: %w: %s", err, input), 0, "URL parse error", false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if u.Scheme != "gemini" {
|
|
||||||
return nil, errors.NewError(fmt.Errorf("error parsing URL: not a gemini URL: %s", input))
|
|
||||||
}
|
|
||||||
protocol := u.Scheme
|
protocol := u.Scheme
|
||||||
hostname := u.Hostname()
|
hostname := u.Hostname()
|
||||||
strPort := u.Port()
|
strPort := u.Port()
|
||||||
@@ -82,7 +80,7 @@ func ParseURL(input string, descr string, normalize bool) (*URL, error) {
|
|||||||
}
|
}
|
||||||
port, err := strconv.Atoi(strPort)
|
port, err := strconv.Atoi(strPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.NewError(fmt.Errorf("error parsing URL: %w: %s", err, input))
|
return nil, xerrors.NewError(fmt.Errorf("error parsing URL: %w: %s", err, input), 0, "URL parse error", false)
|
||||||
}
|
}
|
||||||
full := fmt.Sprintf("%s://%s:%d%s", protocol, hostname, port, urlPath)
|
full := fmt.Sprintf("%s://%s:%d%s", protocol, hostname, port, urlPath)
|
||||||
// full field should also contain query params and url fragments
|
// full field should also contain query params and url fragments
|
||||||
@@ -128,13 +126,13 @@ func NormalizeURL(rawURL string) (*url.URL, error) {
|
|||||||
// Parse the URL
|
// Parse the URL
|
||||||
u, err := url.Parse(rawURL)
|
u, err := url.Parse(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.NewError(fmt.Errorf("error normalizing URL: %w: %s", err, rawURL))
|
return nil, xerrors.NewError(fmt.Errorf("error normalizing URL: %w: %s", err, rawURL), 0, "URL normalization error", false)
|
||||||
}
|
}
|
||||||
if u.Scheme == "" {
|
if u.Scheme == "" {
|
||||||
return nil, errors.NewError(fmt.Errorf("error normalizing URL: No scheme: %s", rawURL))
|
return nil, xerrors.NewError(fmt.Errorf("error normalizing URL: No scheme: %s", rawURL), 0, "Missing URL scheme", false)
|
||||||
}
|
}
|
||||||
if u.Host == "" {
|
if u.Host == "" {
|
||||||
return nil, errors.NewError(fmt.Errorf("error normalizing URL: No host: %s", rawURL))
|
return nil, xerrors.NewError(fmt.Errorf("error normalizing URL: No host: %s", rawURL), 0, "Missing URL host", false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert scheme to lowercase
|
// Convert scheme to lowercase
|
||||||
|
|||||||
114
errors/errors.go
114
errors/errors.go
@@ -1,114 +0,0 @@
|
|||||||
package errors
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
type fatal interface {
|
|
||||||
Fatal() bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func IsFatal(err error) bool {
|
|
||||||
te, ok := errors.Unwrap(err).(fatal)
|
|
||||||
return ok && te.Fatal()
|
|
||||||
}
|
|
||||||
|
|
||||||
func As(err error, target any) bool {
|
|
||||||
return errors.As(err, target)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Is(err, target error) bool {
|
|
||||||
return errors.Is(err, target)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Unwrap(err error) error {
|
|
||||||
return errors.Unwrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Error struct {
|
|
||||||
Err error
|
|
||||||
Stack string
|
|
||||||
fatal bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Error) Error() string {
|
|
||||||
var sb strings.Builder
|
|
||||||
sb.WriteString(fmt.Sprintf("%v\n", e.Err))
|
|
||||||
return sb.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Error) ErrorWithStack() string {
|
|
||||||
var sb strings.Builder
|
|
||||||
sb.WriteString(fmt.Sprintf("%v\n", e.Err))
|
|
||||||
sb.WriteString(fmt.Sprintf("Stack Trace:\n%s", e.Stack))
|
|
||||||
return sb.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Error) Fatal() bool {
|
|
||||||
return e.fatal
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Error) Unwrap() error {
|
|
||||||
return e.Err
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewError(err error) error {
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if it's already of our own
|
|
||||||
// Error type, so we don't add stack twice.
|
|
||||||
var asError *Error
|
|
||||||
if errors.As(err, &asError) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the stack trace
|
|
||||||
var stack strings.Builder
|
|
||||||
buf := make([]uintptr, 50)
|
|
||||||
n := runtime.Callers(2, buf)
|
|
||||||
frames := runtime.CallersFrames(buf[:n])
|
|
||||||
|
|
||||||
// Format the stack trace
|
|
||||||
for {
|
|
||||||
frame, more := frames.Next()
|
|
||||||
// Skip runtime and standard library frames
|
|
||||||
if !strings.Contains(frame.File, "runtime/") {
|
|
||||||
stack.WriteString(fmt.Sprintf("\t%s:%d - %s\n", frame.File, frame.Line, frame.Function))
|
|
||||||
}
|
|
||||||
if !more {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Error{
|
|
||||||
Err: err,
|
|
||||||
Stack: stack.String(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewFatalError(err error) error {
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if it's already of our own
|
|
||||||
// Error type.
|
|
||||||
var asError *Error
|
|
||||||
if errors.As(err, &asError) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err2 := NewError(err)
|
|
||||||
err2.(*Error).fatal = true
|
|
||||||
return err2
|
|
||||||
}
|
|
||||||
|
|
||||||
var ConnectionError error = fmt.Errorf("connection error")
|
|
||||||
|
|
||||||
func NewConnectionError(err error) error {
|
|
||||||
return fmt.Errorf("%w: %w", ConnectionError, err)
|
|
||||||
}
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
package errors
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CustomError struct {
|
|
||||||
Err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *CustomError) Error() string { return e.Err.Error() }
|
|
||||||
|
|
||||||
func IsCustomError(err error) bool {
|
|
||||||
var asError *CustomError
|
|
||||||
return errors.As(err, &asError)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWrapping(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
originalErr := errors.New("original error")
|
|
||||||
err1 := NewError(originalErr)
|
|
||||||
if !errors.Is(err1, originalErr) {
|
|
||||||
t.Errorf("original error is not wrapped")
|
|
||||||
}
|
|
||||||
if !Is(err1, originalErr) {
|
|
||||||
t.Errorf("original error is not wrapped")
|
|
||||||
}
|
|
||||||
unwrappedErr := errors.Unwrap(err1)
|
|
||||||
if !errors.Is(unwrappedErr, originalErr) {
|
|
||||||
t.Errorf("original error is not wrapped")
|
|
||||||
}
|
|
||||||
if !Is(unwrappedErr, originalErr) {
|
|
||||||
t.Errorf("original error is not wrapped")
|
|
||||||
}
|
|
||||||
unwrappedErr = Unwrap(err1)
|
|
||||||
if !errors.Is(unwrappedErr, originalErr) {
|
|
||||||
t.Errorf("original error is not wrapped")
|
|
||||||
}
|
|
||||||
if !Is(unwrappedErr, originalErr) {
|
|
||||||
t.Errorf("original error is not wrapped")
|
|
||||||
}
|
|
||||||
wrappedErr := fmt.Errorf("wrapped: %w", originalErr)
|
|
||||||
if !errors.Is(wrappedErr, originalErr) {
|
|
||||||
t.Errorf("original error is not wrapped")
|
|
||||||
}
|
|
||||||
if !Is(wrappedErr, originalErr) {
|
|
||||||
t.Errorf("original error is not wrapped")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewError(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
originalErr := &CustomError{errors.New("err1")}
|
|
||||||
if !IsCustomError(originalErr) {
|
|
||||||
t.Errorf("TestNewError fail #1")
|
|
||||||
}
|
|
||||||
err1 := NewError(originalErr)
|
|
||||||
if !IsCustomError(err1) {
|
|
||||||
t.Errorf("TestNewError fail #2")
|
|
||||||
}
|
|
||||||
wrappedErr1 := fmt.Errorf("wrapped %w", err1)
|
|
||||||
if !IsCustomError(wrappedErr1) {
|
|
||||||
t.Errorf("TestNewError fail #3")
|
|
||||||
}
|
|
||||||
unwrappedErr1 := Unwrap(wrappedErr1)
|
|
||||||
if !IsCustomError(unwrappedErr1) {
|
|
||||||
t.Errorf("TestNewError fail #4")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
package logging
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
zlog "github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
func LogDebug(format string, args ...interface{}) {
|
|
||||||
zlog.Debug().Msg(fmt.Sprintf(format, args...))
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogInfo(format string, args ...interface{}) {
|
|
||||||
zlog.Info().Msg(fmt.Sprintf(format, args...))
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogWarn(format string, args ...interface{}) {
|
|
||||||
zlog.Warn().Msg(fmt.Sprintf(format, args...))
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogError(format string, args ...interface{}) {
|
|
||||||
zlog.Error().Err(fmt.Errorf(format, args...)).Msg("")
|
|
||||||
}
|
|
||||||
115
main.go
115
main.go
@@ -3,6 +3,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@@ -13,46 +14,40 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gemserve/config"
|
"gemserve/config"
|
||||||
"gemserve/errors"
|
|
||||||
"gemserve/logging"
|
|
||||||
"gemserve/server"
|
"gemserve/server"
|
||||||
"gemserve/uid"
|
"gemserve/uid"
|
||||||
"github.com/rs/zerolog"
|
logging "git.antanst.com/antanst/logging"
|
||||||
zlog "github.com/rs/zerolog/log"
|
"git.antanst.com/antanst/xerrors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var fatalErrors chan error
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
config.CONFIG = *config.GetConfig()
|
config.CONFIG = *config.GetConfig()
|
||||||
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
|
|
||||||
zerolog.SetGlobalLevel(config.CONFIG.LogLevel)
|
logging.InitSlogger(config.CONFIG.LogLevel)
|
||||||
zlog.Logger = zlog.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: "[2006-01-02 15:04:05]"})
|
|
||||||
err := runApp()
|
err := runApp()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("%v\n", err)
|
panic(fmt.Sprintf("Fatal Error: %v", err))
|
||||||
logging.LogError("%v", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func runApp() error {
|
func runApp() error {
|
||||||
logging.LogInfo("Starting up. Press Ctrl+C to exit")
|
logging.LogInfo("Starting up. Press Ctrl+C to exit")
|
||||||
|
|
||||||
var listenHost string
|
listenHost := config.CONFIG.Listen
|
||||||
if len(os.Args) != 2 {
|
|
||||||
listenHost = "0.0.0.0:1965"
|
|
||||||
} else {
|
|
||||||
listenHost = os.Args[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
signals := make(chan os.Signal, 1)
|
signals := make(chan os.Signal, 1)
|
||||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
serverErrors := make(chan error)
|
fatalErrors = make(chan error)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := startServer(listenHost)
|
err := startServer(listenHost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
serverErrors <- errors.NewFatalError(err)
|
fatalErrors <- xerrors.NewError(err, 0, "Server startup failed", true)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -61,16 +56,16 @@ func runApp() error {
|
|||||||
case <-signals:
|
case <-signals:
|
||||||
logging.LogWarn("Received SIGINT or SIGTERM signal, exiting")
|
logging.LogWarn("Received SIGINT or SIGTERM signal, exiting")
|
||||||
return nil
|
return nil
|
||||||
case serverError := <-serverErrors:
|
case fatalError := <-fatalErrors:
|
||||||
return errors.NewFatalError(serverError)
|
return xerrors.NewError(fatalError, 0, "Server error", true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func startServer(listenHost string) (err error) {
|
func startServer(listenHost string) (err error) {
|
||||||
cert, err := tls.LoadX509KeyPair("/certs/cert", "/certs/key")
|
cert, err := tls.LoadX509KeyPair("certs/server.crt", "certs/server.key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.NewFatalError(fmt.Errorf("failed to load certificate: %w", err))
|
return xerrors.NewError(fmt.Errorf("failed to load certificate: %w", err), 0, "Certificate loading failed", true)
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
@@ -80,7 +75,7 @@ func startServer(listenHost string) (err error) {
|
|||||||
|
|
||||||
listener, err := tls.Listen("tcp", listenHost, tlsConfig)
|
listener, err := tls.Listen("tcp", listenHost, tlsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.NewFatalError(fmt.Errorf("failed to create listener: %w", err))
|
return xerrors.NewError(fmt.Errorf("failed to create listener: %w", err), 0, "Listener creation failed", true)
|
||||||
}
|
}
|
||||||
defer func(listener net.Listener) {
|
defer func(listener net.Listener) {
|
||||||
// If we've got an error closing the
|
// If we've got an error closing the
|
||||||
@@ -88,7 +83,7 @@ func startServer(listenHost string) (err error) {
|
|||||||
// the original error (if not nil)
|
// the original error (if not nil)
|
||||||
errClose := listener.Close()
|
errClose := listener.Close()
|
||||||
if errClose != nil && err == nil {
|
if errClose != nil && err == nil {
|
||||||
err = errors.NewFatalError(err)
|
err = xerrors.NewError(err, 0, "Listener close failed", true)
|
||||||
}
|
}
|
||||||
}(listener)
|
}(listener)
|
||||||
|
|
||||||
@@ -102,16 +97,16 @@ func startServer(listenHost string) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := handleConnection(conn.(*tls.Conn))
|
remoteAddr := conn.RemoteAddr().String()
|
||||||
|
connId := uid.UID()
|
||||||
|
err := handleConnection(conn.(*tls.Conn), connId, remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var asErr *errors.Error
|
var asErr *xerrors.XError
|
||||||
if errors.As(err, &asErr) {
|
if errors.As(err, &asErr) && asErr.IsFatal {
|
||||||
logging.LogError("Unexpected error: %v", err.(*errors.Error).ErrorWithStack())
|
fatalErrors <- asErr
|
||||||
|
return
|
||||||
} else {
|
} else {
|
||||||
logging.LogError("Unexpected error: %v", err)
|
logging.LogWarn("%s %s Connection failed: %d %s (%v)", connId, remoteAddr, asErr.Code, asErr.UserMsg, err)
|
||||||
}
|
|
||||||
if config.CONFIG.PanicOnUnexpectedError {
|
|
||||||
panic("Encountered unexpected error")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -121,56 +116,68 @@ func startServer(listenHost string) (err error) {
|
|||||||
func closeConnection(conn *tls.Conn) error {
|
func closeConnection(conn *tls.Conn) error {
|
||||||
err := conn.CloseWrite()
|
err := conn.CloseWrite()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.NewConnectionError(fmt.Errorf("failed to close TLS connection: %w", err))
|
return xerrors.NewError(fmt.Errorf("failed to close TLS connection: %w", err), 50, "Connection close failed", false)
|
||||||
}
|
}
|
||||||
err = conn.Close()
|
err = conn.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.NewConnectionError(fmt.Errorf("failed to close connection: %w", err))
|
return xerrors.NewError(fmt.Errorf("failed to close connection: %w", err), 50, "Connection close failed", false)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleConnection(conn *tls.Conn) (err error) {
|
func handleConnection(conn *tls.Conn, connId string, remoteAddr string) (err error) {
|
||||||
remoteAddr := conn.RemoteAddr().String()
|
|
||||||
connId := uid.UID()
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
var outputBytes []byte
|
var outputBytes []byte
|
||||||
|
|
||||||
defer func(conn *tls.Conn) {
|
defer func(conn *tls.Conn) {
|
||||||
// Three possible cases here:
|
|
||||||
// - We don't have an error
|
|
||||||
// - We have a ConnectionError, which we don't propagate up
|
|
||||||
// - We have an unexpected error.
|
|
||||||
end := time.Now()
|
end := time.Now()
|
||||||
tookMs := end.Sub(start).Milliseconds()
|
tookMs := end.Sub(start).Milliseconds()
|
||||||
var responseHeader string
|
var responseHeader string
|
||||||
if err != nil {
|
|
||||||
_, _ = conn.Write([]byte("50 server error"))
|
// On non-errors, just log response and close connection.
|
||||||
responseHeader = "50 server error"
|
if err == nil {
|
||||||
// We don't propagate connection errors up.
|
// Log non-erroneous responses
|
||||||
if errors.Is(err, errors.ConnectionError) {
|
|
||||||
logging.LogInfo("%s %s %v", connId, remoteAddr, err)
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 {
|
if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 {
|
||||||
responseHeader = string(outputBytes[:i])
|
responseHeader = string(outputBytes[:i])
|
||||||
}
|
}
|
||||||
|
logging.LogInfo("%s %s response %s (%dms)", connId, remoteAddr, responseHeader, tookMs)
|
||||||
|
_ = closeConnection(conn)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
logging.LogInfo("%s %s response %s (%dms)", connId, remoteAddr, responseHeader, tookMs)
|
|
||||||
|
var code int
|
||||||
|
var responseMsg string
|
||||||
|
var xErr *xerrors.XError
|
||||||
|
if errors.As(err, &xErr) {
|
||||||
|
// On fatal errors, immediatelly return the error.
|
||||||
|
if xErr.IsFatal {
|
||||||
|
_ = closeConnection(conn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code = xErr.Code
|
||||||
|
responseMsg = xErr.UserMsg
|
||||||
|
} else {
|
||||||
|
code = 50
|
||||||
|
responseMsg = "server error"
|
||||||
|
}
|
||||||
|
responseHeader = fmt.Sprintf("%d %s", code, responseMsg)
|
||||||
|
_, _ = conn.Write([]byte(responseHeader))
|
||||||
_ = closeConnection(conn)
|
_ = closeConnection(conn)
|
||||||
}(conn)
|
}(conn)
|
||||||
|
|
||||||
// Gemini is supposed to have a 1kb limit
|
// Gemini is supposed to have a 1kb limit
|
||||||
// on input requests.
|
// on input requests.
|
||||||
buffer := make([]byte, 1024)
|
buffer := make([]byte, 1025)
|
||||||
|
|
||||||
n, err := conn.Read(buffer)
|
n, err := conn.Read(buffer)
|
||||||
if err != nil && err != io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
return errors.NewConnectionError(fmt.Errorf("failed to read connection data: %w", err))
|
return xerrors.NewError(fmt.Errorf("failed to read connection data: %w", err), 59, "Connection read failed", false)
|
||||||
}
|
}
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
return errors.NewConnectionError(fmt.Errorf("client did not send data"))
|
return xerrors.NewError(fmt.Errorf("client did not send data"), 59, "No data received", false)
|
||||||
|
}
|
||||||
|
if n > 1024 {
|
||||||
|
return xerrors.NewError(fmt.Errorf("client request size %d > 1024 bytes", n), 59, "Request too large", false)
|
||||||
}
|
}
|
||||||
|
|
||||||
dataBytes := buffer[:n]
|
dataBytes := buffer[:n]
|
||||||
|
|||||||
@@ -2,16 +2,19 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gemserve/common"
|
"gemserve/common"
|
||||||
"gemserve/config"
|
"gemserve/config"
|
||||||
"gemserve/errors"
|
logging "git.antanst.com/antanst/logging"
|
||||||
"gemserve/logging"
|
"git.antanst.com/antanst/xerrors"
|
||||||
"github.com/gabriel-vasile/mimetype"
|
"github.com/gabriel-vasile/mimetype"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,18 +23,43 @@ type ServerConfig interface {
|
|||||||
RootPath() string
|
RootPath() string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkRequestURL(url *common.URL) error {
|
||||||
|
if url.Protocol != "gemini" {
|
||||||
|
return xerrors.NewError(fmt.Errorf("invalid URL"), 53, "URL Protocol not Gemini, proxying refused", false)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, portStr, err := net.SplitHostPort(config.CONFIG.Listen)
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.NewError(fmt.Errorf("failed to parse listen address: %w", err), 50, "Server configuration error", false)
|
||||||
|
}
|
||||||
|
listenPort, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.NewError(fmt.Errorf("failed to parse listen port: %w", err), 50, "Server configuration error", false)
|
||||||
|
}
|
||||||
|
if url.Port != listenPort {
|
||||||
|
return xerrors.NewError(fmt.Errorf("failed to parse URL: %w", err), 53, "invalid URL port, proxying refused", false)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func GenerateResponse(conn *tls.Conn, connId string, input string) ([]byte, error) {
|
func GenerateResponse(conn *tls.Conn, connId string, input string) ([]byte, error) {
|
||||||
trimmedInput := strings.TrimSpace(input)
|
trimmedInput := strings.TrimSpace(input)
|
||||||
// url will have a cleaned and normalized path after this
|
// url will have a cleaned and normalized path after this
|
||||||
url, err := common.ParseURL(trimmedInput, "", true)
|
url, err := common.ParseURL(trimmedInput, "", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.NewConnectionError(fmt.Errorf("failed to parse URL: %w", err))
|
return nil, xerrors.NewError(fmt.Errorf("failed to parse URL: %w", err), 59, "Invalid URL", false)
|
||||||
}
|
}
|
||||||
logging.LogDebug("%s %s normalized URL path: %s", connId, conn.RemoteAddr(), url.Path)
|
logging.LogDebug("%s %s normalized URL path: %s", connId, conn.RemoteAddr(), url.Path)
|
||||||
|
|
||||||
|
err = checkRequestURL(url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
serverRootPath := config.CONFIG.RootPath
|
serverRootPath := config.CONFIG.RootPath
|
||||||
localPath, err := calculateLocalPath(url.Path, serverRootPath)
|
localPath, err := calculateLocalPath(url.Path, serverRootPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.NewConnectionError(err)
|
return nil, xerrors.NewError(err, 59, "Invalid path", false)
|
||||||
}
|
}
|
||||||
logging.LogDebug("%s %s request file path: %s", connId, conn.RemoteAddr(), localPath)
|
logging.LogDebug("%s %s request file path: %s", connId, conn.RemoteAddr(), localPath)
|
||||||
|
|
||||||
@@ -40,7 +68,7 @@ func GenerateResponse(conn *tls.Conn, connId string, input string) ([]byte, erro
|
|||||||
if errors.Is(err, os.ErrNotExist) || errors.Is(err, os.ErrPermission) {
|
if errors.Is(err, os.ErrNotExist) || errors.Is(err, os.ErrPermission) {
|
||||||
return []byte("51 not found\r\n"), nil
|
return []byte("51 not found\r\n"), nil
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, errors.NewConnectionError(fmt.Errorf("%s %s failed to access path: %w", connId, conn.RemoteAddr(), err))
|
return nil, xerrors.NewError(fmt.Errorf("%s %s failed to access path: %w", connId, conn.RemoteAddr(), err), 0, "Path access failed", false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle directory.
|
// Handle directory.
|
||||||
@@ -55,7 +83,7 @@ func generateResponseFile(conn *tls.Conn, connId string, url *common.URL, localP
|
|||||||
if errors.Is(err, os.ErrNotExist) || errors.Is(err, os.ErrPermission) {
|
if errors.Is(err, os.ErrNotExist) || errors.Is(err, os.ErrPermission) {
|
||||||
return []byte("51 not found\r\n"), nil
|
return []byte("51 not found\r\n"), nil
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, errors.NewConnectionError(fmt.Errorf("%s %s failed to read file: %w", connId, conn.RemoteAddr(), err))
|
return nil, xerrors.NewError(fmt.Errorf("%s %s failed to read file: %w", connId, conn.RemoteAddr(), err), 0, "File read failed", false)
|
||||||
}
|
}
|
||||||
|
|
||||||
var mimeType string
|
var mimeType string
|
||||||
@@ -64,7 +92,7 @@ func generateResponseFile(conn *tls.Conn, connId string, url *common.URL, localP
|
|||||||
} else {
|
} else {
|
||||||
mimeType = mimetype.Detect(data).String()
|
mimeType = mimetype.Detect(data).String()
|
||||||
}
|
}
|
||||||
headerBytes := []byte(fmt.Sprintf("20 %s\r\n", mimeType))
|
headerBytes := []byte(fmt.Sprintf("20 %s; lang=en\r\n", mimeType))
|
||||||
response := append(headerBytes, data...)
|
response := append(headerBytes, data...)
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
@@ -72,7 +100,7 @@ func generateResponseFile(conn *tls.Conn, connId string, url *common.URL, localP
|
|||||||
func generateResponseDir(conn *tls.Conn, connId string, url *common.URL, localPath string) (output []byte, err error) {
|
func generateResponseDir(conn *tls.Conn, connId string, url *common.URL, localPath string) (output []byte, err error) {
|
||||||
entries, err := os.ReadDir(localPath)
|
entries, err := os.ReadDir(localPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.NewConnectionError(fmt.Errorf("%s %s failed to read directory: %w", connId, conn.RemoteAddr(), err))
|
return nil, xerrors.NewError(fmt.Errorf("%s %s failed to read directory: %w", connId, conn.RemoteAddr(), err), 0, "Directory read failed", false)
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.CONFIG.DirIndexingEnabled {
|
if config.CONFIG.DirIndexingEnabled {
|
||||||
@@ -87,7 +115,7 @@ func generateResponseDir(conn *tls.Conn, connId string, url *common.URL, localPa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
data := []byte(strings.Join(contents, ""))
|
data := []byte(strings.Join(contents, ""))
|
||||||
headerBytes := []byte("20 text/gemini;\r\n")
|
headerBytes := []byte("20 text/gemini; lang=en\r\n")
|
||||||
response := append(headerBytes, data...)
|
response := append(headerBytes, data...)
|
||||||
return response, nil
|
return response, nil
|
||||||
} else {
|
} else {
|
||||||
@@ -100,7 +128,7 @@ func generateResponseDir(conn *tls.Conn, connId string, url *common.URL, localPa
|
|||||||
func calculateLocalPath(input string, basePath string) (string, error) {
|
func calculateLocalPath(input string, basePath string) (string, error) {
|
||||||
// Check for invalid characters early
|
// Check for invalid characters early
|
||||||
if strings.ContainsAny(input, "\\") {
|
if strings.ContainsAny(input, "\\") {
|
||||||
return "", errors.NewError(fmt.Errorf("invalid characters in path: %s", input))
|
return "", xerrors.NewError(fmt.Errorf("invalid characters in path: %s", input), 0, "Invalid path characters", false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If IsLocal(path) returns true, then Join(base, path)
|
// If IsLocal(path) returns true, then Join(base, path)
|
||||||
@@ -116,7 +144,7 @@ func calculateLocalPath(input string, basePath string) (string, error) {
|
|||||||
|
|
||||||
localPath, err := filepath.Localize(filePath)
|
localPath, err := filepath.Localize(filePath)
|
||||||
if err != nil || !filepath.IsLocal(localPath) {
|
if err != nil || !filepath.IsLocal(localPath) {
|
||||||
return "", errors.NewError(fmt.Errorf("could not construct local path from %s: %s", input, err))
|
return "", xerrors.NewError(fmt.Errorf("could not construct local path from %s: %s", input, err), 0, "Invalid local path", false)
|
||||||
}
|
}
|
||||||
|
|
||||||
filePath = path.Join(basePath, localPath)
|
filePath = path.Join(basePath, localPath)
|
||||||
|
|||||||
Reference in New Issue
Block a user