commit 3004bc9dd3471d6158232983f6e0e74a6eaf620a Author: antanst Date: Mon Feb 3 12:52:21 2025 +0200 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9c6d0be --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +**/.#* +**/*~ +/.idea +/run.sh diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..bcf2465 --- /dev/null +++ b/LICENSE @@ -0,0 +1,15 @@ +ISC License + +Copyright (c) Antanst 2025 + +Permission to use, copy, modify, and distribute this software for any +purpose with or without fee is hereby granted, provided that the above +copyright notice and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..16f3d0a --- /dev/null +++ b/Makefile @@ -0,0 +1,47 @@ +SHELL := /bin/env oksh +export PATH := $(PATH) + +all: fmt lintfix tidy test clean build + +clean: + rm -f ./gemserve + +debug: + @echo "PATH: $(PATH)" + @echo "GOPATH: $(shell go env GOPATH)" + @which go + @which gofumpt + @which gci + @which golangci-lint + +# Test +test: + go test ./... + +tidy: + go mod tidy + +# Format code +fmt: + gofumpt -l -w . + gci write . + +# Run linter +lint: fmt + golangci-lint run + +# Run linter and fix +lintfix: fmt + golangci-lint run --fix + +build: + go build -o ./gemserve ./main.go + +show-updates: + go list -m -u all + +update: + go get -u all + +update-patch: + go get -u=patch all diff --git a/README.md b/README.md new file mode 100644 index 0000000..648f9ef --- /dev/null +++ b/README.md @@ -0,0 +1,33 @@ +``` + __ _ ___ _ __ ___ ___ ___ _ ____ _____ + / _` |/ _ | '_ ` _ \/ __|/ _ | '__\ \ / / _ \ +| (_| | __| | | | | \__ | __| | \ V | __/ + \__, |\___|_| |_| |_|___/\___|_| \_/ \___| + |___/ +``` + +Gemserve is a simple Gemini server written in Go. + +Run tests and build: + +```shell +make test #run tests only +make #run tests and build +``` + +Run: + +```shell +LOG_LEVEL=info \ +PANIC_ON_UNEXPECTED_ERROR=true \ +RESPONSE_TIMEOUT=10 \ #seconds +ROOT_PATH=./srv \ +DIR_INDEXING_ENABLED=false \ +./gemserve 0.0.0.0:1965 +``` + +You'll need TLS keys, you can use `certs/generate.sh` +for quick generation. + +## TODO +- [ ] Fix slowloris (proper response timeouts) diff --git a/certs/.gitignore b/certs/.gitignore new file mode 100644 index 0000000..07b2319 --- /dev/null +++ b/certs/.gitignore @@ -0,0 +1,2 @@ +ca* +server* diff --git a/certs/generate.sh b/certs/generate.sh new file mode 100755 index 0000000..b84d67b --- /dev/null +++ b/certs/generate.sh @@ -0,0 +1,20 @@ +#!/bin/sh +set -eu + +# Generate private key for CA +openssl genrsa -out ca.key 4096 + +# Generate CA certificate +openssl req -x509 -new -nodes -key ca.key -sha256 -days 3650 -out ca.crt \ + -subj "/C=US/ST=State/L=City/O=Organization/CN=My CA" + +# Generate private key for server +openssl genrsa -out server.key 2048 + +# Generate Certificate Signing Request (CSR) for server +openssl req -new -key server.key -out server.csr \ + -subj "/C=US/ST=State/L=City/O=Organization/CN=localhost" + +# Generate server certificate signed by our CA +openssl x509 -req -in server.csr -CA ca.crt -CAkey ca.key \ + -CAcreateserial -out server.crt -days 3650 -sha256 diff --git a/common/gemini_url.go b/common/gemini_url.go new file mode 100644 index 0000000..64e8bad --- /dev/null +++ b/common/gemini_url.go @@ -0,0 +1,242 @@ +package common + +import ( + "database/sql/driver" + "fmt" + "net/url" + "path" + "strconv" + "strings" + + "gemserve/errors" +) + +type URL struct { + Protocol string `json:"protocol,omitempty"` + Hostname string `json:"hostname,omitempty"` + Port int `json:"port,omitempty"` + Path string `json:"path,omitempty"` + Descr string `json:"descr,omitempty"` + Full string `json:"full,omitempty"` +} + +func (u *URL) Scan(value interface{}) error { + if value == nil { + // Clear the fields in the current GeminiUrl object (not the pointer itself) + *u = URL{} + return nil + } + b, ok := value.(string) + if !ok { + return errors.NewFatalError(fmt.Errorf("database scan error: expected string, got %T", value)) + } + parsedURL, err := ParseURL(b, "", false) + if err != nil { + return err + } + *u = *parsedURL + return nil +} + +func (u URL) String() string { + return u.Full +} + +func (u URL) StringNoDefaultPort() string { + if u.Port == 1965 { + return fmt.Sprintf("%s://%s%s", u.Protocol, u.Hostname, u.Path) + } + return u.Full +} + +func (u URL) Value() (driver.Value, error) { + if u.Full == "" { + return nil, nil + } + return u.Full, nil +} + +func ParseURL(input string, descr string, normalize bool) (*URL, error) { + var u *url.URL + var err error + if normalize { + u, err = NormalizeURL(input) + if err != nil { + return nil, err + } + } else { + u, err = url.Parse(input) + if err != nil { + return nil, errors.NewError(fmt.Errorf("error parsing URL: %w: %s", err, input)) + } + } + if u.Scheme != "gemini" { + return nil, errors.NewError(fmt.Errorf("error parsing URL: not a gemini URL: %s", input)) + } + protocol := u.Scheme + hostname := u.Hostname() + strPort := u.Port() + urlPath := u.EscapedPath() + if strPort == "" { + strPort = "1965" + } + port, err := strconv.Atoi(strPort) + if err != nil { + return nil, errors.NewError(fmt.Errorf("error parsing URL: %w: %s", err, input)) + } + full := fmt.Sprintf("%s://%s:%d%s", protocol, hostname, port, urlPath) + // full field should also contain query params and url fragments + if u.RawQuery != "" { + full += "?" + u.RawQuery + } + if u.Fragment != "" { + full += "#" + u.Fragment + } + return &URL{Protocol: protocol, Hostname: hostname, Port: port, Path: urlPath, Descr: descr, Full: full}, nil +} + +// DeriveAbsoluteURL converts a (possibly) relative +// URL to an absolute one. Used primarily to calculate +// the full redirection URL target from a response header. +func DeriveAbsoluteURL(currentURL URL, input string) (*URL, error) { + // If target URL is absolute, return just it + if strings.Contains(input, "://") { + return ParseURL(input, "", true) + } + // input is a relative path. Clean it and construct absolute. + var newPath string + // Handle weird cases found in the wild + if strings.HasPrefix(input, "/") { + newPath = path.Clean(input) + } else if input == "./" || input == "." { + newPath = path.Join(currentURL.Path, "/") + } else { + newPath = path.Join(currentURL.Path, "/", path.Clean(input)) + } + strURL := fmt.Sprintf("%s://%s:%d%s", currentURL.Protocol, currentURL.Hostname, currentURL.Port, newPath) + return ParseURL(strURL, "", true) +} + +// NormalizeURL takes a URL string and returns a normalized version +// Normalized meaning: +// - Path normalization (removing redundant slashes, . and .. segments) +// - Proper escaping of special characters +// - Lowercase scheme and host +// - Removal of default ports +// - Empty path becomes "/" +func NormalizeURL(rawURL string) (*url.URL, error) { + // Parse the URL + u, err := url.Parse(rawURL) + if err != nil { + return nil, errors.NewError(fmt.Errorf("error normalizing URL: %w: %s", err, rawURL)) + } + if u.Scheme == "" { + return nil, errors.NewError(fmt.Errorf("error normalizing URL: No scheme: %s", rawURL)) + } + if u.Host == "" { + return nil, errors.NewError(fmt.Errorf("error normalizing URL: No host: %s", rawURL)) + } + + // Convert scheme to lowercase + u.Scheme = strings.ToLower(u.Scheme) + + // Convert hostname to lowercase + if u.Host != "" { + u.Host = strings.ToLower(u.Host) + } + + // Remove default ports + if u.Port() != "" { + switch { + case u.Scheme == "http" && u.Port() == "80": + u.Host = u.Hostname() + case u.Scheme == "https" && u.Port() == "443": + u.Host = u.Hostname() + case u.Scheme == "gemini" && u.Port() == "1965": + u.Host = u.Hostname() + } + } + + // Handle path normalization while preserving trailing slash + if u.Path != "" { + // Check if there was a trailing slash before cleaning + hadTrailingSlash := strings.HasSuffix(u.Path, "/") + + u.Path = path.Clean(u.EscapedPath()) + // If path was "/", path.Clean() will return "." + if u.Path == "." { + u.Path = "/" + } else if hadTrailingSlash && u.Path != "/" { + // Restore trailing slash if it existed and path isn't just "/" + u.Path += "/" + } + } + + // Properly escape the path + // First split on '/' to avoid escaping them + parts := strings.Split(u.Path, "/") + for i, part := range parts { + parts[i] = url.PathEscape(part) + } + u.Path = strings.Join(parts, "/") + + // Remove trailing fragment if empty + if u.Fragment == "" { + u.Fragment = "" + } + + // Remove trailing query if empty + if u.RawQuery == "" { + u.RawQuery = "" + } + + return u, nil +} + +func EscapeURL(input string) string { + // Only escape if not already escaped + if strings.Contains(input, "%") && !strings.Contains(input, "% ") { + return input + } + // Split URL into parts (protocol, host, p) + parts := strings.SplitN(input, "://", 2) + if len(parts) != 2 { + return input + } + + protocol := parts[0] + remainder := parts[1] + + // If URL ends with just a slash, return as is + if strings.HasSuffix(remainder, "/") && !strings.Contains(remainder[:len(remainder)-1], "/") { + return input + } + + // Split host and p + parts = strings.SplitN(remainder, "/", 2) + host := parts[0] + if len(parts) == 1 { + return protocol + "://" + host + } + + // Escape the path portion + escapedPath := url.PathEscape(parts[1]) + + // Reconstruct the URL + return protocol + "://" + host + "/" + escapedPath +} + +// normalizePath trims trailing slash and handles empty path +func TrimTrailingPathSlash(path string) string { + // Handle empty path (e.g., "http://example.com" -> treat as root) + if path == "" { + return "/" + } + + // Trim trailing slash while preserving root slash + path = strings.TrimSuffix(path, "/") + if path == "" { // This happens if path was just "/" + return "/" + } + return path +} diff --git a/common/gemini_url_test.go b/common/gemini_url_test.go new file mode 100644 index 0000000..6b23d30 --- /dev/null +++ b/common/gemini_url_test.go @@ -0,0 +1,384 @@ +package common + +import ( + "net/url" + "reflect" + "testing" +) + +func TestParseURL(t *testing.T) { + t.Parallel() + input := "gemini://caolan.uk/cgi-bin/weather.py/wxfcs/3162" + parsed, err := ParseURL(input, "", true) + value, _ := parsed.Value() + if err != nil || !(value == "gemini://caolan.uk:1965/cgi-bin/weather.py/wxfcs/3162") { + t.Errorf("fail: %s", parsed) + } +} + +func TestDeriveAbsoluteURL_abs_url_input(t *testing.T) { + t.Parallel() + currentURL := URL{ + Protocol: "gemini", + Hostname: "smol.gr", + Port: 1965, + Path: "/a/b", + Descr: "Nothing", + Full: "gemini://smol.gr:1965/a/b", + } + input := "gemini://a.b/c" + output, err := DeriveAbsoluteURL(currentURL, input) + if err != nil { + t.Errorf("fail: %v", err) + } + expected := &URL{ + Protocol: "gemini", + Hostname: "a.b", + Port: 1965, + Path: "/c", + Descr: "", + Full: "gemini://a.b:1965/c", + } + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestDeriveAbsoluteURL_abs_path_input(t *testing.T) { + t.Parallel() + currentURL := URL{ + Protocol: "gemini", + Hostname: "smol.gr", + Port: 1965, + Path: "/a/b", + Descr: "Nothing", + Full: "gemini://smol.gr:1965/a/b", + } + input := "/c" + output, err := DeriveAbsoluteURL(currentURL, input) + if err != nil { + t.Errorf("fail: %v", err) + } + expected := &URL{ + Protocol: "gemini", + Hostname: "smol.gr", + Port: 1965, + Path: "/c", + Descr: "", + Full: "gemini://smol.gr:1965/c", + } + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestDeriveAbsoluteURL_rel_path_input(t *testing.T) { + t.Parallel() + currentURL := URL{ + Protocol: "gemini", + Hostname: "smol.gr", + Port: 1965, + Path: "/a/b", + Descr: "Nothing", + Full: "gemini://smol.gr:1965/a/b", + } + input := "c/d" + output, err := DeriveAbsoluteURL(currentURL, input) + if err != nil { + t.Errorf("fail: %v", err) + } + expected := &URL{ + Protocol: "gemini", + Hostname: "smol.gr", + Port: 1965, + Path: "/a/b/c/d", + Descr: "", + Full: "gemini://smol.gr:1965/a/b/c/d", + } + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizeURLSlash(t *testing.T) { + t.Parallel() + input := "gemini://uscoffings.net/retro-computing/magazines/" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := input + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizeURLNoSlash(t *testing.T) { + t.Parallel() + input := "gemini://uscoffings.net/retro-computing/magazines" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := input + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizeMultiSlash(t *testing.T) { + t.Parallel() + input := "gemini://uscoffings.net/retro-computing/////////a///magazines" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := "gemini://uscoffings.net/retro-computing/a/magazines" + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizeTrailingSlash(t *testing.T) { + t.Parallel() + input := "gemini://uscoffings.net/" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := "gemini://uscoffings.net/" + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizeNoTrailingSlash(t *testing.T) { + t.Parallel() + input := "gemini://uscoffings.net" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := "gemini://uscoffings.net" + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizeTrailingSlashPath(t *testing.T) { + t.Parallel() + input := "gemini://uscoffings.net/a/" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := "gemini://uscoffings.net/a/" + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizeNoTrailingSlashPath(t *testing.T) { + t.Parallel() + input := "gemini://uscoffings.net/a" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := "gemini://uscoffings.net/a" + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizeDot(t *testing.T) { + t.Parallel() + input := "gemini://uscoffings.net/retro-computing/./././////a///magazines" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := "gemini://uscoffings.net/retro-computing/a/magazines" + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizePort(t *testing.T) { + t.Parallel() + input := "gemini://uscoffings.net:1965/a" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := "gemini://uscoffings.net/a" + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizeURL(t *testing.T) { + t.Parallel() + input := "gemini://chat.gemini.lehmann.cx:11965/" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := "gemini://chat.gemini.lehmann.cx:11965/" + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } + + input = "gemini://chat.gemini.lehmann.cx:11965/index?a=1&b=c" + normalized, _ = NormalizeURL(input) + output = normalized.String() + expected = "gemini://chat.gemini.lehmann.cx:11965/index?a=1&b=c" + pass = reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } + + input = "gemini://chat.gemini.lehmann.cx:11965/index#1" + normalized, _ = NormalizeURL(input) + output = normalized.String() + expected = "gemini://chat.gemini.lehmann.cx:11965/index#1" + pass = reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } + + input = "gemini://gemi.dev/cgi-bin/xkcd.cgi?1494" + normalized, _ = NormalizeURL(input) + output = normalized.String() + expected = "gemini://gemi.dev/cgi-bin/xkcd.cgi?1494" + pass = reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizePath(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input string // URL string to parse + expected string // Expected normalized path + }{ + // Basic cases + { + name: "empty_path", + input: "http://example.com", + expected: "/", + }, + { + name: "root_path", + input: "http://example.com/", + expected: "/", + }, + { + name: "single_trailing_slash", + input: "http://example.com/test/", + expected: "/test", + }, + { + name: "no_trailing_slash", + input: "http://example.com/test", + expected: "/test", + }, + + // Edge cases with slashes + { + name: "multiple_trailing_slashes", + input: "http://example.com/test//", + expected: "/test/", + }, + { + name: "multiple_consecutive_slashes", + input: "http://example.com//test//", + expected: "//test/", + }, + { + name: "only_slashes", + input: "http://example.com////", + expected: "///", + }, + { + name: "single_slash", + input: "/", + expected: "/", + }, + + // Encoded characters + { + name: "encoded_spaces", + input: "http://example.com/foo%20bar/", + expected: "/foo%20bar", + }, + { + name: "encoded_special_chars", + input: "http://example.com/foo%2Fbar/", + expected: "/foo%2Fbar", + }, + + // Query parameters and fragments + { + name: "with_query_parameters", + input: "http://example.com/path?query=param", + expected: "/path", + }, + { + name: "with_fragment", + input: "http://example.com/path#fragment", + expected: "/path", + }, + { + name: "with_both_query_and_fragment", + input: "http://example.com/path?query=param#fragment", + expected: "/path", + }, + + // Relative URLs + { + name: "relative_path", + input: "/just/a/path/", + expected: "/just/a/path", + }, + + // Unicode paths + { + name: "unicode_characters", + input: "http://example.com/über/path/", + expected: "/%C3%BCber/path", + }, + { + name: "unicode_encoded", + input: "http://example.com/%C3%BCber/path/", + expected: "/%C3%BCber/path", + }, + + // Weird but valid cases + { + name: "dot_in_path", + input: "http://example.com/./path/", + expected: "/./path", + }, + { + name: "double_dot_in_path", + input: "http://example.com/../path/", + expected: "/../path", + }, + { + name: "mixed_case", + input: "http://example.com/PaTh/", + expected: "/PaTh", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, err := url.Parse(tt.input) + if err != nil { + t.Fatalf("Failed to parse URL %q: %v", tt.input, err) + } + + result := TrimTrailingPathSlash(u.EscapedPath()) + if result != tt.expected { + t.Errorf("Input: %s\nExpected: %q\nGot: %q", + u.Path, tt.expected, result) + } + }) + } +} diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..32160dc --- /dev/null +++ b/config/config.go @@ -0,0 +1,126 @@ +package config + +import ( + "fmt" + "os" + "strconv" + + "github.com/rs/zerolog" +) + +// Environment variable names. +const ( + EnvLogLevel = "LOG_LEVEL" + EnvResponseTimeout = "RESPONSE_TIMEOUT" + EnvPanicOnUnexpectedError = "PANIC_ON_UNEXPECTED_ERROR" + EnvRootPath = "ROOT_PATH" + EnvDirIndexingEnabled = "DIR_INDEXING_ENABLED" +) + +// Config holds the application configuration loaded from environment variables. +type Config struct { + LogLevel zerolog.Level // Logging level (debug, info, warn, error) + ResponseTimeout int // Timeout for responses in seconds + PanicOnUnexpectedError bool // Panic on unexpected errors when visiting a URL + RootPath string // Path to serve files from + DirIndexingEnabled bool // Allow client to browse directories or not +} + +var CONFIG Config //nolint:gochecknoglobals + +// parsePositiveInt parses and validates positive integer values. +func parsePositiveInt(param, value string) (int, error) { + val, err := strconv.Atoi(value) + if err != nil { + return 0, ValidationError{ + Param: param, + Value: value, + Reason: "must be a valid integer", + } + } + if val <= 0 { + return 0, ValidationError{ + Param: param, + Value: value, + Reason: "must be positive", + } + } + return val, nil +} + +func parseBool(param, value string) (bool, error) { + val, err := strconv.ParseBool(value) + if err != nil { + return false, ValidationError{ + Param: param, + Value: value, + Reason: "cannot be converted to boolean", + } + } + return val, nil +} + +// GetConfig loads and validates configuration from environment variables +func GetConfig() *Config { + config := &Config{} + + // Map of environment variables to their parsing functions + parsers := map[string]func(string) error{ + EnvLogLevel: func(v string) error { + level, err := zerolog.ParseLevel(v) + if err != nil { + return ValidationError{ + Param: EnvLogLevel, + Value: v, + Reason: "must be one of: debug, info, warn, error", + } + } + config.LogLevel = level + return nil + }, + EnvResponseTimeout: func(v string) error { + val, err := parsePositiveInt(EnvResponseTimeout, v) + if err != nil { + return err + } + config.ResponseTimeout = val + return nil + }, + EnvPanicOnUnexpectedError: func(v string) error { + val, err := parseBool(EnvPanicOnUnexpectedError, v) + if err != nil { + return err + } + config.PanicOnUnexpectedError = val + return nil + }, + EnvRootPath: func(v string) error { + config.RootPath = v + return nil + }, + EnvDirIndexingEnabled: func(v string) error { + val, err := parseBool(EnvDirIndexingEnabled, v) + if err != nil { + return err + } + config.DirIndexingEnabled = val + return nil + }, + } + + // Process each environment variable + for envVar, parser := range parsers { + value, ok := os.LookupEnv(envVar) + if !ok { + _, _ = fmt.Fprintf(os.Stderr, "Missing required environment variable: %s\n", envVar) + os.Exit(1) + } + + if err := parser(value); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "Configuration error: %v\n", err) + os.Exit(1) + } + } + + return config +} diff --git a/config/errors.go b/config/errors.go new file mode 100644 index 0000000..60482d7 --- /dev/null +++ b/config/errors.go @@ -0,0 +1,14 @@ +package config + +import "fmt" + +// ValidationError represents a config validation error +type ValidationError struct { + Param string + Value string + Reason string +} + +func (e ValidationError) Error() string { + return fmt.Sprintf("invalid value '%s' for %s: %s", e.Value, e.Param, e.Reason) +} diff --git a/errors/errors.go b/errors/errors.go new file mode 100644 index 0000000..6bd39ea --- /dev/null +++ b/errors/errors.go @@ -0,0 +1,114 @@ +package errors + +import ( + "errors" + "fmt" + "runtime" + "strings" +) + +type fatal interface { + Fatal() bool +} + +func IsFatal(err error) bool { + te, ok := errors.Unwrap(err).(fatal) + return ok && te.Fatal() +} + +func As(err error, target any) bool { + return errors.As(err, target) +} + +func Is(err, target error) bool { + return errors.Is(err, target) +} + +func Unwrap(err error) error { + return errors.Unwrap(err) +} + +type Error struct { + Err error + Stack string + fatal bool +} + +func (e *Error) Error() string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("%v\n", e.Err)) + return sb.String() +} + +func (e *Error) ErrorWithStack() string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("%v\n", e.Err)) + sb.WriteString(fmt.Sprintf("Stack Trace:\n%s", e.Stack)) + return sb.String() +} + +func (e *Error) Fatal() bool { + return e.fatal +} + +func (e *Error) Unwrap() error { + return e.Err +} + +func NewError(err error) error { + if err == nil { + return nil + } + + // Check if it's already of our own + // Error type, so we don't add stack twice. + var asError *Error + if errors.As(err, &asError) { + return err + } + + // Get the stack trace + var stack strings.Builder + buf := make([]uintptr, 50) + n := runtime.Callers(2, buf) + frames := runtime.CallersFrames(buf[:n]) + + // Format the stack trace + for { + frame, more := frames.Next() + // Skip runtime and standard library frames + if !strings.Contains(frame.File, "runtime/") { + stack.WriteString(fmt.Sprintf("\t%s:%d - %s\n", frame.File, frame.Line, frame.Function)) + } + if !more { + break + } + } + + return &Error{ + Err: err, + Stack: stack.String(), + } +} + +func NewFatalError(err error) error { + if err == nil { + return nil + } + + // Check if it's already of our own + // Error type. + var asError *Error + if errors.As(err, &asError) { + return err + } + err2 := NewError(err) + err2.(*Error).fatal = true + return err2 +} + +var ConnectionError error = fmt.Errorf("connection error") + +func NewConnectionError(err error) error { + return fmt.Errorf("%w: %w", ConnectionError, err) +} diff --git a/errors/errors_test.go b/errors/errors_test.go new file mode 100644 index 0000000..30bde0e --- /dev/null +++ b/errors/errors_test.go @@ -0,0 +1,71 @@ +package errors + +import ( + "errors" + "fmt" + "testing" +) + +type CustomError struct { + Err error +} + +func (e *CustomError) Error() string { return e.Err.Error() } + +func IsCustomError(err error) bool { + var asError *CustomError + return errors.As(err, &asError) +} + +func TestWrapping(t *testing.T) { + t.Parallel() + originalErr := errors.New("original error") + err1 := NewError(originalErr) + if !errors.Is(err1, originalErr) { + t.Errorf("original error is not wrapped") + } + if !Is(err1, originalErr) { + t.Errorf("original error is not wrapped") + } + unwrappedErr := errors.Unwrap(err1) + if !errors.Is(unwrappedErr, originalErr) { + t.Errorf("original error is not wrapped") + } + if !Is(unwrappedErr, originalErr) { + t.Errorf("original error is not wrapped") + } + unwrappedErr = Unwrap(err1) + if !errors.Is(unwrappedErr, originalErr) { + t.Errorf("original error is not wrapped") + } + if !Is(unwrappedErr, originalErr) { + t.Errorf("original error is not wrapped") + } + wrappedErr := fmt.Errorf("wrapped: %w", originalErr) + if !errors.Is(wrappedErr, originalErr) { + t.Errorf("original error is not wrapped") + } + if !Is(wrappedErr, originalErr) { + t.Errorf("original error is not wrapped") + } +} + +func TestNewError(t *testing.T) { + t.Parallel() + originalErr := &CustomError{errors.New("err1")} + if !IsCustomError(originalErr) { + t.Errorf("TestNewError fail #1") + } + err1 := NewError(originalErr) + if !IsCustomError(err1) { + t.Errorf("TestNewError fail #2") + } + wrappedErr1 := fmt.Errorf("wrapped %w", err1) + if !IsCustomError(wrappedErr1) { + t.Errorf("TestNewError fail #3") + } + unwrappedErr1 := Unwrap(wrappedErr1) + if !IsCustomError(unwrappedErr1) { + t.Errorf("TestNewError fail #4") + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..ce1f750 --- /dev/null +++ b/go.mod @@ -0,0 +1,16 @@ +module gemserve + +go 1.23.4 + +require ( + github.com/gabriel-vasile/mimetype v1.4.8 + github.com/matoous/go-nanoid/v2 v2.1.0 + github.com/rs/zerolog v1.33.0 +) + +require ( + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + golang.org/x/net v0.33.0 // indirect + golang.org/x/sys v0.29.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..30aaf07 --- /dev/null +++ b/go.sum @@ -0,0 +1,30 @@ +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= +github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/matoous/go-nanoid/v2 v2.1.0 h1:P64+dmq21hhWdtvZfEAofnvJULaRR1Yib0+PnU669bE= +github.com/matoous/go-nanoid/v2 v2.1.0/go.mod h1:KlbGNQ+FhrUNIHUxZdL63t7tl4LaPkZNpUULS8H4uVM= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= +github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/logging/logging.go b/logging/logging.go new file mode 100644 index 0000000..3b8ec62 --- /dev/null +++ b/logging/logging.go @@ -0,0 +1,23 @@ +package logging + +import ( + "fmt" + + zlog "github.com/rs/zerolog/log" +) + +func LogDebug(format string, args ...interface{}) { + zlog.Debug().Msg(fmt.Sprintf(format, args...)) +} + +func LogInfo(format string, args ...interface{}) { + zlog.Info().Msg(fmt.Sprintf(format, args...)) +} + +func LogWarn(format string, args ...interface{}) { + zlog.Warn().Msg(fmt.Sprintf(format, args...)) +} + +func LogError(format string, args ...interface{}) { + zlog.Error().Err(fmt.Errorf(format, args...)).Msg("") +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..415fe82 --- /dev/null +++ b/main.go @@ -0,0 +1,189 @@ +package main + +import ( + "bytes" + "crypto/tls" + "fmt" + "io" + "net" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "gemserve/config" + "gemserve/errors" + "gemserve/logging" + "gemserve/server" + "gemserve/uid" + "github.com/rs/zerolog" + zlog "github.com/rs/zerolog/log" +) + +func main() { + config.CONFIG = *config.GetConfig() + zerolog.TimeFieldFormat = zerolog.TimeFormatUnix + zerolog.SetGlobalLevel(config.CONFIG.LogLevel) + zlog.Logger = zlog.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: "[2006-01-02 15:04:05]"}) + err := runApp() + if err != nil { + fmt.Printf("%v\n", err) + logging.LogError("%v", err) + os.Exit(1) + } +} + +func runApp() error { + logging.LogInfo("Starting up. Press Ctrl+C to exit") + + var listenHost string + if len(os.Args) != 2 { + listenHost = "0.0.0.0:1965" + } else { + listenHost = os.Args[1] + } + + signals := make(chan os.Signal, 1) + signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) + + serverErrors := make(chan error) + + go func() { + err := startServer(listenHost) + if err != nil { + serverErrors <- errors.NewFatalError(err) + } + }() + + for { + select { + case <-signals: + logging.LogWarn("Received SIGINT or SIGTERM signal, exiting") + return nil + case serverError := <-serverErrors: + return errors.NewFatalError(serverError) + } + } +} + +func startServer(listenHost string) (err error) { + cert, err := tls.LoadX509KeyPair("/certs/cert", "/certs/key") + if err != nil { + return errors.NewFatalError(fmt.Errorf("failed to load certificate: %w", err)) + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + } + + listener, err := tls.Listen("tcp", listenHost, tlsConfig) + if err != nil { + return errors.NewFatalError(fmt.Errorf("failed to create listener: %w", err)) + } + defer func(listener net.Listener) { + // If we've got an error closing the + // listener, make sure we don't override + // the original error (if not nil) + errClose := listener.Close() + if errClose != nil && err == nil { + err = errors.NewFatalError(err) + } + }(listener) + + logging.LogInfo("Server listening on %s", listenHost) + + for { + conn, err := listener.Accept() + if err != nil { + logging.LogInfo("Failed to accept connection: %v", err) + continue + } + + go func() { + err := handleConnection(conn.(*tls.Conn)) + if err != nil { + var asErr *errors.Error + if errors.As(err, &asErr) { + logging.LogError("Unexpected error: %v %v", err, err.(*errors.Error).ErrorWithStack()) + } else { + logging.LogError("Unexpected error: %v", err) + } + if config.CONFIG.PanicOnUnexpectedError { + panic("Encountered unexpected error") + } + } + }() + } +} + +func closeConnection(conn *tls.Conn) error { + err := conn.CloseWrite() + if err != nil { + return errors.NewConnectionError(fmt.Errorf("failed to close TLS connection: %w", err)) + } + err = conn.Close() + if err != nil { + return errors.NewConnectionError(fmt.Errorf("failed to close connection: %w", err)) + } + return nil +} + +func handleConnection(conn *tls.Conn) (err error) { + remoteAddr := conn.RemoteAddr().String() + connId := uid.UID() + start := time.Now() + var outputBytes []byte + + defer func(conn *tls.Conn) { + // Three possible cases here: + // - We don't have an error + // - We have a ConnectionError, which we don't propagate up + // - We have an unexpected error. + end := time.Now() + tookMs := end.Sub(start).Milliseconds() + var responseHeader string + if err != nil { + _, _ = conn.Write([]byte("50 server error")) + responseHeader = "50 server error" + // We don't propagate connection errors up. + if errors.Is(err, errors.ConnectionError) { + logging.LogInfo("%s %s %v", connId, remoteAddr, err) + err = nil + } + } else { + if i := bytes.Index(outputBytes, []byte{'\r'}); i >= 0 { + responseHeader = string(outputBytes[:i]) + } + } + logging.LogInfo("%s %s response %s (%dms)", connId, remoteAddr, responseHeader, tookMs) + _ = closeConnection(conn) + }(conn) + + // Gemini is supposed to have a 1kb limit + // on input requests. + buffer := make([]byte, 1024) + + n, err := conn.Read(buffer) + if err != nil && err != io.EOF { + return errors.NewConnectionError(fmt.Errorf("failed to read connection data: %w", err)) + } + if n == 0 { + return errors.NewConnectionError(fmt.Errorf("client did not send data")) + } + + dataBytes := buffer[:n] + dataString := string(dataBytes) + + logging.LogInfo("%s %s request %s (%d bytes)", connId, remoteAddr, strings.TrimSpace(dataString), len(dataBytes)) + outputBytes, err = server.GenerateResponse(conn, connId, dataString) + if err != nil { + return err + } + _, err = conn.Write(outputBytes) + if err != nil { + return err + } + return nil +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..4b51368 --- /dev/null +++ b/server/server.go @@ -0,0 +1,124 @@ +package server + +import ( + "crypto/tls" + "fmt" + "os" + "path" + "path/filepath" + "strings" + + "gemserve/common" + "gemserve/config" + "gemserve/errors" + "gemserve/logging" + "github.com/gabriel-vasile/mimetype" +) + +type ServerConfig interface { + DirIndexingEnabled() bool + RootPath() string +} + +func GenerateResponse(conn *tls.Conn, connId string, input string) ([]byte, error) { + trimmedInput := strings.TrimSpace(input) + // url will have a cleaned and normalized path after this + url, err := common.ParseURL(trimmedInput, "", true) + if err != nil { + return nil, errors.NewConnectionError(fmt.Errorf("failed to parse URL: %w", err)) + } + logging.LogDebug("%s %s normalized URL path: %s", connId, conn.RemoteAddr(), url.Path) + serverRootPath := config.CONFIG.RootPath + localPath, err := calculateLocalPath(url.Path, serverRootPath) + if err != nil { + return nil, errors.NewConnectionError(err) + } + logging.LogDebug("%s %s request file path: %s", connId, conn.RemoteAddr(), localPath) + + // Get file/directory information + info, err := os.Stat(localPath) + if errors.Is(err, os.ErrNotExist) || errors.Is(err, os.ErrPermission) { + return []byte("51 not found\r\n"), nil + } else if err != nil { + return nil, errors.NewConnectionError(fmt.Errorf("%s %s failed to access path: %w", connId, conn.RemoteAddr(), err)) + } + + // Handle directory. + if info.IsDir() { + return generateResponseDir(conn, connId, url, localPath) + } + return generateResponseFile(conn, connId, url, localPath) +} + +func generateResponseFile(conn *tls.Conn, connId string, url *common.URL, localPath string) ([]byte, error) { + data, err := os.ReadFile(localPath) + if errors.Is(err, os.ErrNotExist) || errors.Is(err, os.ErrPermission) { + return []byte("51 not found\r\n"), nil + } else if err != nil { + return nil, errors.NewConnectionError(fmt.Errorf("%s %s failed to read file: %w", connId, conn.RemoteAddr(), err)) + } + + var mimeType string + if path.Ext(localPath) == ".gmi" { + mimeType = "text/gemini" + } else { + mimeType = mimetype.Detect(data).String() + } + headerBytes := []byte(fmt.Sprintf("20 %s\r\n", mimeType)) + response := append(headerBytes, data...) + return response, nil +} + +func generateResponseDir(conn *tls.Conn, connId string, url *common.URL, localPath string) (output []byte, err error) { + entries, err := os.ReadDir(localPath) + if err != nil { + return nil, errors.NewConnectionError(fmt.Errorf("%s %s failed to read directory: %w", connId, conn.RemoteAddr(), err)) + } + + if config.CONFIG.DirIndexingEnabled { + var contents []string + contents = append(contents, "Directory index:\n\n") + contents = append(contents, "=> ../\n") + for _, entry := range entries { + if entry.IsDir() { + contents = append(contents, fmt.Sprintf("=> %s/\n", entry.Name())) + } else { + contents = append(contents, fmt.Sprintf("=> %s\n", entry.Name())) + } + } + data := []byte(strings.Join(contents, "")) + headerBytes := []byte("20 text/gemini;\r\n") + response := append(headerBytes, data...) + return response, nil + } else { + filePath := path.Join(localPath, "index.gmi") + return generateResponseFile(conn, connId, url, filePath) + + } +} + +func calculateLocalPath(input string, basePath string) (string, error) { + // Check for invalid characters early + if strings.ContainsAny(input, "\\") { + return "", errors.NewError(fmt.Errorf("invalid characters in path: %s", input)) + } + + // If IsLocal(path) returns true, then Join(base, path) + // will always produce a path contained within base and + // Clean(path) will always produce an unrooted path with + // no ".." path elements. + filePath := input + filePath = strings.TrimPrefix(filePath, "/") + if filePath == "" { + filePath = "." + } + filePath = strings.TrimSuffix(filePath, "/") + + localPath, err := filepath.Localize(filePath) + if err != nil || !filepath.IsLocal(localPath) { + return "", errors.NewError(fmt.Errorf("could not construct local path from %s: %s", input, err)) + } + + filePath = path.Join(basePath, localPath) + return filePath, nil +} diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 0000000..ccb3644 --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,225 @@ +package server + +import ( + "path/filepath" + "strings" + "testing" +) + +func TestCalculateLocalPath(t *testing.T) { + tests := []struct { + name string + input string + basePath string + want string + expectError bool + }{ + // Basic path handling + { + name: "Simple valid path", + input: "folder/file.txt", + basePath: "/base", + want: "/base/folder/file.txt", + expectError: false, + }, + { + name: "Empty path", + input: "", + basePath: "/base", + want: "/base", + expectError: false, + }, + { + name: "Current directory", + input: ".", + basePath: "/base", + want: "/base", + expectError: false, + }, + + // Leading/trailing slash handling + { + name: "Path with leading slash", + input: "/folder/file.txt", + basePath: "/base", + want: "/base/folder/file.txt", + expectError: false, + }, + { + name: "Path with trailing slash", + input: "folder/", + basePath: "/base", + want: "/base/folder", + expectError: false, + }, + { + name: "Path with both leading and trailing slashes", + input: "/folder/", + basePath: "/base", + want: "/base/folder", + expectError: false, + }, + + // Path traversal attempts + { + name: "Simple path traversal attempt", + input: "../file.txt", + basePath: "/base", + want: "", + expectError: true, + }, + { + name: "Complex path traversal attempt", + input: "folder/../../../etc/passwd", + basePath: "/base", + want: "", + expectError: true, + }, + { + name: "Encoded path traversal attempt", + input: "folder/..%2F..%2F..%2Fetc%2Fpasswd", + basePath: "/base", + want: "/base/folder/..%2F..%2F..%2Fetc%2Fpasswd", + expectError: false, + }, + { + name: "Double dot hidden in path", + input: "folder/.../.../etc/passwd", + basePath: "/base", + want: "/base/folder/.../.../etc/passwd", + expectError: false, + }, + + // Edge cases + { + name: "Multiple sequential slashes", + input: "folder///subfolder////file.txt", + basePath: "/base", + want: "", + expectError: true, + }, + { + name: "Unicode characters in path", + input: "фольдер/файл.txt", + basePath: "/base", + want: "/base/фольдер/файл.txt", + expectError: false, + }, + { + name: "Path with spaces and special characters", + input: "my folder/my file!@#$%.txt", + basePath: "/base", + want: "/base/my folder/my file!@#$%.txt", + expectError: false, + }, + { + name: "Very long path", + input: "a/b/c/d/e/f/g/h/i/j/k/l/m/n/o/p/q/r/s/t/u/v/w/x/y/z/file.txt", + basePath: "/base", + want: "/base/a/b/c/d/e/f/g/h/i/j/k/l/m/n/o/p/q/r/s/t/u/v/w/x/y/z/file.txt", + expectError: false, + }, + + // Base path variations + { + name: "Empty base path", + input: "file.txt", + basePath: "", + want: "file.txt", + expectError: false, + }, + { + name: "Relative base path", + input: "file.txt", + basePath: "base/folder", + want: "base/folder/file.txt", + expectError: false, + }, + { + name: "Base path with trailing slash", + input: "file.txt", + basePath: "/base/", + want: "/base/file.txt", + expectError: false, + }, + + // Symbolic link-like paths (if supported) + { + name: "Path with symbolic link-like components", + input: "folder/symlink/../file.txt", + basePath: "/base", + want: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := calculateLocalPath(tt.input, tt.basePath) + + // Check error expectation + if (err != nil) != tt.expectError { + t.Errorf("calculateLocalPath() error = %v, expectError = %v", err, tt.expectError) + return + } + + // If we expect an error, don't check the returned path + if tt.expectError { + return + } + + // Check if the returned path matches expected + if got != tt.want { + t.Errorf("calculateLocalPath() = %v, want %v", got, tt.want) + } + + // Additional security checks for non-error cases + if !tt.expectError { + // Verify the returned path is within base path + if !isWithinBasePath(got, tt.basePath) { + t.Errorf("Result path %v escapes base path %v", got, tt.basePath) + } + + // Verify no '..' components in final path + if containsParentRef(got) { + t.Errorf("Result path %v contains parent references", got) + } + } + }) + } +} + +// Helper function to check if a path is contained within the base path +func isWithinBasePath(path, basePath string) bool { + if basePath == "" { + return true + } + + absBase, err := filepath.Abs(basePath) + if err != nil { + return false + } + + absPath, err := filepath.Abs(path) + if err != nil { + return false + } + + rel, err := filepath.Rel(absBase, absPath) + if err != nil { + return false + } + + return !strings.HasPrefix(rel, "..") +} + +// Helper function to check if a path contains parent directory references +func containsParentRef(path string) bool { + parts := strings.Split(filepath.Clean(path), string(filepath.Separator)) + for _, part := range parts { + if part == ".." { + return true + } + } + return false +} diff --git a/uid/uid.go b/uid/uid.go new file mode 100644 index 0000000..b98e342 --- /dev/null +++ b/uid/uid.go @@ -0,0 +1,14 @@ +package uid + +import ( + nanoid "github.com/matoous/go-nanoid/v2" +) + +func UID() string { + // No 'o','O' and 'l' + id, err := nanoid.Generate("abcdefghijkmnpqrstuvwxyzABCDEFGHIJKLMNPQRSTUVWXYZ0123456789", 20) + if err != nil { + panic(err) + } + return id +}