From b52df073e95619058ea4d8b723eea632a13e4dd7 Mon Sep 17 00:00:00 2001 From: antanst Date: Fri, 27 Dec 2024 12:09:55 +0200 Subject: [PATCH] Add first version of gemini-grc. --- bin/normalizeSnapshot/main.go | 116 +++++++++++ config/config.go | 156 ++++++++++++++ config/errors.go | 14 ++ gemini/blacklist.go | 50 +++++ gemini/connectionPool.go | 28 +++ gemini/db.go | 176 ++++++++++++++++ gemini/db_queries.go | 78 +++++++ gemini/errors.go | 100 +++++++++ gemini/errors_test.go | 24 +++ gemini/files.go | 113 +++++++++++ gemini/gemini.go | 139 +++++++++++++ gemini/gemini_test.go | 65 ++++++ gemini/gemini_url.go | 229 +++++++++++++++++++++ gemini/gemini_url_test.go | 223 ++++++++++++++++++++ gemini/ip-address-pool.go | 54 +++++ gemini/network.go | 244 ++++++++++++++++++++++ gemini/network_test.go | 78 +++++++ gemini/processing.go | 59 ++++++ gemini/processing_test.go | 14 ++ gemini/robotmatch.go | 82 ++++++++ gemini/robots.go | 31 +++ gemini/robots_test.go | 55 +++++ gemini/snapshot.go | 42 ++++ gemini/worker.go | 368 ++++++++++++++++++++++++++++++++++ go.mod | 24 +++ go.sum | 59 ++++++ logging/logging.go | 23 +++ main.go | 60 ++++++ uid/uid.go | 14 ++ util/util.go | 36 ++++ 30 files changed, 2754 insertions(+) create mode 100644 bin/normalizeSnapshot/main.go create mode 100644 config/config.go create mode 100644 config/errors.go create mode 100644 gemini/blacklist.go create mode 100644 gemini/connectionPool.go create mode 100644 gemini/db.go create mode 100644 gemini/db_queries.go create mode 100644 gemini/errors.go create mode 100644 gemini/errors_test.go create mode 100644 gemini/files.go create mode 100644 gemini/gemini.go create mode 100644 gemini/gemini_test.go create mode 100644 gemini/gemini_url.go create mode 100644 gemini/gemini_url_test.go create mode 100644 gemini/ip-address-pool.go create mode 100644 gemini/network.go create mode 100644 gemini/network_test.go create mode 100644 gemini/processing.go create mode 100644 gemini/processing_test.go create mode 100644 gemini/robotmatch.go create mode 100644 gemini/robots.go create mode 100644 gemini/robots_test.go create mode 100644 gemini/snapshot.go create mode 100644 gemini/worker.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 logging/logging.go create mode 100644 main.go create mode 100644 uid/uid.go create mode 100644 util/util.go diff --git a/bin/normalizeSnapshot/main.go b/bin/normalizeSnapshot/main.go new file mode 100644 index 0000000..fb6fea7 --- /dev/null +++ b/bin/normalizeSnapshot/main.go @@ -0,0 +1,116 @@ +package main + +import ( + "fmt" + "os" + + "gemini-grc/gemini" + _ "github.com/jackc/pgx/v5/stdlib" // PGX driver for PostgreSQL + "github.com/jmoiron/sqlx" +) + +// Populates the `host` field +func main() { + db := connectToDB() + count := 0 + + for { + tx := db.MustBegin() + query := ` + SELECT * FROM snapshots + ORDER BY id + LIMIT 10000 OFFSET $1 + ` + var snapshots []gemini.Snapshot + err := tx.Select(&snapshots, query, count) + if err != nil { + printErrorAndExit(tx, err) + } + if len(snapshots) == 0 { + fmt.Println("Done!") + return + } + for _, s := range snapshots { + count++ + escaped := gemini.EscapeURL(s.URL.String()) + normalizedGeminiURL, err := gemini.ParseURL(escaped, "") + if err != nil { + fmt.Println(s.URL.String()) + fmt.Println(escaped) + printErrorAndExit(tx, err) + } + normalizedURLString := normalizedGeminiURL.String() + // If URL is already normalized, skip snapshot + if normalizedURLString == s.URL.String() { + // fmt.Printf("[%5d] Skipping %d %s\n", count, s.ID, s.URL.String()) + continue + } + // If a snapshot already exists with the normalized + // URL, delete the current snapshot and leave the other. + var ss []gemini.Snapshot + err = tx.Select(&ss, "SELECT * FROM snapshots WHERE URL=$1", normalizedURLString) + if err != nil { + printErrorAndExit(tx, err) + } + if len(ss) > 0 { + tx.MustExec("DELETE FROM snapshots WHERE id=$1", s.ID) + fmt.Printf("%d Deleting %d %s\n", count, s.ID, s.URL.String()) + //err = tx.Commit() + //if err != nil { + // printErrorAndExit(tx, err) + //} + //return + continue + } + // fmt.Printf("%s =>\n%s\n", s.URL.String(), normalizedURLString) + // At this point we just update the snapshot, + // and the normalized URL will be saved. + fmt.Printf("%d Updating %d %s => %s\n", count, s.ID, s.URL.String(), normalizedURLString) + // Saves the snapshot with the normalized URL + tx.MustExec("DELETE FROM snapshots WHERE id=$1", s.ID) + s.URL = *normalizedGeminiURL + err = gemini.UpsertSnapshot(0, tx, &s) + if err != nil { + printErrorAndExit(tx, err) + } + //err = tx.Commit() + //if err != nil { + // printErrorAndExit(tx, err) + //} + //return + } + err = tx.Commit() + if err != nil { + printErrorAndExit(tx, err) + } + } +} + +func printErrorAndExit(tx *sqlx.Tx, err error) { + _ = tx.Rollback() + panic(err) +} + +func connectToDB() *sqlx.DB { + connStr := fmt.Sprintf("postgres://%s:%s@%s:%s/%s", + os.Getenv("PG_USER"), + os.Getenv("PG_PASSWORD"), + os.Getenv("PG_HOST"), + os.Getenv("PG_PORT"), + os.Getenv("PG_DATABASE"), + ) + + // Create a connection pool + db, err := sqlx.Open("pgx", connStr) + if err != nil { + panic(fmt.Sprintf("Unable to connect to database with URL %s: %v\n", connStr, err)) + } + db.SetMaxOpenConns(20) + err = db.Ping() + if err != nil { + panic(fmt.Sprintf("Unable to ping database: %v\n", err)) + } + + fmt.Println("Connected to database") + return db +} diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..cc2d430 --- /dev/null +++ b/config/config.go @@ -0,0 +1,156 @@ +package config + +import ( + "fmt" + "os" + "strconv" + + "github.com/rs/zerolog" +) + +// Environment variable names. +const ( + EnvLogLevel = "LOG_LEVEL" + EnvNumWorkers = "NUM_OF_WORKERS" + EnvWorkerBatchSize = "WORKER_BATCH_SIZE" + EnvMaxResponseSize = "MAX_RESPONSE_SIZE" + EnvResponseTimeout = "RESPONSE_TIMEOUT" + EnvPanicOnUnexpectedError = "PANIC_ON_UNEXPECTED_ERROR" + EnvBlacklistPath = "BLACKLIST_PATH" + EnvDryRun = "DRY_RUN" +) + +// Config holds the application configuration loaded from environment variables. +type Config struct { + LogLevel zerolog.Level // Logging level (debug, info, warn, error) + MaxResponseSize int // Maximum size of response in bytes + NumOfWorkers int // Number of concurrent workers + ResponseTimeout int // Timeout for responses in seconds + WorkerBatchSize int // Batch size for worker processing + PanicOnUnexpectedError bool // Panic on unexpected errors when visiting a URL + BlacklistPath string // File that has blacklisted strings of "host:port" + DryRun bool // If false, don't write to disk +} + +var CONFIG Config //nolint:gochecknoglobals + +// parsePositiveInt parses and validates positive integer values. +func parsePositiveInt(param, value string) (int, error) { + val, err := strconv.Atoi(value) + if err != nil { + return 0, ValidationError{ + Param: param, + Value: value, + Reason: "must be a valid integer", + } + } + if val <= 0 { + return 0, ValidationError{ + Param: param, + Value: value, + Reason: "must be positive", + } + } + return val, nil +} + +func parseBool(param, value string) (bool, error) { + val, err := strconv.ParseBool(value) + if err != nil { + return false, ValidationError{ + Param: param, + Value: value, + Reason: "cannot be converted to boolean", + } + } + return val, nil +} + +// GetConfig loads and validates configuration from environment variables +func GetConfig() *Config { + config := &Config{} + + // Map of environment variables to their parsing functions + parsers := map[string]func(string) error{ + EnvLogLevel: func(v string) error { + level, err := zerolog.ParseLevel(v) + if err != nil { + return ValidationError{ + Param: EnvLogLevel, + Value: v, + Reason: "must be one of: debug, info, warn, error", + } + } + config.LogLevel = level + return nil + }, + EnvNumWorkers: func(v string) error { + val, err := parsePositiveInt(EnvNumWorkers, v) + if err != nil { + return err + } + config.NumOfWorkers = val + return nil + }, + EnvWorkerBatchSize: func(v string) error { + val, err := parsePositiveInt(EnvWorkerBatchSize, v) + if err != nil { + return err + } + config.WorkerBatchSize = val + return nil + }, + EnvMaxResponseSize: func(v string) error { + val, err := parsePositiveInt(EnvMaxResponseSize, v) + if err != nil { + return err + } + config.MaxResponseSize = val + return nil + }, + EnvResponseTimeout: func(v string) error { + val, err := parsePositiveInt(EnvResponseTimeout, v) + if err != nil { + return err + } + config.ResponseTimeout = val + return nil + }, + EnvPanicOnUnexpectedError: func(v string) error { + val, err := parseBool(EnvPanicOnUnexpectedError, v) + if err != nil { + return err + } + config.PanicOnUnexpectedError = val + return nil + }, + EnvBlacklistPath: func(v string) error { + config.BlacklistPath = v + return nil + }, + EnvDryRun: func(v string) error { + val, err := parseBool(EnvDryRun, v) + if err != nil { + return err + } + config.DryRun = val + return nil + }, + } + + // Process each environment variable + for envVar, parser := range parsers { + value, ok := os.LookupEnv(envVar) + if !ok { + fmt.Fprintf(os.Stderr, "Missing required environment variable: %s\n", envVar) + os.Exit(1) + } + + if err := parser(value); err != nil { + fmt.Fprintf(os.Stderr, "Configuration error: %v\n", err) + os.Exit(1) + } + } + + return config +} diff --git a/config/errors.go b/config/errors.go new file mode 100644 index 0000000..60482d7 --- /dev/null +++ b/config/errors.go @@ -0,0 +1,14 @@ +package config + +import "fmt" + +// ValidationError represents a config validation error +type ValidationError struct { + Param string + Value string + Reason string +} + +func (e ValidationError) Error() string { + return fmt.Sprintf("invalid value '%s' for %s: %s", e.Value, e.Param, e.Reason) +} diff --git a/gemini/blacklist.go b/gemini/blacklist.go new file mode 100644 index 0000000..6b4148c --- /dev/null +++ b/gemini/blacklist.go @@ -0,0 +1,50 @@ +package gemini + +import ( + "fmt" + "os" + "strings" + + "gemini-grc/config" + "gemini-grc/logging" +) + +var Blacklist *[]string //nolint:gochecknoglobals + +func LoadBlacklist() { + if Blacklist == nil { + data, err := os.ReadFile(config.CONFIG.BlacklistPath) + if err != nil { + Blacklist = &[]string{} + logging.LogWarn("Could not load Blacklist file: %v", err) + return + } + lines := strings.Split(string(data), "\n") + + // Ignore lines starting with '#' (comments) + filteredLines := func() []string { + out := make([]string, 0, len(lines)) + for _, line := range lines { + if !strings.HasPrefix(line, "#") { + out = append(out, line) + } + } + return out + }() + + if len(lines) > 0 { + Blacklist = &filteredLines + logging.LogInfo("Blacklist has %d entries", len(*Blacklist)) + } + } +} + +func IsBlacklisted(url URL) bool { + hostWithPort := fmt.Sprintf("%s:%d", url.Hostname, url.Port) + for _, v := range *Blacklist { + if v == url.Hostname || v == hostWithPort { + return true + } + } + return false +} diff --git a/gemini/connectionPool.go b/gemini/connectionPool.go new file mode 100644 index 0000000..72149c7 --- /dev/null +++ b/gemini/connectionPool.go @@ -0,0 +1,28 @@ +package gemini + +import ( + "gemini-grc/logging" +) + +var IPPool = IpAddressPool{IPs: make(map[string]int)} + +func AddIPsToPool(ips []string) { + IPPool.Lock.Lock() + for _, ip := range ips { + logging.LogDebug("Adding %s to pool", ip) + IPPool.IPs[ip] = 1 + } + IPPool.Lock.Unlock() +} + +func RemoveIPsFromPool(IPs []string) { + IPPool.Lock.Lock() + for _, ip := range IPs { + _, ok := IPPool.IPs[ip] + if ok { + logging.LogDebug("Removing %s from pool", ip) + delete(IPPool.IPs, ip) + } + } + IPPool.Lock.Unlock() +} diff --git a/gemini/db.go b/gemini/db.go new file mode 100644 index 0000000..d35a3f5 --- /dev/null +++ b/gemini/db.go @@ -0,0 +1,176 @@ +package gemini + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "strconv" + + "gemini-grc/config" + "gemini-grc/logging" + _ "github.com/jackc/pgx/v5/stdlib" // PGX driver for PostgreSQL + "github.com/jmoiron/sqlx" + "github.com/lib/pq" +) + +func ConnectToDB() *sqlx.DB { + connStr := fmt.Sprintf("postgres://%s:%s@%s:%s/%s", //nolint:nosprintfhostport + os.Getenv("PG_USER"), + os.Getenv("PG_PASSWORD"), + os.Getenv("PG_HOST"), + os.Getenv("PG_PORT"), + os.Getenv("PG_DATABASE"), + ) + + // Create a connection pool + db, err := sqlx.Open("pgx", connStr) + if err != nil { + panic(fmt.Sprintf("Unable to connect to database with URL %s: %v\n", connStr, err)) + } + // TODO move PG_MAX_OPEN_CONNECTIONS to config env variables + maxConnections, err := strconv.Atoi(os.Getenv("PG_MAX_OPEN_CONNECTIONS")) + if err != nil { + panic(fmt.Sprintf("Unable to set max DB connections: %s\n", err)) + } + db.SetMaxOpenConns(maxConnections) + err = db.Ping() + if err != nil { + panic(fmt.Sprintf("Unable to ping database: %v\n", err)) + } + + logging.LogDebug("Connected to database") + return db +} + +// isDeadlockError checks if the error is a PostgreSQL deadlock error +func isDeadlockError(err error) bool { + var pqErr *pq.Error + if errors.As(err, &pqErr) { + return pqErr.Code == "40P01" // PostgreSQL deadlock error code + } + return false +} + +func GetSnapshotsToVisit(tx *sqlx.Tx) ([]Snapshot, error) { + var snapshots []Snapshot + err := tx.Select(&snapshots, SQL_SELECT_UNVISITED_SNAPSHOTS_UNIQUE_HOSTS, config.CONFIG.WorkerBatchSize) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrDatabase, err) + } + return snapshots, nil +} + +func SaveSnapshotIfNew(tx *sqlx.Tx, s *Snapshot) error { + if config.CONFIG.DryRun { + marshalled, err := json.MarshalIndent(s, "", " ") + if err != nil { + panic(fmt.Sprintf("JSON serialization error for %v", s)) + } + logging.LogDebug("Would insert (if new) snapshot %s", marshalled) + return nil + } + query := SQL_INSERT_SNAPSHOT_IF_NEW + _, err := tx.NamedExec(query, s) + if err != nil { + return fmt.Errorf("[%s] GeminiError inserting snapshot: %w", s.URL, err) + } + return nil +} + +func UpsertSnapshot(workedID int, tx *sqlx.Tx, s *Snapshot) (err error) { + // if config.CONFIG.DryRun { + //marshalled, err := json.MarshalIndent(s, "", " ") + //if err != nil { + // panic(fmt.Sprintf("JSON serialization error for %v", s)) + //} + //logging.LogDebug("[%d] Would upsert snapshot %s", workedID, marshalled) + // return nil + // } + query := SQL_UPSERT_SNAPSHOT + rows, err := tx.NamedQuery(query, s) + if err != nil { + return fmt.Errorf("[%d] %w while upserting snapshot: %w", workedID, ErrDatabase, err) + } + defer func() { + _err := rows.Close() + if _err != nil { + err = fmt.Errorf("[%d] %w error closing rows: %w", workedID, ErrDatabase, _err) + } + }() + if rows.Next() { + var returnedID int + err = rows.Scan(&returnedID) + if err != nil { + return fmt.Errorf("[%d] %w error scanning returned id: %w", workedID, ErrDatabase, err) + } + s.ID = returnedID + // logging.LogDebug("[%d] Upserted snapshot with ID %d", workedID, returnedID) + } + return nil +} + +func UpdateSnapshot(workedID int, tx *sqlx.Tx, s *Snapshot) (err error) { + // if config.CONFIG.DryRun { + //marshalled, err := json.MarshalIndent(s, "", " ") + //if err != nil { + // panic(fmt.Sprintf("JSON serialization error for %v", s)) + //} + //logging.LogDebug("[%d] Would upsert snapshot %s", workedID, marshalled) + // return nil + // } + query := SQL_UPDATE_SNAPSHOT + rows, err := tx.NamedQuery(query, s) + if err != nil { + return fmt.Errorf("[%d] %w while updating snapshot: %w", workedID, ErrDatabase, err) + } + defer func() { + _err := rows.Close() + if _err != nil { + err = fmt.Errorf("[%d] %w error closing rows: %w", workedID, ErrDatabase, _err) + } + }() + if rows.Next() { + var returnedID int + err = rows.Scan(&returnedID) + if err != nil { + return fmt.Errorf("[%d] %w error scanning returned id: %w", workedID, ErrDatabase, err) + } + s.ID = returnedID + // logging.LogDebug("[%d] Updated snapshot with ID %d", workedID, returnedID) + } + return nil +} + +func SaveLinksToDBinBatches(tx *sqlx.Tx, snapshots []*Snapshot) error { + if config.CONFIG.DryRun { + return nil + } + const batchSize = 5000 + query := SQL_INSERT_SNAPSHOT_IF_NEW + for i := 0; i < len(snapshots); i += batchSize { + end := i + batchSize + if end > len(snapshots) { + end = len(snapshots) + } + batch := snapshots[i:end] + _, err := tx.NamedExec(query, batch) + if err != nil { + return fmt.Errorf("%w: While saving links in batches: %w", ErrDatabase, err) + } + } + return nil +} + +func SaveLinksToDB(tx *sqlx.Tx, snapshots []*Snapshot) error { + if config.CONFIG.DryRun { + return nil + } + query := SQL_INSERT_SNAPSHOT_IF_NEW + _, err := tx.NamedExec(query, snapshots) + if err != nil { + logging.LogError("GeminiError batch inserting snapshots: %w", err) + return fmt.Errorf("DB error: %w", err) + } + return nil +} diff --git a/gemini/db_queries.go b/gemini/db_queries.go new file mode 100644 index 0000000..4ef07a5 --- /dev/null +++ b/gemini/db_queries.go @@ -0,0 +1,78 @@ +package gemini + +const ( + SQL_SELECT_RANDOM_UNVISITED_SNAPSHOTS = ` +SELECT * +FROM snapshots +WHERE response_code IS NULL + AND error IS NULL +ORDER BY RANDOM() +FOR UPDATE SKIP LOCKED +LIMIT $1 + ` + SQL_SELECT_RANDOM_UNVISITED_SNAPSHOTS_UNIQUE_HOSTS = ` +SELECT * +FROM snapshots s +WHERE response_code IS NULL + AND error IS NULL + AND s.id IN ( + SELECT MIN(id) + FROM snapshots + WHERE response_code IS NULL + AND error IS NULL + GROUP BY host + ) +ORDER BY RANDOM() +FOR UPDATE SKIP LOCKED +LIMIT $1 +` + SQL_SELECT_UNVISITED_SNAPSHOTS_UNIQUE_HOSTS = ` +SELECT * +FROM snapshots s +WHERE response_code IS NULL + AND error IS NULL + AND s.id IN ( + SELECT MIN(id) + FROM snapshots + WHERE response_code IS NULL + AND error IS NULL + GROUP BY host + ) +FOR UPDATE SKIP LOCKED +LIMIT $1 +` + SQL_INSERT_SNAPSHOT_IF_NEW = ` + INSERT INTO snapshots (url, host, timestamp, mimetype, data, gemtext, links, lang, response_code, error) + VALUES (:url, :host, :timestamp, :mimetype, :data, :gemtext, :links, :lang, :response_code, :error) + ON CONFLICT (url) DO NOTHING + ` + SQL_UPSERT_SNAPSHOT = `INSERT INTO snapshots (url, host, timestamp, mimetype, data, gemtext, links, lang, response_code, error) + VALUES (:url, :host, :timestamp, :mimetype, :data, :gemtext, :links, :lang, :response_code, :error) + ON CONFLICT (url) DO UPDATE SET + url = EXCLUDED.url, + host = EXCLUDED.host, + timestamp = EXCLUDED.timestamp, + mimetype = EXCLUDED.mimetype, + data = EXCLUDED.data, + gemtext = EXCLUDED.gemtext, + links = EXCLUDED.links, + lang = EXCLUDED.lang, + response_code = EXCLUDED.response_code, + error = EXCLUDED.error + RETURNING id +` + SQL_UPDATE_SNAPSHOT = `UPDATE snapshots +SET url = :url, +host = :host, +timestamp = :timestamp, +mimetype = :mimetype, +data = :data, +gemtext = :gemtext, +links = :links, +lang = :lang, +response_code = :response_code, +error = :error +WHERE id = :id +RETURNING id +` +) diff --git a/gemini/errors.go b/gemini/errors.go new file mode 100644 index 0000000..5ae014c --- /dev/null +++ b/gemini/errors.go @@ -0,0 +1,100 @@ +package gemini + +import ( + "errors" + "fmt" +) + +type GeminiError struct { + Msg string + Code int + Header string +} + +func (e *GeminiError) Error() string { + return fmt.Sprintf("%s: %s", e.Msg, e.Header) +} + +func NewErrGeminiStatusCode(code int, header string) error { + var msg string + switch { + case code >= 10 && code < 20: + msg = "needs input" + case code >= 30 && code < 40: + msg = "redirect" + case code >= 40 && code < 50: + msg = "bad request" + case code >= 50 && code < 60: + msg = "server error" + case code >= 60 && code < 70: + msg = "TLS error" + default: + msg = "unexpected status code" + } + return &GeminiError{ + Msg: msg, + Code: code, + Header: header, + } +} + +var ( + ErrGeminiRobotsParse = errors.New("gemini robots.txt parse error") + ErrGeminiRobotsDisallowed = errors.New("gemini robots.txt disallowed") + ErrGeminiResponseHeader = errors.New("gemini response header error") + ErrGeminiRedirect = errors.New("gemini redirection error") + ErrGeminiLinkLineParse = errors.New("gemini link line parse error") + + ErrURLParse = errors.New("URL parse error") + ErrURLNotGemini = errors.New("not a Gemini URL") + ErrURLDecode = errors.New("URL decode error") + ErrUTF8Parse = errors.New("UTF-8 parse error") + ErrTextParse = errors.New("text parse error") + + ErrNetwork = errors.New("network error") + ErrNetworkDNS = errors.New("network DNS error") + ErrNetworkTLS = errors.New("network TLS error") + ErrNetworkSetConnectionDeadline = errors.New("network error - cannot set connection deadline") + ErrNetworkCannotWrite = errors.New("network error - cannot write") + ErrNetworkResponseSizeExceededMax = errors.New("network error - response size exceeded maximum size") + + ErrDatabase = errors.New("database error") +) + +// We could have used a map for speed, but +// we would lose ability to check wrapped +// errors via errors.Is(). + +var errGemini *GeminiError + +var knownErrors = []error{ //nolint:gochecknoglobals + errGemini, + ErrGeminiLinkLineParse, + ErrGeminiRobotsParse, + ErrGeminiRobotsDisallowed, + ErrGeminiResponseHeader, + ErrGeminiRedirect, + + ErrURLParse, + ErrURLDecode, + ErrUTF8Parse, + ErrTextParse, + + ErrNetwork, + ErrNetworkDNS, + ErrNetworkTLS, + ErrNetworkSetConnectionDeadline, + ErrNetworkCannotWrite, + ErrNetworkResponseSizeExceededMax, + + ErrDatabase, +} + +func IsKnownError(err error) bool { + for _, known := range knownErrors { + if errors.Is(err, known) { + return true + } + } + return errors.As(err, new(*GeminiError)) +} diff --git a/gemini/errors_test.go b/gemini/errors_test.go new file mode 100644 index 0000000..e9ba76f --- /dev/null +++ b/gemini/errors_test.go @@ -0,0 +1,24 @@ +package gemini + +import ( + "errors" + "fmt" + "testing" +) + +func TestErrGemini(t *testing.T) { + t.Parallel() + err := NewErrGeminiStatusCode(50, "50 server error") + if !errors.As(err, new(*GeminiError)) { + t.Errorf("TestErrGemini fail") + } +} + +func TestErrGeminiWrapped(t *testing.T) { + t.Parallel() + err := NewErrGeminiStatusCode(50, "50 server error") + errWrapped := fmt.Errorf("%w wrapped", err) + if !errors.As(errWrapped, new(*GeminiError)) { + t.Errorf("TestErrGeminiWrapped fail") + } +} diff --git a/gemini/files.go b/gemini/files.go new file mode 100644 index 0000000..84b3012 --- /dev/null +++ b/gemini/files.go @@ -0,0 +1,113 @@ +package gemini + +import ( + "fmt" + "net/url" + "os" + "path" + "path/filepath" + "strings" + + "gemini-grc/logging" +) + +// sanitizePath encodes invalid filesystem characters using URL encoding. +// Example: +// /example/path/to/page?query=param&another=value +// would become +// example/path/to/page%3Fquery%3Dparam%26another%3Dvalue +func sanitizePath(p string) string { + // Split the path into its components + components := strings.Split(p, "/") + + // Encode each component separately + for i, component := range components { + // Decode any existing percent-encoded characters + decodedComponent, err := url.PathUnescape(component) + if err != nil { + decodedComponent = component // Fallback to original if unescape fails + } + + // Encode the component to escape invalid filesystem characters + encodedComponent := url.QueryEscape(decodedComponent) + + // Replace '+' (from QueryEscape) with '%20' to handle spaces correctly + encodedComponent = strings.ReplaceAll(encodedComponent, "+", "%20") + + components[i] = encodedComponent + } + + // Rejoin the components into a sanitized path + safe := filepath.Join(components...) + + return safe +} + +// getFilePath constructs a safe file path from the root path and URL path. +// It URL-encodes invalid filesystem characters to ensure the path is valid. +func calcFilePath(rootPath, urlPath string) (string, error) { + // Normalize the URL path + cleanPath := filepath.Clean(urlPath) + + // Safe check to prevent directory traversal + if strings.Contains(cleanPath, "..") { + return "", fmt.Errorf("Invalid URL path: contains directory traversal") + } + + // Sanitize the path by encoding invalid characters + safePath := sanitizePath(cleanPath) + + // Join the root path and the sanitized URL path + finalPath := filepath.Join(rootPath, safePath) + + return finalPath, nil +} + +func SaveToFile(rootPath string, s *Snapshot, done chan struct{}) { + parentPath := path.Join(rootPath, s.URL.Hostname) + urlPath := s.URL.Path + // If path is empty, add `index.gmi` as the file to save + if urlPath == "" || urlPath == "." { + urlPath = "index.gmi" + } + // If path ends with '/' then add index.gmi for the + // directory to be created. + if strings.HasSuffix(urlPath, "/") { + urlPath = strings.Join([]string{urlPath, "index.gmi"}, "") + } + + finalPath, err := calcFilePath(parentPath, urlPath) + if err != nil { + logging.LogError("GeminiError saving %s: %w", s.URL, err) + return + } + // Ensure the directory exists + dir := filepath.Dir(finalPath) + if err := os.MkdirAll(dir, os.ModePerm); err != nil { + logging.LogError("Failed to create directory: %w", err) + return + } + if s.MimeType.Valid && s.MimeType.String == "text/gemini" { + err = os.WriteFile(finalPath, (*s).Data.V, 0o666) + } else { + err = os.WriteFile(finalPath, []byte((*s).GemText.String), 0o666) + } + if err != nil { + logging.LogError("GeminiError saving %s: %w", s.URL.Full, err) + } + close(done) +} + +func ReadLines(path string) []string { + data, err := os.ReadFile(path) + if err != nil { + panic(fmt.Sprintf("Failed to read file: %s", err)) + } + lines := strings.Split(string(data), "\n") + // Remove last line if empty + // (happens when file ends with '\n') + if lines[len(lines)-1] == "" { + lines = lines[:len(lines)-1] + } + return lines +} diff --git a/gemini/gemini.go b/gemini/gemini.go new file mode 100644 index 0000000..ab10222 --- /dev/null +++ b/gemini/gemini.go @@ -0,0 +1,139 @@ +package gemini + +import ( + "fmt" + "net/url" + "regexp" + "strconv" + + "gemini-grc/logging" +) + +func GetPageLinks(currentURL URL, gemtext string) LinkList { + // Grab link lines + linkLines := ExtractLinkLines(gemtext) + if len(linkLines) == 0 { + return nil + } + var linkURLs LinkList + // Normalize URLs in links, and store them in snapshot + for _, line := range linkLines { + linkURL, err := NormalizeLink(line, currentURL.String()) + if err != nil { + logging.LogDebug("%s: %s", ErrGeminiLinkLineParse, err) + continue + } + linkURLs = append(linkURLs, *linkURL) + } + return linkURLs +} + +// ExtractLinkLines takes a Gemtext document as a string and returns all lines that are link lines +func ExtractLinkLines(gemtext string) []string { + // Define the regular expression pattern to match link lines + re := regexp.MustCompile(`(?m)^=>[ \t]+.*`) + + // Find all matches using the regular expression + matches := re.FindAllString(gemtext, -1) + + return matches +} + +// NormalizeLink takes a single link line and the current URL, +// return the URL converted to an absolute URL +// and its description. +func NormalizeLink(linkLine string, currentURL string) (*URL, error) { + // Parse the current URL + baseURL, err := url.Parse(currentURL) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrURLParse, err) + } + + // Regular expression to extract the URL part from a link line + re := regexp.MustCompile(`^=>[ \t]+(\S+)([ \t]+.*)?`) + + // Use regex to extract the URL and the rest of the line + matches := re.FindStringSubmatch(linkLine) + if len(matches) == 0 { + // If the line doesn't match the expected format, return it unchanged + return nil, fmt.Errorf("%w for link line %s", ErrGeminiLinkLineParse, linkLine) + } + + originalURLStr := matches[1] + _, err = url.QueryUnescape(originalURLStr) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrURLDecode, err) + } + + restOfLine := "" + if len(matches) > 2 { + restOfLine = matches[2] + } + + // Parse the URL from the link line + parsedURL, err := url.Parse(originalURLStr) + if err != nil { + // If URL parsing fails, return an error + return nil, fmt.Errorf("%w: %w", ErrURLParse, err) + } + + // Resolve relative URLs against the base URL + if !parsedURL.IsAbs() { + parsedURL = baseURL.ResolveReference(parsedURL) + } + + // Remove usual first space from URL description: + // => URL description + // ^^^^^^^^^^^^ + if len(restOfLine) > 0 && restOfLine[0] == ' ' { + restOfLine = restOfLine[1:] + } + + finalURL, err := ParseURL(parsedURL.String(), restOfLine) + if err != nil { + // If URL parsing fails, return an error + return nil, fmt.Errorf("%w: %w", ErrURLParse, err) + } + + return finalURL, nil +} + +// ParseFirstTwoDigits takes a string and returns the first one or two digits as an int. +// If no valid digits are found, it returns an error. +func ParseFirstTwoDigits(input string) (int, error) { + // Define the regular expression pattern to match one or two leading digits + re := regexp.MustCompile(`^(\d{1,2})`) + + // Find the first match in the string + matches := re.FindStringSubmatch(input) + if len(matches) == 0 { + return 0, fmt.Errorf("%w", ErrGeminiResponseHeader) + } + + // Parse the captured match as an integer + snapshot, err := strconv.Atoi(matches[1]) + if err != nil { + return 0, fmt.Errorf("%w: %w", ErrTextParse, err) + } + + return snapshot, nil +} + +// extractRedirectTarget returns the redirection +// URL by parsing the header (or error message) +func extractRedirectTarget(currentURL URL, input string) (*URL, error) { + // \d+ - matches one or more digits + // \s+ - matches one or more whitespace + // ([^\r]+) - captures everything until it hits a \r (or end of string) + pattern := `\d+\s+([^\r]+)` + re := regexp.MustCompile(pattern) + matches := re.FindStringSubmatch(input) + if len(matches) < 2 { + return nil, fmt.Errorf("%w: %s", ErrGeminiRedirect, input) + } + newURL, err := DeriveAbsoluteURL(currentURL, matches[1]) + if err != nil { + return nil, fmt.Errorf("%w: %w: %s", ErrGeminiRedirect, err, input) + } + return newURL, nil +} diff --git a/gemini/gemini_test.go b/gemini/gemini_test.go new file mode 100644 index 0000000..df4a6c5 --- /dev/null +++ b/gemini/gemini_test.go @@ -0,0 +1,65 @@ +package gemini + +import "testing" + +func TestExtractRedirectTargetFullURL(t *testing.T) { + t.Parallel() + currentURL, _ := ParseURL("gemini://smol.gr", "") + input := "redirect: 31 gemini://target.gr" + result, err := extractRedirectTarget(*currentURL, input) + expected := "gemini://target.gr:1965" + if err != nil || (result.String() != expected) { + t.Errorf("fail: Expected %s got %s", expected, result) + } +} + +func TestExtractRedirectTargetFullURLSlash(t *testing.T) { + t.Parallel() + currentURL, _ := ParseURL("gemini://smol.gr", "") + input := "redirect: 31 gemini://target.gr/" + result, err := extractRedirectTarget(*currentURL, input) + expected := "gemini://target.gr:1965/" + if err != nil || (result.String() != expected) { + t.Errorf("fail: Expected %s got %s", expected, result) + } +} + +func TestExtractRedirectTargetRelativeURL(t *testing.T) { + t.Parallel() + currentURL, _ := ParseURL("gemini://smol.gr", "") + input := "redirect: 31 /a/b" + result, err := extractRedirectTarget(*currentURL, input) + if err != nil || (result.String() != "gemini://smol.gr:1965/a/b") { + t.Errorf("fail: %s", result) + } +} + +func TestExtractRedirectTargetRelativeURL2(t *testing.T) { + t.Parallel() + currentURL, _ := ParseURL("gemini://nox.im:1965", "") + input := "redirect: 31 ./" + result, err := extractRedirectTarget(*currentURL, input) + if err != nil || (result.String() != "gemini://nox.im:1965/") { + t.Errorf("fail: %s", result) + } +} + +func TestExtractRedirectTargetRelativeURL3(t *testing.T) { + t.Parallel() + currentURL, _ := ParseURL("gemini://status.zvava.org:1965", "") + input := "redirect: 31 index.gmi" + result, err := extractRedirectTarget(*currentURL, input) + if err != nil || (result.String() != "gemini://status.zvava.org:1965/index.gmi") { + t.Errorf("fail: %s", result) + } +} + +func TestExtractRedirectTargetWrong(t *testing.T) { + t.Parallel() + currentURL, _ := ParseURL("gemini://smol.gr", "") + input := "redirect: 31" + result, err := extractRedirectTarget(*currentURL, input) + if result != nil || err == nil { + t.Errorf("fail: result should be nil, err is %s", err) + } +} diff --git a/gemini/gemini_url.go b/gemini/gemini_url.go new file mode 100644 index 0000000..0ab39d9 --- /dev/null +++ b/gemini/gemini_url.go @@ -0,0 +1,229 @@ +package gemini + +import ( + "database/sql/driver" + "fmt" + "net/url" + "path" + "strconv" + "strings" +) + +type URL struct { + Protocol string `json:"protocol,omitempty"` + Hostname string `json:"hostname,omitempty"` + Port int `json:"port,omitempty"` + Path string `json:"path,omitempty"` + Descr string `json:"descr,omitempty"` + Full string `json:"full,omitempty"` +} + +func (u *URL) Scan(value interface{}) error { + if value == nil { + // Clear the fields in the current GeminiUrl object (not the pointer itself) + *u = URL{} + return nil + } + b, ok := value.(string) + if !ok { + return fmt.Errorf("failed to scan GeminiUrl: expected string, got %T", value) + } + parsedURL, err := ParseURLNoNormalize(b, "") + if err != nil { + err = fmt.Errorf("failed to scan GeminiUrl %s: %v", b, err) + return err + } + *u = *parsedURL + return nil +} + +func (u URL) String() string { + return u.Full +} + +func (u URL) StringNoDefaultPort() string { + if u.Port == 1965 { + return fmt.Sprintf("%s://%s%s", u.Protocol, u.Hostname, u.Path) + } + return u.Full +} + +func (u URL) Value() (driver.Value, error) { + if u.Full == "" { + return nil, nil + } + return u.Full, nil +} + +func ParseURLNoNormalize(input string, descr string) (*URL, error) { + u, err := url.Parse(input) + if err != nil { + return nil, fmt.Errorf("%w: Input %s URL Parse Error: %w", ErrURLParse, input, err) + } + if u.Scheme != "gemini" { + return nil, fmt.Errorf("%w: URL scheme '%s' is not supported", ErrURLNotGemini, u.Scheme) + } + protocol := u.Scheme + hostname := u.Hostname() + strPort := u.Port() + urlPath := u.Path + if strPort == "" { + strPort = "1965" + } + port, err := strconv.Atoi(strPort) + if err != nil { + return nil, fmt.Errorf("%w: Input %s GeminiError %w", ErrURLParse, input, err) + } + full := fmt.Sprintf("%s://%s:%d%s", protocol, hostname, port, urlPath) + return &URL{Protocol: protocol, Hostname: hostname, Port: port, Path: urlPath, Descr: descr, Full: full}, nil +} + +func ParseURL(input string, descr string) (*URL, error) { + u, err := NormalizeURL(input) + if err != nil { + return nil, fmt.Errorf("%w: Input %s URL Parse Error: %w", ErrURLParse, input, err) + } + if u.Scheme != "gemini" { + return nil, fmt.Errorf("%w: URL scheme '%s' is not supported", ErrURLNotGemini, u.Scheme) + } + protocol := u.Scheme + hostname := u.Hostname() + strPort := u.Port() + urlPath := u.Path + if strPort == "" { + strPort = "1965" + } + port, err := strconv.Atoi(strPort) + if err != nil { + return nil, fmt.Errorf("%w: Input %s GeminiError %w", ErrURLParse, input, err) + } + full := fmt.Sprintf("%s://%s:%d%s", protocol, hostname, port, urlPath) + return &URL{Protocol: protocol, Hostname: hostname, Port: port, Path: urlPath, Descr: descr, Full: full}, nil +} + +// DeriveAbsoluteURL converts a (possibly) relative +// URL to an absolute one. Used primarily to calculate +// the full redirection URL target from a response header. +func DeriveAbsoluteURL(currentURL URL, input string) (*URL, error) { + // If target URL is absolute, return just it + if strings.Contains(input, "://") { + return ParseURL(input, "") + } + // input is a relative path. Clean it and construct absolute. + var newPath string + // Handle weird cases found in the wild + if strings.HasPrefix(input, "/") { + newPath = path.Clean(input) + } else if input == "./" || input == "." { + newPath = path.Join(currentURL.Path, "/") + } else { + newPath = path.Join(currentURL.Path, "/", path.Clean(input)) + } + strURL := fmt.Sprintf("%s://%s:%d%s", currentURL.Protocol, currentURL.Hostname, currentURL.Port, newPath) + return ParseURL(strURL, "") +} + +// NormalizeURL takes a URL string and returns a normalized version. +// Normalized meaning: +// - Path normalization (removing redundant slashes, . and .. segments) +// - Proper escaping of special characters +// - Lowercase scheme and host +// - Removal of default ports +// - Empty path becomes "/" +func NormalizeURL(rawURL string) (*url.URL, error) { + // Parse the URL + u, err := url.Parse(rawURL) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrURLParse, err) + } + + // Convert scheme to lowercase + u.Scheme = strings.ToLower(u.Scheme) + + // Convert hostname to lowercase + if u.Host != "" { + u.Host = strings.ToLower(u.Host) + } + + // Remove default ports + if u.Port() != "" { + switch { + case u.Scheme == "http" && u.Port() == "80": + u.Host = u.Hostname() + case u.Scheme == "https" && u.Port() == "443": + u.Host = u.Hostname() + case u.Scheme == "gemini" && u.Port() == "1965": + u.Host = u.Hostname() + } + } + + // Handle path normalization while preserving trailing slash + if u.Path != "" { + // Check if there was a trailing slash before cleaning + hadTrailingSlash := strings.HasSuffix(u.Path, "/") + + u.Path = path.Clean(u.Path) + // If path was "/", path.Clean() will return "." + if u.Path == "." { + u.Path = "/" + } else if hadTrailingSlash && u.Path != "/" { + // Restore trailing slash if it existed and path isn't just "/" + u.Path += "/" + } + } + + // Properly escape the path + // First split on '/' to avoid escaping them + parts := strings.Split(u.Path, "/") + for i, part := range parts { + parts[i] = url.PathEscape(part) + } + u.Path = strings.Join(parts, "/") + + // Remove trailing fragment if empty + if u.Fragment == "" { + u.Fragment = "" + } + + // Remove trailing query if empty + if u.RawQuery == "" { + u.RawQuery = "" + } + + return u, nil +} + +func EscapeURL(input string) string { + // Only escape if not already escaped + if strings.Contains(input, "%") && !strings.Contains(input, "% ") { + return input + } + // Split URL into parts (protocol, host, path) + parts := strings.SplitN(input, "://", 2) + if len(parts) != 2 { + return input + } + + protocol := parts[0] + remainder := parts[1] + + // If URL ends with just a slash, return as is + if strings.HasSuffix(remainder, "/") && !strings.Contains(remainder[:len(remainder)-1], "/") { + return input + } + + // Split host and path + parts = strings.SplitN(remainder, "/", 2) + host := parts[0] + if len(parts) == 1 { + return protocol + "://" + host + } + + path := parts[1] + + // Escape the path portion + escapedPath := url.PathEscape(path) + + // Reconstruct the URL + return protocol + "://" + host + "/" + escapedPath +} diff --git a/gemini/gemini_url_test.go b/gemini/gemini_url_test.go new file mode 100644 index 0000000..36c4662 --- /dev/null +++ b/gemini/gemini_url_test.go @@ -0,0 +1,223 @@ +package gemini + +import ( + "reflect" + "testing" +) + +func TestParseURL(t *testing.T) { + t.Parallel() + input := "gemini://caolan.uk/cgi-bin/weather.py/wxfcs/3162" + parsed, err := ParseURL(input, "") + value, _ := parsed.Value() + if err != nil || !(value == "gemini://caolan.uk:1965/cgi-bin/weather.py/wxfcs/3162") { + t.Errorf("fail: %s", parsed) + } +} + +func TestDeriveAbsoluteURL_abs_url_input(t *testing.T) { + t.Parallel() + currentURL := URL{ + Protocol: "gemini", + Hostname: "smol.gr", + Port: 1965, + Path: "/a/b", + Descr: "Nothing", + Full: "gemini://smol.gr:1965/a/b", + } + input := "gemini://a.b/c" + output, err := DeriveAbsoluteURL(currentURL, input) + if err != nil { + t.Errorf("fail: %v", err) + } + expected := &URL{ + Protocol: "gemini", + Hostname: "a.b", + Port: 1965, + Path: "/c", + Descr: "", + Full: "gemini://a.b:1965/c", + } + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestDeriveAbsoluteURL_abs_path_input(t *testing.T) { + t.Parallel() + currentURL := URL{ + Protocol: "gemini", + Hostname: "smol.gr", + Port: 1965, + Path: "/a/b", + Descr: "Nothing", + Full: "gemini://smol.gr:1965/a/b", + } + input := "/c" + output, err := DeriveAbsoluteURL(currentURL, input) + if err != nil { + t.Errorf("fail: %v", err) + } + expected := &URL{ + Protocol: "gemini", + Hostname: "smol.gr", + Port: 1965, + Path: "/c", + Descr: "", + Full: "gemini://smol.gr:1965/c", + } + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestDeriveAbsoluteURL_rel_path_input(t *testing.T) { + t.Parallel() + currentURL := URL{ + Protocol: "gemini", + Hostname: "smol.gr", + Port: 1965, + Path: "/a/b", + Descr: "Nothing", + Full: "gemini://smol.gr:1965/a/b", + } + input := "c/d" + output, err := DeriveAbsoluteURL(currentURL, input) + if err != nil { + t.Errorf("fail: %v", err) + } + expected := &URL{ + Protocol: "gemini", + Hostname: "smol.gr", + Port: 1965, + Path: "/a/b/c/d", + Descr: "", + Full: "gemini://smol.gr:1965/a/b/c/d", + } + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizeURLSlash(t *testing.T) { + t.Parallel() + input := "gemini://uscoffings.net/retro-computing/magazines/" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := input + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizeURLNoSlash(t *testing.T) { + t.Parallel() + input := "gemini://uscoffings.net/retro-computing/magazines" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := input + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizeMultiSlash(t *testing.T) { + t.Parallel() + input := "gemini://uscoffings.net/retro-computing/////////a///magazines" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := "gemini://uscoffings.net/retro-computing/a/magazines" + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizeTrailingSlash(t *testing.T) { + t.Parallel() + input := "gemini://uscoffings.net/" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := "gemini://uscoffings.net/" + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizeNoTrailingSlash(t *testing.T) { + t.Parallel() + input := "gemini://uscoffings.net" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := "gemini://uscoffings.net" + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizeTrailingSlashPath(t *testing.T) { + t.Parallel() + input := "gemini://uscoffings.net/a/" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := "gemini://uscoffings.net/a/" + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizeNoTrailingSlashPath(t *testing.T) { + t.Parallel() + input := "gemini://uscoffings.net/a" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := "gemini://uscoffings.net/a" + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizeDot(t *testing.T) { + t.Parallel() + input := "gemini://uscoffings.net/retro-computing/./././////a///magazines" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := "gemini://uscoffings.net/retro-computing/a/magazines" + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizePort(t *testing.T) { + t.Parallel() + input := "gemini://uscoffings.net:1965/a" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := "gemini://uscoffings.net/a" + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} + +func TestNormalizeURL(t *testing.T) { + t.Parallel() + input := "gemini://chat.gemini.lehmann.cx:11965/" + normalized, _ := NormalizeURL(input) + output := normalized.String() + expected := "gemini://chat.gemini.lehmann.cx:11965/" + pass := reflect.DeepEqual(output, expected) + if !pass { + t.Errorf("fail: %#v != %#v", output, expected) + } +} diff --git a/gemini/ip-address-pool.go b/gemini/ip-address-pool.go new file mode 100644 index 0000000..b87bec6 --- /dev/null +++ b/gemini/ip-address-pool.go @@ -0,0 +1,54 @@ +package gemini + +import "sync" + +// Used to limit requests per +// IP address. Maps IP address +// to number of active connections. +type IpAddressPool struct { + IPs map[string]int + Lock sync.RWMutex +} + +func (p *IpAddressPool) Set(key string, value int) { + p.Lock.Lock() // Lock for writing + defer p.Lock.Unlock() // Ensure mutex is unlocked after the write + p.IPs[key] = value +} + +func (p *IpAddressPool) Get(key string) int { + p.Lock.RLock() // Lock for reading + defer p.Lock.RUnlock() // Ensure mutex is unlocked after reading + if value, ok := p.IPs[key]; !ok { + return 0 + } else { + return value + } +} + +func (p *IpAddressPool) Delete(key string) { + p.Lock.Lock() + defer p.Lock.Unlock() + delete(p.IPs, key) +} + +func (p *IpAddressPool) Incr(key string) { + p.Lock.Lock() + defer p.Lock.Unlock() + if _, ok := p.IPs[key]; !ok { + p.IPs[key] = 1 + } else { + p.IPs[key] = p.IPs[key] + 1 + } +} + +func (p *IpAddressPool) Decr(key string) { + p.Lock.Lock() + defer p.Lock.Unlock() + if val, ok := p.IPs[key]; ok { + p.IPs[key] = val - 1 + if p.IPs[key] == 0 { + delete(p.IPs, key) + } + } +} diff --git a/gemini/network.go b/gemini/network.go new file mode 100644 index 0000000..74fde37 --- /dev/null +++ b/gemini/network.go @@ -0,0 +1,244 @@ +package gemini + +import ( + "crypto/tls" + "errors" + "fmt" + "io" + "net" + gourl "net/url" + "regexp" + "slices" + "strconv" + "strings" + "time" + + "gemini-grc/config" + "gemini-grc/logging" + "github.com/guregu/null/v5" +) + +type PageData struct { + ResponseCode int + ResponseHeader string + MimeType string + Lang string + GemText string + Data []byte +} + +// Resolve the URL hostname and +// check if we already have an open +// connection to this host. +// If we can connect, return a list +// of the resolved IPs. +func getHostIPAddresses(hostname string) ([]string, error) { + addrs, err := net.LookupHost(hostname) + if err != nil { + return nil, fmt.Errorf("%w:%w", ErrNetworkDNS, err) + } + IPPool.Lock.RLock() + defer func() { + IPPool.Lock.RUnlock() + }() + return addrs, nil +} + +func ConnectAndGetData(url string) ([]byte, error) { + parsedURL, err := gourl.Parse(url) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrURLParse, err) + } + hostname := parsedURL.Hostname() + port := parsedURL.Port() + if port == "" { + port = "1965" + } + host := fmt.Sprintf("%s:%s", hostname, port) + // Establish the underlying TCP connection. + dialer := &net.Dialer{ + Timeout: time.Duration(config.CONFIG.ResponseTimeout) * time.Second, + } + conn, err := dialer.Dial("tcp", host) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrNetwork, err) + } + // Make sure we always close the connection. + defer func() { + // No need to handle error: + // Connection will time out eventually if still open somehow. + _ = conn.Close() + }() + + // Set read and write timeouts on the TCP connection. + err = conn.SetReadDeadline(time.Now().Add(time.Duration(config.CONFIG.ResponseTimeout) * time.Second)) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrNetworkSetConnectionDeadline, err) + } + err = conn.SetWriteDeadline(time.Now().Add(time.Duration(config.CONFIG.ResponseTimeout) * time.Second)) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrNetworkSetConnectionDeadline, err) + } + + // Perform the TLS handshake + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, //nolint:gosec // Accept all TLS certs, even if insecure. + ServerName: parsedURL.Hostname(), // SNI says we should not include port in hostname + // MinVersion: tls.VersionTLS12, // Use a minimum TLS version. Warning breaks a lot of sites. + } + tlsConn := tls.Client(conn, tlsConfig) + if err := tlsConn.Handshake(); err != nil { + return nil, fmt.Errorf("%w: %w", ErrNetworkTLS, err) + } + + // We read `buf`-sized chunks and add data to `data`. + buf := make([]byte, 4096) + var data []byte + + // Send Gemini request to trigger server response. + // Fix for stupid server bug: + // Some servers return 'Header: 53 No proxying to other hosts or ports!' + // when the port is 1965 and is still specified explicitly in the URL. + _url, _ := ParseURL(url, "") + _, err = tlsConn.Write([]byte(fmt.Sprintf("%s\r\n", _url.StringNoDefaultPort()))) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrNetworkCannotWrite, err) + } + // Read response bytes in len(buf) byte chunks + for { + n, err := tlsConn.Read(buf) + if n > 0 { + data = append(data, buf[:n]...) + } + if len(data) > config.CONFIG.MaxResponseSize { + return nil, fmt.Errorf("%w: %v", ErrNetworkResponseSizeExceededMax, config.CONFIG.MaxResponseSize) + } + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, fmt.Errorf("%w: %w", ErrNetwork, err) + } + } + return data, nil +} + +// Visit given URL, using the Gemini protocol. +// Mutates given Snapshot with the data. +// In case of error, we store the error string +// inside snapshot and return the error. +func Visit(s *Snapshot) (err error) { + // Don't forget to also store error + // response code (if we have one) + // and header + defer func() { + if err != nil { + s.Error = null.StringFrom(err.Error()) + if errors.As(err, new(*GeminiError)) { + s.Header = null.StringFrom(err.(*GeminiError).Header) + s.ResponseCode = null.IntFrom(int64(err.(*GeminiError).Code)) + } + } + }() + s.Timestamp = null.TimeFrom(time.Now()) + data, err := ConnectAndGetData(s.URL.String()) + if err != nil { + return err + } + pageData, err := processData(data) + if err != nil { + return err + } + s.Header = null.StringFrom(pageData.ResponseHeader) + s.ResponseCode = null.IntFrom(int64(pageData.ResponseCode)) + s.MimeType = null.StringFrom(pageData.MimeType) + s.Lang = null.StringFrom(pageData.Lang) + if pageData.GemText != "" { + s.GemText = null.StringFrom(pageData.GemText) + } + if pageData.Data != nil { + s.Data = null.ValueFrom(pageData.Data) + } + return nil +} + +// processData returne results from +// parsing Gemini header data: +// Code, mime type and lang (optional) +// Returns error if header was invalid +func processData(data []byte) (*PageData, error) { + header, body, err := getHeadersAndData(data) + if err != nil { + return nil, err + } + code, mimeType, lang := getMimeTypeAndLang(header) + logging.LogDebug("Header: %s", strings.TrimSpace(header)) + if code != 20 { + return nil, NewErrGeminiStatusCode(code, header) + } + + pageData := PageData{ + ResponseCode: code, + ResponseHeader: header, + MimeType: mimeType, + Lang: lang, + } + // If we've got a Gemini document, populate + // `GemText` field, otherwise raw data goes to `Data`. + if mimeType == "text/gemini" { + validBody, err := BytesToValidUTF8(body) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrUTF8Parse, err) + } + pageData.GemText = validBody + } else { + pageData.Data = body + } + return &pageData, nil +} + +// Checks for a Gemini header, which is +// basically the first line of the response +// and should contain the response code, +// mimeType and language. +func getHeadersAndData(data []byte) (string, []byte, error) { + firstLineEnds := slices.Index(data, '\n') + if firstLineEnds == -1 { + return "", nil, ErrGeminiResponseHeader + } + firstLine := string(data[:firstLineEnds]) + rest := data[firstLineEnds+1:] + return firstLine, rest, nil +} + +// Parses code, mime type and language +// from a Gemini header. +// Examples: +// `20 text/gemini lang=en` (code, mimetype, lang) +// `20 text/gemini` (code, mimetype) +// `31 gemini://redirected.to/other/site` (code) +func getMimeTypeAndLang(headers string) (int, string, string) { + // Regex that parses code, mimetype & optional charset/lang parameters + re := regexp.MustCompile(`^(\d+)\s+([a-zA-Z0-9/\-+]+)(?:[;\s]+(?:(?:charset|lang)=([a-zA-Z0-9-]+)))?\s*$`) + matches := re.FindStringSubmatch(headers) + if matches == nil || len(matches) <= 1 { + // Try to get code at least + re := regexp.MustCompile(`^(\d+)\s+`) + matches := re.FindStringSubmatch(headers) + if matches == nil || len(matches) <= 1 { + return 0, "", "" + } + code, err := strconv.Atoi(matches[1]) + if err != nil { + return 0, "", "" + } + return code, "", "" + } + code, err := strconv.Atoi(matches[1]) + if err != nil { + return 0, "", "" + } + mimeType := matches[2] + param := matches[3] // This will capture either charset or lang value + return code, mimeType, param +} diff --git a/gemini/network_test.go b/gemini/network_test.go new file mode 100644 index 0000000..81202db --- /dev/null +++ b/gemini/network_test.go @@ -0,0 +1,78 @@ +package gemini + +import ( + "testing" +) + +// Test for input: `20 text/gemini` +func TestGetMimeTypeAndLang1(t *testing.T) { + t.Parallel() + code, mimeType, lang := getMimeTypeAndLang("20 text/gemini") + if code != 20 || mimeType != "text/gemini" || lang != "" { + t.Errorf("Expected (20, 'text/gemini', ''), got (%d, '%s', '%s')", code, mimeType, lang) + } +} + +func TestGetMimeTypeAndLang11(t *testing.T) { + t.Parallel() + code, mimeType, lang := getMimeTypeAndLang("20 text/gemini\n") + if code != 20 || mimeType != "text/gemini" || lang != "" { + t.Errorf("Expected (20, 'text/gemini', ''), got (%d, '%s', '%s')", code, mimeType, lang) + } +} + +func TestGetMimeTypeAndLang12(t *testing.T) { + t.Parallel() + code, mimeType, lang := getMimeTypeAndLang("20 text/plain; charset=utf-8") + if code != 20 || mimeType != "text/plain" || lang != "utf-8" { + t.Errorf("Expected (20, 'text/plain', ''), got (%d, '%s', '%s')", code, mimeType, lang) + } +} + +func TestGetMimeTypeAndLang13(t *testing.T) { + t.Parallel() + code, mimeType, lang := getMimeTypeAndLang("20 text/gemini; charset=utf-8") + if code != 20 || mimeType != "text/gemini" || lang != "utf-8" { + t.Errorf("Expected (20, 'text/plain', ''), got (%d, '%s', '%s')", code, mimeType, lang) + } +} + +func TestGetTypeAndLang2(t *testing.T) { + t.Parallel() + code, mimeType, lang := getMimeTypeAndLang("20 text/gemini charset=en") + if code != 20 || mimeType != "text/gemini" || lang != "en" { + t.Errorf("Expected (20, 'text/gemini', 'en'), got (%d, '%s', '%s')", code, mimeType, lang) + } +} + +func TestGetTypeAndLang21(t *testing.T) { + t.Parallel() + code, mimeType, lang := getMimeTypeAndLang("20 text/gemini lang=en") + if code != 20 || mimeType != "text/gemini" || lang != "en" { + t.Errorf("Expected (20, 'text/gemini', 'en'), got (%d, '%s', '%s')", code, mimeType, lang) + } +} + +func TestGetMimeTypeAndLang3(t *testing.T) { + t.Parallel() + code, mimeType, lang := getMimeTypeAndLang("31 gemini://redirect.to/page") + if code != 31 || mimeType != "" || lang != "" { + t.Errorf("Expected (20, '', ''), got (%d, '%s', '%s')", code, mimeType, lang) + } +} + +func TestGetMimeTypeAndLang4(t *testing.T) { + t.Parallel() + code, mimeType, lang := getMimeTypeAndLang("aaafdasdasd") + if code != 0 || mimeType != "" || lang != "" { + t.Errorf("Expected (0, '', ''), got (%d, '%s', '%s')", code, mimeType, lang) + } +} + +func TestGetMimeTypeAndLang5(t *testing.T) { + t.Parallel() + code, mimeType, lang := getMimeTypeAndLang("") + if code != 0 || mimeType != "" || lang != "" { + t.Errorf("Expected (0, '', ''), got (%d, '%s', '%s')", code, mimeType, lang) + } +} diff --git a/gemini/processing.go b/gemini/processing.go new file mode 100644 index 0000000..0afdac3 --- /dev/null +++ b/gemini/processing.go @@ -0,0 +1,59 @@ +package gemini + +import ( + "bytes" + "errors" + "fmt" + "io" + "unicode/utf8" + + "golang.org/x/text/encoding/charmap" + "golang.org/x/text/encoding/japanese" + "golang.org/x/text/encoding/korean" + "golang.org/x/text/transform" +) + +var ( + ErrInputTooLarge = errors.New("input too large") + ErrUTF8Conversion = errors.New("UTF-8 conversion error") +) + +func BytesToValidUTF8(input []byte) (string, error) { + if len(input) == 0 { + return "", nil + } + const maxSize = 10 * 1024 * 1024 // 10MB + if len(input) > maxSize { + return "", fmt.Errorf("%w: %d bytes (max %d)", ErrInputTooLarge, len(input), maxSize) + } + // Remove NULL byte 0x00 (ReplaceAll accepts slices) + inputNoNull := bytes.ReplaceAll(input, []byte{byte(0)}, []byte{}) + if utf8.Valid(inputNoNull) { + return string(inputNoNull), nil + } + encodings := []transform.Transformer{ + charmap.ISO8859_1.NewDecoder(), + charmap.ISO8859_7.NewDecoder(), + charmap.Windows1250.NewDecoder(), // Central European + charmap.Windows1251.NewDecoder(), // Cyrillic + charmap.Windows1252.NewDecoder(), + charmap.Windows1256.NewDecoder(), // Arabic + japanese.EUCJP.NewDecoder(), // Japanese + korean.EUCKR.NewDecoder(), // Korean + } + // First successful conversion wins. + var lastErr error + for _, encoding := range encodings { + reader := transform.NewReader(bytes.NewReader(inputNoNull), encoding) + result, err := io.ReadAll(reader) + if err != nil { + lastErr = err + continue + } + if utf8.Valid(result) { + return string(result), nil + } + } + + return "", fmt.Errorf("%w (tried %d encodings): %w", ErrUTF8Conversion, len(encodings), lastErr) +} diff --git a/gemini/processing_test.go b/gemini/processing_test.go new file mode 100644 index 0000000..986323d --- /dev/null +++ b/gemini/processing_test.go @@ -0,0 +1,14 @@ +package gemini + +import "testing" + +// Make sure NULL bytes are removed +func TestEnsureValidUTF8(t *testing.T) { + t.Parallel() + // Create a string with a null byte + strWithNull := "Hello" + string('\x00') + "world" + result, _ := BytesToValidUTF8([]byte(strWithNull)) + if result != "Helloworld" { + t.Errorf("Expected string without NULL byte, got %s", result) + } +} diff --git a/gemini/robotmatch.go b/gemini/robotmatch.go new file mode 100644 index 0000000..b786204 --- /dev/null +++ b/gemini/robotmatch.go @@ -0,0 +1,82 @@ +package gemini + +import ( + "fmt" + "strings" + "sync" + + "gemini-grc/logging" +) + +// RobotsCache is a map of blocked URLs +// key: URL +// value: []string list of disallowed URLs +// If a key has no blocked URLs, an empty +// list is stored for caching. +var RobotsCache sync.Map //nolint:gochecknoglobals + +func populateBlacklist(key string) (entries []string) { + // We either store an empty list when + // no rules, or a list of disallowed URLs. + // This applies even if we have an error + // finding/downloading robots.txt + defer func() { + RobotsCache.Store(key, entries) + }() + url := fmt.Sprintf("gemini://%s/robots.txt", key) + robotsContent, err := ConnectAndGetData(url) + if err != nil { + logging.LogDebug("robots.txt error %s", err) + return []string{} + } + robotsData, err := processData(robotsContent) + if err != nil { + logging.LogDebug("robots.txt error %s", err) + return []string{} + } + if robotsData.ResponseCode != 20 { + logging.LogDebug("robots.txt error code %d, ignoring", robotsData.ResponseCode) + return []string{} + } + // Some return text/plain, others text/gemini. + // According to spec, the first is correct, + // however let's be lenient + var data string + switch { + case robotsData.MimeType == "text/plain": + data = string(robotsData.Data) + case robotsData.MimeType == "text/gemini": + data = robotsData.GemText + default: + return []string{} + } + entries = ParseRobotsTxt(data, key) + return entries +} + +// RobotMatch checks if the snapshot URL matches +// a robots.txt allow rule. +func RobotMatch(url URL) bool { + key := strings.ToLower(fmt.Sprintf("%s:%d", url.Hostname, url.Port)) + logging.LogDebug("Checking robots.txt cache for %s", key) + var disallowedURLs []string + cacheEntries, ok := RobotsCache.Load(key) + if !ok { + // First time check, populate robot cache + disallowedURLs = populateBlacklist(key) + logging.LogDebug("Added to robots.txt cache: %v => %v", key, disallowedURLs) + } else { + disallowedURLs, _ = cacheEntries.([]string) + } + return isURLblocked(disallowedURLs, url.Full) +} + +func isURLblocked(disallowedURLs []string, input string) bool { + for _, url := range disallowedURLs { + if strings.HasPrefix(strings.ToLower(input), url) { + logging.LogDebug("robots.txt match: %s matches %s", input, url) + return true + } + } + return false +} diff --git a/gemini/robots.go b/gemini/robots.go new file mode 100644 index 0000000..0653b62 --- /dev/null +++ b/gemini/robots.go @@ -0,0 +1,31 @@ +package gemini + +import ( + "fmt" + "strings" +) + +// ParseRobotsTxt takes robots.txt content and a host, and +// returns a list of full URLs that shouldn't +// be visited. +// TODO Also take into account the user agent? +// Check gemini://geminiprotocol.net/docs/companion/robots.gmi +func ParseRobotsTxt(content string, host string) []string { + var disallowedPaths []string + for _, line := range strings.Split(content, "\n") { + line = strings.TrimSpace(line) + line = strings.ToLower(line) + if strings.HasPrefix(line, "disallow:") { + parts := strings.SplitN(line, ":", 2) + if len(parts) == 2 { + path := strings.TrimSpace(parts[1]) + if path != "" { + // Construct full Gemini URL + disallowedPaths = append(disallowedPaths, + fmt.Sprintf("gemini://%s%s", host, path)) + } + } + } + } + return disallowedPaths +} diff --git a/gemini/robots_test.go b/gemini/robots_test.go new file mode 100644 index 0000000..e73e7b5 --- /dev/null +++ b/gemini/robots_test.go @@ -0,0 +1,55 @@ +package gemini + +import ( + "reflect" + "testing" +) + +func TestParseRobotsTxt(t *testing.T) { + t.Parallel() + input := `User-agent: * +Disallow: /cgi-bin/wp.cgi/view +Disallow: /cgi-bin/wp.cgi/media +User-agent: googlebot +Disallow: /admin/` + + expected := []string{ + "gemini://example.com/cgi-bin/wp.cgi/view", + "gemini://example.com/cgi-bin/wp.cgi/media", + "gemini://example.com/admin/", + } + + result := ParseRobotsTxt(input, "example.com") + + if !reflect.DeepEqual(result, expected) { + t.Errorf("ParseRobotsTxt() = %v, want %v", result, expected) + } +} + +func TestParseRobotsTxtEmpty(t *testing.T) { + t.Parallel() + input := `` + + result := ParseRobotsTxt(input, "example.com") + + if len(result) != 0 { + t.Errorf("ParseRobotsTxt() = %v, want empty []string", result) + } +} + +func TestIsURLblocked(t *testing.T) { + t.Parallel() + disallowedURLs := []string{ + "gemini://example.com/cgi-bin/wp.cgi/view", + "gemini://example.com/cgi-bin/wp.cgi/media", + "gemini://example.com/admin/", + } + url := "gemini://example.com/admin/index.html" + if !isURLblocked(disallowedURLs, url) { + t.Errorf("Expected %s to be blocked", url) + } + url = "gemini://example1.com/admin/index.html" + if isURLblocked(disallowedURLs, url) { + t.Errorf("expected %s to not be blocked", url) + } +} diff --git a/gemini/snapshot.go b/gemini/snapshot.go new file mode 100644 index 0000000..bbd1a75 --- /dev/null +++ b/gemini/snapshot.go @@ -0,0 +1,42 @@ +package gemini + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + + "github.com/guregu/null/v5" +) + +type LinkList []URL + +func (l *LinkList) Value() (driver.Value, error) { + return json.Marshal(l) +} + +func (l *LinkList) Scan(value interface{}) error { + if value == nil { + *l = nil + return nil + } + b, ok := value.([]byte) // Type assertion! Converts to []byte + if !ok { + return fmt.Errorf("failed to scan LinkList: expected []byte, got %T", value) + } + return json.Unmarshal(b, l) +} + +type Snapshot struct { + ID int `db:"id" json:"id,omitempty"` + URL URL `db:"url" json:"url,omitempty"` + Host string `db:"host" json:"host,omitempty"` + Timestamp null.Time `db:"timestamp" json:"timestamp,omitempty"` + MimeType null.String `db:"mimetype" json:"mimetype,omitempty"` + Data null.Value[[]byte] `db:"data" json:"data,omitempty"` // For non text/gemini files. + GemText null.String `db:"gemtext" json:"gemtext,omitempty"` // For text/gemini files. + Header null.String `db:"header" json:"header,omitempty"` // Response header. + Links null.Value[LinkList] `db:"links" json:"links,omitempty"` + Lang null.String `db:"lang" json:"lang,omitempty"` + ResponseCode null.Int `db:"response_code" json:"code,omitempty"` // Gemini response status code. + Error null.String `db:"error" json:"error,omitempty"` // On network errors only +} diff --git a/gemini/worker.go b/gemini/worker.go new file mode 100644 index 0000000..626d5c9 --- /dev/null +++ b/gemini/worker.go @@ -0,0 +1,368 @@ +package gemini + +import ( + "errors" + "fmt" + "strings" + "time" + + "gemini-grc/logging" + "gemini-grc/util" + "github.com/guregu/null/v5" + "github.com/jmoiron/sqlx" +) + +type WorkerStatus struct { + id int + status string +} + +func PrintWorkerStatus(totalWorkers int, statusChan chan WorkerStatus) { + // Create a slice to store current status of each worker + statuses := make([]string, totalWorkers) + + // Initialize empty statuses + for i := range statuses { + statuses[i] = "" + } + + // Initial print + var output strings.Builder + // \033[H moves the cursor to the top left corner of the screen + // (ie, the first column of the first row in the screen). + // \033[J clears the part of the screen from the cursor to the end of the screen. + output.WriteString("\033[H\033[J") // Clear screen and move cursor to top + for i := range statuses { + output.WriteString(fmt.Sprintf("[%2d] \n", i)) + } + fmt.Print(output.String()) + + // Continuously receive status updates + for update := range statusChan { + if update.id >= totalWorkers { + continue + } + + // Update the status + statuses[update.id] = update.status + + // Build the complete output string + output.Reset() + output.WriteString("\033[H\033[J") // Clear screen and move cursor to top + for i, status := range statuses { + output.WriteString(fmt.Sprintf("[%2d] %.100s\n", i, status)) + } + + // Print the entire status + fmt.Print(output.String()) + } +} + +var statusChan chan WorkerStatus + +func SpawnWorkers(numOfWorkers int, db *sqlx.DB) { + logging.LogInfo("Spawning %d workers", numOfWorkers) + statusChan = make(chan WorkerStatus, numOfWorkers) + go PrintWorkerStatus(numOfWorkers, statusChan) + + for i := range numOfWorkers { + go func(i int) { + // Jitter to avoid starting everything at the same time + time.Sleep(time.Duration(util.SecureRandomInt(10)) * time.Second) + for { + RunWorkerWithTx(i, db, nil) + } + }(i) + } +} + +func RunWorkerWithTx(workerID int, db *sqlx.DB, url *string) { + statusChan <- WorkerStatus{ + id: workerID, + status: "Starting up", + } + defer func() { + statusChan <- WorkerStatus{ + id: workerID, + status: "Done", + } + }() + tx, err := db.Beginx() + if err != nil { + panic(fmt.Sprintf("Failed to begin transaction: %v", err)) + } + runWorker(workerID, tx, url) + logging.LogDebug("[%d] Committing transaction", workerID) + err = tx.Commit() + // On deadlock errors, rollback and return, otherwise panic. + if err != nil { + logging.LogError("[%d] Failed to commit transaction: %w", workerID, err) + if isDeadlockError(err) { + logging.LogError("[%d] Deadlock detected. Rolling back", workerID) + time.Sleep(time.Duration(10) * time.Second) + err := tx.Rollback() + if err != nil { + panic(fmt.Sprintf("[%d] Failed to roll back transaction: %v", workerID, err)) + } + return + } + panic(fmt.Sprintf("[%d] Failed to commit transaction: %v", workerID, err)) + } + logging.LogDebug("[%d] Worker done!", workerID) +} + +func runWorker(workerID int, tx *sqlx.Tx, url *string) { + var snapshots []Snapshot + var err error + + // If not given a specific URL, + // get some random ones to visit from DB. + if url == nil { + statusChan <- WorkerStatus{ + id: workerID, + status: "Getting snapshots", + } + snapshots, err = GetSnapshotsToVisit(tx) + if err != nil { + logging.LogError("[%d] GeminiError retrieving snapshot: %w", workerID, err) + panic("This should never happen") + } else if len(snapshots) == 0 { + logging.LogInfo("[%d] No snapshots to visit.", workerID) + time.Sleep(1 * time.Minute) + return + } + } else { + snapshotURL, err := ParseURL(*url, "") + if err != nil { + logging.LogError("Invalid URL given: %s", *url) + return + } + snapshots = []Snapshot{{ + // UID: uid.UID(), + URL: *snapshotURL, + Host: snapshotURL.Hostname, + Timestamp: null.TimeFrom(time.Now()), + }} + } + + total := len(snapshots) + for i, s := range snapshots { + logging.LogDebug("[%d] Snapshot %d/%d: %s", workerID, i+1, total, s.URL.String()) + } + // Start visiting URLs. + for i, s := range snapshots { + logging.LogDebug("[%d] Starting %d/%d %s", workerID, i+1, total, s.URL.String()) + // We differentiate between errors: + // Unexpected errors are the ones returned from the following function. + // If an error is unexpected (which should never happen) we panic. + // Expected errors are stored as strings within the snapshot, + // so that they can also be stored in DB. + err := workOnSnapshot(workerID, tx, &s) + if err != nil { + logging.LogError("[%d] [%s] Unexpected GeminiError %w", workerID, s.URL.String(), err) + util.PrintStackAndPanic(err) + } + if s.Error.Valid { + logging.LogDebug("[%d] Error: %v", workerID, s.Error.String) + } + logging.LogDebug("[%d] Done %d/%d.", workerID, i+1, total) + } +} + +// workOnSnapshot visits a URL and stores the result. +// unexpected errors are returned. +// expected errors are stored within the snapshot. +func workOnSnapshot(workerID int, tx *sqlx.Tx, s *Snapshot) (err error) { + if IsBlacklisted(s.URL) { + logging.LogDebug("[%d] URL matches Blacklist, ignoring %s", workerID, s.URL.String()) + return nil + } + + // If URL matches a robots.txt disallow line, + // add it as an error so next time it won't be + // crawled. + if RobotMatch(s.URL) { + s.Error = null.StringFrom(ErrGeminiRobotsDisallowed.Error()) + err = UpsertSnapshot(workerID, tx, s) + if err != nil { + return fmt.Errorf("[%d] %w", workerID, err) + } + return nil + } + + // Resolve IP address via DNS + IPs, err := getHostIPAddresses(s.Host) + if err != nil { + s.Error = null.StringFrom(err.Error()) + err = UpsertSnapshot(workerID, tx, s) + if err != nil { + return fmt.Errorf("[%d] %w", workerID, err) + } + return nil + } + + for { + count := 1 + if isAnotherWorkerVisitingHost(workerID, IPs) { + logging.LogDebug("[%d] Another worker is visiting this host, waiting", workerID) + statusChan <- WorkerStatus{ + id: workerID, + status: fmt.Sprintf("Waiting to grab lock for host %s", s.Host), + } + time.Sleep(1 * time.Second) // Avoid flood-retrying + count++ + if count == 3 { + return + } + } else { + break + } + } + + AddIPsToPool(IPs) + // After finishing, remove the host IPs from + // the connections pool, with a small delay + // to avoid potentially hitting the same IP quickly. + defer func() { + go func() { + time.Sleep(1 * time.Second) + RemoveIPsFromPool(IPs) + }() + }() + + statusChan <- WorkerStatus{ + id: workerID, + status: fmt.Sprintf("Visiting %s", s.URL.String()), + } + + err = Visit(s) + if err != nil { + if !IsKnownError(err) { + logging.LogError("[%d] Unknown error visiting %s: %w", workerID, s.URL.String(), err) + return err + } + s.Error = null.StringFrom(err.Error()) + // Check if error is redirection, and handle it + if errors.As(err, new(*GeminiError)) && + err.(*GeminiError).Msg == "redirect" { + err = handleRedirection(workerID, tx, s) + if err != nil { + if IsKnownError(err) { + s.Error = null.StringFrom(err.Error()) + } else { + return err + } + } + } + } + // If this is a gemini page, parse possible links inside + if !s.Error.Valid && s.MimeType.Valid && s.MimeType.String == "text/gemini" { + links := GetPageLinks(s.URL, s.GemText.String) + if len(links) > 0 { + logging.LogDebug("[%d] Found %d links", workerID, len(links)) + s.Links = null.ValueFrom(links) + err = storeLinks(tx, s) + if err != nil { + return err + } + } + } else { + logging.LogDebug("[%d] Not text/gemini, so not looking for page links", workerID) + } + + err = UpsertSnapshot(workerID, tx, s) + logging.LogInfo("[%3d] %2d %s", workerID, s.ResponseCode.ValueOrZero(), s.URL.String()) + if err != nil { + return err + } + + return nil +} + +func isAnotherWorkerVisitingHost(workerID int, IPs []string) bool { + IPPool.Lock.RLock() + defer func() { + IPPool.Lock.RUnlock() + }() + logging.LogDebug("[%d] Checking pool for IPs", workerID) + for _, ip := range IPs { + _, ok := IPPool.IPs[ip] + if ok { + return true + } + } + return false +} + +func storeLinks(tx *sqlx.Tx, s *Snapshot) error { + if s.Links.Valid { + var batchSnapshots []*Snapshot + for _, link := range s.Links.ValueOrZero() { + if shouldPersistURL(&link) { + newSnapshot := &Snapshot{ + URL: link, + Host: link.Hostname, + Timestamp: null.TimeFrom(time.Now()), + } + batchSnapshots = append(batchSnapshots, newSnapshot) + } + } + + if len(batchSnapshots) > 0 { + err := SaveLinksToDBinBatches(tx, batchSnapshots) + if err != nil { + return err + } + } + } + return nil +} + +// shouldPersistURL returns true if we +// should save the URL in the DB. +// Only gemini:// urls are saved. +func shouldPersistURL(u *URL) bool { + return strings.HasPrefix(u.String(), "gemini://") +} + +// handleRedirection saves redirect URL as new snapshot +func handleRedirection(workerID int, tx *sqlx.Tx, s *Snapshot) error { + newURL, err := extractRedirectTarget(s.URL, s.Error.ValueOrZero()) + if err != nil { + if errors.Is(err, ErrGeminiRedirect) { + logging.LogDebug("[%d] %s", workerID, err) + } + return err + } + logging.LogDebug("[%d] Page redirects to %s", workerID, newURL) + // Insert fresh snapshot with new URL + if shouldPersistURL(newURL) { + snapshot := &Snapshot{ + // UID: uid.UID(), + URL: *newURL, + Host: newURL.Hostname, + Timestamp: null.TimeFrom(time.Now()), + } + logging.LogDebug("[%d] Saving redirection URL %s", workerID, snapshot.URL.String()) + err = SaveSnapshotIfNew(tx, snapshot) + if err != nil { + return err + } + } + return nil +} + +func GetSnapshotFromURL(tx *sqlx.Tx, url string) ([]Snapshot, error) { + query := ` + SELECT * + FROM snapshots + WHERE url=$1 + LIMIT 1 + ` + var snapshots []Snapshot + err := tx.Select(&snapshots, query, url) + if err != nil { + return nil, err + } + return snapshots, nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..89a5188 --- /dev/null +++ b/go.mod @@ -0,0 +1,24 @@ +module gemini-grc + +go 1.23.1 + +require ( + github.com/guregu/null/v5 v5.0.0 + github.com/jackc/pgx/v5 v5.7.1 + github.com/jmoiron/sqlx v1.4.0 + github.com/lib/pq v1.10.9 + github.com/matoous/go-nanoid/v2 v2.1.0 + github.com/rs/zerolog v1.33.0 + golang.org/x/text v0.19.0 +) + +require ( + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + golang.org/x/crypto v0.27.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/sys v0.25.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..6deda2a --- /dev/null +++ b/go.sum @@ -0,0 +1,59 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/guregu/null/v5 v5.0.0 h1:PRxjqyOekS11W+w/7Vfz6jgJE/BCwELWtgvOJzddimw= +github.com/guregu/null/v5 v5.0.0/go.mod h1:SjupzNy+sCPtwQTKWhUCqjhVCO69hpsl2QsZrWHjlwU= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs= +github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= +github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/matoous/go-nanoid/v2 v2.1.0 h1:P64+dmq21hhWdtvZfEAofnvJULaRR1Yib0+PnU669bE= +github.com/matoous/go-nanoid/v2 v2.1.0/go.mod h1:KlbGNQ+FhrUNIHUxZdL63t7tl4LaPkZNpUULS8H4uVM= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= +github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= +golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/logging/logging.go b/logging/logging.go new file mode 100644 index 0000000..3b8ec62 --- /dev/null +++ b/logging/logging.go @@ -0,0 +1,23 @@ +package logging + +import ( + "fmt" + + zlog "github.com/rs/zerolog/log" +) + +func LogDebug(format string, args ...interface{}) { + zlog.Debug().Msg(fmt.Sprintf(format, args...)) +} + +func LogInfo(format string, args ...interface{}) { + zlog.Info().Msg(fmt.Sprintf(format, args...)) +} + +func LogWarn(format string, args ...interface{}) { + zlog.Warn().Msg(fmt.Sprintf(format, args...)) +} + +func LogError(format string, args ...interface{}) { + zlog.Error().Err(fmt.Errorf(format, args...)).Msg("") +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..2bbaab9 --- /dev/null +++ b/main.go @@ -0,0 +1,60 @@ +package main + +import ( + "os" + "os/signal" + "syscall" + + "gemini-grc/config" + "gemini-grc/gemini" + "gemini-grc/logging" + "github.com/jmoiron/sqlx" + "github.com/rs/zerolog" + zlog "github.com/rs/zerolog/log" +) + +func main() { + config.CONFIG = *config.GetConfig() + zerolog.TimeFieldFormat = zerolog.TimeFormatUnix + zerolog.SetGlobalLevel(config.CONFIG.LogLevel) + zlog.Logger = zlog.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: "[2006-01-02 15:04:05]"}) + if err := runApp(); err != nil { + logging.LogError("Application error: %w", err) + os.Exit(1) + } +} + +func runApp() error { + logging.LogInfo("Starting up. Press Ctrl+C to exit") + signals := make(chan os.Signal, 1) + signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) + + db := gemini.ConnectToDB() + + // !!! DANGER !!! + // Removes all rows and adds some seed URLs. + // populateDB(db) + + defer func(db *sqlx.DB) { + err := db.Close() + if err != nil { + // TODO properly log & hangle error + panic(err) + } + }(db) + + gemini.LoadBlacklist() + + // If there's an argument, visit this + // URL only and don't spawn other workers + if len(os.Args) > 1 { + url := os.Args[1] + go gemini.RunWorkerWithTx(0, db, &url) + } else { + go gemini.SpawnWorkers(config.CONFIG.NumOfWorkers, db) + } + + <-signals + logging.LogWarn("Received SIGINT or SIGTERM signal, exiting") + return nil +} diff --git a/uid/uid.go b/uid/uid.go new file mode 100644 index 0000000..b98e342 --- /dev/null +++ b/uid/uid.go @@ -0,0 +1,14 @@ +package uid + +import ( + nanoid "github.com/matoous/go-nanoid/v2" +) + +func UID() string { + // No 'o','O' and 'l' + id, err := nanoid.Generate("abcdefghijkmnpqrstuvwxyzABCDEFGHIJKLMNPQRSTUVWXYZ0123456789", 20) + if err != nil { + panic(err) + } + return id +} diff --git a/util/util.go b/util/util.go new file mode 100644 index 0000000..dbf36c8 --- /dev/null +++ b/util/util.go @@ -0,0 +1,36 @@ +package util + +import ( + "crypto/rand" + "encoding/json" + "fmt" + "math/big" + "runtime/debug" +) + +func PrintStackAndPanic(err error) { + fmt.Printf("Error %s Stack trace:\n%s", err, debug.Stack()) + panic("PANIC") +} + +// SecureRandomInt returns a cryptographically secure random integer in the range [0,max). +// Panics if max <= 0 or if there's an error reading from the system's secure +// random number generator. +func SecureRandomInt(max int) int { + // Convert max to *big.Int for crypto/rand operations + maxBig := big.NewInt(int64(max)) + + // Generate random number + n, err := rand.Int(rand.Reader, maxBig) + if err != nil { + PrintStackAndPanic(fmt.Errorf("could not generate a random integer between 0 and %d", max)) + } + + // Convert back to int + return int(n.Int64()) +} + +func PrettyJson(data string) string { + marshalled, _ := json.MarshalIndent(data, "", " ") + return fmt.Sprintf("%s\n", marshalled) +}