diff --git a/bin/gemget/main.go b/bin/gemget/main.go deleted file mode 100644 index 129e739..0000000 --- a/bin/gemget/main.go +++ /dev/null @@ -1,47 +0,0 @@ -package main - -import ( - "encoding/json" - "fmt" - "os" - - "gemini-grc/common/snapshot" - _url "gemini-grc/common/url" - "gemini-grc/config" - "gemini-grc/gemini" - "gemini-grc/gopher" - "gemini-grc/logging" - "github.com/antanst/go_errors" -) - -func main() { - config.CONFIG = *config.GetConfig() - err := runApp() - if err != nil { - fmt.Printf("%v\n", err) - logging.LogError("%v", err) - os.Exit(1) - } -} - -func runApp() error { - if len(os.Args) != 2 { - return go_errors.NewError(fmt.Errorf("missing URL to visit")) - } - url := os.Args[1] - var s *snapshot.Snapshot - var err error - if _url.IsGeminiUrl(url) { - s, err = gemini.Visit(url) - } else if _url.IsGopherURL(url) { - s, err = gopher.Visit(url) - } else { - return go_errors.NewFatalError(fmt.Errorf("not a Gemini or Gopher URL")) - } - if err != nil { - return err - } - _json, _ := json.MarshalIndent(s, "", " ") - fmt.Printf("%s\n", _json) - return err -} diff --git a/bin/normalizeSnapshot/main.go b/bin/normalizeSnapshot/main.go deleted file mode 100644 index 56de12d..0000000 --- a/bin/normalizeSnapshot/main.go +++ /dev/null @@ -1,118 +0,0 @@ -package main - -import ( - "fmt" - "os" - - "gemini-grc/common/snapshot" - "gemini-grc/common/url" - main2 "gemini-grc/db" - _ "github.com/jackc/pgx/v5/stdlib" // PGX driver for PostgreSQL - "github.com/jmoiron/sqlx" -) - -// Populates the `host` field -func main() { - db := connectToDB() - count := 0 - - for { - tx := db.MustBegin() - query := ` - SELECT * FROM snapshots - ORDER BY id - LIMIT 10000 OFFSET $1 - ` - var snapshots []snapshot.Snapshot - err := tx.Select(&snapshots, query, count) - if err != nil { - printErrorAndExit(tx, err) - } - if len(snapshots) == 0 { - fmt.Println("Done!") - return - } - for _, s := range snapshots { - count++ - escaped := url.EscapeURL(s.URL.String()) - normalizedGeminiURL, err := url.ParseURL(escaped, "", true) - if err != nil { - fmt.Println(s.URL.String()) - fmt.Println(escaped) - printErrorAndExit(tx, err) - } - normalizedURLString := normalizedGeminiURL.String() - // If URL is already normalized, skip snapshot - if normalizedURLString == s.URL.String() { - // fmt.Printf("[%5d] Skipping %d %s\n", count, s.ID, s.URL.String()) - continue - } - // If a snapshot already exists with the normalized - // URL, delete the current snapshot and leave the other. - var ss []snapshot.Snapshot - err = tx.Select(&ss, "SELECT * FROM snapshots WHERE URL=$1", normalizedURLString) - if err != nil { - printErrorAndExit(tx, err) - } - if len(ss) > 0 { - tx.MustExec("DELETE FROM snapshots WHERE id=$1", s.ID) - fmt.Printf("%d Deleting %d %s\n", count, s.ID, s.URL.String()) - //err = tx.Commit() - //if err != nil { - // printErrorAndExit(tx, err) - //} - //return - continue - } - // fmt.Printf("%s =>\n%s\n", s.URL.String(), normalizedURLString) - // At this point we just update the snapshot, - // and the normalized URL will be saved. - fmt.Printf("%d Updating %d %s => %s\n", count, s.ID, s.URL.String(), normalizedURLString) - // Saves the snapshot with the normalized URL - tx.MustExec("DELETE FROM snapshots WHERE id=$1", s.ID) - s.URL = *normalizedGeminiURL - err = main2.OverwriteSnapshot(tx, &s) - if err != nil { - printErrorAndExit(tx, err) - } - //err = tx.Commit() - //if err != nil { - // printErrorAndExit(tx, err) - //} - //return - } - err = tx.Commit() - if err != nil { - printErrorAndExit(tx, err) - } - } -} - -func printErrorAndExit(tx *sqlx.Tx, err error) { - _ = tx.Rollback() - panic(err) -} - -func connectToDB() *sqlx.DB { - connStr := fmt.Sprintf("postgres://%s:%s@%s:%s/%s", - os.Getenv("PG_USER"), - os.Getenv("PG_PASSWORD"), - os.Getenv("PG_HOST"), - os.Getenv("PG_PORT"), - os.Getenv("PG_DATABASE"), - ) - - // Create a connection pool - db, err := sqlx.Open("pgx", connStr) - if err != nil { - panic(fmt.Sprintf("Unable to connect to database with URL %s: %v\n", connStr, err)) - } - db.SetMaxOpenConns(20) - err = db.Ping() - if err != nil { - panic(fmt.Sprintf("Unable to ping database: %v\n", err)) - } - - fmt.Println("Connected to database") - return db -} diff --git a/common/blackList/blacklist.go b/common/blackList/blacklist.go index 43894c9..febb081 100644 --- a/common/blackList/blacklist.go +++ b/common/blackList/blacklist.go @@ -8,45 +8,63 @@ import ( "gemini-grc/config" "gemini-grc/logging" - "github.com/antanst/go_errors" + "git.antanst.com/antanst/xerrors" ) -var Blacklist []regexp.Regexp //nolint:gochecknoglobals +var blacklist []regexp.Regexp //nolint:gochecknoglobals -func LoadBlacklist() error { - if config.CONFIG.BlacklistPath == "" { - return nil - } - if Blacklist == nil { - data, err := os.ReadFile(config.CONFIG.BlacklistPath) - if err != nil { - Blacklist = []regexp.Regexp{} - return go_errors.NewError(fmt.Errorf("could not load Blacklist file: %w", err)) - } +func Initialize() error { + var err error - lines := strings.Split(string(data), "\n") - - for _, line := range lines { - if line == "" || strings.HasPrefix(line, "#") { - continue - } - regex, err := regexp.Compile(line) - if err != nil { - return go_errors.NewError(fmt.Errorf("could not compile Blacklist line %s: %w", line, err)) - } - Blacklist = append(Blacklist, *regex) - - } - - if len(lines) > 0 { - logging.LogInfo("Loaded %d blacklist entries", len(Blacklist)) + // Initialize blacklist + if config.CONFIG.BlacklistPath != "" { + if err = loadBlacklist(config.CONFIG.BlacklistPath); err != nil { + return err } } + return nil } +func loadBlacklist(filePath string) error { + if blacklist != nil { + return nil + } + + data, err := os.ReadFile(filePath) + if err != nil { + blacklist = []regexp.Regexp{} + return xerrors.NewError(fmt.Errorf("could not load blacklist file: %w", err), 0, "", true) + } + + lines := strings.Split(string(data), "\n") + blacklist = []regexp.Regexp{} + + for _, line := range lines { + if line == "" || strings.HasPrefix(line, "#") { + continue + } + regex, err := regexp.Compile(line) + if err != nil { + return xerrors.NewError(fmt.Errorf("could not compile blacklist line %s: %w", line, err), 0, "", true) + } + blacklist = append(blacklist, *regex) + } + + if len(blacklist) > 0 { + logging.LogInfo("Loaded %d blacklist entries", len(blacklist)) + } + + return nil +} + +func Shutdown() error { + return nil +} + +// IsBlacklisted checks if the URL matches any blacklist pattern func IsBlacklisted(u string) bool { - for _, v := range Blacklist { + for _, v := range blacklist { if v.MatchString(u) { return true } diff --git a/common/blackList/blacklist_test.go b/common/blackList/blacklist_test.go index 55a6252..6f44799 100644 --- a/common/blackList/blacklist_test.go +++ b/common/blackList/blacklist_test.go @@ -3,16 +3,17 @@ package blackList import ( "os" "regexp" + "strings" "testing" "gemini-grc/config" ) func TestIsBlacklisted(t *testing.T) { - // Save original blacklist to restore after test - originalBlacklist := Blacklist + // Save original blacklist and whitelist to restore after test + originalBlacklist := blacklist defer func() { - Blacklist = originalBlacklist + blacklist = originalBlacklist }() tests := []struct { @@ -24,7 +25,7 @@ func TestIsBlacklisted(t *testing.T) { { name: "empty blacklist", setup: func() { - Blacklist = []regexp.Regexp{} + blacklist = []regexp.Regexp{} }, url: "https://example.com", expected: false, @@ -33,7 +34,7 @@ func TestIsBlacklisted(t *testing.T) { name: "exact hostname match", setup: func() { regex, _ := regexp.Compile(`example\.com`) - Blacklist = []regexp.Regexp{*regex} + blacklist = []regexp.Regexp{*regex} }, url: "example.com", expected: true, @@ -42,7 +43,7 @@ func TestIsBlacklisted(t *testing.T) { name: "hostname in URL match", setup: func() { regex, _ := regexp.Compile(`example\.com`) - Blacklist = []regexp.Regexp{*regex} + blacklist = []regexp.Regexp{*regex} }, url: "https://example.com/path", expected: true, @@ -51,7 +52,7 @@ func TestIsBlacklisted(t *testing.T) { name: "partial hostname match", setup: func() { regex, _ := regexp.Compile(`example\.com`) - Blacklist = []regexp.Regexp{*regex} + blacklist = []regexp.Regexp{*regex} }, url: "https://safe-example.com", expected: true, @@ -60,7 +61,7 @@ func TestIsBlacklisted(t *testing.T) { name: "full URL match", setup: func() { regex, _ := regexp.Compile(`https://example\.com/bad-path`) - Blacklist = []regexp.Regexp{*regex} + blacklist = []regexp.Regexp{*regex} }, url: "https://example.com/bad-path", expected: true, @@ -69,7 +70,7 @@ func TestIsBlacklisted(t *testing.T) { name: "path match", setup: func() { regex, _ := regexp.Compile("/malicious-path") - Blacklist = []regexp.Regexp{*regex} + blacklist = []regexp.Regexp{*regex} }, url: "https://example.com/malicious-path", expected: true, @@ -78,7 +79,7 @@ func TestIsBlacklisted(t *testing.T) { name: "subdomain match with word boundary", setup: func() { regex, _ := regexp.Compile(`bad\.example\.com`) - Blacklist = []regexp.Regexp{*regex} + blacklist = []regexp.Regexp{*regex} }, url: "https://bad.example.com/path", expected: true, @@ -89,7 +90,7 @@ func TestIsBlacklisted(t *testing.T) { regex1, _ := regexp.Compile(`badsite\.com`) regex2, _ := regexp.Compile(`malicious\.org`) regex3, _ := regexp.Compile(`example\.com/sensitive`) - Blacklist = []regexp.Regexp{*regex1, *regex2, *regex3} + blacklist = []regexp.Regexp{*regex1, *regex2, *regex3} }, url: "https://example.com/sensitive/data", expected: true, @@ -100,7 +101,7 @@ func TestIsBlacklisted(t *testing.T) { regex1, _ := regexp.Compile(`badsite\.com`) regex2, _ := regexp.Compile(`malicious\.org`) regex3, _ := regexp.Compile(`example\.com/sensitive`) - Blacklist = []regexp.Regexp{*regex1, *regex2, *regex3} + blacklist = []regexp.Regexp{*regex1, *regex2, *regex3} }, url: "https://example.com/safe/data", expected: false, @@ -109,7 +110,7 @@ func TestIsBlacklisted(t *testing.T) { name: "pattern with wildcard", setup: func() { regex, _ := regexp.Compile(`.*\.evil\.com`) - Blacklist = []regexp.Regexp{*regex} + blacklist = []regexp.Regexp{*regex} }, url: "https://subdomain.evil.com/path", expected: true, @@ -118,7 +119,7 @@ func TestIsBlacklisted(t *testing.T) { name: "pattern with special characters", setup: func() { regex, _ := regexp.Compile(`example\.com/path\?id=[0-9]+`) - Blacklist = []regexp.Regexp{*regex} + blacklist = []regexp.Regexp{*regex} }, url: "https://example.com/path?id=12345", expected: true, @@ -127,7 +128,7 @@ func TestIsBlacklisted(t *testing.T) { name: "unicode character support", setup: func() { regex, _ := regexp.Compile(`example\.com/[\p{L}]+`) - Blacklist = []regexp.Regexp{*regex} + blacklist = []regexp.Regexp{*regex} }, url: "https://example.com/café", expected: true, @@ -145,12 +146,88 @@ func TestIsBlacklisted(t *testing.T) { } } -func TestLoadBlacklist(t *testing.T) { - // Save original blacklist to restore after test - originalBlacklist := Blacklist +// TestBlacklistLoading tests that the blacklist loading logic works with a mock blacklist file +func TestBlacklistLoading(t *testing.T) { + // Save original blacklist and config + originalBlacklist := blacklist originalConfigPath := config.CONFIG.BlacklistPath defer func() { - Blacklist = originalBlacklist + blacklist = originalBlacklist + config.CONFIG.BlacklistPath = originalConfigPath + }() + + // Create a temporary blacklist file with known patterns + tmpFile, err := os.CreateTemp("", "mock-blacklist-*.txt") + if err != nil { + t.Fatalf("Failed to create temporary file: %v", err) + } + defer os.Remove(tmpFile.Name()) + + // Write some test patterns to the mock blacklist file + mockBlacklistContent := `# Mock blacklist file for testing +/git/ +/.git/ +/cgit/ +gemini://git\..*$ +gemini://.*/git/.* +gopher://.*/git/.* +.*/(commit|blob|tree)/.* +.*/[0-9a-f]{7,40}$ +` + if err := os.WriteFile(tmpFile.Name(), []byte(mockBlacklistContent), 0o644); err != nil { + t.Fatalf("Failed to write to temporary file: %v", err) + } + + // Configure and load the mock blacklist + blacklist = nil + config.CONFIG.BlacklistPath = tmpFile.Name() + err = Initialize() + if err != nil { + t.Fatalf("Failed to load mock blacklist: %v", err) + } + + // Count the number of non-comment, non-empty lines to verify loading + lineCount := 0 + for _, line := range strings.Split(mockBlacklistContent, "\n") { + if line != "" && !strings.HasPrefix(line, "#") { + lineCount++ + } + } + + if len(blacklist) != lineCount { + t.Errorf("Expected %d patterns to be loaded, got %d", lineCount, len(blacklist)) + } + + // Verify some sample URLs against our known patterns + testURLs := []struct { + url string + expected bool + desc string + }{ + {"gemini://example.com/git/repo", true, "git repository"}, + {"gemini://git.example.com", true, "git subdomain"}, + {"gemini://example.com/cgit/repo", true, "cgit repository"}, + {"gemini://example.com/repo/commit/abc123", true, "git commit"}, + {"gemini://example.com/123abc7", true, "commit hash at path end"}, + {"gopher://example.com/1/git/repo", true, "gopher git repository"}, + {"gemini://example.com/normal/page.gmi", false, "normal gemini page"}, + {"gemini://example.com/project/123abc", false, "hash not at path end"}, + } + + for _, tt := range testURLs { + result := IsBlacklisted(tt.url) + if result != tt.expected { + t.Errorf("With mock blacklist, IsBlacklisted(%q) = %v, want %v", tt.url, result, tt.expected) + } + } +} + +func TestLoadBlacklist(t *testing.T) { + // Save original blacklist to restore after test + originalBlacklist := blacklist + originalConfigPath := config.CONFIG.BlacklistPath + defer func() { + blacklist = originalBlacklist config.CONFIG.BlacklistPath = originalConfigPath }() @@ -161,7 +238,7 @@ func TestLoadBlacklist(t *testing.T) { } defer os.Remove(tmpFile.Name()) - // Test cases for LoadBlacklist + // Test cases for Initialize tests := []struct { name string blacklistLines []string @@ -202,7 +279,7 @@ func TestLoadBlacklist(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Reset blacklist - Blacklist = nil + blacklist = nil // Set config path config.CONFIG.BlacklistPath = tt.configPath @@ -219,29 +296,186 @@ func TestLoadBlacklist(t *testing.T) { } // Call the function - err := LoadBlacklist() + err := Initialize() // Check results if (err != nil) != tt.wantErr { - t.Errorf("LoadBlacklist() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("Initialize() error = %v, wantErr %v", err, tt.wantErr) return } - if !tt.wantErr && len(Blacklist) != tt.expectedLen { - t.Errorf("LoadBlacklist() loaded %d entries, want %d", len(Blacklist), tt.expectedLen) + if !tt.wantErr && len(blacklist) != tt.expectedLen { + t.Errorf("Initialize() loaded %d entries, want %d", len(blacklist), tt.expectedLen) + } + }) + } +} + +// TestGitPatterns tests the blacklist patterns specifically for Git repositories +func TestGitPatterns(t *testing.T) { + // Save original blacklist to restore after test + originalBlacklist := blacklist + defer func() { + blacklist = originalBlacklist + }() + + // Create patterns similar to those in the blacklist.txt file + patterns := []string{ + "/git/", + "/.git/", + "/cgit/", + "/gitweb/", + "/gitea/", + "/scm/", + ".*/(commit|blob|tree|tag|diff|blame|log|raw)/.*", + ".*/(commits|objects|refs|branches|tags)/.*", + ".*/[0-9a-f]{7,40}$", + "gemini://git\\..*$", + "gemini://.*/git/.*", + "gemini://.*\\.git/.*", + "gopher://.*/git/.*", + } + + // Compile and set up the patterns + blacklist = []regexp.Regexp{} + for _, pattern := range patterns { + regex, err := regexp.Compile(pattern) + if err != nil { + t.Fatalf("Failed to compile pattern %q: %v", pattern, err) + } + blacklist = append(blacklist, *regex) + } + + // Test URLs against git-related patterns + tests := []struct { + url string + expected bool + desc string + }{ + // Git paths + {"gemini://example.com/git/", true, "basic git path"}, + {"gemini://example.com/.git/", true, "hidden git path"}, + {"gemini://example.com/cgit/", true, "cgit path"}, + {"gemini://example.com/gitweb/", true, "gitweb path"}, + {"gemini://example.com/gitea/", true, "gitea path"}, + {"gemini://example.com/scm/", true, "scm path"}, + + // Git operations + {"gemini://example.com/repo/commit/abc123", true, "commit path"}, + {"gemini://example.com/repo/blob/main/README.md", true, "blob path"}, + {"gemini://example.com/repo/tree/master", true, "tree path"}, + {"gemini://example.com/repo/tag/v1.0", true, "tag path"}, + + // Git internals + {"gemini://example.com/repo/commits/", true, "commits path"}, + {"gemini://example.com/repo/objects/", true, "objects path"}, + {"gemini://example.com/repo/refs/heads/main", true, "refs path"}, + + // Git hashes + {"gemini://example.com/commit/a1b2c3d", true, "short hash"}, + {"gemini://example.com/commit/a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d6e7f8a9b0", true, "long hash"}, + + // Git domains + {"gemini://git.example.com/", true, "git subdomain"}, + {"gemini://example.com/git/repo", true, "git directory"}, + {"gemini://example.com/project.git/", true, "git extension"}, + + // Gopher protocol + {"gopher://example.com/1/git/repo", true, "gopher git path"}, + + // Non-matching URLs + {"gemini://example.com/project/", false, "regular project path"}, + {"gemini://example.com/blog/", false, "blog path"}, + {"gemini://example.com/git-guide.gmi", false, "hyphenated word with git"}, + {"gemini://example.com/digital/", false, "word containing 'git'"}, + {"gemini://example.com/ab12cd3", true, "short hex string matches commit hash pattern"}, + {"gemini://example.com/ab12cdz", false, "alphanumeric string with non-hex chars won't match commit hash"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + result := IsBlacklisted(tt.url) + if result != tt.expected { + t.Errorf("IsBlacklisted(%q) = %v, want %v", tt.url, result, tt.expected) + } + }) + } +} + +// TestGeminiGopherPatterns tests the blacklist patterns specific to Gemini and Gopher protocols +func TestGeminiGopherPatterns(t *testing.T) { + // Save original blacklist to restore after test + originalBlacklist := blacklist + defer func() { + blacklist = originalBlacklist + }() + + // Create patterns for Gemini and Gopher + patterns := []string{ + "gemini://badhost\\.com", + "gemini://.*/cgi-bin/", + "gemini://.*/private/", + "gemini://.*\\.evil\\..*", + "gopher://badhost\\.org", + "gopher://.*/I/onlyfans/", + "gopher://.*/[0-9]/(cgi|bin)/", + } + + // Compile and set up the patterns + blacklist = []regexp.Regexp{} + for _, pattern := range patterns { + regex, err := regexp.Compile(pattern) + if err != nil { + t.Fatalf("Failed to compile pattern %q: %v", pattern, err) + } + blacklist = append(blacklist, *regex) + } + + // Test URLs against Gemini and Gopher patterns + tests := []struct { + url string + expected bool + desc string + }{ + // Gemini URLs + {"gemini://badhost.com/", true, "blacklisted gemini host"}, + {"gemini://badhost.com/page.gmi", true, "blacklisted gemini host with path"}, + {"gemini://example.com/cgi-bin/script.cgi", true, "gemini cgi-bin path"}, + {"gemini://example.com/private/docs", true, "gemini private path"}, + {"gemini://subdomain.evil.org", true, "gemini evil domain pattern"}, + {"gemini://example.com/public/docs", false, "safe gemini path"}, + {"gemini://goodhost.com/", false, "safe gemini host"}, + + // Gopher URLs + {"gopher://badhost.org/1/menu", true, "blacklisted gopher host"}, + {"gopher://example.org/I/onlyfans/image", true, "gopher onlyfans path"}, + {"gopher://example.org/1/cgi/script", true, "gopher cgi path"}, + {"gopher://example.org/1/bin/executable", true, "gopher bin path"}, + {"gopher://example.org/0/text", false, "safe gopher text"}, + {"gopher://goodhost.org/1/menu", false, "safe gopher host"}, + + // Protocol distinction + {"https://badhost.com/", false, "blacklisted host but wrong protocol"}, + {"http://example.com/cgi-bin/script.cgi", false, "bad path but wrong protocol"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + result := IsBlacklisted(tt.url) + if result != tt.expected { + t.Errorf("IsBlacklisted(%q) = %v, want %v", tt.url, result, tt.expected) } }) } } -// TestIsBlacklistedIntegration tests the integration between LoadBlacklist and IsBlacklisted func TestIsBlacklistedIntegration(t *testing.T) { // Save original blacklist to restore after test - originalBlacklist := Blacklist - originalConfigPath := config.CONFIG.BlacklistPath + originalBlacklist := blacklist + originalBlacklistPath := config.CONFIG.BlacklistPath defer func() { - Blacklist = originalBlacklist - config.CONFIG.BlacklistPath = originalConfigPath + blacklist = originalBlacklist + config.CONFIG.BlacklistPath = originalBlacklistPath }() // Create a temporary blacklist file for testing @@ -264,12 +498,12 @@ malicious\.org } // Set up the test - Blacklist = nil + blacklist = nil config.CONFIG.BlacklistPath = tmpFile.Name() // Load the blacklist - if err := LoadBlacklist(); err != nil { - t.Fatalf("LoadBlacklist() failed: %v", err) + if err := Initialize(); err != nil { + t.Fatalf("Initialize() failed: %v", err) } // Test URLs against the loaded blacklist diff --git a/common/linkList/linkList.go b/common/linkList/linkList.go index 12615e9..c5e786a 100644 --- a/common/linkList/linkList.go +++ b/common/linkList/linkList.go @@ -10,8 +10,15 @@ import ( type LinkList []url.URL -func (l *LinkList) Value() (driver.Value, error) { - return json.Marshal(l) +func (l LinkList) Value() (driver.Value, error) { + if len(l) == 0 { + return nil, nil + } + data, err := json.Marshal(l) + if err != nil { + return nil, err + } + return data, nil } func (l *LinkList) Scan(value interface{}) error { @@ -19,7 +26,7 @@ func (l *LinkList) Scan(value interface{}) error { *l = nil return nil } - b, ok := value.([]byte) // Type assertion! Converts to []byte + b, ok := value.([]byte) if !ok { return fmt.Errorf("failed to scan LinkList: expected []byte, got %T", value) } diff --git a/common/shared.go b/common/shared.go index 2ecb642..d190f49 100644 --- a/common/shared.go +++ b/common/shared.go @@ -1,11 +1,13 @@ package common +import "os" + +// FatalErrorsChan accepts errors from workers. +// In case of fatal error, gracefully +// exits the application. var ( - StatusChan chan WorkerStatus - // ErrorsChan accepts errors from workers. - // In case of fatal error, gracefully - // exits the application. - ErrorsChan chan error + FatalErrorsChan chan error + SignalsChan chan os.Signal ) const VERSION string = "0.0.1" diff --git a/common/snapshot/snapshot.go b/common/snapshot/snapshot.go index 7fa999c..09d1db9 100644 --- a/common/snapshot/snapshot.go +++ b/common/snapshot/snapshot.go @@ -5,12 +5,12 @@ import ( "gemini-grc/common/linkList" commonUrl "gemini-grc/common/url" - "github.com/antanst/go_errors" + "git.antanst.com/antanst/xerrors" "github.com/guregu/null/v5" ) type Snapshot struct { - ID int `db:"ID" json:"ID,omitempty"` + ID int `db:"id" json:"ID,omitempty"` URL commonUrl.URL `db:"url" json:"url,omitempty"` Host string `db:"host" json:"host,omitempty"` Timestamp null.Time `db:"timestamp" json:"timestamp,omitempty"` @@ -27,7 +27,7 @@ type Snapshot struct { func SnapshotFromURL(u string, normalize bool) (*Snapshot, error) { url, err := commonUrl.ParseURL(u, "", normalize) if err != nil { - return nil, go_errors.NewError(err) + return nil, xerrors.NewError(err, 0, "", false) } newSnapshot := Snapshot{ URL: *url, diff --git a/common/url/url.go b/common/url/url.go index 4a3b0fb..66ee4bf 100644 --- a/common/url/url.go +++ b/common/url/url.go @@ -9,7 +9,7 @@ import ( "strconv" "strings" - "github.com/antanst/go_errors" + "git.antanst.com/antanst/xerrors" ) type URL struct { @@ -29,7 +29,7 @@ func (u *URL) Scan(value interface{}) error { } b, ok := value.(string) if !ok { - return go_errors.NewFatalError(fmt.Errorf("database scan error: expected string, got %T", value)) + return xerrors.NewError(fmt.Errorf("database scan error: expected string, got %T", value), 0, "", true) } parsedURL, err := ParseURL(b, "", false) if err != nil { @@ -82,7 +82,7 @@ func ParseURL(input string, descr string, normalize bool) (*URL, error) { } else { u, err = url.Parse(input) if err != nil { - return nil, go_errors.NewError(fmt.Errorf("error parsing URL: %w: %s", err, input)) + return nil, xerrors.NewError(fmt.Errorf("error parsing URL: %w: %s", err, input), 0, "", false) } } protocol := u.Scheme @@ -99,7 +99,7 @@ func ParseURL(input string, descr string, normalize bool) (*URL, error) { } port, err := strconv.Atoi(strPort) if err != nil { - return nil, go_errors.NewError(fmt.Errorf("error parsing URL: %w: %s", err, input)) + return nil, xerrors.NewError(fmt.Errorf("error parsing URL: %w: %s", err, input), 0, "", false) } full := fmt.Sprintf("%s://%s:%d%s", protocol, hostname, port, urlPath) // full field should also contain query params and url fragments @@ -145,13 +145,13 @@ func NormalizeURL(rawURL string) (*url.URL, error) { // Parse the URL u, err := url.Parse(rawURL) if err != nil { - return nil, go_errors.NewError(fmt.Errorf("error normalizing URL: %w: %s", err, rawURL)) + return nil, xerrors.NewError(fmt.Errorf("error normalizing URL: %w: %s", err, rawURL), 0, "", false) } if u.Scheme == "" { - return nil, go_errors.NewError(fmt.Errorf("error normalizing URL: No scheme: %s", rawURL)) + return nil, xerrors.NewError(fmt.Errorf("error normalizing URL: No scheme: %s", rawURL), 0, "", false) } if u.Host == "" { - return nil, go_errors.NewError(fmt.Errorf("error normalizing URL: No host: %s", rawURL)) + return nil, xerrors.NewError(fmt.Errorf("error normalizing URL: No host: %s", rawURL), 0, "", false) } // Convert scheme to lowercase @@ -275,7 +275,7 @@ func ExtractRedirectTargetFromHeader(currentURL URL, input string) (*URL, error) re := regexp.MustCompile(pattern) matches := re.FindStringSubmatch(input) if len(matches) < 2 { - return nil, go_errors.NewError(fmt.Errorf("error extracting redirect target from string %s", input)) + return nil, xerrors.NewError(fmt.Errorf("error extracting redirect target from string %s", input), 0, "", false) } newURL, err := DeriveAbsoluteURL(currentURL, matches[1]) if err != nil { diff --git a/common/worker.go b/common/worker.go index 20a65d5..89463f2 100644 --- a/common/worker.go +++ b/common/worker.go @@ -1,245 +1,271 @@ package common import ( + "context" + "database/sql" + "errors" "fmt" "time" "gemini-grc/common/blackList" - errors2 "gemini-grc/common/errors" + "gemini-grc/common/contextlog" + commonErrors "gemini-grc/common/errors" "gemini-grc/common/snapshot" url2 "gemini-grc/common/url" - _db "gemini-grc/db" + "gemini-grc/common/whiteList" + "gemini-grc/config" + "gemini-grc/contextutil" + gemdb "gemini-grc/db" "gemini-grc/gemini" "gemini-grc/gopher" "gemini-grc/hostPool" "gemini-grc/logging" - "github.com/antanst/go_errors" + "gemini-grc/robotsMatch" + "git.antanst.com/antanst/xerrors" "github.com/guregu/null/v5" "github.com/jmoiron/sqlx" ) -func CrawlOneURL(db *sqlx.DB, url *string) error { - parsedURL, err := url2.ParseURL(*url, "", true) +func RunWorkerWithTx(workerID int, job string) { + // Extract host from URL for the context. + parsedURL, err := url2.ParseURL(job, "", true) if err != nil { - return err + logging.LogInfo("Failed to parse job URL: %s Error: %s", job, err) + return } + host := parsedURL.Hostname - if !url2.IsGeminiUrl(parsedURL.String()) && !url2.IsGopherURL(parsedURL.String()) { - return go_errors.NewError(fmt.Errorf("error parsing URL: not a Gemini or Gopher URL: %s", parsedURL.String())) - } + // Create a new worker context + baseCtx := context.Background() + ctx, cancel := contextutil.NewRequestContext(baseCtx, job, host, workerID) + defer cancel() // Ensure the context is cancelled when we're done + ctx = contextutil.ContextWithComponent(ctx, "worker") + contextlog.LogDebugWithContext(ctx, logging.GetSlogger(), "Starting worker for URL %s", job) - tx, err := db.Beginx() + // Create a new db transaction + tx, err := gemdb.Database.NewTx(ctx) if err != nil { - return go_errors.NewFatalError(err) - } - - err = _db.InsertURL(tx, parsedURL.Full) - if err != nil { - return err - } - - err = workOnUrl(0, tx, parsedURL.Full) - if err != nil { - return err - } - - err = tx.Commit() - if err != nil { - //if _db.IsDeadlockError(err) { - // logging.LogError("Deadlock detected. Rolling back") - // time.Sleep(time.Duration(10) * time.Second) - // err := tx.Rollback() - // return go_errors.NewFatalError(err) - //} - return go_errors.NewFatalError(err) - } - logging.LogInfo("Done") - return nil -} - -func SpawnWorkers(numOfWorkers int, db *sqlx.DB) { - logging.LogInfo("Spawning %d workers", numOfWorkers) - go PrintWorkerStatus(numOfWorkers, StatusChan) - - for i := range numOfWorkers { - go func(i int) { - UpdateWorkerStatus(i, "Waiting to start") - // Jitter to avoid starting everything at the same time - time.Sleep(time.Duration(i+2) * time.Second) - for { - // TODO: Use cancellable context with tx, logger & worker ID. - // ctx := context.WithCancel() - // ctx = context.WithValue(ctx, common.CtxKeyLogger, &RequestLogger{r: r}) - RunWorkerWithTx(i, db) - } - }(i) - } -} - -func RunWorkerWithTx(workerID int, db *sqlx.DB) { - defer func() { - UpdateWorkerStatus(workerID, "Done") - }() - - tx, err := db.Beginx() - if err != nil { - ErrorsChan <- err + FatalErrorsChan <- err return } - err = runWorker(workerID, tx) + err = runWorker(ctx, tx, []string{job}) if err != nil { - // TODO: Rollback in this case? - ErrorsChan <- err - return - } - - logging.LogDebug("[%3d] Committing transaction", workerID) - err = tx.Commit() - // On deadlock errors, rollback and return, otherwise panic. - if err != nil { - logging.LogError("[%3d] Failed to commit transaction: %w", workerID, err) - if _db.IsDeadlockError(err) { - logging.LogError("[%3d] Deadlock detected. Rolling back", workerID) - time.Sleep(time.Duration(10) * time.Second) - err := tx.Rollback() - if err != nil { - panic(fmt.Sprintf("[%3d] Failed to roll back transaction: %v", workerID, err)) + // Handle context cancellation and timeout errors gracefully, instead of treating them as fatal + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + contextlog.LogDebugWithContext(ctx, logging.GetSlogger(), "Worker timed out or canceled: %v", err) + rollbackErr := SafeRollback(ctx, tx) + if rollbackErr != nil { + FatalErrorsChan <- rollbackErr + return } return } - panic(fmt.Sprintf("[%3d] Failed to commit transaction: %v", workerID, err)) + // For other errors, we treat them as fatal. + contextlog.LogErrorWithContext(ctx, logging.GetSlogger(), "Worker failed: %v", err) + rollbackErr := SafeRollback(ctx, tx) + if rollbackErr != nil { + FatalErrorsChan <- rollbackErr + } + FatalErrorsChan <- err + return } - logging.LogDebug("[%3d] Worker done!", workerID) + + contextlog.LogDebugWithContext(ctx, logging.GetSlogger(), "Committing transaction") + err = tx.Commit() + if err != nil && !errors.Is(err, sql.ErrTxDone) { + contextlog.LogErrorWithContext(ctx, logging.GetSlogger(), "Failed to commit transaction: %v", err) + if rollbackErr := SafeRollback(ctx, tx); rollbackErr != nil { + FatalErrorsChan <- err + return + } + } + contextlog.LogDebugWithContext(ctx, logging.GetSlogger(), "Worker done!") } -func runWorker(workerID int, tx *sqlx.Tx) error { - var urls []string - var err error - - UpdateWorkerStatus(workerID, "Getting URLs from DB") - urls, err = _db.GetRandomUrls(tx) - // urls, err = _db.GetRandomUrlsWithBasePath(tx) - if err != nil { - return err - } else if len(urls) == 0 { - logging.LogInfo("[%3d] No URLs to visit, sleeping...", workerID) - UpdateWorkerStatus(workerID, "No URLs to visit, sleeping...") - time.Sleep(1 * time.Minute) - return nil - } - - // Start visiting URLs. - total := len(urls) - for i, u := range urls { - logging.LogInfo("[%3d] Starting %d/%d %s", workerID, i+1, total, u) - UpdateWorkerStatus(workerID, fmt.Sprintf("Starting %d/%d %s", i+1, total, u)) - err := workOnUrl(workerID, tx, u) - if err != nil { - return err +// SafeRollback attempts to roll back a transaction, +// handling the case if the tx was already finalized. +func SafeRollback(ctx context.Context, tx *sqlx.Tx) error { + rollbackErr := tx.Rollback() + if rollbackErr != nil { + // Check if it's the standard "transaction already finalized" error + if errors.Is(rollbackErr, sql.ErrTxDone) { + contextlog.LogWarnWithContext(ctx, logging.GetSlogger(), "Rollback failed because transaction is already finalized") + return nil } - logging.LogDebug("[%3d] Done %d/%d.", workerID, i+1, total) - UpdateWorkerStatus(workerID, fmt.Sprintf("Done %d/%d %s", i+1, total, u)) + // Only panic for other types of rollback failures + contextlog.LogErrorWithContext(ctx, logging.GetSlogger(), "Failed to rollback transaction: %v", rollbackErr) + return xerrors.NewError(fmt.Errorf("failed to rollback transaction: %w", rollbackErr), 0, "", true) } return nil } -// workOnUrl visits a URL and stores the result. +func runWorker(ctx context.Context, tx *sqlx.Tx, urls []string) error { + total := len(urls) + for i, u := range urls { + contextlog.LogDebugWithContext(ctx, logging.GetSlogger(), "Starting %d/%d %s", i+1, total, u) + urlCtx, cancelFunc := context.WithCancel(ctx) + err := WorkOnUrl(urlCtx, tx, u) + cancelFunc() + if err != nil { + return err + } + contextlog.LogDebugWithContext(ctx, logging.GetSlogger(), "Done %d/%d.", i+1, total) + } + return nil +} + +// WorkOnUrl visits a URL and stores the result. // unexpected errors are returned. // expected errors are stored within the snapshot. -func workOnUrl(workerID int, tx *sqlx.Tx, url string) (err error) { - s, err := snapshot.SnapshotFromURL(url, false) +func WorkOnUrl(ctx context.Context, tx *sqlx.Tx, url string) (err error) { + // Create a context specifically for this URL with "url" component + urlCtx := contextutil.ContextWithComponent(ctx, "url") + + contextlog.LogDebugWithContext(urlCtx, logging.GetSlogger(), "Processing URL: %s", url) + + s, err := snapshot.SnapshotFromURL(url, true) if err != nil { + contextlog.LogErrorWithContext(urlCtx, logging.GetSlogger(), "Failed to parse URL: %v", err) return err } isGemini := url2.IsGeminiUrl(s.URL.String()) isGopher := url2.IsGopherURL(s.URL.String()) + if !isGemini && !isGopher { - return go_errors.NewError(fmt.Errorf("not a Gopher or Gemini URL: %s", s.URL.String())) + return xerrors.NewError(fmt.Errorf("not a Gopher or Gemini URL: %s", s.URL.String()), 0, "", false) } - if blackList.IsBlacklisted(s.URL.String()) { - logging.LogInfo("[%3d] URL matches blacklist, ignoring", workerID) - s.Error = null.StringFrom(errors2.ErrBlacklistMatch.Error()) - return saveSnapshotAndRemoveURL(tx, s) + if isGopher && !config.CONFIG.GopherEnable { + contextlog.LogDebugWithContext(urlCtx, logging.GetSlogger(), "Skipping gopher URL (disabled in config)") + return nil } - if isGemini { + if url != s.URL.Full { + err = gemdb.Database.NormalizeURL(ctx, tx, url, s.URL.Full) + if err != nil { + return err + } + contextlog.LogDebugWithContext(urlCtx, logging.GetSlogger(), "Normalized URL: %s → %s", url, s.URL.Full) + url = s.URL.Full + } + + // Check if URL is whitelisted + isUrlWhitelisted := whiteList.IsWhitelisted(s.URL.String()) + if isUrlWhitelisted { + contextlog.LogInfoWithContext(urlCtx, logging.GetSlogger(), "URL matches whitelist, forcing crawl %s", url) + } + + // Only check blacklist if URL is not whitelisted + if !isUrlWhitelisted && blackList.IsBlacklisted(s.URL.String()) { + contextlog.LogInfoWithContext(urlCtx, logging.GetSlogger(), "URL matches blacklist, ignoring %s", url) + s.Error = null.StringFrom(commonErrors.ErrBlacklistMatch.Error()) + return saveSnapshotAndRemoveURL(ctx, tx, s) + } + + // Only check robots.txt if URL is not whitelisted and is a Gemini URL + var robotMatch bool + if !isUrlWhitelisted && isGemini { // If URL matches a robots.txt disallow line, // add it as an error and remove url - robotMatch, err := gemini.RobotMatch(s.URL.String()) + robotMatch, err = robotsMatch.RobotMatch(urlCtx, s.URL.String()) if err != nil { - // robotMatch returns only network errors! - // we stop because we don't want to hit - // the server with another request on this case. + if commonErrors.IsHostError(err) { + return removeURL(ctx, tx, s.URL.String()) + } return err } if robotMatch { - logging.LogInfo("[%3d] URL matches robots.txt, ignoring", workerID) - s.Error = null.StringFrom(errors2.ErrRobotsMatch.Error()) - return saveSnapshotAndRemoveURL(tx, s) + contextlog.LogInfoWithContext(urlCtx, logging.GetSlogger(), "URL matches robots.txt, skipping") + s.Error = null.StringFrom(commonErrors.ErrRobotsMatch.Error()) + return saveSnapshotAndRemoveURL(ctx, tx, s) } } - logging.LogDebug("[%3d] Adding to pool %s", workerID, s.URL.String()) - UpdateWorkerStatus(workerID, fmt.Sprintf("Adding to pool %s", s.URL.String())) - hostPool.AddHostToHostPool(s.Host) - defer func(s string) { - hostPool.RemoveHostFromPool(s) - }(s.Host) + contextlog.LogDebugWithContext(urlCtx, logging.GetSlogger(), "Adding to host pool") + err = hostPool.AddHostToHostPool(urlCtx, s.Host) + if err != nil { + contextlog.LogErrorWithContext(urlCtx, logging.GetSlogger(), "Failed to add host to pool: %v", err) + return err + } - logging.LogDebug("[%3d] Visiting %s", workerID, s.URL.String()) - UpdateWorkerStatus(workerID, fmt.Sprintf("Visiting %s", s.URL.String())) + defer func(ctx context.Context, host string) { + hostPool.RemoveHostFromPool(ctx, host) + }(urlCtx, s.Host) + contextlog.LogDebugWithContext(urlCtx, logging.GetSlogger(), "Visiting %s", s.URL.String()) + + // Use context-aware visits for both protocols if isGopher { - s, err = gopher.Visit(s.URL.String()) + // Use the context-aware version for Gopher visits + s, err = gopher.VisitWithContext(urlCtx, s.URL.String()) } else { - s, err = gemini.Visit(s.URL.String()) + // Use the context-aware version for Gemini visits + s, err = gemini.Visit(urlCtx, s.URL.String()) } if err != nil { + contextlog.LogErrorWithContext(urlCtx, logging.GetSlogger(), "Error visiting URL: %v", err) return err } + if s == nil { + contextlog.LogDebugWithContext(urlCtx, logging.GetSlogger(), "No snapshot returned") + return nil + } // Handle Gemini redirection. if isGemini && s.ResponseCode.ValueOrZero() >= 30 && s.ResponseCode.ValueOrZero() < 40 { - err = handleRedirection(workerID, tx, s) + err = handleRedirection(urlCtx, tx, s) if err != nil { return fmt.Errorf("error while handling redirection: %s", err) } } - // Store links + // Check if content is identical to previous snapshot and we should skip further processing + if config.CONFIG.SkipIdenticalContent { + identical, err := gemdb.Database.IsContentIdentical(ctx, tx, s) + if err != nil { + return err + } + if identical { + contextlog.LogDebugWithContext(urlCtx, logging.GetSlogger(), "Content identical to existing snapshot, skipping") + return removeURL(ctx, tx, s.URL.String()) + } + } + + // Process and store links since content has changed if len(s.Links.ValueOrZero()) > 0 { - logging.LogDebug("[%3d] Found %d links", workerID, len(s.Links.ValueOrZero())) - err = storeLinks(tx, s) + contextlog.LogDebugWithContext(urlCtx, logging.GetSlogger(), "Found %d links", len(s.Links.ValueOrZero())) + err = storeLinks(ctx, tx, s) if err != nil { return err } } - logging.LogInfo("[%3d] %2d %s", workerID, s.ResponseCode.ValueOrZero(), s.URL.String()) - return saveSnapshotAndRemoveURL(tx, s) + // Save the snapshot and remove the URL from the queue + contextlog.LogInfoWithContext(urlCtx, logging.GetSlogger(), "%2d %s", s.ResponseCode.ValueOrZero(), s.URL.String()) + return saveSnapshotAndRemoveURL(ctx, tx, s) } -func storeLinks(tx *sqlx.Tx, s *snapshot.Snapshot) error { +// storeLinks checks and stores the snapshot links in the database. +func storeLinks(ctx context.Context, tx *sqlx.Tx, s *snapshot.Snapshot) error { if s.Links.Valid { //nolint:nestif for _, link := range s.Links.ValueOrZero() { if shouldPersistURL(&link) { - visited, err := haveWeVisitedURL(tx, link.Full) + visited, err := haveWeVisitedURL(ctx, tx, link.Full) if err != nil { return err } if !visited { - err := _db.InsertURL(tx, link.Full) + err := gemdb.Database.InsertURL(ctx, tx, link.Full) if err != nil { return err } } else { - logging.LogDebug("Link already persisted: %s", link.Full) + contextlog.LogDebugWithContext(ctx, logging.GetSlogger(), "Link already persisted: %s", link.Full) } } } @@ -247,74 +273,117 @@ func storeLinks(tx *sqlx.Tx, s *snapshot.Snapshot) error { return nil } -func saveSnapshotAndRemoveURL(tx *sqlx.Tx, s *snapshot.Snapshot) error { - err := _db.OverwriteSnapshot(tx, s) - if err != nil { - return err - } - err = _db.DeleteURL(tx, s.URL.String()) - if err != nil { - return err - } - return nil +// Context-aware version of removeURL +func removeURL(ctx context.Context, tx *sqlx.Tx, url string) error { + return gemdb.Database.DeleteURL(ctx, tx, url) } -// shouldPersistURL returns true if we -// should save the URL in the _db. -// Only gemini:// urls are saved. +// Context-aware version of saveSnapshotAndRemoveURL +func saveSnapshotAndRemoveURL(ctx context.Context, tx *sqlx.Tx, s *snapshot.Snapshot) error { + err := gemdb.Database.SaveSnapshot(ctx, tx, s) + if err != nil { + return err + } + return gemdb.Database.DeleteURL(ctx, tx, s.URL.String()) +} + +// shouldPersistURL returns true given URL is a +// non-blacklisted Gemini or Gopher URL. func shouldPersistURL(u *url2.URL) bool { - return url2.IsGeminiUrl(u.String()) || url2.IsGopherURL(u.String()) + if blackList.IsBlacklisted(u.String()) { + return false + } + if config.CONFIG.GopherEnable && url2.IsGopherURL(u.String()) { + return true + } + return url2.IsGeminiUrl(u.String()) } -func haveWeVisitedURL(tx *sqlx.Tx, u string) (bool, error) { +func haveWeVisitedURL(ctx context.Context, tx *sqlx.Tx, u string) (bool, error) { var result []bool - err := tx.Select(&result, `SELECT TRUE FROM urls WHERE url=$1`, u) + + // Check if the context is cancelled + if err := ctx.Err(); err != nil { + return false, err + } + + // Check the urls table which holds the crawl queue. + err := tx.SelectContext(ctx, &result, `SELECT TRUE FROM urls WHERE url=$1`, u) if err != nil { - return false, go_errors.NewFatalError(fmt.Errorf("database error: %w", err)) + return false, xerrors.NewError(fmt.Errorf("database error: %w", err), 0, "", true) } if len(result) > 0 { - return result[0], nil + return false, nil } - err = tx.Select(&result, `SELECT TRUE FROM snapshots WHERE snapshots.url=$1`, u) - if err != nil { - return false, go_errors.NewFatalError(fmt.Errorf("database error: %w", err)) - } - if len(result) > 0 { - return result[0], nil + + // If we're skipping URLs based on recent updates, check if this URL has been + // crawled within the specified number of days + if config.CONFIG.SkipIfUpdatedDays > 0 { + var recentSnapshots []bool + cutoffDate := time.Now().AddDate(0, 0, -config.CONFIG.SkipIfUpdatedDays) + + // Check if the context is cancelled + if err := ctx.Err(); err != nil { + return false, err + } + + err = tx.SelectContext(ctx, &recentSnapshots, ` + SELECT TRUE FROM snapshots + WHERE snapshots.url=$1 + AND timestamp > $2 + LIMIT 1`, u, cutoffDate) + if err != nil { + return false, xerrors.NewError(fmt.Errorf("database error checking recent snapshots: %w", err), 0, "", true) + } + + if len(recentSnapshots) > 0 { + contextlog.LogDebugWithContext(ctx, logging.GetSlogger(), "Skipping URL %s (updated within last %d days)", u, config.CONFIG.SkipIfUpdatedDays) + return true, nil + } } + return false, nil } -// handleRedirection saves redirection URL. -func handleRedirection(workerID int, tx *sqlx.Tx, s *snapshot.Snapshot) error { - newURL, err := url2.ExtractRedirectTargetFromHeader(s.URL, s.Error.ValueOrZero()) +func handleRedirection(ctx context.Context, tx *sqlx.Tx, s *snapshot.Snapshot) error { + // Create a context specifically for redirection handling + redirectCtx := contextutil.ContextWithComponent(ctx, "redirect") + + // Use the redirectCtx for all operations + newURL, err := url2.ExtractRedirectTargetFromHeader(s.URL, s.Header.ValueOrZero()) + if err != nil { + contextlog.LogErrorWithContext(redirectCtx, logging.GetSlogger(), "Failed to extract redirect target: %v", err) + return err + } + contextlog.LogDebugWithContext(redirectCtx, logging.GetSlogger(), "Page redirects to %s", newURL) + + haveWeVisited, err := haveWeVisitedURL(redirectCtx, tx, newURL.String()) if err != nil { return err } - logging.LogDebug("[%3d] Page redirects to %s", workerID, newURL) - haveWeVisited, _ := haveWeVisitedURL(tx, newURL.String()) if shouldPersistURL(newURL) && !haveWeVisited { - err = _db.InsertURL(tx, newURL.Full) + err = gemdb.Database.InsertURL(redirectCtx, tx, newURL.Full) if err != nil { + contextlog.LogErrorWithContext(redirectCtx, logging.GetSlogger(), "Failed to insert redirect URL: %v", err) return err } - logging.LogDebug("[%3d] Saved redirection URL %s", workerID, newURL.String()) + contextlog.LogDebugWithContext(redirectCtx, logging.GetSlogger(), "Saved redirection URL %s", newURL.String()) } return nil } -func GetSnapshotFromURL(tx *sqlx.Tx, url string) ([]snapshot.Snapshot, error) { - query := ` - SELECT * - FROM snapshots - WHERE url=$1 - LIMIT 1 - ` - var snapshots []snapshot.Snapshot - err := tx.Select(&snapshots, query, url) - if err != nil { - return nil, err - } - return snapshots, nil -} +//func GetSnapshotFromURL(tx *sqlx.Tx, url string) ([]snapshot.Snapshot, error) { +// query := ` +// SELECT * +// FROM snapshots +// WHERE url=$1 +// LIMIT 1 +// ` +// var snapshots []snapshot.Snapshot +// err := tx.Select(&snapshots, query, url) +// if err != nil { +// return nil, err +// } +// return snapshots, nil +//} diff --git a/config/config.go b/config/config.go index f08b09c..2f4e75b 100644 --- a/config/config.go +++ b/config/config.go @@ -1,166 +1,93 @@ package config import ( + "flag" "fmt" + "log/slog" "os" - "strconv" - - "github.com/rs/zerolog" -) - -// Environment variable names. -const ( - EnvLogLevel = "LOG_LEVEL" - EnvNumWorkers = "NUM_OF_WORKERS" - EnvWorkerBatchSize = "WORKER_BATCH_SIZE" - EnvMaxResponseSize = "MAX_RESPONSE_SIZE" - EnvResponseTimeout = "RESPONSE_TIMEOUT" - EnvPanicOnUnexpectedError = "PANIC_ON_UNEXPECTED_ERROR" - EnvBlacklistPath = "BLACKLIST_PATH" - EnvDryRun = "DRY_RUN" - EnvPrintWorkerStatus = "PRINT_WORKER_STATUS" ) // Config holds the application configuration loaded from environment variables. type Config struct { - LogLevel zerolog.Level // Logging level (debug, info, warn, error) - MaxResponseSize int // Maximum size of response in bytes - NumOfWorkers int // Number of concurrent workers - ResponseTimeout int // Timeout for responses in seconds - WorkerBatchSize int // Batch size for worker processing - PanicOnUnexpectedError bool // Panic on unexpected errors when visiting a URL - BlacklistPath string // File that has blacklisted strings of "host:port" - DryRun bool // If false, don't write to disk - PrintWorkerStatus bool // If false, don't print worker status table + PgURL string + LogLevel slog.Level // Logging level (debug, info, warn, error) + MaxResponseSize int // Maximum size of response in bytes + MaxDbConnections int // Maximum number of database connections. + NumOfWorkers int // Number of concurrent workers + ResponseTimeout int // Timeout for responses in seconds + BlacklistPath string // File that has blacklisted strings of "host:port" + WhitelistPath string // File with URLs that should always be crawled regardless of blacklist + DryRun bool // If false, don't write to disk + GopherEnable bool // Enable Gopher crawling + SeedUrlPath string // Add URLs from file to queue + SkipIdenticalContent bool // When true, skip storing snapshots with identical content + SkipIfUpdatedDays int // Skip re-crawling URLs updated within this many days (0 to disable, default 0) } var CONFIG Config //nolint:gochecknoglobals -// parsePositiveInt parses and validates positive integer values. -func parsePositiveInt(param, value string) (int, error) { - val, err := strconv.Atoi(value) - if err != nil { - return 0, ValidationError{ - Param: param, - Value: value, - Reason: "must be a valid integer", - } - } - if val <= 0 { - return 0, ValidationError{ - Param: param, - Value: value, - Reason: "must be positive", - } - } - return val, nil -} - -func parseBool(param, value string) (bool, error) { - val, err := strconv.ParseBool(value) - if err != nil { - return false, ValidationError{ - Param: param, - Value: value, - Reason: "cannot be converted to boolean", - } - } - return val, nil -} - -// GetConfig loads and validates configuration from environment variables -func GetConfig() *Config { +// Initialize loads and validates configuration from environment variables +func Initialize() *Config { config := &Config{} - // Map of environment variables to their parsing functions - parsers := map[string]func(string) error{ - EnvLogLevel: func(v string) error { - level, err := zerolog.ParseLevel(v) - if err != nil { - return ValidationError{ - Param: EnvLogLevel, - Value: v, - Reason: "must be one of: debug, info, warn, error", - } - } - config.LogLevel = level - return nil - }, - EnvNumWorkers: func(v string) error { - val, err := parsePositiveInt(EnvNumWorkers, v) - if err != nil { - return err - } - config.NumOfWorkers = val - return nil - }, - EnvWorkerBatchSize: func(v string) error { - val, err := parsePositiveInt(EnvWorkerBatchSize, v) - if err != nil { - return err - } - config.WorkerBatchSize = val - return nil - }, - EnvMaxResponseSize: func(v string) error { - val, err := parsePositiveInt(EnvMaxResponseSize, v) - if err != nil { - return err - } - config.MaxResponseSize = val - return nil - }, - EnvResponseTimeout: func(v string) error { - val, err := parsePositiveInt(EnvResponseTimeout, v) - if err != nil { - return err - } - config.ResponseTimeout = val - return nil - }, - EnvPanicOnUnexpectedError: func(v string) error { - val, err := parseBool(EnvPanicOnUnexpectedError, v) - if err != nil { - return err - } - config.PanicOnUnexpectedError = val - return nil - }, - EnvBlacklistPath: func(v string) error { - config.BlacklistPath = v - return nil - }, - EnvDryRun: func(v string) error { - val, err := parseBool(EnvDryRun, v) - if err != nil { - return err - } - config.DryRun = val - return nil - }, - EnvPrintWorkerStatus: func(v string) error { - val, err := parseBool(EnvPrintWorkerStatus, v) - if err != nil { - return err - } - config.PrintWorkerStatus = val - return nil - }, - } + loglevel := flag.String("log-level", "info", "Logging level (debug, info, warn, error)") + pgURL := flag.String("pgurl", "", "Postgres URL") + dryRun := flag.Bool("dry-run", false, "Dry run mode (default false)") + gopherEnable := flag.Bool("gopher", false, "Enable crawling of Gopher holes (default false)") + maxDbConnections := flag.Int("max-db-connections", 100, "Maximum number of database connections (default 100)") + numOfWorkers := flag.Int("workers", 1, "Number of concurrent workers (default 1)") + maxResponseSize := flag.Int("max-response-size", 1024*1024, "Maximum size of response in bytes (default 1MB)") + responseTimeout := flag.Int("response-timeout", 10, "Timeout for network responses in seconds (default 10)") + blacklistPath := flag.String("blacklist-path", "", "File that has blacklist regexes") + skipIdenticalContent := flag.Bool("skip-identical-content", true, "Skip storing snapshots with identical content (default true)") + skipIfUpdatedDays := flag.Int("skip-if-updated-days", 60, "Skip re-crawling URLs updated within this many days (0 to disable, default 60)") + whitelistPath := flag.String("whitelist-path", "", "File with URLs that should always be crawled regardless of blacklist") + seedUrlPath := flag.String("seed-url-path", "", "File with seed URLs that should be added to the queue immediatelly") - // Process each environment variable - for envVar, parser := range parsers { - value, ok := os.LookupEnv(envVar) - if !ok { - fmt.Fprintf(os.Stderr, "Missing required environment variable: %s\n", envVar) - os.Exit(1) - } + flag.Parse() - if err := parser(value); err != nil { - fmt.Fprintf(os.Stderr, "Configuration error: %v\n", err) - os.Exit(1) - } + config.PgURL = *pgURL + config.DryRun = *dryRun + config.GopherEnable = *gopherEnable + config.NumOfWorkers = *numOfWorkers + config.MaxResponseSize = *maxResponseSize + config.ResponseTimeout = *responseTimeout + config.BlacklistPath = *blacklistPath + config.WhitelistPath = *whitelistPath + config.SeedUrlPath = *seedUrlPath + config.MaxDbConnections = *maxDbConnections + config.SkipIdenticalContent = *skipIdenticalContent + config.SkipIfUpdatedDays = *skipIfUpdatedDays + + level, err := ParseSlogLevel(*loglevel) + if err != nil { + _, _ = fmt.Fprint(os.Stderr, err.Error()) + os.Exit(-1) } + config.LogLevel = level return config } + +// ParseSlogLevel converts a string level to slog.Level +func ParseSlogLevel(levelStr string) (slog.Level, error) { + switch levelStr { + case "debug": + return slog.LevelDebug, nil + case "info": + return slog.LevelInfo, nil + case "warn": + return slog.LevelWarn, nil + case "error": + return slog.LevelError, nil + default: + return slog.LevelInfo, fmt.Errorf("invalid log level: %s", levelStr) + } +} + +// Convert method for backward compatibility with existing codebase +// This can be removed once all references to Convert() are updated +func (c *Config) Convert() *Config { + // Just return the config itself as it now directly contains slog.Level + return c +} diff --git a/gemini/geminiLinks.go b/gemini/geminiLinks.go index c77fd84..862eff4 100644 --- a/gemini/geminiLinks.go +++ b/gemini/geminiLinks.go @@ -9,7 +9,7 @@ import ( url2 "gemini-grc/common/url" "gemini-grc/logging" "gemini-grc/util" - "github.com/antanst/go_errors" + "git.antanst.com/antanst/xerrors" ) func GetPageLinks(currentURL url2.URL, gemtext string) linkList.LinkList { @@ -37,14 +37,14 @@ func ParseGeminiLinkLine(linkLine string, currentURL string) (*url2.URL, error) // Check: currentURL is parseable baseURL, err := url.Parse(currentURL) if err != nil { - return nil, go_errors.NewError(fmt.Errorf("error parsing link line: %w input '%s'", err, linkLine)) + return nil, xerrors.NewError(fmt.Errorf("error parsing link line: %w input '%s'", err, linkLine), 0, "", false) } // Extract the actual URL and the description re := regexp.MustCompile(`^=>[ \t]+(\S+)([ \t]+.*)?`) matches := re.FindStringSubmatch(linkLine) if len(matches) == 0 { - return nil, go_errors.NewError(fmt.Errorf("error parsing link line: no regexp match for line %s", linkLine)) + return nil, xerrors.NewError(fmt.Errorf("error parsing link line: no regexp match for line %s", linkLine), 0, "", false) } originalURLStr := matches[1] @@ -52,7 +52,7 @@ func ParseGeminiLinkLine(linkLine string, currentURL string) (*url2.URL, error) // Check: Unescape the URL if escaped _, err = url.QueryUnescape(originalURLStr) if err != nil { - return nil, go_errors.NewError(fmt.Errorf("error parsing link line: %w input '%s'", err, linkLine)) + return nil, xerrors.NewError(fmt.Errorf("error parsing link line: %w input '%s'", err, linkLine), 0, "", false) } description := "" @@ -63,7 +63,7 @@ func ParseGeminiLinkLine(linkLine string, currentURL string) (*url2.URL, error) // Parse the URL from the link line parsedURL, err := url.Parse(originalURLStr) if err != nil { - return nil, go_errors.NewError(fmt.Errorf("error parsing link line: %w input '%s'", err, linkLine)) + return nil, xerrors.NewError(fmt.Errorf("error parsing link line: %w input '%s'", err, linkLine), 0, "", false) } // If link URL is relative, resolve full URL @@ -80,7 +80,7 @@ func ParseGeminiLinkLine(linkLine string, currentURL string) (*url2.URL, error) finalURL, err := url2.ParseURL(parsedURL.String(), description, true) if err != nil { - return nil, go_errors.NewError(fmt.Errorf("error parsing link line: %w input '%s'", err, linkLine)) + return nil, xerrors.NewError(fmt.Errorf("error parsing link line: %w input '%s'", err, linkLine), 0, "", false) } return finalURL, nil diff --git a/gemini/network.go b/gemini/network.go index bc1c674..be76ae2 100644 --- a/gemini/network.go +++ b/gemini/network.go @@ -1,166 +1,23 @@ package gemini import ( - "crypto/tls" "fmt" - "io" - "net" - stdurl "net/url" "regexp" "slices" "strconv" "strings" - "time" - errors2 "gemini-grc/common/errors" + commonErrors "gemini-grc/common/errors" "gemini-grc/common/snapshot" - _url "gemini-grc/common/url" - "gemini-grc/config" - "gemini-grc/logging" - "github.com/antanst/go_errors" "github.com/guregu/null/v5" ) -// Visit given URL, using the Gemini protocol. -// Mutates given Snapshot with the data. -// In case of error, we store the error string -// inside snapshot and return the error. -func Visit(url string) (s *snapshot.Snapshot, err error) { - s, err = snapshot.SnapshotFromURL(url, true) - if err != nil { - return nil, err - } - - defer func() { - if err != nil { - // GeminiError and HostError should - // be stored in the snapshot. Other - // errors are returned. - if errors2.IsHostError(err) { - s.Error = null.StringFrom(err.Error()) - err = nil - } else if IsGeminiError(err) { - s.Error = null.StringFrom(err.Error()) - s.Header = null.StringFrom(go_errors.Unwrap(err).(*GeminiError).Header) - s.ResponseCode = null.IntFrom(int64(go_errors.Unwrap(err).(*GeminiError).Code)) - err = nil - } else { - s = nil - } - } - }() - - data, err := ConnectAndGetData(s.URL.String()) - if err != nil { - return s, err - } - - s, err = processData(*s, data) - if err != nil { - return s, err - } - - if isGeminiCapsule(s) { - links := GetPageLinks(s.URL, s.GemText.String) - if len(links) > 0 { - logging.LogDebug("Found %d links", len(links)) - s.Links = null.ValueFrom(links) - } - } - return s, nil -} - -func ConnectAndGetData(url string) ([]byte, error) { - parsedURL, err := stdurl.Parse(url) - if err != nil { - return nil, go_errors.NewError(err) - } - hostname := parsedURL.Hostname() - port := parsedURL.Port() - if port == "" { - port = "1965" - } - host := fmt.Sprintf("%s:%s", hostname, port) - timeoutDuration := time.Duration(config.CONFIG.ResponseTimeout) * time.Second - // Establish the underlying TCP connection. - dialer := &net.Dialer{ - Timeout: timeoutDuration, - } - conn, err := dialer.Dial("tcp", host) - if err != nil { - return nil, errors2.NewHostError(err) - } - // Make sure we always close the connection. - defer func() { - _ = conn.Close() - }() - - // Set read and write timeouts on the TCP connection. - err = conn.SetReadDeadline(time.Now().Add(timeoutDuration)) - if err != nil { - return nil, errors2.NewHostError(err) - } - err = conn.SetWriteDeadline(time.Now().Add(timeoutDuration)) - if err != nil { - return nil, errors2.NewHostError(err) - } - - // Perform the TLS handshake - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, //nolint:gosec // Accept all TLS certs, even if insecure. - ServerName: parsedURL.Hostname(), // SNI says we should not include port in hostname - // MinVersion: tls.VersionTLS12, // Use a minimum TLS version. Warning breaks a lot of sites. - } - tlsConn := tls.Client(conn, tlsConfig) - err = tlsConn.SetReadDeadline(time.Now().Add(timeoutDuration)) - if err != nil { - return nil, errors2.NewHostError(err) - } - err = tlsConn.SetWriteDeadline(time.Now().Add(timeoutDuration)) - if err != nil { - return nil, errors2.NewHostError(err) - } - err = tlsConn.Handshake() - if err != nil { - return nil, errors2.NewHostError(err) - } - - // We read `buf`-sized chunks and add data to `data`. - buf := make([]byte, 4096) - var data []byte - - // Send Gemini request to trigger server response. - // Fix for stupid server bug: - // Some servers return 'Header: 53 No proxying to other hosts or ports!' - // when the port is 1965 and is still specified explicitly in the URL. - url2, _ := _url.ParseURL(url, "", true) - _, err = tlsConn.Write([]byte(fmt.Sprintf("%s\r\n", url2.StringNoDefaultPort()))) - if err != nil { - return nil, errors2.NewHostError(err) - } - // Read response bytes in len(buf) byte chunks - for { - n, err := tlsConn.Read(buf) - if n > 0 { - data = append(data, buf[:n]...) - } - if len(data) > config.CONFIG.MaxResponseSize { - return nil, errors2.NewHostError(err) - } - if err != nil { - if go_errors.Is(err, io.EOF) { - break - } - return nil, errors2.NewHostError(err) - } - } - return data, nil -} - -func processData(s snapshot.Snapshot, data []byte) (*snapshot.Snapshot, error) { +// ProcessData processes the raw data from a Gemini response and populates the Snapshot. +// This function is exported for use by the robotsMatch package. +func ProcessData(s snapshot.Snapshot, data []byte) (*snapshot.Snapshot, error) { header, body, err := getHeadersAndData(data) if err != nil { - return nil, err + return &s, err } code, mimeType, lang := getMimeTypeAndLang(header) @@ -198,7 +55,7 @@ func processData(s snapshot.Snapshot, data []byte) (*snapshot.Snapshot, error) { func getHeadersAndData(data []byte) (string, []byte, error) { firstLineEnds := slices.Index(data, '\n') if firstLineEnds == -1 { - return "", nil, errors2.NewHostError(fmt.Errorf("error parsing header")) + return "", nil, commonErrors.NewHostError(fmt.Errorf("error parsing header")) } firstLine := string(data[:firstLineEnds]) rest := data[firstLineEnds+1:] @@ -252,4 +109,4 @@ func getMimeTypeAndLang(headers string) (int, string, string) { func isGeminiCapsule(s *snapshot.Snapshot) bool { return !s.Error.Valid && s.MimeType.Valid && s.MimeType.String == "text/gemini" -} \ No newline at end of file +} diff --git a/gemini/network_test.go b/gemini/network_test.go index 6c0656a..94d0a9b 100644 --- a/gemini/network_test.go +++ b/gemini/network_test.go @@ -135,7 +135,7 @@ func TestProcessData(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := snapshot.Snapshot{} - result, err := processData(s, test.inputData) + result, err := ProcessData(s, test.inputData) if test.expectedError && err == nil { t.Errorf("Expected error, got nil") @@ -175,192 +175,3 @@ func TestProcessData(t *testing.T) { }) } } - -//// Mock Gemini server for testing ConnectAndGetData -//func mockGeminiServer(response string, delay time.Duration, closeConnection bool) net.Listener { -// listener, err := net.Listen("tcp", "127.0.0.1:0") // Bind to a random available port -// if err != nil { -// panic(fmt.Sprintf("Failed to create mock server: %v", err)) -// } -// -// go func() { -// conn, err := listener.Accept() -// if err != nil { -// if !closeConnection { // Don't panic if we closed the connection on purpose -// panic(fmt.Sprintf("Failed to accept connection: %v", err)) -// } -// return -// } -// defer conn.Close() -// -// time.Sleep(delay) // Simulate network latency -// -// _, err = conn.Write([]byte(response)) -// if err != nil && !closeConnection { -// panic(fmt.Sprintf("Failed to write response: %v", err)) -// } -// }() -// -// return listener -//} - -// func TestConnectAndGetData(t *testing.T) { -// config.CONFIG = config.ConfigStruct{ -// ResponseTimeout: 5, -// MaxResponseSize: 1024 * 1024, -// } -// tests := []struct { -// name string -// serverResponse string -// serverDelay time.Duration -// expectedData []byte -// expectedError bool -// closeConnection bool -// }{ -// { -// name: "Successful response", -// serverResponse: "20 text/gemini\r\n# Hello", -// expectedData: []byte("20 text/gemini\r\n# Hello"), -// expectedError: false, -// }, -// { -// name: "Server error", -// serverResponse: "50 Server error\r\n", -// expectedData: []byte("50 Server error\r\n"), -// expectedError: false, -// }, -// { -// name: "Timeout", -// serverDelay: 6 * time.Second, // Longer than the timeout -// expectedError: true, -// }, -// { -// name: "Server closes connection", -// closeConnection: true, -// expectedError: true, -// }, -// } - -// for _, test := range tests { -// t.Run(test.name, func(t *testing.T) { -// listener := mockGeminiServer(test.serverResponse, test.serverDelay, test.closeConnection) -// defer func() { -// test.closeConnection = true // Prevent panic in mock server -// listener.Close() -// }() -// addr := listener.Addr().String() -// data, err := ConnectAndGetData(fmt.Sprintf("gemini://%s/", addr)) - -// if test.expectedError && err == nil { -// t.Errorf("Expected error, got nil") -// } - -// if !test.expectedError && err != nil { -// t.Errorf("Unexpected error: %v", err) -// } - -// if !slices.Equal(data, test.expectedData) { -// t.Errorf("Expected data '%s', got '%s'", test.expectedData, data) -// } -// }) -// } -// } - -// func TestVisit(t *testing.T) { -// config.CONFIG = config.ConfigStruct{ -// ResponseTimeout: 5, -// MaxResponseSize: 1024 * 1024, -// } -// tests := []struct { -// name string -// serverResponse string -// expectedCode int -// expectedMime string -// expectedError bool -// expectedLinks []string -// }{ -// { -// name: "Successful response", -// serverResponse: "20 text/gemini\r\n# Hello\n=> /link1 Link 1\n=> /link2 Link 2", -// expectedCode: 20, -// expectedMime: "text/gemini", -// expectedError: false, -// expectedLinks: []string{"gemini://127.0.0.1:1965/link1", "gemini://127.0.0.1:1965/link2"}, -// }, -// { -// name: "Server error", -// serverResponse: "50 Server error\r\n", -// expectedCode: 50, -// expectedMime: "Server error", -// expectedError: false, -// expectedLinks: []string{}, -// }, -// } - -// for _, test := range tests { -// t.Run(test.name, func(t *testing.T) { -// listener := mockGeminiServer(test.serverResponse, 0, false) -// defer listener.Close() -// addr := listener.Addr().String() -// snapshot, err := Visit(fmt.Sprintf("gemini://%s/", addr)) - -// if test.expectedError && err == nil { -// t.Errorf("Expected error, got nil") -// } - -// if !test.expectedError && err != nil { -// t.Errorf("Unexpected error: %v", err) -// } - -// if snapshot.ResponseCode.ValueOrZero() != int64(test.expectedCode) { -// t.Errorf("Expected code %d, got %d", test.expectedCode, snapshot.ResponseCode.ValueOrZero()) -// } - -// if snapshot.MimeType.ValueOrZero() != test.expectedMime { -// t.Errorf("Expected mimeType '%s', got '%s'", test.expectedMime, snapshot.MimeType.ValueOrZero()) -// } - -// if test.expectedLinks != nil { -// links, _ := snapshot.Links.Value() - -// if len(links) != len(test.expectedLinks) { -// t.Errorf("Expected %d links, got %d", len(test.expectedLinks), len(links)) -// } -// for i, link := range links { -// if link != test.expectedLinks[i] { -// t.Errorf("Expected link '%s', got '%s'", test.expectedLinks[i], link) -// } -// } -// } -// }) -// } -// } - -func TestVisit_InvalidURL(t *testing.T) { - t.Parallel() - _, err := Visit("invalid-url") - if err == nil { - t.Errorf("Expected error for invalid URL, got nil") - } -} - -//func TestVisit_GeminiError(t *testing.T) { -// listener := mockGeminiServer("51 Not Found\r\n", 0, false) -// defer listener.Close() -// addr := listener.Addr().String() -// -// s, err := Visit(fmt.Sprintf("gemini://%s/", addr)) -// if err != nil { -// t.Errorf("Unexpected error: %v", err) -// } -// -// expectedError := "51 Not Found" -// if s.Error.ValueOrZero() != expectedError { -// t.Errorf("Expected error in snapshot: %v, got %v", expectedError, s.Error) -// } -// -// expectedCode := 51 -// if s.ResponseCode.ValueOrZero() != int64(expectedCode) { -// t.Errorf("Expected code %d, got %d", expectedCode, s.ResponseCode.ValueOrZero()) -// } -//} diff --git a/gemini/processing.go b/gemini/processing.go index 0d86758..771c841 100644 --- a/gemini/processing.go +++ b/gemini/processing.go @@ -7,7 +7,7 @@ import ( "io" "unicode/utf8" - "github.com/antanst/go_errors" + "git.antanst.com/antanst/xerrors" "golang.org/x/text/encoding/charmap" "golang.org/x/text/encoding/japanese" "golang.org/x/text/encoding/korean" @@ -25,7 +25,7 @@ func BytesToValidUTF8(input []byte) (string, error) { } const maxSize = 10 * 1024 * 1024 // 10MB if len(input) > maxSize { - return "", go_errors.NewError(fmt.Errorf("%w: %d bytes (max %d)", ErrInputTooLarge, len(input), maxSize)) + return "", xerrors.NewError(fmt.Errorf("%w: %d bytes (max %d)", ErrInputTooLarge, len(input), maxSize), 0, "", false) } // remove NULL byte 0x00 (ReplaceAll accepts slices) inputNoNull := bytes.ReplaceAll(input, []byte{byte(0)}, []byte{}) @@ -56,5 +56,5 @@ func BytesToValidUTF8(input []byte) (string, error) { } } - return "", go_errors.NewError(fmt.Errorf("%w (tried %d encodings): %w", ErrUTF8Conversion, len(encodings), lastErr)) + return "", xerrors.NewError(fmt.Errorf("%w (tried %d encodings): %w", ErrUTF8Conversion, len(encodings), lastErr), 0, "", false) } diff --git a/gemini/robotmatch.go b/gemini/robotmatch.go deleted file mode 100644 index 07819b7..0000000 --- a/gemini/robotmatch.go +++ /dev/null @@ -1,95 +0,0 @@ -package gemini - -import ( - "fmt" - "strings" - "sync" - - "gemini-grc/common/snapshot" - geminiUrl "gemini-grc/common/url" - "gemini-grc/logging" -) - -// RobotsCache is a map of blocked URLs -// key: URL -// value: []string list of disallowed URLs -// If a key has no blocked URLs, an empty -// list is stored for caching. -var RobotsCache sync.Map //nolint:gochecknoglobals - -func populateRobotsCache(key string) (entries []string, _err error) { - // We either store an empty list when - // no rules, or a list of disallowed URLs. - // This applies even if we have an error - // finding/downloading robots.txt - defer func() { - RobotsCache.Store(key, entries) - }() - url := fmt.Sprintf("gemini://%s/robots.txt", key) - robotsContent, err := ConnectAndGetData(url) - if err != nil { - return []string{}, err - } - s, err := snapshot.SnapshotFromURL(url, true) - if err != nil { - return []string{}, nil - } - s, err = processData(*s, robotsContent) - if err != nil { - logging.LogDebug("robots.txt error %s", err) - return []string{}, nil - } - if s.ResponseCode.ValueOrZero() != 20 { - logging.LogDebug("robots.txt error code %d, ignoring", s.ResponseCode.ValueOrZero()) - return []string{}, nil - } - // Some return text/plain, others text/gemini. - // According to spec, the first is correct, - // however let's be lenient - var data string - switch { - case s.MimeType.ValueOrZero() == "text/plain": - data = string(s.Data.ValueOrZero()) - case s.MimeType.ValueOrZero() == "text/gemini": - data = s.GemText.ValueOrZero() - default: - return []string{}, nil - } - entries = ParseRobotsTxt(data, key) - return entries, nil -} - -// RobotMatch checks if the snapshot URL matches -// a robots.txt allow rule. -func RobotMatch(u string) (bool, error) { - url, err := geminiUrl.ParseURL(u, "", true) - if err != nil { - return false, err - } - key := strings.ToLower(fmt.Sprintf("%s:%d", url.Hostname, url.Port)) - var disallowedURLs []string - cacheEntries, ok := RobotsCache.Load(key) - if !ok { - // First time check, populate robot cache - disallowedURLs, err := populateRobotsCache(key) - if err != nil { - return false, err - } - if len(disallowedURLs) > 0 { - logging.LogDebug("Added to robots.txt cache: %v => %v", key, disallowedURLs) - } - } else { - disallowedURLs, _ = cacheEntries.([]string) - } - return isURLblocked(disallowedURLs, url.Full), nil -} - -func isURLblocked(disallowedURLs []string, input string) bool { - for _, url := range disallowedURLs { - if strings.HasPrefix(strings.ToLower(input), url) { - logging.LogDebug("robots.txt match: %s matches %s", input, url) - return true - } - } - return false -} diff --git a/gemini/robots.go b/gemini/robots.go deleted file mode 100644 index 0653b62..0000000 --- a/gemini/robots.go +++ /dev/null @@ -1,31 +0,0 @@ -package gemini - -import ( - "fmt" - "strings" -) - -// ParseRobotsTxt takes robots.txt content and a host, and -// returns a list of full URLs that shouldn't -// be visited. -// TODO Also take into account the user agent? -// Check gemini://geminiprotocol.net/docs/companion/robots.gmi -func ParseRobotsTxt(content string, host string) []string { - var disallowedPaths []string - for _, line := range strings.Split(content, "\n") { - line = strings.TrimSpace(line) - line = strings.ToLower(line) - if strings.HasPrefix(line, "disallow:") { - parts := strings.SplitN(line, ":", 2) - if len(parts) == 2 { - path := strings.TrimSpace(parts[1]) - if path != "" { - // Construct full Gemini URL - disallowedPaths = append(disallowedPaths, - fmt.Sprintf("gemini://%s%s", host, path)) - } - } - } - } - return disallowedPaths -} diff --git a/gemini/robots_test.go b/gemini/robots_test.go deleted file mode 100644 index e73e7b5..0000000 --- a/gemini/robots_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package gemini - -import ( - "reflect" - "testing" -) - -func TestParseRobotsTxt(t *testing.T) { - t.Parallel() - input := `User-agent: * -Disallow: /cgi-bin/wp.cgi/view -Disallow: /cgi-bin/wp.cgi/media -User-agent: googlebot -Disallow: /admin/` - - expected := []string{ - "gemini://example.com/cgi-bin/wp.cgi/view", - "gemini://example.com/cgi-bin/wp.cgi/media", - "gemini://example.com/admin/", - } - - result := ParseRobotsTxt(input, "example.com") - - if !reflect.DeepEqual(result, expected) { - t.Errorf("ParseRobotsTxt() = %v, want %v", result, expected) - } -} - -func TestParseRobotsTxtEmpty(t *testing.T) { - t.Parallel() - input := `` - - result := ParseRobotsTxt(input, "example.com") - - if len(result) != 0 { - t.Errorf("ParseRobotsTxt() = %v, want empty []string", result) - } -} - -func TestIsURLblocked(t *testing.T) { - t.Parallel() - disallowedURLs := []string{ - "gemini://example.com/cgi-bin/wp.cgi/view", - "gemini://example.com/cgi-bin/wp.cgi/media", - "gemini://example.com/admin/", - } - url := "gemini://example.com/admin/index.html" - if !isURLblocked(disallowedURLs, url) { - t.Errorf("Expected %s to be blocked", url) - } - url = "gemini://example1.com/admin/index.html" - if isURLblocked(disallowedURLs, url) { - t.Errorf("expected %s to not be blocked", url) - } -} diff --git a/gopher/network.go b/gopher/network.go index 4e33857..3268656 100644 --- a/gopher/network.go +++ b/gopher/network.go @@ -1,6 +1,7 @@ package gopher import ( + "errors" "fmt" "io" "net" @@ -8,17 +9,11 @@ import ( "regexp" "strings" "time" - "unicode/utf8" - errors2 "gemini-grc/common/errors" - "gemini-grc/common/linkList" - "gemini-grc/common/snapshot" - "gemini-grc/common/text" - _url "gemini-grc/common/url" + commonErrors "gemini-grc/common/errors" "gemini-grc/config" "gemini-grc/logging" - "github.com/antanst/go_errors" - "github.com/guregu/null/v5" + "git.antanst.com/antanst/xerrors" ) // References: @@ -62,64 +57,10 @@ import ( // The original Gopher protocol only specified types 0-9, `+`, `g`, `I`, and `T`. // The others were added by various implementations and extensions over time. -// Error methodology: -// HostError for DNS/network errors -// GopherError for network/gopher errors -// NewError for other errors -// NewFatalError for other fatal errors - -func Visit(url string) (*snapshot.Snapshot, error) { - s, err := snapshot.SnapshotFromURL(url, false) - if err != nil { - return nil, err - } - - data, err := connectAndGetData(url) - if err != nil { - logging.LogDebug("Error: %s", err.Error()) - if IsGopherError(err) || errors2.IsHostError(err) { - s.Error = null.StringFrom(err.Error()) - return s, nil - } - return nil, err - } - - isValidUTF8 := utf8.ValidString(string(data)) - if isValidUTF8 { - s.GemText = null.StringFrom(text.RemoveNullChars(string(data))) - } else { - s.Data = null.ValueFrom(data) - } - - if !isValidUTF8 { - return s, nil - } - - responseError := checkForError(string(data)) - if responseError != nil { - s.Error = null.StringFrom(responseError.Error()) - return s, nil - } - - links := getGopherPageLinks(string(data)) - linkURLs := linkList.LinkList(make([]_url.URL, len(links))) - for i, link := range links { - linkURL, err := _url.ParseURL(link, "", true) - if err == nil { - linkURLs[i] = *linkURL - } - } - if len(links) != 0 { - s.Links = null.ValueFrom(linkURLs) - } - - return s, nil -} - func connectAndGetData(url string) ([]byte, error) { parsedURL, err := stdurl.Parse(url) if err != nil { - return nil, go_errors.NewError(err) + return nil, xerrors.NewError(fmt.Errorf("error parsing URL: %w", err), 0, "", false) } hostname := parsedURL.Hostname() @@ -136,7 +77,7 @@ func connectAndGetData(url string) ([]byte, error) { logging.LogDebug("Dialing %s", host) conn, err := dialer.Dial("tcp", host) if err != nil { - return nil, errors2.NewHostError(err) + return nil, commonErrors.NewHostError(err) } // Make sure we always close the connection. defer func() { @@ -146,11 +87,11 @@ func connectAndGetData(url string) ([]byte, error) { // Set read and write timeouts on the TCP connection. err = conn.SetReadDeadline(time.Now().Add(timeoutDuration)) if err != nil { - return nil, errors2.NewHostError(err) + return nil, commonErrors.NewHostError(err) } err = conn.SetWriteDeadline(time.Now().Add(timeoutDuration)) if err != nil { - return nil, errors2.NewHostError(err) + return nil, commonErrors.NewHostError(err) } // We read `buf`-sized chunks and add data to `data`. @@ -161,7 +102,7 @@ func connectAndGetData(url string) ([]byte, error) { payload := constructPayloadFromPath(parsedURL.Path) _, err = conn.Write([]byte(fmt.Sprintf("%s\r\n", payload))) if err != nil { - return nil, errors2.NewHostError(err) + return nil, commonErrors.NewHostError(err) } // Read response bytes in len(buf) byte chunks for { @@ -170,13 +111,13 @@ func connectAndGetData(url string) ([]byte, error) { data = append(data, buf[:n]...) } if err != nil { - if go_errors.Is(err, io.EOF) { + if errors.Is(err, io.EOF) { break } - return nil, errors2.NewHostError(err) + return nil, commonErrors.NewHostError(err) } if len(data) > config.CONFIG.MaxResponseSize { - return nil, errors2.NewHostError(fmt.Errorf("response exceeded max")) + return nil, commonErrors.NewHostError(fmt.Errorf("response exceeded max")) } } logging.LogDebug("Got %d bytes", len(data)) diff --git a/gopher/network_test.go b/gopher/network_test.go index 8bfe32e..8d9a45a 100644 --- a/gopher/network_test.go +++ b/gopher/network_test.go @@ -288,7 +288,7 @@ func TestConnectAndGetDataTimeout(t *testing.T) { // Check if the error is due to timeout if err == nil { t.Error("Expected an error due to timeout, but got no error") - } else if !errors.IsHostError(err) { + } else if !commonErrors.IsHostError(err) { t.Errorf("Expected a HostError, but got: %v", err) } else { // Here you might want to check if the specific error message contains 'timeout' diff --git a/main.go b/main.go deleted file mode 100644 index 939ebf8..0000000 --- a/main.go +++ /dev/null @@ -1,82 +0,0 @@ -package main - -import ( - "fmt" - "os" - "os/signal" - "syscall" - - "gemini-grc/common" - "gemini-grc/common/blackList" - "gemini-grc/config" - "gemini-grc/db" - "gemini-grc/logging" - "github.com/antanst/go_errors" - "github.com/jmoiron/sqlx" - "github.com/rs/zerolog" - zlog "github.com/rs/zerolog/log" -) - -func main() { - config.CONFIG = *config.GetConfig() - zerolog.TimeFieldFormat = zerolog.TimeFormatUnix - zerolog.SetGlobalLevel(config.CONFIG.LogLevel) - zlog.Logger = zlog.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: "[2006-01-02 15:04:05]"}) - err := runApp() - if err != nil { - var asErr *go_errors.Error - if go_errors.As(err, &asErr) { - logging.LogError("Unexpected error: %v", err) - _, _ = fmt.Fprintf(os.Stderr, "Unexpected error: %v", err) - } else { - logging.LogError("Unexpected error: %v", err) - } - os.Exit(1) - } -} - -func runApp() (err error) { - logging.LogInfo("gemcrawl %s starting up. Press Ctrl+C to exit", common.VERSION) - signals := make(chan os.Signal, 1) - signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) - - _db, err := db.ConnectToDB() - if err != nil { - return err - } - - defer func(db *sqlx.DB) { - _ = db.Close() - }(_db) - - err = blackList.LoadBlacklist() - if err != nil { - return err - } - - common.StatusChan = make(chan common.WorkerStatus, config.CONFIG.NumOfWorkers) - common.ErrorsChan = make(chan error, config.CONFIG.NumOfWorkers) - - // If there's an argument, visit this - // URL only and don't spawn other workers - if len(os.Args) > 1 { - url := os.Args[1] - err = common.CrawlOneURL(_db, &url) - return err - } - - go common.SpawnWorkers(config.CONFIG.NumOfWorkers, _db) - - for { - select { - case <-signals: - logging.LogWarn("Received SIGINT or SIGTERM signal, exiting") - return nil - case err := <-common.ErrorsChan: - if go_errors.IsFatal(err) { - return err - } - logging.LogError("%s", fmt.Sprintf("%v", err)) - } - } -} diff --git a/uid/uid.go b/uid/uid.go deleted file mode 100644 index b98e342..0000000 --- a/uid/uid.go +++ /dev/null @@ -1,14 +0,0 @@ -package uid - -import ( - nanoid "github.com/matoous/go-nanoid/v2" -) - -func UID() string { - // No 'o','O' and 'l' - id, err := nanoid.Generate("abcdefghijkmnpqrstuvwxyzABCDEFGHIJKLMNPQRSTUVWXYZ0123456789", 20) - if err != nil { - panic(err) - } - return id -} diff --git a/util/util.go b/util/util.go index b5efb82..4b33647 100644 --- a/util/util.go +++ b/util/util.go @@ -6,14 +6,8 @@ import ( "fmt" "math/big" "regexp" - "runtime/debug" ) -func PrintStackAndPanic(err error) { - fmt.Printf("PANIC Error %s Stack trace:\n%s", err, debug.Stack()) - panic("PANIC") -} - // SecureRandomInt returns a cryptographically secure random integer in the range [0,max). // Panics if max <= 0 or if there's an error reading from the system's secure // random number generator. @@ -24,14 +18,14 @@ func SecureRandomInt(max int) int { // Generate random number n, err := rand.Int(rand.Reader, maxBig) if err != nil { - PrintStackAndPanic(fmt.Errorf("could not generate a random integer between 0 and %d", max)) + panic(fmt.Errorf("could not generate a random integer between 0 and %d", max)) } // Convert back to int return int(n.Int64()) } -func PrettyJson(data string) string { +func PrettifyJson(data string) string { marshalled, _ := json.MarshalIndent(data, "", " ") return fmt.Sprintf("%s\n", marshalled) } @@ -42,3 +36,27 @@ func GetLinesMatchingRegex(input string, pattern string) []string { matches := re.FindAllString(input, -1) return matches } + +// Filter applies a predicate function to each element in a slice and returns a new slice +// containing only the elements for which the predicate returns true. +// Type parameter T allows this function to work with slices of any type. +func Filter[T any](slice []T, f func(T) bool) []T { + filtered := make([]T, 0) + for _, v := range slice { + if f(v) { + filtered = append(filtered, v) + } + } + return filtered +} + +// Map applies a function to each element in a slice and returns a new slice +// containing the results. +// Type parameters T and R allow this function to work with different input and output types. +func Map[T any, R any](slice []T, f func(T) R) []R { + result := make([]R, len(slice)) + for i, v := range slice { + result[i] = f(v) + } + return result +}