.
This commit is contained in:
187
main.go
187
main.go
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -10,68 +11,87 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"gemserve/lib/apperrors"
|
||||
"gemserve/lib/logging"
|
||||
|
||||
"gemserve/config"
|
||||
"gemserve/gemini"
|
||||
"gemserve/server"
|
||||
logging "git.antanst.com/antanst/logging"
|
||||
"git.antanst.com/antanst/uid"
|
||||
"git.antanst.com/antanst/xerrors"
|
||||
)
|
||||
|
||||
// This channel is for handling fatal errors
|
||||
// from anywhere in the app. The consumer
|
||||
// should log and panic.
|
||||
var fatalErrors chan error
|
||||
"git.antanst.com/antanst/uid"
|
||||
)
|
||||
|
||||
func main() {
|
||||
config.CONFIG = *config.GetConfig()
|
||||
|
||||
logging.InitSlogger(config.CONFIG.LogLevel)
|
||||
|
||||
err := runApp()
|
||||
logging.SetupLogging()
|
||||
logger := logging.Logger
|
||||
ctx := logging.WithLogger(context.Background(), logger)
|
||||
err := runApp(ctx)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("Fatal Error: %v", err))
|
||||
panic(fmt.Sprintf("Fatal Error: %v", err))
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func runApp() error {
|
||||
logging.LogInfo("Starting up. Press Ctrl+C to exit")
|
||||
func runApp(ctx context.Context) error {
|
||||
logger := logging.FromContext(ctx)
|
||||
logger.Info("Starting up. Press Ctrl+C to exit")
|
||||
|
||||
listenAddr := config.CONFIG.ListenAddr
|
||||
|
||||
signals := make(chan os.Signal, 1)
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
fatalErrors = make(chan error)
|
||||
fatalErrors := make(chan error)
|
||||
|
||||
// Root server context, used to cancel
|
||||
// connections and graceful shutdown.
|
||||
serverCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// WaitGroup to track active connections
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Spawn server on the background.
|
||||
// Returned errors are considered fatal.
|
||||
go func() {
|
||||
err := startServer(listenAddr)
|
||||
err := startServer(serverCtx, listenAddr, &wg, fatalErrors)
|
||||
if err != nil {
|
||||
fatalErrors <- xerrors.NewError(err, 0, "Server startup failed", true)
|
||||
fatalErrors <- apperrors.NewFatalError(fmt.Errorf("server startup failed: %w", err))
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-signals:
|
||||
logging.LogWarn("Received SIGINT or SIGTERM signal, exiting")
|
||||
logger.Warn("Received SIGINT or SIGTERM signal, shutting down gracefully")
|
||||
cancel()
|
||||
wg.Wait()
|
||||
return nil
|
||||
case fatalError := <-fatalErrors:
|
||||
return xerrors.NewError(fatalError, 0, "Server error", true)
|
||||
cancel()
|
||||
wg.Wait()
|
||||
return fatalError
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func startServer(listenAddr string) (err error) {
|
||||
cert, err := tls.LoadX509KeyPair("certs/server.crt", "certs/server.key")
|
||||
func startServer(ctx context.Context, listenAddr string, wg *sync.WaitGroup, fatalErrors chan<- error) (err error) {
|
||||
logger := logging.FromContext(ctx)
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(config.CONFIG.TLSCert, config.CONFIG.TLSKey)
|
||||
if err != nil {
|
||||
return xerrors.NewError(fmt.Errorf("failed to load certificate: %w", err), 0, "Certificate loading failed", true)
|
||||
return apperrors.NewFatalError(fmt.Errorf("failed to load TLS certificate/key: %w", err))
|
||||
}
|
||||
|
||||
logger.Debug("Using TLS cert", "path", config.CONFIG.TLSCert)
|
||||
logger.Debug("Using TLS key", "path", config.CONFIG.TLSKey)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
@@ -79,39 +99,58 @@ func startServer(listenAddr string) (err error) {
|
||||
|
||||
listener, err := tls.Listen("tcp", listenAddr, tlsConfig)
|
||||
if err != nil {
|
||||
return xerrors.NewError(fmt.Errorf("failed to create listener: %w", err), 0, "Listener creation failed", true)
|
||||
return apperrors.NewFatalError(err)
|
||||
}
|
||||
|
||||
defer func(listener net.Listener) {
|
||||
// If we've got an error closing the
|
||||
// listener, make sure we don't override
|
||||
// the original error (if not nil)
|
||||
errClose := listener.Close()
|
||||
if errClose != nil && err == nil {
|
||||
err = xerrors.NewError(err, 0, "Listener close failed", true)
|
||||
}
|
||||
_ = listener.Close()
|
||||
}(listener)
|
||||
|
||||
logging.LogInfo("Server listening on %s", listenAddr)
|
||||
// If context is cancelled, close listener
|
||||
// to unblock Accept() inside main loop.
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = listener.Close()
|
||||
}()
|
||||
|
||||
logger.Info("Server listening", "address", listenAddr)
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
logging.LogInfo("Failed to accept connection: %v", err)
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
} // ctx cancellation
|
||||
logger.Info("Failed to accept connection: %v", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
// Type assert the connection to TLS connection
|
||||
tlsConn, ok := conn.(*tls.Conn)
|
||||
if !ok {
|
||||
logger.Error("Connection is not a TLS connection")
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
remoteAddr := conn.RemoteAddr().String()
|
||||
connId := uid.UID()
|
||||
err := handleConnection(conn.(*tls.Conn), connId, remoteAddr)
|
||||
|
||||
// Create a timeout context for this connection
|
||||
connCtx, cancel := context.WithTimeout(ctx, time.Duration(config.CONFIG.ResponseTimeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := handleConnection(connCtx, tlsConn, connId, remoteAddr)
|
||||
if err != nil {
|
||||
var asErr *xerrors.XError
|
||||
if errors.As(err, &asErr) && asErr.IsFatal {
|
||||
fatalErrors <- asErr
|
||||
if apperrors.IsFatal(err) {
|
||||
fatalErrors <- err
|
||||
return
|
||||
} else {
|
||||
logging.LogWarn("%s %s Connection failed: %v", connId, remoteAddr, err)
|
||||
}
|
||||
logger.Info("Connection failed", "id", connId, "remoteAddr", remoteAddr, "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -120,19 +159,25 @@ func startServer(listenAddr string) (err error) {
|
||||
func closeConnection(conn *tls.Conn) error {
|
||||
err := conn.CloseWrite()
|
||||
if err != nil {
|
||||
return xerrors.NewError(fmt.Errorf("failed to close TLS connection: %w", err), 50, "Connection close failed", false)
|
||||
return apperrors.NewNetworkError(fmt.Errorf("failed to close TLS connection: %w", err))
|
||||
}
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
return xerrors.NewError(fmt.Errorf("failed to close connection: %w", err), 50, "Connection close failed", false)
|
||||
return apperrors.NewNetworkError(fmt.Errorf("failed to close connection: %w", err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleConnection(conn *tls.Conn, connId string, remoteAddr string) (err error) {
|
||||
func handleConnection(ctx context.Context, conn *tls.Conn, connId string, remoteAddr string) (err error) {
|
||||
logger := logging.FromContext(ctx)
|
||||
start := time.Now()
|
||||
var outputBytes []byte
|
||||
|
||||
// Set connection deadline based on context
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
_ = conn.SetDeadline(deadline)
|
||||
}
|
||||
|
||||
defer func(conn *tls.Conn) {
|
||||
end := time.Now()
|
||||
tookMs := end.Sub(start).Milliseconds()
|
||||
@@ -144,54 +189,82 @@ func handleConnection(conn *tls.Conn, connId string, remoteAddr string) (err err
|
||||
if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 {
|
||||
responseHeader = string(outputBytes[:i])
|
||||
}
|
||||
logging.LogInfo("%s %s response %s (%dms)", connId, remoteAddr, responseHeader, tookMs)
|
||||
logger.Info("Response", "connId", connId, "remoteAddr", remoteAddr, "responseHeader", responseHeader, "ms", tookMs)
|
||||
_ = closeConnection(conn)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle context cancellation/timeout
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.Info("Connection timeout", "id", connId, "remoteAddr", remoteAddr, "ms", tookMs)
|
||||
responseHeader = fmt.Sprintf("%d Request timeout", gemini.StatusCGIError)
|
||||
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
|
||||
_ = closeConnection(conn)
|
||||
return
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
logger.Info("Connection cancelled", "id", connId, "remoteAddr", remoteAddr, "ms", tookMs)
|
||||
_ = closeConnection(conn)
|
||||
return
|
||||
}
|
||||
|
||||
var code int
|
||||
var responseMsg string
|
||||
var xErr *xerrors.XError
|
||||
if errors.As(err, &xErr) {
|
||||
// On fatal errors, immediatelly return the error.
|
||||
if xErr.IsFatal {
|
||||
_ = closeConnection(conn)
|
||||
return
|
||||
}
|
||||
code = xErr.Code
|
||||
responseMsg = xErr.UserMsg
|
||||
if apperrors.IsFatal(err) {
|
||||
_ = closeConnection(conn)
|
||||
return
|
||||
}
|
||||
if apperrors.IsGeminiError(err) {
|
||||
code = apperrors.GetStatusCode(err)
|
||||
responseMsg = "server error"
|
||||
} else {
|
||||
code = gemini.StatusPermanentFailure
|
||||
responseMsg = "server error"
|
||||
}
|
||||
responseHeader = fmt.Sprintf("%d %s", code, responseMsg)
|
||||
_, _ = conn.Write([]byte(responseHeader))
|
||||
_, _ = conn.Write([]byte(responseHeader + "\r\n"))
|
||||
_ = closeConnection(conn)
|
||||
}(conn)
|
||||
|
||||
// Check context before starting
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Gemini is supposed to have a 1kb limit
|
||||
// on input requests.
|
||||
buffer := make([]byte, 1025)
|
||||
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil && err != io.EOF {
|
||||
return xerrors.NewError(fmt.Errorf("failed to read connection data: %w", err), 59, "Connection read failed", false)
|
||||
return apperrors.NewGeminiError(fmt.Errorf("failed to read connection data: %w", err), gemini.StatusBadRequest)
|
||||
}
|
||||
if n == 0 {
|
||||
return xerrors.NewError(fmt.Errorf("client did not send data"), 59, "No data received", false)
|
||||
return apperrors.NewGeminiError(fmt.Errorf("client did not send data"), gemini.StatusBadRequest)
|
||||
}
|
||||
if n > 1024 {
|
||||
return xerrors.NewError(fmt.Errorf("client request size %d > 1024 bytes", n), 59, "Request too large", false)
|
||||
return apperrors.NewGeminiError(fmt.Errorf("client request size %d > 1024 bytes", n), gemini.StatusBadRequest)
|
||||
}
|
||||
|
||||
// Check context after read
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dataBytes := buffer[:n]
|
||||
dataString := string(dataBytes)
|
||||
|
||||
logging.LogInfo("%s %s request %s (%d bytes)", connId, remoteAddr, strings.TrimSpace(dataString), len(dataBytes))
|
||||
outputBytes, err = server.GenerateResponse(conn, connId, dataString)
|
||||
logger.Info("Request", "id", connId, "remoteAddr", remoteAddr, "data", strings.TrimSpace(dataString), "size", len(dataBytes))
|
||||
outputBytes, err = server.GenerateResponse(ctx, conn, connId, dataString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check context before write
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = conn.Write(outputBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
Reference in New Issue
Block a user