From 4bceb75695520d9cb9451f045afe99efb1b02c64 Mon Sep 17 00:00:00 2001 From: antanst Date: Wed, 26 Feb 2025 10:34:25 +0200 Subject: [PATCH] Reorganize code for more granular imports --- bin/normalizeSnapshot/main.go | 14 +- common/gemini_url_test.go | 252 ---------------- common/linkList/linkList.go | 27 ++ common/shared.go | 13 + common/snapshot.go | 56 ---- common/snapshot/snapshot.go | 38 +++ common/{gemini_url.go => url/url.go} | 120 ++++++-- common/url/url_test.go | 420 +++++++++++++++++++++++++++ common/worker.go | 320 ++++++++++++++++++++ common/workerStatus.go | 36 ++- db/db.go | 160 ++++------ db/db_queries.go | 52 +--- gemini/files.go | 6 +- gemini/gemini.go | 49 ---- gemini/geminiLinks.go | 29 +- gemini/geminiLinks_test.go | 39 ++- gemini/gemini_test.go | 69 ----- gemini/ip-address-pool.go | 54 ---- gemini/network.go | 229 ++++++++------- gemini/network_test.go | 400 +++++++++++++++++++++---- gemini/processing.go | 2 +- gemini/robotmatch.go | 52 ++-- gemini/worker.go | 344 ---------------------- 23 files changed, 1549 insertions(+), 1232 deletions(-) delete mode 100644 common/gemini_url_test.go create mode 100644 common/linkList/linkList.go create mode 100644 common/shared.go delete mode 100644 common/snapshot.go create mode 100644 common/snapshot/snapshot.go rename common/{gemini_url.go => url/url.go} (61%) create mode 100644 common/url/url_test.go create mode 100644 common/worker.go delete mode 100644 gemini/gemini.go delete mode 100644 gemini/gemini_test.go delete mode 100644 gemini/ip-address-pool.go delete mode 100644 gemini/worker.go diff --git a/bin/normalizeSnapshot/main.go b/bin/normalizeSnapshot/main.go index fb6fea7..56de12d 100644 --- a/bin/normalizeSnapshot/main.go +++ b/bin/normalizeSnapshot/main.go @@ -4,7 +4,9 @@ import ( "fmt" "os" - "gemini-grc/gemini" + "gemini-grc/common/snapshot" + "gemini-grc/common/url" + main2 "gemini-grc/db" _ "github.com/jackc/pgx/v5/stdlib" // PGX driver for PostgreSQL "github.com/jmoiron/sqlx" ) @@ -21,7 +23,7 @@ func main() { ORDER BY id LIMIT 10000 OFFSET $1 ` - var snapshots []gemini.Snapshot + var snapshots []snapshot.Snapshot err := tx.Select(&snapshots, query, count) if err != nil { printErrorAndExit(tx, err) @@ -32,8 +34,8 @@ func main() { } for _, s := range snapshots { count++ - escaped := gemini.EscapeURL(s.URL.String()) - normalizedGeminiURL, err := gemini.ParseURL(escaped, "") + escaped := url.EscapeURL(s.URL.String()) + normalizedGeminiURL, err := url.ParseURL(escaped, "", true) if err != nil { fmt.Println(s.URL.String()) fmt.Println(escaped) @@ -47,7 +49,7 @@ func main() { } // If a snapshot already exists with the normalized // URL, delete the current snapshot and leave the other. - var ss []gemini.Snapshot + var ss []snapshot.Snapshot err = tx.Select(&ss, "SELECT * FROM snapshots WHERE URL=$1", normalizedURLString) if err != nil { printErrorAndExit(tx, err) @@ -69,7 +71,7 @@ func main() { // Saves the snapshot with the normalized URL tx.MustExec("DELETE FROM snapshots WHERE id=$1", s.ID) s.URL = *normalizedGeminiURL - err = gemini.UpsertSnapshot(0, tx, &s) + err = main2.OverwriteSnapshot(tx, &s) if err != nil { printErrorAndExit(tx, err) } diff --git a/common/gemini_url_test.go b/common/gemini_url_test.go deleted file mode 100644 index 821946a..0000000 --- a/common/gemini_url_test.go +++ /dev/null @@ -1,252 +0,0 @@ -package common_test - -import ( - "reflect" - "testing" - - "gemini-grc/common" -) - -func TestParseURL(t *testing.T) { - t.Parallel() - input := "gemini://caolan.uk/cgi-bin/weather.py/wxfcs/3162" - parsed, err := common.ParseURL(input, "", true) - value, _ := parsed.Value() - if err != nil || !(value == "gemini://caolan.uk:1965/cgi-bin/weather.py/wxfcs/3162") { - t.Errorf("fail: %s", parsed) - } -} - -func TestDeriveAbsoluteURL_abs_url_input(t *testing.T) { - t.Parallel() - currentURL := common.URL{ - Protocol: "gemini", - Hostname: "smol.gr", - Port: 1965, - Path: "/a/b", - Descr: "Nothing", - Full: "gemini://smol.gr:1965/a/b", - } - input := "gemini://a.b/c" - output, err := common.DeriveAbsoluteURL(currentURL, input) - if err != nil { - t.Errorf("fail: %v", err) - } - expected := &common.URL{ - Protocol: "gemini", - Hostname: "a.b", - Port: 1965, - Path: "/c", - Descr: "", - Full: "gemini://a.b:1965/c", - } - pass := reflect.DeepEqual(output, expected) - if !pass { - t.Errorf("fail: %#v != %#v", output, expected) - } -} - -func TestDeriveAbsoluteURL_abs_path_input(t *testing.T) { - t.Parallel() - currentURL := common.URL{ - Protocol: "gemini", - Hostname: "smol.gr", - Port: 1965, - Path: "/a/b", - Descr: "Nothing", - Full: "gemini://smol.gr:1965/a/b", - } - input := "/c" - output, err := common.DeriveAbsoluteURL(currentURL, input) - if err != nil { - t.Errorf("fail: %v", err) - } - expected := &common.URL{ - Protocol: "gemini", - Hostname: "smol.gr", - Port: 1965, - Path: "/c", - Descr: "", - Full: "gemini://smol.gr:1965/c", - } - pass := reflect.DeepEqual(output, expected) - if !pass { - t.Errorf("fail: %#v != %#v", output, expected) - } -} - -func TestDeriveAbsoluteURL_rel_path_input(t *testing.T) { - t.Parallel() - currentURL := common.URL{ - Protocol: "gemini", - Hostname: "smol.gr", - Port: 1965, - Path: "/a/b", - Descr: "Nothing", - Full: "gemini://smol.gr:1965/a/b", - } - input := "c/d" - output, err := common.DeriveAbsoluteURL(currentURL, input) - if err != nil { - t.Errorf("fail: %v", err) - } - expected := &common.URL{ - Protocol: "gemini", - Hostname: "smol.gr", - Port: 1965, - Path: "/a/b/c/d", - Descr: "", - Full: "gemini://smol.gr:1965/a/b/c/d", - } - pass := reflect.DeepEqual(output, expected) - if !pass { - t.Errorf("fail: %#v != %#v", output, expected) - } -} - -func TestNormalizeURLSlash(t *testing.T) { - t.Parallel() - input := "gemini://uscoffings.net/retro-computing/magazines/" - normalized, _ := common.NormalizeURL(input) - output := normalized.String() - expected := input - pass := reflect.DeepEqual(output, expected) - if !pass { - t.Errorf("fail: %#v != %#v", output, expected) - } -} - -func TestNormalizeURLNoSlash(t *testing.T) { - t.Parallel() - input := "gemini://uscoffings.net/retro-computing/magazines" - normalized, _ := common.NormalizeURL(input) - output := normalized.String() - expected := input - pass := reflect.DeepEqual(output, expected) - if !pass { - t.Errorf("fail: %#v != %#v", output, expected) - } -} - -func TestNormalizeMultiSlash(t *testing.T) { - t.Parallel() - input := "gemini://uscoffings.net/retro-computing/////////a///magazines" - normalized, _ := common.NormalizeURL(input) - output := normalized.String() - expected := "gemini://uscoffings.net/retro-computing/a/magazines" - pass := reflect.DeepEqual(output, expected) - if !pass { - t.Errorf("fail: %#v != %#v", output, expected) - } -} - -func TestNormalizeTrailingSlash(t *testing.T) { - t.Parallel() - input := "gemini://uscoffings.net/" - normalized, _ := common.NormalizeURL(input) - output := normalized.String() - expected := "gemini://uscoffings.net/" - pass := reflect.DeepEqual(output, expected) - if !pass { - t.Errorf("fail: %#v != %#v", output, expected) - } -} - -func TestNormalizeNoTrailingSlash(t *testing.T) { - t.Parallel() - input := "gemini://uscoffings.net" - normalized, _ := common.NormalizeURL(input) - output := normalized.String() - expected := "gemini://uscoffings.net" - pass := reflect.DeepEqual(output, expected) - if !pass { - t.Errorf("fail: %#v != %#v", output, expected) - } -} - -func TestNormalizeTrailingSlashPath(t *testing.T) { - t.Parallel() - input := "gemini://uscoffings.net/a/" - normalized, _ := common.NormalizeURL(input) - output := normalized.String() - expected := "gemini://uscoffings.net/a/" - pass := reflect.DeepEqual(output, expected) - if !pass { - t.Errorf("fail: %#v != %#v", output, expected) - } -} - -func TestNormalizeNoTrailingSlashPath(t *testing.T) { - t.Parallel() - input := "gemini://uscoffings.net/a" - normalized, _ := common.NormalizeURL(input) - output := normalized.String() - expected := "gemini://uscoffings.net/a" - pass := reflect.DeepEqual(output, expected) - if !pass { - t.Errorf("fail: %#v != %#v", output, expected) - } -} - -func TestNormalizeDot(t *testing.T) { - t.Parallel() - input := "gemini://uscoffings.net/retro-computing/./././////a///magazines" - normalized, _ := common.NormalizeURL(input) - output := normalized.String() - expected := "gemini://uscoffings.net/retro-computing/a/magazines" - pass := reflect.DeepEqual(output, expected) - if !pass { - t.Errorf("fail: %#v != %#v", output, expected) - } -} - -func TestNormalizePort(t *testing.T) { - t.Parallel() - input := "gemini://uscoffings.net:1965/a" - normalized, _ := common.NormalizeURL(input) - output := normalized.String() - expected := "gemini://uscoffings.net/a" - pass := reflect.DeepEqual(output, expected) - if !pass { - t.Errorf("fail: %#v != %#v", output, expected) - } -} - -func TestNormalizeURL(t *testing.T) { - t.Parallel() - input := "gemini://chat.gemini.lehmann.cx:11965/" - normalized, _ := common.NormalizeURL(input) - output := normalized.String() - expected := "gemini://chat.gemini.lehmann.cx:11965/" - pass := reflect.DeepEqual(output, expected) - if !pass { - t.Errorf("fail: %#v != %#v", output, expected) - } - - input = "gemini://chat.gemini.lehmann.cx:11965/index?a=1&b=c" - normalized, _ = common.NormalizeURL(input) - output = normalized.String() - expected = "gemini://chat.gemini.lehmann.cx:11965/index?a=1&b=c" - pass = reflect.DeepEqual(output, expected) - if !pass { - t.Errorf("fail: %#v != %#v", output, expected) - } - - input = "gemini://chat.gemini.lehmann.cx:11965/index#1" - normalized, _ = common.NormalizeURL(input) - output = normalized.String() - expected = "gemini://chat.gemini.lehmann.cx:11965/index#1" - pass = reflect.DeepEqual(output, expected) - if !pass { - t.Errorf("fail: %#v != %#v", output, expected) - } - - input = "gemini://gemi.dev/cgi-bin/xkcd.cgi?1494" - normalized, _ = common.NormalizeURL(input) - output = normalized.String() - expected = "gemini://gemi.dev/cgi-bin/xkcd.cgi?1494" - pass = reflect.DeepEqual(output, expected) - if !pass { - t.Errorf("fail: %#v != %#v", output, expected) - } -} diff --git a/common/linkList/linkList.go b/common/linkList/linkList.go new file mode 100644 index 0000000..12615e9 --- /dev/null +++ b/common/linkList/linkList.go @@ -0,0 +1,27 @@ +package linkList + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + + "gemini-grc/common/url" +) + +type LinkList []url.URL + +func (l *LinkList) Value() (driver.Value, error) { + return json.Marshal(l) +} + +func (l *LinkList) Scan(value interface{}) error { + if value == nil { + *l = nil + return nil + } + b, ok := value.([]byte) // Type assertion! Converts to []byte + if !ok { + return fmt.Errorf("failed to scan LinkList: expected []byte, got %T", value) + } + return json.Unmarshal(b, l) +} diff --git a/common/shared.go b/common/shared.go new file mode 100644 index 0000000..2ecb642 --- /dev/null +++ b/common/shared.go @@ -0,0 +1,13 @@ +package common + +var ( + StatusChan chan WorkerStatus + // ErrorsChan accepts errors from workers. + // In case of fatal error, gracefully + // exits the application. + ErrorsChan chan error +) + +const VERSION string = "0.0.1" + +const CtxKeyLogger string = "CtxKeyLogger" diff --git a/common/snapshot.go b/common/snapshot.go deleted file mode 100644 index 810fb15..0000000 --- a/common/snapshot.go +++ /dev/null @@ -1,56 +0,0 @@ -package common - -import ( - "database/sql/driver" - "encoding/json" - "fmt" - "time" - - "github.com/guregu/null/v5" -) - -type LinkList []URL - -func (l *LinkList) Value() (driver.Value, error) { - return json.Marshal(l) -} - -func (l *LinkList) Scan(value interface{}) error { - if value == nil { - *l = nil - return nil - } - b, ok := value.([]byte) // Type assertion! Converts to []byte - if !ok { - return fmt.Errorf("failed to scan LinkList: expected []byte, got %T", value) - } - return json.Unmarshal(b, l) -} - -type Snapshot struct { - ID int `db:"id" json:"id,omitempty"` - URL URL `db:"url" json:"url,omitempty"` - Host string `db:"host" json:"host,omitempty"` - Timestamp null.Time `db:"timestamp" json:"timestamp,omitempty"` - MimeType null.String `db:"mimetype" json:"mimetype,omitempty"` - Data null.Value[[]byte] `db:"data" json:"data,omitempty"` // For non text/gemini files. - GemText null.String `db:"gemtext" json:"gemtext,omitempty"` // For text/gemini files. - Header null.String `db:"header" json:"header,omitempty"` // Response header. - Links null.Value[LinkList] `db:"links" json:"links,omitempty"` - Lang null.String `db:"lang" json:"lang,omitempty"` - ResponseCode null.Int `db:"response_code" json:"code,omitempty"` // Gemini response status code. - Error null.String `db:"error" json:"error,omitempty"` // On network errors only -} - -func SnapshotFromURL(u string) *Snapshot { - url, err := ParseURL(u, "") - if err != nil { - return nil - } - newSnapshot := Snapshot{ - URL: *url, - Host: url.Hostname, - Timestamp: null.TimeFrom(time.Now()), - } - return &newSnapshot -} diff --git a/common/snapshot/snapshot.go b/common/snapshot/snapshot.go new file mode 100644 index 0000000..c55248d --- /dev/null +++ b/common/snapshot/snapshot.go @@ -0,0 +1,38 @@ +package snapshot + +import ( + "time" + + "gemini-grc/common/linkList" + commonUrl "gemini-grc/common/url" + "gemini-grc/errors" + "github.com/guregu/null/v5" +) + +type Snapshot struct { + ID int `db:"ID" json:"ID,omitempty"` + URL commonUrl.URL `db:"url" json:"url,omitempty"` + Host string `db:"host" json:"host,omitempty"` + Timestamp null.Time `db:"timestamp" json:"timestamp,omitempty"` + MimeType null.String `db:"mimetype" json:"mimetype,omitempty"` + Data null.Value[[]byte] `db:"data" json:"data,omitempty"` // For non text/gemini files. + GemText null.String `db:"gemtext" json:"gemtext,omitempty"` // For text/gemini files. + Header null.String `db:"header" json:"header,omitempty"` // Response header. + Links null.Value[linkList.LinkList] `db:"links" json:"links,omitempty"` + Lang null.String `db:"lang" json:"lang,omitempty"` + ResponseCode null.Int `db:"response_code" json:"code,omitempty"` // Gemini response Status code. + Error null.String `db:"error" json:"error,omitempty"` // On network errors only +} + +func SnapshotFromURL(u string, normalize bool) (*Snapshot, error) { + url, err := commonUrl.ParseURL(u, "", normalize) + if err != nil { + return nil, errors.NewError(err) + } + newSnapshot := Snapshot{ + URL: *url, + Host: url.Hostname, + Timestamp: null.TimeFrom(time.Now()), + } + return &newSnapshot, nil +} diff --git a/common/gemini_url.go b/common/url/url.go similarity index 61% rename from common/gemini_url.go rename to common/url/url.go index 74ec472..57bc4b2 100644 --- a/common/gemini_url.go +++ b/common/url/url.go @@ -1,12 +1,15 @@ -package common +package url import ( "database/sql/driver" "fmt" "net/url" "path" + "regexp" "strconv" "strings" + + "gemini-grc/errors" ) type URL struct { @@ -26,11 +29,10 @@ func (u *URL) Scan(value interface{}) error { } b, ok := value.(string) if !ok { - return fmt.Errorf("%w: expected string, got %T", ErrDatabaseScan, value) + return errors.NewFatalError(fmt.Errorf("database scan error: expected string, got %T", value)) } parsedURL, err := ParseURL(b, "", false) if err != nil { - err = fmt.Errorf("%w: failed to scan GeminiUrl %s: %v", ErrDatabaseScan, b, err) return err } *u = *parsedURL @@ -42,8 +44,14 @@ func (u URL) String() string { } func (u URL) StringNoDefaultPort() string { - if u.Port == 1965 { - return fmt.Sprintf("%s://%s%s", u.Protocol, u.Hostname, u.Path) + if IsGeminiUrl(u.String()) { + if u.Port == 1965 { + return fmt.Sprintf("%s://%s%s", u.Protocol, u.Hostname, u.Path) + } + } else { + if u.Port == 70 { + return fmt.Sprintf("%s://%s%s", u.Protocol, u.Hostname, u.Path) + } } return u.Full } @@ -55,30 +63,43 @@ func (u URL) Value() (driver.Value, error) { return u.Full, nil } +func IsGeminiUrl(url string) bool { + return strings.HasPrefix(url, "gemini://") +} + +func IsGopherURL(s string) bool { + return strings.HasPrefix(s, "gopher://") +} + func ParseURL(input string, descr string, normalize bool) (*URL, error) { var u *url.URL var err error if normalize { u, err = NormalizeURL(input) + if err != nil { + return nil, err + } } else { u, err = url.Parse(input) - } - if err != nil { - return nil, fmt.Errorf("%w: Input %s URL Parse Error: %w", ErrURLParse, input, err) - } - if u.Scheme != "gemini" { - return nil, fmt.Errorf("%w: URL scheme '%s' is not supported", ErrURLNotGemini, u.Scheme) + if err != nil { + return nil, errors.NewError(fmt.Errorf("error parsing URL: %w: %s", err, input)) + } } protocol := u.Scheme hostname := u.Hostname() strPort := u.Port() + // urlPath := u.EscapedPath() urlPath := u.Path if strPort == "" { - strPort = "1965" + if u.Scheme == "gemini" { + strPort = "1965" // default Gemini port + } else { + strPort = "70" // default Gopher port + } } port, err := strconv.Atoi(strPort) if err != nil { - return nil, fmt.Errorf("%w: Input %s GeminiError %w", ErrURLParse, input, err) + return nil, errors.NewError(fmt.Errorf("error parsing URL: %w: %s", err, input)) } full := fmt.Sprintf("%s://%s:%d%s", protocol, hostname, port, urlPath) // full field should also contain query params and url fragments @@ -113,7 +134,7 @@ func DeriveAbsoluteURL(currentURL URL, input string) (*URL, error) { return ParseURL(strURL, "", true) } -// NormalizeURL takes a URL string and returns a normalized version. +// NormalizeURL takes a URL string and returns a normalized version // Normalized meaning: // - Path normalization (removing redundant slashes, . and .. segments) // - Proper escaping of special characters @@ -124,7 +145,13 @@ func NormalizeURL(rawURL string) (*url.URL, error) { // Parse the URL u, err := url.Parse(rawURL) if err != nil { - return nil, fmt.Errorf("%w: %w", ErrURLParse, err) + return nil, errors.NewError(fmt.Errorf("error normalizing URL: %w: %s", err, rawURL)) + } + if u.Scheme == "" { + return nil, errors.NewError(fmt.Errorf("error normalizing URL: No scheme: %s", rawURL)) + } + if u.Host == "" { + return nil, errors.NewError(fmt.Errorf("error normalizing URL: No host: %s", rawURL)) } // Convert scheme to lowercase @@ -135,7 +162,7 @@ func NormalizeURL(rawURL string) (*url.URL, error) { u.Host = strings.ToLower(u.Host) } - // Remove default ports + // remove default ports if u.Port() != "" { switch { case u.Scheme == "http" && u.Port() == "80": @@ -144,6 +171,8 @@ func NormalizeURL(rawURL string) (*url.URL, error) { u.Host = u.Hostname() case u.Scheme == "gemini" && u.Port() == "1965": u.Host = u.Hostname() + case u.Scheme == "gopher" && u.Port() == "70": + u.Host = u.Hostname() } } @@ -152,7 +181,7 @@ func NormalizeURL(rawURL string) (*url.URL, error) { // Check if there was a trailing slash before cleaning hadTrailingSlash := strings.HasSuffix(u.Path, "/") - u.Path = path.Clean(u.Path) + u.Path = path.Clean(u.EscapedPath()) // If path was "/", path.Clean() will return "." if u.Path == "." { u.Path = "/" @@ -162,20 +191,25 @@ func NormalizeURL(rawURL string) (*url.URL, error) { } } - // Properly escape the path - // First split on '/' to avoid escaping them + // Properly escape the path, but only for unescaped parts parts := strings.Split(u.Path, "/") for i, part := range parts { - parts[i] = url.PathEscape(part) + // Try to unescape to check if it's already escaped + unescaped, err := url.PathUnescape(part) + if err != nil || unescaped == part { + // Part is not escaped, so escape it + parts[i] = url.PathEscape(part) + } + // If already escaped, leave as is } u.Path = strings.Join(parts, "/") - // Remove trailing fragment if empty + // remove trailing fragment if empty if u.Fragment == "" { u.Fragment = "" } - // Remove trailing query if empty + // remove trailing query if empty if u.RawQuery == "" { u.RawQuery = "" } @@ -188,7 +222,7 @@ func EscapeURL(input string) string { if strings.Contains(input, "%") && !strings.Contains(input, "% ") { return input } - // Split URL into parts (protocol, host, path) + // Split URL into parts (protocol, host, p) parts := strings.SplitN(input, "://", 2) if len(parts) != 2 { return input @@ -202,18 +236,50 @@ func EscapeURL(input string) string { return input } - // Split host and path + // Split host and p parts = strings.SplitN(remainder, "/", 2) host := parts[0] if len(parts) == 1 { return protocol + "://" + host } - path := parts[1] - // Escape the path portion - escapedPath := url.PathEscape(path) + escapedPath := url.PathEscape(parts[1]) // Reconstruct the URL return protocol + "://" + host + "/" + escapedPath } + +// TrimTrailingPathSlash trims trailing slash and handles empty path +func TrimTrailingPathSlash(path string) string { + // Handle empty path (e.g., "http://example.com" -> treat as root) + if path == "" { + return "/" + } + + // Trim trailing slash while preserving root slash + path = strings.TrimSuffix(path, "/") + if path == "" { // This happens if path was just "/" + return "/" + } + return path +} + +// ExtractRedirectTargetFromHeader returns the redirection +// URL by parsing the header (or error message) +func ExtractRedirectTargetFromHeader(currentURL URL, input string) (*URL, error) { + // \d+ - matches one or more digits + // \s+ - matches one or more whitespace + // ([^\r]+) - captures everything until it hits a \r (or end of string) + pattern := `\d+\s+([^\r]+)` + re := regexp.MustCompile(pattern) + matches := re.FindStringSubmatch(input) + if len(matches) < 2 { + return nil, errors.NewError(fmt.Errorf("error extracting redirect target from string %s", input)) + } + newURL, err := DeriveAbsoluteURL(currentURL, matches[1]) + if err != nil { + return nil, err + } + return newURL, nil +} diff --git a/common/url/url_test.go b/common/url/url_test.go new file mode 100644 index 0000000..5f4373d --- /dev/null +++ b/common/url/url_test.go @@ -0,0 +1,420 @@ +package url + +import ( + "reflect" + "testing" +) + +func TestURLOperations(t *testing.T) { + t.Parallel() + + t.Run("ParseURL", func(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input string + base string + absolute bool + want string + wantErr bool + }{ + { + name: "parse CGI URL", + input: "gemini://caolan.uk/cgi-bin/weather.py/wxfcs/3162", + base: "", + absolute: true, + want: "gemini://caolan.uk:1965/cgi-bin/weather.py/wxfcs/3162", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + parsed, err := ParseURL(tt.input, tt.base, tt.absolute) + if (err != nil) != tt.wantErr { + t.Errorf("ParseURL() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + value, _ := parsed.Value() + if value != tt.want { + t.Errorf("ParseURL() = %v, want %v", value, tt.want) + } + } + }) + } + }) + + t.Run("DeriveAbsoluteURL", func(t *testing.T) { + t.Parallel() + + baseURL := URL{ + Protocol: "gemini", + Hostname: "smol.gr", + Port: 1965, + Path: "/a/b", + Descr: "Nothing", + Full: "gemini://smol.gr:1965/a/b", + } + + tests := []struct { + name string + current URL + input string + expected *URL + }{ + { + name: "absolute URL input", + current: baseURL, + input: "gemini://a.b/c", + expected: &URL{ + Protocol: "gemini", + Hostname: "a.b", + Port: 1965, + Path: "/c", + Full: "gemini://a.b:1965/c", + }, + }, + { + name: "absolute path input", + current: baseURL, + input: "/c", + expected: &URL{ + Protocol: "gemini", + Hostname: "smol.gr", + Port: 1965, + Path: "/c", + Full: "gemini://smol.gr:1965/c", + }, + }, + { + name: "relative path input", + current: baseURL, + input: "c/d", + expected: &URL{ + Protocol: "gemini", + Hostname: "smol.gr", + Port: 1965, + Path: "/a/b/c/d", + Full: "gemini://smol.gr:1965/a/b/c/d", + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + output, err := DeriveAbsoluteURL(tt.current, tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !reflect.DeepEqual(output, tt.expected) { + t.Errorf("got %#v, want %#v", output, tt.expected) + } + }) + } + }) + + t.Run("NormalizeURL", func(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "with trailing slash", + input: "gemini://uscoffings.net/retro-computing/magazines/", + expected: "gemini://uscoffings.net/retro-computing/magazines/", + }, + { + name: "without trailing slash", + input: "gemini://uscoffings.net/retro-computing/magazines", + expected: "gemini://uscoffings.net/retro-computing/magazines", + }, + { + name: "multiple slashes", + input: "gemini://uscoffings.net/retro-computing/////////a///magazines", + expected: "gemini://uscoffings.net/retro-computing/a/magazines", + }, + { + name: "root with trailing slash", + input: "gemini://uscoffings.net/", + expected: "gemini://uscoffings.net/", + }, + { + name: "root without trailing slash", + input: "gemini://uscoffings.net", + expected: "gemini://uscoffings.net", + }, + { + name: "path with trailing slash", + input: "gemini://uscoffings.net/a/", + expected: "gemini://uscoffings.net/a/", + }, + { + name: "path without trailing slash", + input: "gemini://uscoffings.net/a", + expected: "gemini://uscoffings.net/a", + }, + { + name: "with dot segments", + input: "gemini://uscoffings.net/retro-computing/./././////a///magazines", + expected: "gemini://uscoffings.net/retro-computing/a/magazines", + }, + { + name: "with default port", + input: "gemini://uscoffings.net:1965/a", + expected: "gemini://uscoffings.net/a", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + normalized, err := NormalizeURL(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + output := normalized.String() + if output != tt.expected { + t.Errorf("got %#v, want %#v", output, tt.expected) + } + }) + } + }) +} + +func TestNormalizeURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "URL with non-default port", + input: "gemini://chat.gemini.lehmann.cx:11965/", + expected: "gemini://chat.gemini.lehmann.cx:11965/", + }, + { + name: "URL with query parameters", + input: "gemini://chat.gemini.lehmann.cx:11965/index?a=1&b=c", + expected: "gemini://chat.gemini.lehmann.cx:11965/index?a=1&b=c", + }, + { + name: "URL with fragment", + input: "gemini://chat.gemini.lehmann.cx:11965/index#1", + expected: "gemini://chat.gemini.lehmann.cx:11965/index#1", + }, + { + name: "URL with CGI script and query", + input: "gemini://gemi.dev/cgi-bin/xkcd.cgi?1494", + expected: "gemini://gemi.dev/cgi-bin/xkcd.cgi?1494", + }, + } + + for _, tt := range tests { + tt := tt // capture range variable for parallel testing + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + normalized, err := NormalizeURL(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + output := normalized.String() + if output != tt.expected { + t.Errorf("got %#v, want %#v", output, tt.expected) + } + }) + } +} + +func TestNormalizePath(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input string // URL string to parse + expected string // Expected normalized path + }{ + // Basic cases + { + name: "empty_path", + input: "http://example.com", + expected: "", + }, + { + name: "root_path", + input: "http://example.com/", + expected: "/", + }, + { + name: "single_trailing_slash", + input: "http://example.com/test/", + expected: "/test/", + }, + { + name: "no_trailing_slash", + input: "http://example.com/test", + expected: "/test", + }, + + // Edge cases with slashes + { + name: "multiple_trailing_slashes", + input: "http://example.com/test//", + expected: "/test/", + }, + { + name: "multiple_consecutive_slashes", + input: "http://example.com//test//", + expected: "/test/", + }, + { + name: "only_slashes", + input: "http://example.com////", + expected: "/", + }, + + // Encoded characters + { + name: "encoded_spaces", + input: "http://example.com/foo%20bar/", + expected: "/foo%20bar/", + }, + { + name: "encoded_special_chars", + input: "http://example.com/foo%2Fbar/", + expected: "/foo%2Fbar/", + }, + + // Query parameters and fragments + { + name: "with_query_parameters", + input: "http://example.com/path?query=param", + expected: "/path", + }, + { + name: "with_fragment", + input: "http://example.com/path#fragment", + expected: "/path", + }, + { + name: "with_both_query_and_fragment", + input: "http://example.com/path?query=param#fragment", + expected: "/path", + }, + + // Unicode paths + { + name: "unicode_characters", + input: "http://example.com/über/path/", + expected: "/%C3%BCber/path/", + }, + { + name: "unicode_encoded", + input: "http://example.com/%C3%BCber/path/", + expected: "/%C3%BCber/path/", + }, + + // Weird but valid cases + { + name: "dot_in_path", + input: "http://example.com/./path/", + expected: "/path/", + }, + { + name: "double_dot_in_path", + input: "http://example.com/../path/", + expected: "/path/", + }, + { + name: "mixed_case", + input: "http://example.com/PaTh/", + expected: "/PaTh/", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + u, err := ParseURL(tt.input, "", true) + if err != nil { + t.Fatalf("Failed to parse URL %q: %v", tt.input, err) + } + + result := u.Path + if result != tt.expected { + t.Errorf("Input: %s\nExpected: %q\nGot: %q", + u.Path, tt.expected, result) + } + }) + } +} + +func TestExtractRedirectTargetFullURL(t *testing.T) { + t.Parallel() + currentURL, _ := ParseURL("gemini://smol.gr", "", true) + input := "redirect: 31 gemini://target.gr" + result, err := ExtractRedirectTargetFromHeader(*currentURL, input) + expected := "gemini://target.gr:1965" + if err != nil || (result.String() != expected) { + t.Errorf("fail: Expected %s got %s", expected, result) + } +} + +func TestExtractRedirectTargetFullURLSlash(t *testing.T) { + t.Parallel() + currentURL, _ := ParseURL("gemini://smol.gr", "", true) + input := "redirect: 31 gemini://target.gr/" + result, err := ExtractRedirectTargetFromHeader(*currentURL, input) + expected := "gemini://target.gr:1965/" + if err != nil || (result.String() != expected) { + t.Errorf("fail: Expected %s got %s", expected, result) + } +} + +func TestExtractRedirectTargetRelativeURL(t *testing.T) { + t.Parallel() + currentURL, _ := ParseURL("gemini://smol.gr", "", true) + input := "redirect: 31 /a/b" + result, err := ExtractRedirectTargetFromHeader(*currentURL, input) + if err != nil || (result.String() != "gemini://smol.gr:1965/a/b") { + t.Errorf("fail: %s", result) + } +} + +func TestExtractRedirectTargetRelativeURL2(t *testing.T) { + t.Parallel() + currentURL, _ := ParseURL("gemini://nox.im:1965", "", true) + input := "redirect: 31 ./" + result, err := ExtractRedirectTargetFromHeader(*currentURL, input) + if err != nil || (result.String() != "gemini://nox.im:1965/") { + t.Errorf("fail: %s", result) + } +} + +func TestExtractRedirectTargetRelativeURL3(t *testing.T) { + t.Parallel() + currentURL, _ := ParseURL("gemini://status.zvava.org:1965", "", true) + input := "redirect: 31 index.gmi" + result, err := ExtractRedirectTargetFromHeader(*currentURL, input) + if err != nil || (result.String() != "gemini://status.zvava.org:1965/index.gmi") { + t.Errorf("fail: %s", result) + } +} + +func TestExtractRedirectTargetWrong(t *testing.T) { + t.Parallel() + currentURL, _ := ParseURL("gemini://smol.gr", "", true) + input := "redirect: 31" + result, err := ExtractRedirectTargetFromHeader(*currentURL, input) + if result != nil || err == nil { + t.Errorf("fail: result should be nil, err is %s", err) + } +} diff --git a/common/worker.go b/common/worker.go new file mode 100644 index 0000000..f974266 --- /dev/null +++ b/common/worker.go @@ -0,0 +1,320 @@ +package common + +import ( + "fmt" + "time" + + "gemini-grc/common/blackList" + errors2 "gemini-grc/common/errors" + "gemini-grc/common/snapshot" + url2 "gemini-grc/common/url" + _db "gemini-grc/db" + "gemini-grc/errors" + "gemini-grc/gemini" + "gemini-grc/gopher" + "gemini-grc/hostPool" + "gemini-grc/logging" + "github.com/guregu/null/v5" + "github.com/jmoiron/sqlx" +) + +func CrawlOneURL(db *sqlx.DB, url *string) error { + parsedURL, err := url2.ParseURL(*url, "", true) + if err != nil { + return err + } + + if !url2.IsGeminiUrl(parsedURL.String()) && !url2.IsGopherURL(parsedURL.String()) { + return errors.NewError(fmt.Errorf("error parsing URL: not a Gemini or Gopher URL: %s", parsedURL.String())) + } + + tx, err := db.Beginx() + if err != nil { + return errors.NewFatalError(err) + } + + err = _db.InsertURL(tx, parsedURL.Full) + if err != nil { + return err + } + + err = workOnUrl(0, tx, parsedURL.Full) + if err != nil { + return err + } + + err = tx.Commit() + if err != nil { + //if _db.IsDeadlockError(err) { + // logging.LogError("Deadlock detected. Rolling back") + // time.Sleep(time.Duration(10) * time.Second) + // err := tx.Rollback() + // return errors.NewFatalError(err) + //} + return errors.NewFatalError(err) + } + logging.LogInfo("Done") + return nil +} + +func SpawnWorkers(numOfWorkers int, db *sqlx.DB) { + logging.LogInfo("Spawning %d workers", numOfWorkers) + go PrintWorkerStatus(numOfWorkers, StatusChan) + + for i := range numOfWorkers { + go func(i int) { + UpdateWorkerStatus(i, "Waiting to start") + // Jitter to avoid starting everything at the same time + time.Sleep(time.Duration(i+2) * time.Second) + for { + // TODO: Use cancellable context with tx, logger & worker ID. + // ctx := context.WithCancel() + // ctx = context.WithValue(ctx, common.CtxKeyLogger, &RequestLogger{r: r}) + RunWorkerWithTx(i, db) + } + }(i) + } +} + +func RunWorkerWithTx(workerID int, db *sqlx.DB) { + defer func() { + UpdateWorkerStatus(workerID, "Done") + }() + + tx, err := db.Beginx() + if err != nil { + ErrorsChan <- err + return + } + + err = runWorker(workerID, tx) + if err != nil { + // TODO: Rollback in this case? + ErrorsChan <- err + return + } + + logging.LogDebug("[%3d] Committing transaction", workerID) + err = tx.Commit() + // On deadlock errors, rollback and return, otherwise panic. + if err != nil { + logging.LogError("[%3d] Failed to commit transaction: %w", workerID, err) + if _db.IsDeadlockError(err) { + logging.LogError("[%3d] Deadlock detected. Rolling back", workerID) + time.Sleep(time.Duration(10) * time.Second) + err := tx.Rollback() + if err != nil { + panic(fmt.Sprintf("[%3d] Failed to roll back transaction: %v", workerID, err)) + } + return + } + panic(fmt.Sprintf("[%3d] Failed to commit transaction: %v", workerID, err)) + } + logging.LogDebug("[%3d] Worker done!", workerID) +} + +func runWorker(workerID int, tx *sqlx.Tx) error { + var urls []string + var err error + + UpdateWorkerStatus(workerID, "Getting URLs from DB") + urls, err = _db.GetRandomUrls(tx) + // urls, err = _db.GetRandomUrlsWithBasePath(tx) + if err != nil { + return err + } else if len(urls) == 0 { + logging.LogInfo("[%3d] No URLs to visit, sleeping...", workerID) + UpdateWorkerStatus(workerID, "No URLs to visit, sleeping...") + time.Sleep(1 * time.Minute) + return nil + } + + // Start visiting URLs. + total := len(urls) + for i, u := range urls { + logging.LogInfo("[%3d] Starting %d/%d %s", workerID, i+1, total, u) + UpdateWorkerStatus(workerID, fmt.Sprintf("Starting %d/%d %s", i+1, total, u)) + err := workOnUrl(workerID, tx, u) + if err != nil { + return err + } + logging.LogDebug("[%3d] Done %d/%d.", workerID, i+1, total) + UpdateWorkerStatus(workerID, fmt.Sprintf("Done %d/%d %s", i+1, total, u)) + } + return nil +} + +// workOnUrl visits a URL and stores the result. +// unexpected errors are returned. +// expected errors are stored within the snapshot. +func workOnUrl(workerID int, tx *sqlx.Tx, url string) (err error) { + s, err := snapshot.SnapshotFromURL(url, false) + if err != nil { + return err + } + + isGemini := url2.IsGeminiUrl(s.URL.String()) + isGopher := url2.IsGopherURL(s.URL.String()) + if !isGemini && !isGopher { + return errors.NewError(fmt.Errorf("not a Gopher or Gemini URL: %s", s.URL.String())) + } + + if blackList.IsBlacklisted(s.URL.String()) { + logging.LogInfo("[%3d] URL matches blacklist, ignoring", workerID) + s.Error = null.StringFrom(errors2.ErrBlacklistMatch.Error()) + return saveSnapshotAndRemoveURL(tx, s) + } + + if isGemini { + // If URL matches a robots.txt disallow line, + // add it as an error and remove url + robotMatch, err := gemini.RobotMatch(s.URL.String()) + if err != nil { + // robotMatch returns only network errors! + // we stop because we don't want to hit + // the server with another request on this case. + return err + } + if robotMatch { + logging.LogInfo("[%3d] URL matches robots.txt, ignoring", workerID) + s.Error = null.StringFrom(errors2.ErrRobotsMatch.Error()) + return saveSnapshotAndRemoveURL(tx, s) + } + } + + logging.LogDebug("[%3d] Adding to pool %s", workerID, s.URL.String()) + UpdateWorkerStatus(workerID, fmt.Sprintf("Adding to pool %s", s.URL.String())) + hostPool.AddHostToHostPool(s.Host) + defer func(s string) { + hostPool.RemoveHostFromPool(s) + }(s.Host) + + logging.LogDebug("[%3d] Visiting %s", workerID, s.URL.String()) + UpdateWorkerStatus(workerID, fmt.Sprintf("Visiting %s", s.URL.String())) + + if isGopher { + s, err = gopher.Visit(s.URL.String()) + } else { + s, err = gemini.Visit(s.URL.String()) + } + + if err != nil { + return err + } + + // Handle Gemini redirection. + if isGemini && + s.ResponseCode.ValueOrZero() >= 30 && + s.ResponseCode.ValueOrZero() < 40 { + err = handleRedirection(workerID, tx, s) + if err != nil { + return fmt.Errorf("error while handling redirection: %s", err) + } + } + + // Store links + if len(s.Links.ValueOrZero()) > 0 { + logging.LogDebug("[%3d] Found %d links", workerID, len(s.Links.ValueOrZero())) + err = storeLinks(tx, s) + if err != nil { + return err + } + } + + logging.LogInfo("[%3d] %2d %s", workerID, s.ResponseCode.ValueOrZero(), s.URL.String()) + return saveSnapshotAndRemoveURL(tx, s) +} + +func storeLinks(tx *sqlx.Tx, s *snapshot.Snapshot) error { + if s.Links.Valid { //nolint:nestif + for _, link := range s.Links.ValueOrZero() { + if shouldPersistURL(&link) { + visited, err := haveWeVisitedURL(tx, link.Full) + if err != nil { + return err + } + if !visited { + err := _db.InsertURL(tx, link.Full) + if err != nil { + return err + } + } else { + logging.LogDebug("Link already persisted: %s", link.Full) + } + } + } + } + return nil +} + +func saveSnapshotAndRemoveURL(tx *sqlx.Tx, s *snapshot.Snapshot) error { + err := _db.OverwriteSnapshot(tx, s) + if err != nil { + return err + } + err = _db.DeleteURL(tx, s.URL.String()) + if err != nil { + return err + } + return nil +} + +// shouldPersistURL returns true if we +// should save the URL in the _db. +// Only gemini:// urls are saved. +func shouldPersistURL(u *url2.URL) bool { + return url2.IsGeminiUrl(u.String()) || url2.IsGopherURL(u.String()) +} + +func haveWeVisitedURL(tx *sqlx.Tx, u string) (bool, error) { + var result []bool + err := tx.Select(&result, `SELECT TRUE FROM urls WHERE url=$1`, u) + if err != nil { + return false, errors.NewFatalError(fmt.Errorf("database error: %w", err)) + } + if len(result) > 0 { + return result[0], nil + } + err = tx.Select(&result, `SELECT TRUE FROM snapshots WHERE snapshots.url=$1`, u) + if err != nil { + return false, errors.NewFatalError(fmt.Errorf("database error: %w", err)) + } + if len(result) > 0 { + return result[0], nil + } + return false, nil +} + +// handleRedirection saves redirection URL. +func handleRedirection(workerID int, tx *sqlx.Tx, s *snapshot.Snapshot) error { + newURL, err := url2.ExtractRedirectTargetFromHeader(s.URL, s.Error.ValueOrZero()) + if err != nil { + return err + } + logging.LogDebug("[%3d] Page redirects to %s", workerID, newURL) + + haveWeVisited, _ := haveWeVisitedURL(tx, newURL.String()) + if shouldPersistURL(newURL) && !haveWeVisited { + err = _db.InsertURL(tx, newURL.Full) + if err != nil { + return err + } + logging.LogDebug("[%3d] Saved redirection URL %s", workerID, newURL.String()) + } + return nil +} + +func GetSnapshotFromURL(tx *sqlx.Tx, url string) ([]snapshot.Snapshot, error) { + query := ` + SELECT * + FROM snapshots + WHERE url=$1 + LIMIT 1 + ` + var snapshots []snapshot.Snapshot + err := tx.Select(&snapshots, query, url) + if err != nil { + return nil, err + } + return snapshots, nil +} diff --git a/common/workerStatus.go b/common/workerStatus.go index d77b618..2c3f6ef 100644 --- a/common/workerStatus.go +++ b/common/workerStatus.go @@ -1,19 +1,35 @@ -package gemini +package common import ( "fmt" "strings" + + "gemini-grc/config" ) type WorkerStatus struct { - id int - status string + ID int + Status string } -var statusChan chan WorkerStatus +func UpdateWorkerStatus(workerID int, status string) { + if !config.GetConfig().PrintWorkerStatus { + return + } + if config.CONFIG.NumOfWorkers > 1 { + StatusChan <- WorkerStatus{ + ID: workerID, + Status: status, + } + } +} func PrintWorkerStatus(totalWorkers int, statusChan chan WorkerStatus) { - // Create a slice to store current status of each worker + if !config.GetConfig().PrintWorkerStatus { + return + } + + // Create a slice to store current Status of each worker statuses := make([]string, totalWorkers) // Initialize empty statuses @@ -32,14 +48,14 @@ func PrintWorkerStatus(totalWorkers int, statusChan chan WorkerStatus) { } fmt.Print(output.String()) - // Continuously receive status updates + // Continuously receive Status updates for update := range statusChan { - if update.id >= totalWorkers { + if update.ID >= totalWorkers { continue } - // Update the status - statuses[update.id] = update.status + // Update the Status + statuses[update.ID] = update.Status // Build the complete output string output.Reset() @@ -48,7 +64,7 @@ func PrintWorkerStatus(totalWorkers int, statusChan chan WorkerStatus) { output.WriteString(fmt.Sprintf("[%2d] %.100s\n", i, status)) } - // Print the entire status + // Print the entire Status fmt.Print(output.String()) } } diff --git a/db/db.go b/db/db.go index 7885b4d..97337c2 100644 --- a/db/db.go +++ b/db/db.go @@ -2,20 +2,22 @@ package db import ( "encoding/json" - "errors" "fmt" - "gemini-grc/common" "os" "strconv" + "time" + "gemini-grc/common/snapshot" + commonUrl "gemini-grc/common/url" "gemini-grc/config" + "gemini-grc/errors" "gemini-grc/logging" _ "github.com/jackc/pgx/v5/stdlib" // PGX driver for PostgreSQL "github.com/jmoiron/sqlx" "github.com/lib/pq" ) -func ConnectToDB() *sqlx.DB { +func ConnectToDB() (*sqlx.DB, error) { connStr := fmt.Sprintf("postgres://%s:%s@%s:%s/%s", //nolint:nosprintfhostport os.Getenv("PG_USER"), os.Getenv("PG_PASSWORD"), @@ -27,25 +29,26 @@ func ConnectToDB() *sqlx.DB { // Create a connection pool db, err := sqlx.Open("pgx", connStr) if err != nil { - panic(fmt.Sprintf("Unable to connect to database with URL %s: %v\n", connStr, err)) + return nil, errors.NewFatalError(fmt.Errorf("unable to connect to database with URL %s: %w", connStr, err)) } // TODO move PG_MAX_OPEN_CONNECTIONS to config env variables maxConnections, err := strconv.Atoi(os.Getenv("PG_MAX_OPEN_CONNECTIONS")) if err != nil { - panic(fmt.Sprintf("Unable to set max DB connections: %s\n", err)) + return nil, errors.NewFatalError(fmt.Errorf("unable to set DB max connections: %w", err)) } db.SetMaxOpenConns(maxConnections) err = db.Ping() if err != nil { - panic(fmt.Sprintf("Unable to ping database: %v\n", err)) + return nil, errors.NewFatalError(fmt.Errorf("unable to ping database: %w", err)) } logging.LogDebug("Connected to database") - return db + return db, nil } -// IsDeadlockError checks if the error is a PostgreSQL deadlock error +// IsDeadlockError checks if the error is a PostgreSQL deadlock error. func IsDeadlockError(err error) bool { + err = errors.Unwrap(err) var pqErr *pq.Error if errors.As(err, &pqErr) { return pqErr.Code == "40P01" // PostgreSQL deadlock error code @@ -53,134 +56,85 @@ func IsDeadlockError(err error) bool { return false } -func GetURLsToVisit(tx *sqlx.Tx) ([]string, error) { +func GetRandomUrls(tx *sqlx.Tx) ([]string, error) { var urls []string - err := tx.Select(&urls, SQL_SELECT_RANDOM_URLS_UNIQUE_HOSTS, config.CONFIG.WorkerBatchSize) + err := tx.Select(&urls, SQL_SELECT_RANDOM_URLS, config.CONFIG.WorkerBatchSize) if err != nil { - return nil, fmt.Errorf("%w: %w", common.ErrDatabase, err) + return nil, errors.NewFatalError(err) + } + return urls, nil +} + +func GetRandomUrlsWithBasePath(tx *sqlx.Tx) ([]string, error) { + SqlQuery := `SELECT url FROM snapshots WHERE url ~ '^[^:]+://[^/]+/?$' ORDER BY RANDOM() LIMIT $1` + var urls []string + err := tx.Select(&urls, SqlQuery, config.CONFIG.WorkerBatchSize) + if err != nil { + return nil, errors.NewFatalError(err) } return urls, nil } func InsertURL(tx *sqlx.Tx, url string) error { + logging.LogDebug("Inserting URL %s", url) query := SQL_INSERT_URL - _, err := tx.NamedExec(query, url) + normalizedURL, err := commonUrl.ParseURL(url, "", true) if err != nil { - return fmt.Errorf("%w inserting URL: %w", common.ErrDatabase, err) + return err + } + a := struct { + Url string + Host string + Timestamp time.Time + }{ + Url: normalizedURL.Full, + Host: normalizedURL.Hostname, + Timestamp: time.Now(), + } + _, err = tx.NamedExec(query, a) + if err != nil { + return errors.NewFatalError(fmt.Errorf("cannot insert URL: database error %w URL %s", err, url)) } return nil } -func SaveSnapshotIfNew(tx *sqlx.Tx, s *common.Snapshot) error { +func DeleteURL(tx *sqlx.Tx, url string) error { + logging.LogDebug("Deleting URL %s", url) + query := SQL_DELETE_URL + _, err := tx.Exec(query, url) + if err != nil { + return errors.NewFatalError(fmt.Errorf("cannot delete URL: database error %w URL %s", err, url)) + } + return nil +} + +func OverwriteSnapshot(tx *sqlx.Tx, s *snapshot.Snapshot) (err error) { if config.CONFIG.DryRun { marshalled, err := json.MarshalIndent(s, "", " ") if err != nil { - panic(fmt.Sprintf("JSON serialization error for %v", s)) + return errors.NewFatalError(fmt.Errorf("JSON serialization error for %v", s)) } - logging.LogDebug("Would insert (if new) snapshot %s", marshalled) + logging.LogDebug("Would upsert snapshot %s", marshalled) return nil } - query := SQL_INSERT_SNAPSHOT_IF_NEW - _, err := tx.NamedExec(query, s) - if err != nil { - return fmt.Errorf("[%s] GeminiError inserting snapshot: %w", s.URL, err) - } - return nil -} - -func OverwriteSnapshot(workedID int, tx *sqlx.Tx, s *common.Snapshot) (err error) { - // if config.CONFIG.DryRun { - //marshalled, err := json.MarshalIndent(s, "", " ") - //if err != nil { - // panic(fmt.Sprintf("JSON serialization error for %v", s)) - //} - //logging.LogDebug("[%d] Would upsert snapshot %s", workedID, marshalled) - // return nil - // } query := SQL_UPSERT_SNAPSHOT rows, err := tx.NamedQuery(query, s) if err != nil { - return fmt.Errorf("[%d] %w while upserting snapshot: %w", workedID, common.ErrDatabase, err) + return errors.NewFatalError(fmt.Errorf("cannot overwrite snapshot: %w", err)) } defer func() { _err := rows.Close() - if _err != nil { - err = fmt.Errorf("[%d] %w error closing rows: %w", workedID, common.ErrDatabase, _err) + if err == nil && _err != nil { + err = errors.NewFatalError(fmt.Errorf("cannot overwrite snapshot: error closing rows: %w", err)) } }() if rows.Next() { var returnedID int err = rows.Scan(&returnedID) if err != nil { - return fmt.Errorf("[%d] %w error scanning returned id: %w", workedID, common.ErrDatabase, err) + return errors.NewFatalError(fmt.Errorf("cannot overwrite snapshot: error scanning rows: %w", err)) } s.ID = returnedID - // logging.LogDebug("[%d] Upserted snapshot with ID %d", workedID, returnedID) - } - return nil -} - -func UpdateSnapshot(workedID int, tx *sqlx.Tx, s *common.Snapshot) (err error) { - // if config.CONFIG.DryRun { - //marshalled, err := json.MarshalIndent(s, "", " ") - //if err != nil { - // panic(fmt.Sprintf("JSON serialization error for %v", s)) - //} - //logging.LogDebug("[%d] Would upsert snapshot %s", workedID, marshalled) - // return nil - // } - query := SQL_UPDATE_SNAPSHOT - rows, err := tx.NamedQuery(query, s) - if err != nil { - return fmt.Errorf("[%d] %w while updating snapshot: %w", workedID, common.ErrDatabase, err) - } - defer func() { - _err := rows.Close() - if _err != nil { - err = fmt.Errorf("[%d] %w error closing rows: %w", workedID, common.ErrDatabase, _err) - } - }() - if rows.Next() { - var returnedID int - err = rows.Scan(&returnedID) - if err != nil { - return fmt.Errorf("[%d] %w error scanning returned id: %w", workedID, common.ErrDatabase, err) - } - s.ID = returnedID - // logging.LogDebug("[%d] Updated snapshot with ID %d", workedID, returnedID) - } - return nil -} - -func SaveLinksToDBinBatches(tx *sqlx.Tx, snapshots []*common.Snapshot) error { - if config.CONFIG.DryRun { - return nil - } - const batchSize = 5000 - query := SQL_INSERT_SNAPSHOT_IF_NEW - for i := 0; i < len(snapshots); i += batchSize { - end := i + batchSize - if end > len(snapshots) { - end = len(snapshots) - } - batch := snapshots[i:end] - _, err := tx.NamedExec(query, batch) - if err != nil { - return fmt.Errorf("%w: While saving links in batches: %w", common.ErrDatabase, err) - } - } - return nil -} - -func SaveLinksToDB(tx *sqlx.Tx, snapshots []*common.Snapshot) error { - if config.CONFIG.DryRun { - return nil - } - query := SQL_INSERT_SNAPSHOT_IF_NEW - _, err := tx.NamedExec(query, snapshots) - if err != nil { - logging.LogError("GeminiError batch inserting snapshots: %w", err) - return fmt.Errorf("DB error: %w", err) } return nil } diff --git a/db/db_queries.go b/db/db_queries.go index bb2d238..4e3c1a6 100644 --- a/db/db_queries.go +++ b/db/db_queries.go @@ -1,53 +1,24 @@ package db const ( - SQL_SELECT_RANDOM_UNVISITED_SNAPSHOTS = ` -SELECT * -FROM snapshots -WHERE response_code IS NULL - AND error IS NULL -ORDER BY RANDOM() -FOR UPDATE SKIP LOCKED -LIMIT $1 - ` SQL_SELECT_RANDOM_URLS_UNIQUE_HOSTS = ` SELECT url FROM urls u WHERE u.id IN ( - SELECT MIN(id) - FROM urls - GROUP BY host + SELECT id FROM ( + SELECT id, ROW_NUMBER() OVER (PARTITION BY host ORDER BY id) as rn + FROM urls + ) t + WHERE rn <= 3 ) -LIMIT $1 -` - SQL_SELECT_RANDOM_UNVISITED_SNAPSHOTS_UNIQUE_HOSTS = ` -SELECT * -FROM snapshots s -WHERE response_code IS NULL - AND error IS NULL - AND s.id IN ( - SELECT MIN(id) - FROM snapshots - WHERE response_code IS NULL - AND error IS NULL - GROUP BY host - ) ORDER BY RANDOM() FOR UPDATE SKIP LOCKED LIMIT $1 -` - SQL_SELECT_UNVISITED_SNAPSHOTS_UNIQUE_HOSTS = ` -SELECT * -FROM snapshots s -WHERE response_code IS NULL - AND error IS NULL - AND s.id IN ( - SELECT MIN(id) - FROM snapshots - WHERE response_code IS NULL - AND error IS NULL - GROUP BY host - ) + ` + SQL_SELECT_RANDOM_URLS = ` +SELECT url +FROM urls u +ORDER BY RANDOM() FOR UPDATE SKIP LOCKED LIMIT $1 ` @@ -90,4 +61,7 @@ RETURNING id VALUES (:url, :host, :timestamp) ON CONFLICT (url) DO NOTHING ` + SQL_DELETE_URL = ` + DELETE FROM urls WHERE url=$1 + ` ) diff --git a/gemini/files.go b/gemini/files.go index 27a9fc0..995b5fc 100644 --- a/gemini/files.go +++ b/gemini/files.go @@ -2,13 +2,13 @@ package gemini import ( "fmt" - "gemini-grc/common" "net/url" "os" "path" "path/filepath" "strings" + "gemini-grc/common/snapshot" "gemini-grc/logging" ) @@ -64,7 +64,7 @@ func calcFilePath(rootPath, urlPath string) (string, error) { return finalPath, nil } -func SaveToFile(rootPath string, s *common.Snapshot, done chan struct{}) { +func SaveToFile(rootPath string, s *snapshot.Snapshot, done chan struct{}) { parentPath := path.Join(rootPath, s.URL.Hostname) urlPath := s.URL.Path // If path is empty, add `index.gmi` as the file to save @@ -105,7 +105,7 @@ func ReadLines(path string) []string { panic(fmt.Sprintf("Failed to read file: %s", err)) } lines := strings.Split(string(data), "\n") - // Remove last line if empty + // remove last line if empty // (happens when file ends with '\n') if lines[len(lines)-1] == "" { lines = lines[:len(lines)-1] diff --git a/gemini/gemini.go b/gemini/gemini.go deleted file mode 100644 index 023e0f0..0000000 --- a/gemini/gemini.go +++ /dev/null @@ -1,49 +0,0 @@ -package gemini - -import ( - "fmt" - "regexp" - "strconv" - - "gemini-grc/common" -) - -// ParseFirstTwoDigits takes a string and returns the first one or two digits as an int. -// If no valid digits are found, it returns an error. -func ParseFirstTwoDigits(input string) (int, error) { - // Define the regular expression pattern to match one or two leading digits - re := regexp.MustCompile(`^(\d{1,2})`) - - // Find the first match in the string - matches := re.FindStringSubmatch(input) - if len(matches) == 0 { - return 0, fmt.Errorf("%w", common.ErrGeminiResponseHeader) - } - - // Parse the captured match as an integer - snapshot, err := strconv.Atoi(matches[1]) - if err != nil { - return 0, fmt.Errorf("%w: %w", common.ErrTextParse, err) - } - - return snapshot, nil -} - -// extractRedirectTarget returns the redirection -// URL by parsing the header (or error message) -func extractRedirectTarget(currentURL common.URL, input string) (*common.URL, error) { - // \d+ - matches one or more digits - // \s+ - matches one or more whitespace - // ([^\r]+) - captures everything until it hits a \r (or end of string) - pattern := `\d+\s+([^\r]+)` - re := regexp.MustCompile(pattern) - matches := re.FindStringSubmatch(input) - if len(matches) < 2 { - return nil, fmt.Errorf("%w: %s", common.ErrGeminiRedirect, input) - } - newURL, err := common.DeriveAbsoluteURL(currentURL, matches[1]) - if err != nil { - return nil, fmt.Errorf("%w: %w: %s", common.ErrGeminiRedirect, err, input) - } - return newURL, nil -} diff --git a/gemini/geminiLinks.go b/gemini/geminiLinks.go index bfc33b3..a3762a5 100644 --- a/gemini/geminiLinks.go +++ b/gemini/geminiLinks.go @@ -5,22 +5,24 @@ import ( "net/url" "regexp" - "gemini-grc/common" + "gemini-grc/common/linkList" + url2 "gemini-grc/common/url" + "gemini-grc/errors" "gemini-grc/logging" "gemini-grc/util" ) -func GetPageLinks(currentURL common.URL, gemtext string) common.LinkList { +func GetPageLinks(currentURL url2.URL, gemtext string) linkList.LinkList { linkLines := util.GetLinesMatchingRegex(gemtext, `(?m)^=>[ \t]+.*`) if len(linkLines) == 0 { return nil } - var linkURLs common.LinkList + var linkURLs linkList.LinkList // Normalize URLs in links for _, line := range linkLines { linkUrl, err := ParseGeminiLinkLine(line, currentURL.String()) if err != nil { - logging.LogDebug("%s: %s", common.ErrGeminiLinkLineParse, err) + logging.LogDebug("error parsing gemini link line: %s", err) continue } linkURLs = append(linkURLs, *linkUrl) @@ -31,19 +33,18 @@ func GetPageLinks(currentURL common.URL, gemtext string) common.LinkList { // ParseGeminiLinkLine takes a single link line and the current URL, // return the URL converted to an absolute URL // and its description. -func ParseGeminiLinkLine(linkLine string, currentURL string) (*common.URL, error) { +func ParseGeminiLinkLine(linkLine string, currentURL string) (*url2.URL, error) { // Check: currentURL is parseable baseURL, err := url.Parse(currentURL) if err != nil { - return nil, fmt.Errorf("%w: %w", common.ErrURLParse, err) + return nil, errors.NewError(fmt.Errorf("error parsing link line: %w input '%s'", err, linkLine)) } // Extract the actual URL and the description re := regexp.MustCompile(`^=>[ \t]+(\S+)([ \t]+.*)?`) matches := re.FindStringSubmatch(linkLine) if len(matches) == 0 { - // If the line doesn't match the expected format, return it unchanged - return nil, fmt.Errorf("%w could not parse gemini link %s", common.ErrGeminiLinkLineParse, linkLine) + return nil, errors.NewError(fmt.Errorf("error parsing link line: no regexp match for line %s", linkLine)) } originalURLStr := matches[1] @@ -51,7 +52,7 @@ func ParseGeminiLinkLine(linkLine string, currentURL string) (*common.URL, error // Check: Unescape the URL if escaped _, err = url.QueryUnescape(originalURLStr) if err != nil { - return nil, fmt.Errorf("%w: %w", common.ErrURLDecode, err) + return nil, errors.NewError(fmt.Errorf("error parsing link line: %w input '%s'", err, linkLine)) } description := "" @@ -62,8 +63,7 @@ func ParseGeminiLinkLine(linkLine string, currentURL string) (*common.URL, error // Parse the URL from the link line parsedURL, err := url.Parse(originalURLStr) if err != nil { - // If URL parsing fails, return an error - return nil, fmt.Errorf("%w: %w", common.ErrURLParse, err) + return nil, errors.NewError(fmt.Errorf("error parsing link line: %w input '%s'", err, linkLine)) } // If link URL is relative, resolve full URL @@ -71,17 +71,16 @@ func ParseGeminiLinkLine(linkLine string, currentURL string) (*common.URL, error parsedURL = baseURL.ResolveReference(parsedURL) } - // Remove usual first space from URL description: + // remove usual first space from URL description: // => URL description // ^^^^^^^^^^^^ if len(description) > 0 && description[0] == ' ' { description = description[1:] } - finalURL, err := common.ParseURL(parsedURL.String(), description, true) + finalURL, err := url2.ParseURL(parsedURL.String(), description, true) if err != nil { - // If URL parsing fails, return an error - return nil, fmt.Errorf("%w: %w", common.ErrURLParse, err) + return nil, errors.NewError(fmt.Errorf("error parsing link line: %w input '%s'", err, linkLine)) } return finalURL, nil diff --git a/gemini/geminiLinks_test.go b/gemini/geminiLinks_test.go index 76bc8c9..5973cf3 100644 --- a/gemini/geminiLinks_test.go +++ b/gemini/geminiLinks_test.go @@ -1,18 +1,18 @@ package gemini import ( - "errors" "reflect" + "strings" "testing" - "gemini-grc/common" + "gemini-grc/common/url" ) type TestData struct { currentURL string link string - value *common.URL - error error + value *url.URL + error string } var data = []TestData{ @@ -20,12 +20,12 @@ var data = []TestData{ currentURL: "https://gemini.com/", link: "https://gemini.com/", value: nil, - error: common.ErrGeminiLinkLineParse, + error: "error parsing link line", }, { currentURL: "gemini://gemi.dev/cgi-bin/xkcd/", link: "=> archive/ Complete Archive", - value: &common.URL{ + value: &url.URL{ Protocol: "gemini", Hostname: "gemi.dev", Port: 1965, @@ -33,12 +33,12 @@ var data = []TestData{ Descr: "Complete Archive", Full: "gemini://gemi.dev:1965/cgi-bin/xkcd/archive/", }, - error: nil, + error: "", }, { currentURL: "gemini://gemi.dev/cgi-bin/xkcd/", link: "=> /cgi-bin/xkcd.cgi?a=5&b=6 Example", - value: &common.URL{ + value: &url.URL{ Protocol: "gemini", Hostname: "gemi.dev", Port: 1965, @@ -46,12 +46,12 @@ var data = []TestData{ Descr: "Example", Full: "gemini://gemi.dev:1965/cgi-bin/xkcd.cgi?a=5&b=6", }, - error: nil, + error: "", }, { currentURL: "gemini://gemi.dev/cgi-bin/xkcd/", link: "=> /cgi-bin/xkcd.cgi?1494 XKCD 1494: Insurance", - value: &common.URL{ + value: &url.URL{ Protocol: "gemini", Hostname: "gemi.dev", Port: 1965, @@ -59,12 +59,12 @@ var data = []TestData{ Descr: "XKCD 1494: Insurance", Full: "gemini://gemi.dev:1965/cgi-bin/xkcd.cgi?1494", }, - error: nil, + error: "", }, { currentURL: "gemini://gemi.dev/cgi-bin/xkcd/", link: "=> /cgi-bin/xkcd.cgi?1494#f XKCD 1494: Insurance", - value: &common.URL{ + value: &url.URL{ Protocol: "gemini", Hostname: "gemi.dev", Port: 1965, @@ -72,12 +72,12 @@ var data = []TestData{ Descr: "XKCD 1494: Insurance", Full: "gemini://gemi.dev:1965/cgi-bin/xkcd.cgi?1494#f", }, - error: nil, + error: "", }, { currentURL: "gemini://gemi.dev/cgi-bin/xkcd/", link: "=> /cgi-bin/xkcd.cgi?c=5#d XKCD 1494: Insurance", - value: &common.URL{ + value: &url.URL{ Protocol: "gemini", Hostname: "gemi.dev", Port: 1965, @@ -85,12 +85,12 @@ var data = []TestData{ Descr: "XKCD 1494: Insurance", Full: "gemini://gemi.dev:1965/cgi-bin/xkcd.cgi?c=5#d", }, - error: nil, + error: "", }, { currentURL: "gemini://a.b/c#d", link: "=> /d/e#f", - value: &common.URL{ + value: &url.URL{ Protocol: "gemini", Hostname: "a.b", Port: 1965, @@ -98,7 +98,7 @@ var data = []TestData{ Descr: "", Full: "gemini://a.b:1965/d/e#f", }, - error: nil, + error: "", }, } @@ -110,13 +110,10 @@ func Test(t *testing.T) { if expected.value != nil { t.Errorf("data[%d]: Expected value %v, got %v", i, nil, expected.value) } - if !errors.Is(err, common.ErrGeminiLinkLineParse) { + if !strings.HasPrefix(err.Error(), expected.error) { t.Errorf("data[%d]: expected error %v, got %v", i, expected.error, err) } } else { - if expected.error != nil { - t.Errorf("data[%d]: Expected error %v, got %v", i, nil, expected.error) - } if !(reflect.DeepEqual(result, expected.value)) { t.Errorf("data[%d]: expected %#v, got %#v", i, expected.value, result) } diff --git a/gemini/gemini_test.go b/gemini/gemini_test.go deleted file mode 100644 index 13aa3c3..0000000 --- a/gemini/gemini_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package gemini - -import ( - "testing" - - "gemini-grc/common" -) - -func TestExtractRedirectTargetFullURL(t *testing.T) { - t.Parallel() - currentURL, _ := common.ParseURL("gemini://smol.gr", "") - input := "redirect: 31 gemini://target.gr" - result, err := extractRedirectTarget(*currentURL, input) - expected := "gemini://target.gr:1965" - if err != nil || (result.String() != expected) { - t.Errorf("fail: Expected %s got %s", expected, result) - } -} - -func TestExtractRedirectTargetFullURLSlash(t *testing.T) { - t.Parallel() - currentURL, _ := common.ParseURL("gemini://smol.gr", "") - input := "redirect: 31 gemini://target.gr/" - result, err := extractRedirectTarget(*currentURL, input) - expected := "gemini://target.gr:1965/" - if err != nil || (result.String() != expected) { - t.Errorf("fail: Expected %s got %s", expected, result) - } -} - -func TestExtractRedirectTargetRelativeURL(t *testing.T) { - t.Parallel() - currentURL, _ := common.ParseURL("gemini://smol.gr", "") - input := "redirect: 31 /a/b" - result, err := extractRedirectTarget(*currentURL, input) - if err != nil || (result.String() != "gemini://smol.gr:1965/a/b") { - t.Errorf("fail: %s", result) - } -} - -func TestExtractRedirectTargetRelativeURL2(t *testing.T) { - t.Parallel() - currentURL, _ := common.ParseURL("gemini://nox.im:1965", "") - input := "redirect: 31 ./" - result, err := extractRedirectTarget(*currentURL, input) - if err != nil || (result.String() != "gemini://nox.im:1965/") { - t.Errorf("fail: %s", result) - } -} - -func TestExtractRedirectTargetRelativeURL3(t *testing.T) { - t.Parallel() - currentURL, _ := common.ParseURL("gemini://status.zvava.org:1965", "") - input := "redirect: 31 index.gmi" - result, err := extractRedirectTarget(*currentURL, input) - if err != nil || (result.String() != "gemini://status.zvava.org:1965/index.gmi") { - t.Errorf("fail: %s", result) - } -} - -func TestExtractRedirectTargetWrong(t *testing.T) { - t.Parallel() - currentURL, _ := common.ParseURL("gemini://smol.gr", "") - input := "redirect: 31" - result, err := extractRedirectTarget(*currentURL, input) - if result != nil || err == nil { - t.Errorf("fail: result should be nil, err is %s", err) - } -} diff --git a/gemini/ip-address-pool.go b/gemini/ip-address-pool.go deleted file mode 100644 index b87bec6..0000000 --- a/gemini/ip-address-pool.go +++ /dev/null @@ -1,54 +0,0 @@ -package gemini - -import "sync" - -// Used to limit requests per -// IP address. Maps IP address -// to number of active connections. -type IpAddressPool struct { - IPs map[string]int - Lock sync.RWMutex -} - -func (p *IpAddressPool) Set(key string, value int) { - p.Lock.Lock() // Lock for writing - defer p.Lock.Unlock() // Ensure mutex is unlocked after the write - p.IPs[key] = value -} - -func (p *IpAddressPool) Get(key string) int { - p.Lock.RLock() // Lock for reading - defer p.Lock.RUnlock() // Ensure mutex is unlocked after reading - if value, ok := p.IPs[key]; !ok { - return 0 - } else { - return value - } -} - -func (p *IpAddressPool) Delete(key string) { - p.Lock.Lock() - defer p.Lock.Unlock() - delete(p.IPs, key) -} - -func (p *IpAddressPool) Incr(key string) { - p.Lock.Lock() - defer p.Lock.Unlock() - if _, ok := p.IPs[key]; !ok { - p.IPs[key] = 1 - } else { - p.IPs[key] = p.IPs[key] + 1 - } -} - -func (p *IpAddressPool) Decr(key string) { - p.Lock.Lock() - defer p.Lock.Unlock() - if val, ok := p.IPs[key]; ok { - p.IPs[key] = val - 1 - if p.IPs[key] == 0 { - delete(p.IPs, key) - } - } -} diff --git a/gemini/network.go b/gemini/network.go index a9cbb77..79a0c38 100644 --- a/gemini/network.go +++ b/gemini/network.go @@ -2,44 +2,78 @@ package gemini import ( "crypto/tls" - "errors" "fmt" "io" "net" - gourl "net/url" + stdurl "net/url" "regexp" "slices" "strconv" "strings" "time" - "gemini-grc/common" + errors2 "gemini-grc/common/errors" + "gemini-grc/common/snapshot" + _url "gemini-grc/common/url" "gemini-grc/config" + "gemini-grc/errors" "gemini-grc/logging" "github.com/guregu/null/v5" ) -type PageData struct { - ResponseCode int - ResponseHeader string - MimeType string - Lang string - GemText string - Data []byte -} - -func getHostIPAddresses(hostname string) ([]string, error) { - addrs, err := net.LookupHost(hostname) +// Visit given URL, using the Gemini protocol. +// Mutates given Snapshot with the data. +// In case of error, we store the error string +// inside snapshot and return the error. +func Visit(url string) (s *snapshot.Snapshot, err error) { + s, err = snapshot.SnapshotFromURL(url, true) if err != nil { - return nil, fmt.Errorf("%w:%w", common.ErrNetworkDNS, err) + return nil, err } - return addrs, nil + + defer func() { + if err != nil { + // GeminiError and HostError should + // be stored in the snapshot. Other + // errors are returned. + if errors2.IsHostError(err) { + s.Error = null.StringFrom(err.Error()) + err = nil + } else if IsGeminiError(err) { + s.Error = null.StringFrom(err.Error()) + s.Header = null.StringFrom(errors.Unwrap(err).(*GeminiError).Header) + s.ResponseCode = null.IntFrom(int64(errors.Unwrap(err).(*GeminiError).Code)) + err = nil + } else { + s = nil + } + } + }() + + data, err := ConnectAndGetData(s.URL.String()) + if err != nil { + return s, err + } + + s, err = processData(*s, data) + if err != nil { + return s, err + } + + if isGeminiCapsule(s) { + links := GetPageLinks(s.URL, s.GemText.String) + if len(links) > 0 { + logging.LogDebug("Found %d links", len(links)) + s.Links = null.ValueFrom(links) + } + } + return s, nil } func ConnectAndGetData(url string) ([]byte, error) { - parsedURL, err := gourl.Parse(url) + parsedURL, err := stdurl.Parse(url) if err != nil { - return nil, fmt.Errorf("%w: %w", common.ErrURLParse, err) + return nil, errors.NewError(err) } hostname := parsedURL.Hostname() port := parsedURL.Port() @@ -47,29 +81,28 @@ func ConnectAndGetData(url string) ([]byte, error) { port = "1965" } host := fmt.Sprintf("%s:%s", hostname, port) + timeoutDuration := time.Duration(config.CONFIG.ResponseTimeout) * time.Second // Establish the underlying TCP connection. dialer := &net.Dialer{ - Timeout: time.Duration(config.CONFIG.ResponseTimeout) * time.Second, + Timeout: timeoutDuration, } conn, err := dialer.Dial("tcp", host) if err != nil { - return nil, fmt.Errorf("%w: %w", common.ErrNetwork, err) + return nil, errors2.NewHostError(err) } // Make sure we always close the connection. defer func() { - // No need to handle error: - // Connection will time out eventually if still open somehow. _ = conn.Close() }() // Set read and write timeouts on the TCP connection. - err = conn.SetReadDeadline(time.Now().Add(time.Duration(config.CONFIG.ResponseTimeout) * time.Second)) + err = conn.SetReadDeadline(time.Now().Add(timeoutDuration)) if err != nil { - return nil, fmt.Errorf("%w: %w", common.ErrNetworkSetConnectionDeadline, err) + return nil, errors2.NewHostError(err) } - err = conn.SetWriteDeadline(time.Now().Add(time.Duration(config.CONFIG.ResponseTimeout) * time.Second)) + err = conn.SetWriteDeadline(time.Now().Add(timeoutDuration)) if err != nil { - return nil, fmt.Errorf("%w: %w", common.ErrNetworkSetConnectionDeadline, err) + return nil, errors2.NewHostError(err) } // Perform the TLS handshake @@ -79,8 +112,17 @@ func ConnectAndGetData(url string) ([]byte, error) { // MinVersion: tls.VersionTLS12, // Use a minimum TLS version. Warning breaks a lot of sites. } tlsConn := tls.Client(conn, tlsConfig) - if err := tlsConn.Handshake(); err != nil { - return nil, fmt.Errorf("%w: %w", common.ErrNetworkTLS, err) + err = tlsConn.SetReadDeadline(time.Now().Add(timeoutDuration)) + if err != nil { + return nil, errors2.NewHostError(err) + } + err = tlsConn.SetWriteDeadline(time.Now().Add(timeoutDuration)) + if err != nil { + return nil, errors2.NewHostError(err) + } + err = tlsConn.Handshake() + if err != nil { + return nil, errors2.NewHostError(err) } // We read `buf`-sized chunks and add data to `data`. @@ -91,10 +133,10 @@ func ConnectAndGetData(url string) ([]byte, error) { // Fix for stupid server bug: // Some servers return 'Header: 53 No proxying to other hosts or ports!' // when the port is 1965 and is still specified explicitly in the URL. - _url, _ := common.ParseURL(url, "") - _, err = tlsConn.Write([]byte(fmt.Sprintf("%s\r\n", _url.StringNoDefaultPort()))) + url2, _ := _url.ParseURL(url, "", true) + _, err = tlsConn.Write([]byte(fmt.Sprintf("%s\r\n", url2.StringNoDefaultPort()))) if err != nil { - return nil, fmt.Errorf("%w: %w", common.ErrNetworkCannotWrite, err) + return nil, errors2.NewHostError(err) } // Read response bytes in len(buf) byte chunks for { @@ -103,90 +145,50 @@ func ConnectAndGetData(url string) ([]byte, error) { data = append(data, buf[:n]...) } if len(data) > config.CONFIG.MaxResponseSize { - return nil, fmt.Errorf("%w: %v", common.ErrNetworkResponseSizeExceededMax, config.CONFIG.MaxResponseSize) + return nil, errors2.NewHostError(err) } if err != nil { if errors.Is(err, io.EOF) { break } - return nil, fmt.Errorf("%w: %w", common.ErrNetwork, err) + return nil, errors2.NewHostError(err) } } return data, nil } -// Visit given URL, using the Gemini protocol. -// Mutates given Snapshot with the data. -// In case of error, we store the error string -// inside snapshot and return the error. -func Visit(s *common.Snapshot) (err error) { - // Don't forget to also store error - // response code (if we have one) - // and header - defer func() { - if err != nil { - s.Error = null.StringFrom(err.Error()) - if errors.As(err, new(*common.GeminiError)) { - s.Header = null.StringFrom(err.(*common.GeminiError).Header) - s.ResponseCode = null.IntFrom(int64(err.(*common.GeminiError).Code)) - } - } - }() - s.Timestamp = null.TimeFrom(time.Now()) - data, err := ConnectAndGetData(s.URL.String()) - if err != nil { - return err - } - pageData, err := processData(data) - if err != nil { - return err - } - s.Header = null.StringFrom(pageData.ResponseHeader) - s.ResponseCode = null.IntFrom(int64(pageData.ResponseCode)) - s.MimeType = null.StringFrom(pageData.MimeType) - s.Lang = null.StringFrom(pageData.Lang) - if pageData.GemText != "" { - s.GemText = null.StringFrom(pageData.GemText) - } - if pageData.Data != nil { - s.Data = null.ValueFrom(pageData.Data) - } - return nil -} - -// processData returne results from -// parsing Gemini header data: -// Code, mime type and lang (optional) -// Returns error if header was invalid -func processData(data []byte) (*PageData, error) { +func processData(s snapshot.Snapshot, data []byte) (*snapshot.Snapshot, error) { header, body, err := getHeadersAndData(data) if err != nil { return nil, err } code, mimeType, lang := getMimeTypeAndLang(header) - logging.LogDebug("Header: %s", strings.TrimSpace(header)) - if code != 20 { - return nil, common.NewErrGeminiStatusCode(code, header) + + if code != 0 { + s.ResponseCode = null.IntFrom(int64(code)) + } + if header != "" { + s.Header = null.StringFrom(header) + } + if mimeType != "" { + s.MimeType = null.StringFrom(mimeType) + } + if lang != "" { + s.Lang = null.StringFrom(lang) } - pageData := PageData{ - ResponseCode: code, - ResponseHeader: header, - MimeType: mimeType, - Lang: lang, - } // If we've got a Gemini document, populate // `GemText` field, otherwise raw data goes to `Data`. if mimeType == "text/gemini" { validBody, err := BytesToValidUTF8(body) if err != nil { - return nil, fmt.Errorf("%w: %w", common.ErrUTF8Parse, err) + return nil, errors.NewError(err) } - pageData.GemText = validBody + s.GemText = null.StringFrom(validBody) } else { - pageData.Data = body + s.Data = null.ValueFrom(body) } - return &pageData, nil + return &s, nil } // Checks for a Gemini header, which is @@ -196,29 +198,42 @@ func processData(data []byte) (*PageData, error) { func getHeadersAndData(data []byte) (string, []byte, error) { firstLineEnds := slices.Index(data, '\n') if firstLineEnds == -1 { - return "", nil, common.ErrGeminiResponseHeader + return "", nil, errors2.NewHostError(fmt.Errorf("error parsing header")) } firstLine := string(data[:firstLineEnds]) rest := data[firstLineEnds+1:] - return firstLine, rest, nil + return strings.TrimSpace(firstLine), rest, nil } -// Parses code, mime type and language -// from a Gemini header. -// Examples: -// `20 text/gemini lang=en` (code, mimetype, lang) -// `20 text/gemini` (code, mimetype) -// `31 gemini://redirected.to/other/site` (code) +// getMimeTypeAndLang Parses code, mime type and language +// given a Gemini header. func getMimeTypeAndLang(headers string) (int, string, string) { - // Regex that parses code, mimetype & optional charset/lang parameters - re := regexp.MustCompile(`^(\d+)\s+([a-zA-Z0-9/\-+]+)(?:[;\s]+(?:(?:charset|lang)=([a-zA-Z0-9-]+)))?\s*$`) + // First try to match the full format: " [charset=] [lang=]" + // The regex looks for: + // - A number (\d+) + // - Followed by whitespace and a mimetype ([a-zA-Z0-9/\-+]+) + // - Optionally followed by charset and/or lang parameters in any order + // - Only capturing the lang value, ignoring charset + re := regexp.MustCompile(`^(\d+)\s+([a-zA-Z0-9/\-+]+)(?:(?:[\s;]+(?:charset=[^;\s]+|lang=([a-zA-Z0-9-]+)))*)\s*$`) matches := re.FindStringSubmatch(headers) if matches == nil || len(matches) <= 1 { - // Try to get code at least - re := regexp.MustCompile(`^(\d+)\s+`) + // If full format doesn't match, try to match redirect format: " " + // This handles cases like "31 gemini://example.com" + re := regexp.MustCompile(`^(\d+)\s+(.+)$`) matches := re.FindStringSubmatch(headers) if matches == nil || len(matches) <= 1 { - return 0, "", "" + // If redirect format doesn't match, try to match just a status code + // This handles cases like "99" + re := regexp.MustCompile(`^(\d+)\s*$`) + matches := re.FindStringSubmatch(headers) + if matches == nil || len(matches) <= 1 { + return 0, "", "" + } + code, err := strconv.Atoi(matches[1]) + if err != nil { + return 0, "", "" + } + return code, "", "" } code, err := strconv.Atoi(matches[1]) if err != nil { @@ -231,6 +246,10 @@ func getMimeTypeAndLang(headers string) (int, string, string) { return 0, "", "" } mimeType := matches[2] - param := matches[3] // This will capture either charset or lang value - return code, mimeType, param + lang := matches[3] // Will be empty string if no lang parameter was found + return code, mimeType, lang +} + +func isGeminiCapsule(s *snapshot.Snapshot) bool { + return !s.Error.Valid && s.MimeType.Valid && s.MimeType.String == "text/gemini" } diff --git a/gemini/network_test.go b/gemini/network_test.go index 81202db..6c0656a 100644 --- a/gemini/network_test.go +++ b/gemini/network_test.go @@ -1,78 +1,366 @@ package gemini import ( + "slices" + "strings" "testing" + + "gemini-grc/common/snapshot" ) -// Test for input: `20 text/gemini` -func TestGetMimeTypeAndLang1(t *testing.T) { +func TestGetHeadersAndData(t *testing.T) { t.Parallel() - code, mimeType, lang := getMimeTypeAndLang("20 text/gemini") - if code != 20 || mimeType != "text/gemini" || lang != "" { - t.Errorf("Expected (20, 'text/gemini', ''), got (%d, '%s', '%s')", code, mimeType, lang) + tests := []struct { + input []byte + header string + body []byte + expectError bool + }{ + {[]byte("20 text/gemini\r\nThis is the body"), "20 text/gemini", []byte("This is the body"), false}, + {[]byte("20 text/gemini\nThis is the body"), "20 text/gemini", []byte("This is the body"), false}, + {[]byte("53 No proxying!\r\n"), "53 No proxying!", []byte(""), false}, + {[]byte("No header"), "", nil, true}, + } + + for _, test := range tests { + header, body, err := getHeadersAndData(test.input) + + if test.expectError && err == nil { + t.Errorf("Expected error, got nil for input: %s", test.input) + } + + if !test.expectError && err != nil { + t.Errorf("Unexpected error for input '%s': %v", test.input, err) + } + + if header != test.header { + t.Errorf("Expected header '%s', got '%s' for input: %s", test.header, header, test.input) + } + + if !slices.Equal(body, test.body) { + t.Errorf("Expected body '%s', got '%s' for input: %s", test.body, string(body), test.input) + } } } -func TestGetMimeTypeAndLang11(t *testing.T) { +func TestGetMimeTypeAndLang(t *testing.T) { t.Parallel() - code, mimeType, lang := getMimeTypeAndLang("20 text/gemini\n") - if code != 20 || mimeType != "text/gemini" || lang != "" { - t.Errorf("Expected (20, 'text/gemini', ''), got (%d, '%s', '%s')", code, mimeType, lang) + tests := []struct { + header string + code int + mimeType string + lang string + }{ + {"20 text/gemini lang=en", 20, "text/gemini", "en"}, + {"20 text/gemini", 20, "text/gemini", ""}, + {"31 gemini://redirected.to/other/site", 31, "", ""}, + {"20 text/plain;charset=utf-8", 20, "text/plain", ""}, + {"20 text/plain;lang=el-GR", 20, "text/plain", "el-GR"}, + {"20 text/gemini;lang=en-US;charset=utf-8", 20, "text/gemini", "en-US"}, // charset should be ignored + {"Invalid header", 0, "", ""}, + {"99", 99, "", ""}, + } + + for _, test := range tests { + code, mimeType, lang := getMimeTypeAndLang(test.header) + + if code != test.code { + t.Errorf("Expected code %d, got %d for header: %s", test.code, code, test.header) + } + + if mimeType != test.mimeType { + t.Errorf("Expected mimeType '%s', got '%s' for header: %s", test.mimeType, mimeType, test.header) + } + + if lang != test.lang { + t.Errorf("Expected lang '%s', got '%s' for header: %s", test.lang, lang, test.header) + } } } -func TestGetMimeTypeAndLang12(t *testing.T) { +func TestProcessData(t *testing.T) { t.Parallel() - code, mimeType, lang := getMimeTypeAndLang("20 text/plain; charset=utf-8") - if code != 20 || mimeType != "text/plain" || lang != "utf-8" { - t.Errorf("Expected (20, 'text/plain', ''), got (%d, '%s', '%s')", code, mimeType, lang) + tests := []struct { + name string + inputData []byte + expectedCode int + expectedMime string + expectedLang string + expectedData []byte + expectedError bool + }{ + { + name: "Gemini document", + inputData: []byte("20 text/gemini\r\n# Hello\nWorld"), + expectedCode: 20, + expectedMime: "text/gemini", + expectedLang: "", + expectedData: []byte("# Hello\nWorld"), + expectedError: false, + }, + { + name: "Gemini document with language", + inputData: []byte("20 text/gemini lang=en\r\n# Hello\nWorld"), + expectedCode: 20, + expectedMime: "text/gemini", + expectedLang: "en", + expectedData: []byte("# Hello\nWorld"), + expectedError: false, + }, + { + name: "Non-Gemini document", + inputData: []byte("20 text/html\r\n

Hello

"), + expectedCode: 20, + expectedMime: "text/html", + expectedLang: "", + expectedData: []byte("

Hello

"), + expectedError: false, + }, + { + name: "Error header", + inputData: []byte("53 No proxying!\r\n"), + expectedCode: 53, + expectedMime: "", + expectedLang: "", + expectedData: []byte(""), + expectedError: false, + }, + { + name: "Invalid header", + inputData: []byte("Invalid header"), + expectedError: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := snapshot.Snapshot{} + result, err := processData(s, test.inputData) + + if test.expectedError && err == nil { + t.Errorf("Expected error, got nil") + return + } + + if !test.expectedError && err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if test.expectedError { + return + } + + if int(result.ResponseCode.ValueOrZero()) != test.expectedCode { + t.Errorf("Expected code %d, got %d", test.expectedCode, int(result.ResponseCode.ValueOrZero())) + } + + if result.MimeType.ValueOrZero() != test.expectedMime { + t.Errorf("Expected mimeType '%s', got '%s'", test.expectedMime, result.MimeType.ValueOrZero()) + } + + if result.Lang.ValueOrZero() != test.expectedLang { + t.Errorf("Expected lang '%s', got '%s'", test.expectedLang, result.Lang.ValueOrZero()) + } + + if test.expectedMime == "text/gemini" { + if !strings.Contains(result.GemText.String, string(test.expectedData)) { + t.Errorf("Expected GemText '%s', got '%s'", test.expectedData, result.GemText.String) + } + } else { + if !slices.Equal(result.Data.ValueOrZero(), test.expectedData) { + t.Errorf("Expected data '%s', got '%s'", test.expectedData, result.Data.ValueOrZero()) + } + } + }) } } -func TestGetMimeTypeAndLang13(t *testing.T) { +//// Mock Gemini server for testing ConnectAndGetData +//func mockGeminiServer(response string, delay time.Duration, closeConnection bool) net.Listener { +// listener, err := net.Listen("tcp", "127.0.0.1:0") // Bind to a random available port +// if err != nil { +// panic(fmt.Sprintf("Failed to create mock server: %v", err)) +// } +// +// go func() { +// conn, err := listener.Accept() +// if err != nil { +// if !closeConnection { // Don't panic if we closed the connection on purpose +// panic(fmt.Sprintf("Failed to accept connection: %v", err)) +// } +// return +// } +// defer conn.Close() +// +// time.Sleep(delay) // Simulate network latency +// +// _, err = conn.Write([]byte(response)) +// if err != nil && !closeConnection { +// panic(fmt.Sprintf("Failed to write response: %v", err)) +// } +// }() +// +// return listener +//} + +// func TestConnectAndGetData(t *testing.T) { +// config.CONFIG = config.ConfigStruct{ +// ResponseTimeout: 5, +// MaxResponseSize: 1024 * 1024, +// } +// tests := []struct { +// name string +// serverResponse string +// serverDelay time.Duration +// expectedData []byte +// expectedError bool +// closeConnection bool +// }{ +// { +// name: "Successful response", +// serverResponse: "20 text/gemini\r\n# Hello", +// expectedData: []byte("20 text/gemini\r\n# Hello"), +// expectedError: false, +// }, +// { +// name: "Server error", +// serverResponse: "50 Server error\r\n", +// expectedData: []byte("50 Server error\r\n"), +// expectedError: false, +// }, +// { +// name: "Timeout", +// serverDelay: 6 * time.Second, // Longer than the timeout +// expectedError: true, +// }, +// { +// name: "Server closes connection", +// closeConnection: true, +// expectedError: true, +// }, +// } + +// for _, test := range tests { +// t.Run(test.name, func(t *testing.T) { +// listener := mockGeminiServer(test.serverResponse, test.serverDelay, test.closeConnection) +// defer func() { +// test.closeConnection = true // Prevent panic in mock server +// listener.Close() +// }() +// addr := listener.Addr().String() +// data, err := ConnectAndGetData(fmt.Sprintf("gemini://%s/", addr)) + +// if test.expectedError && err == nil { +// t.Errorf("Expected error, got nil") +// } + +// if !test.expectedError && err != nil { +// t.Errorf("Unexpected error: %v", err) +// } + +// if !slices.Equal(data, test.expectedData) { +// t.Errorf("Expected data '%s', got '%s'", test.expectedData, data) +// } +// }) +// } +// } + +// func TestVisit(t *testing.T) { +// config.CONFIG = config.ConfigStruct{ +// ResponseTimeout: 5, +// MaxResponseSize: 1024 * 1024, +// } +// tests := []struct { +// name string +// serverResponse string +// expectedCode int +// expectedMime string +// expectedError bool +// expectedLinks []string +// }{ +// { +// name: "Successful response", +// serverResponse: "20 text/gemini\r\n# Hello\n=> /link1 Link 1\n=> /link2 Link 2", +// expectedCode: 20, +// expectedMime: "text/gemini", +// expectedError: false, +// expectedLinks: []string{"gemini://127.0.0.1:1965/link1", "gemini://127.0.0.1:1965/link2"}, +// }, +// { +// name: "Server error", +// serverResponse: "50 Server error\r\n", +// expectedCode: 50, +// expectedMime: "Server error", +// expectedError: false, +// expectedLinks: []string{}, +// }, +// } + +// for _, test := range tests { +// t.Run(test.name, func(t *testing.T) { +// listener := mockGeminiServer(test.serverResponse, 0, false) +// defer listener.Close() +// addr := listener.Addr().String() +// snapshot, err := Visit(fmt.Sprintf("gemini://%s/", addr)) + +// if test.expectedError && err == nil { +// t.Errorf("Expected error, got nil") +// } + +// if !test.expectedError && err != nil { +// t.Errorf("Unexpected error: %v", err) +// } + +// if snapshot.ResponseCode.ValueOrZero() != int64(test.expectedCode) { +// t.Errorf("Expected code %d, got %d", test.expectedCode, snapshot.ResponseCode.ValueOrZero()) +// } + +// if snapshot.MimeType.ValueOrZero() != test.expectedMime { +// t.Errorf("Expected mimeType '%s', got '%s'", test.expectedMime, snapshot.MimeType.ValueOrZero()) +// } + +// if test.expectedLinks != nil { +// links, _ := snapshot.Links.Value() + +// if len(links) != len(test.expectedLinks) { +// t.Errorf("Expected %d links, got %d", len(test.expectedLinks), len(links)) +// } +// for i, link := range links { +// if link != test.expectedLinks[i] { +// t.Errorf("Expected link '%s', got '%s'", test.expectedLinks[i], link) +// } +// } +// } +// }) +// } +// } + +func TestVisit_InvalidURL(t *testing.T) { t.Parallel() - code, mimeType, lang := getMimeTypeAndLang("20 text/gemini; charset=utf-8") - if code != 20 || mimeType != "text/gemini" || lang != "utf-8" { - t.Errorf("Expected (20, 'text/plain', ''), got (%d, '%s', '%s')", code, mimeType, lang) + _, err := Visit("invalid-url") + if err == nil { + t.Errorf("Expected error for invalid URL, got nil") } } -func TestGetTypeAndLang2(t *testing.T) { - t.Parallel() - code, mimeType, lang := getMimeTypeAndLang("20 text/gemini charset=en") - if code != 20 || mimeType != "text/gemini" || lang != "en" { - t.Errorf("Expected (20, 'text/gemini', 'en'), got (%d, '%s', '%s')", code, mimeType, lang) - } -} - -func TestGetTypeAndLang21(t *testing.T) { - t.Parallel() - code, mimeType, lang := getMimeTypeAndLang("20 text/gemini lang=en") - if code != 20 || mimeType != "text/gemini" || lang != "en" { - t.Errorf("Expected (20, 'text/gemini', 'en'), got (%d, '%s', '%s')", code, mimeType, lang) - } -} - -func TestGetMimeTypeAndLang3(t *testing.T) { - t.Parallel() - code, mimeType, lang := getMimeTypeAndLang("31 gemini://redirect.to/page") - if code != 31 || mimeType != "" || lang != "" { - t.Errorf("Expected (20, '', ''), got (%d, '%s', '%s')", code, mimeType, lang) - } -} - -func TestGetMimeTypeAndLang4(t *testing.T) { - t.Parallel() - code, mimeType, lang := getMimeTypeAndLang("aaafdasdasd") - if code != 0 || mimeType != "" || lang != "" { - t.Errorf("Expected (0, '', ''), got (%d, '%s', '%s')", code, mimeType, lang) - } -} - -func TestGetMimeTypeAndLang5(t *testing.T) { - t.Parallel() - code, mimeType, lang := getMimeTypeAndLang("") - if code != 0 || mimeType != "" || lang != "" { - t.Errorf("Expected (0, '', ''), got (%d, '%s', '%s')", code, mimeType, lang) - } -} +//func TestVisit_GeminiError(t *testing.T) { +// listener := mockGeminiServer("51 Not Found\r\n", 0, false) +// defer listener.Close() +// addr := listener.Addr().String() +// +// s, err := Visit(fmt.Sprintf("gemini://%s/", addr)) +// if err != nil { +// t.Errorf("Unexpected error: %v", err) +// } +// +// expectedError := "51 Not Found" +// if s.Error.ValueOrZero() != expectedError { +// t.Errorf("Expected error in snapshot: %v, got %v", expectedError, s.Error) +// } +// +// expectedCode := 51 +// if s.ResponseCode.ValueOrZero() != int64(expectedCode) { +// t.Errorf("Expected code %d, got %d", expectedCode, s.ResponseCode.ValueOrZero()) +// } +//} diff --git a/gemini/processing.go b/gemini/processing.go index 0afdac3..d631778 100644 --- a/gemini/processing.go +++ b/gemini/processing.go @@ -26,7 +26,7 @@ func BytesToValidUTF8(input []byte) (string, error) { if len(input) > maxSize { return "", fmt.Errorf("%w: %d bytes (max %d)", ErrInputTooLarge, len(input), maxSize) } - // Remove NULL byte 0x00 (ReplaceAll accepts slices) + // remove NULL byte 0x00 (ReplaceAll accepts slices) inputNoNull := bytes.ReplaceAll(input, []byte{byte(0)}, []byte{}) if utf8.Valid(inputNoNull) { return string(inputNoNull), nil diff --git a/gemini/robotmatch.go b/gemini/robotmatch.go index 52967f4..07819b7 100644 --- a/gemini/robotmatch.go +++ b/gemini/robotmatch.go @@ -2,10 +2,11 @@ package gemini import ( "fmt" - "gemini-grc/common" "strings" "sync" + "gemini-grc/common/snapshot" + geminiUrl "gemini-grc/common/url" "gemini-grc/logging" ) @@ -16,7 +17,7 @@ import ( // list is stored for caching. var RobotsCache sync.Map //nolint:gochecknoglobals -func populateBlacklist(key string) (entries []string) { +func populateRobotsCache(key string) (entries []string, _err error) { // We either store an empty list when // no rules, or a list of disallowed URLs. // This applies even if we have an error @@ -27,53 +28,60 @@ func populateBlacklist(key string) (entries []string) { url := fmt.Sprintf("gemini://%s/robots.txt", key) robotsContent, err := ConnectAndGetData(url) if err != nil { - logging.LogDebug("robots.txt error %s", err) - return []string{} + return []string{}, err } - robotsData, err := processData(robotsContent) + s, err := snapshot.SnapshotFromURL(url, true) + if err != nil { + return []string{}, nil + } + s, err = processData(*s, robotsContent) if err != nil { logging.LogDebug("robots.txt error %s", err) - return []string{} + return []string{}, nil } - if robotsData.ResponseCode != 20 { - logging.LogDebug("robots.txt error code %d, ignoring", robotsData.ResponseCode) - return []string{} + if s.ResponseCode.ValueOrZero() != 20 { + logging.LogDebug("robots.txt error code %d, ignoring", s.ResponseCode.ValueOrZero()) + return []string{}, nil } // Some return text/plain, others text/gemini. // According to spec, the first is correct, // however let's be lenient var data string switch { - case robotsData.MimeType == "text/plain": - data = string(robotsData.Data) - case robotsData.MimeType == "text/gemini": - data = robotsData.GemText + case s.MimeType.ValueOrZero() == "text/plain": + data = string(s.Data.ValueOrZero()) + case s.MimeType.ValueOrZero() == "text/gemini": + data = s.GemText.ValueOrZero() default: - return []string{} + return []string{}, nil } entries = ParseRobotsTxt(data, key) - return entries + return entries, nil } // RobotMatch checks if the snapshot URL matches // a robots.txt allow rule. -func RobotMatch(u string) bool { - url, err := common.ParseURL(u, "") +func RobotMatch(u string) (bool, error) { + url, err := geminiUrl.ParseURL(u, "", true) if err != nil { - return false + return false, err } key := strings.ToLower(fmt.Sprintf("%s:%d", url.Hostname, url.Port)) - logging.LogDebug("Checking robots.txt cache for %s", key) var disallowedURLs []string cacheEntries, ok := RobotsCache.Load(key) if !ok { // First time check, populate robot cache - disallowedURLs = populateBlacklist(key) - logging.LogDebug("Added to robots.txt cache: %v => %v", key, disallowedURLs) + disallowedURLs, err := populateRobotsCache(key) + if err != nil { + return false, err + } + if len(disallowedURLs) > 0 { + logging.LogDebug("Added to robots.txt cache: %v => %v", key, disallowedURLs) + } } else { disallowedURLs, _ = cacheEntries.([]string) } - return isURLblocked(disallowedURLs, url.Full) + return isURLblocked(disallowedURLs, url.Full), nil } func isURLblocked(disallowedURLs []string, input string) bool { diff --git a/gemini/worker.go b/gemini/worker.go deleted file mode 100644 index 8c84997..0000000 --- a/gemini/worker.go +++ /dev/null @@ -1,344 +0,0 @@ -package gemini - -import ( - "errors" - "fmt" - "gemini-grc/common" - _db "gemini-grc/db" - "strings" - "time" - - "gemini-grc/logging" - "gemini-grc/util" - "github.com/guregu/null/v5" - "github.com/jmoiron/sqlx" -) - -func SpawnWorkers(numOfWorkers int, db *sqlx.DB) { - logging.LogInfo("Spawning %d workers", numOfWorkers) - statusChan = make(chan WorkerStatus, numOfWorkers) - go PrintWorkerStatus(numOfWorkers, statusChan) - - for i := range numOfWorkers { - go func(i int) { - // Jitter to avoid starting everything at the same time - time.Sleep(time.Duration(util.SecureRandomInt(10)) * time.Second) - for { - RunWorkerWithTx(i, db, nil) - } - }(i) - } -} - -func RunWorkerWithTx(workerID int, db *sqlx.DB, url *string) { - statusChan <- WorkerStatus{ - id: workerID, - status: "Starting up", - } - defer func() { - statusChan <- WorkerStatus{ - id: workerID, - status: "Done", - } - }() - tx, err := db.Beginx() - if err != nil { - panic(fmt.Sprintf("Failed to begin transaction: %v", err)) - } - runWorker(workerID, tx, url) - logging.LogDebug("[%d] Committing transaction", workerID) - err = tx.Commit() - // On deadlock errors, rollback and return, otherwise panic. - if err != nil { - logging.LogError("[%d] Failed to commit transaction: %w", workerID, err) - if _db.IsDeadlockError(err) { - logging.LogError("[%d] Deadlock detected. Rolling back", workerID) - time.Sleep(time.Duration(10) * time.Second) - err := tx.Rollback() - if err != nil { - panic(fmt.Sprintf("[%d] Failed to roll back transaction: %v", workerID, err)) - } - return - } - panic(fmt.Sprintf("[%d] Failed to commit transaction: %v", workerID, err)) - } - logging.LogDebug("[%d] Worker done!", workerID) -} - -func runWorker(workerID int, tx *sqlx.Tx, url *string) { - var urls []string - var err error - - // If not given a specific URL, - // get some random ones to visit from db. - if url == nil { - statusChan <- WorkerStatus{ - id: workerID, - status: "Getting URLs", - } - urls, err = _db.GetURLsToVisit(tx) - if err != nil { - logging.LogError("[%d] GeminiError retrieving snapshot: %w", workerID, err) - panic("This should never happen") - } else if len(urls) == 0 { - logging.LogInfo("[%d] No URLs to visit.", workerID) - time.Sleep(1 * time.Minute) - return - } - } else { - geminiURL, err := common.ParseURL(*url, "") - if err != nil { - logging.LogError("Invalid URL given: %s", *url) - return - } - urls = []string{geminiURL.String()} - } - - // Start visiting URLs. - total := len(urls) - for i, u := range urls { - logging.LogDebug("[%d] Starting %d/%d %s", workerID, i+1, total, u) - // We differentiate between errors: - // Unexpected errors are the ones returned from the following function. - // If an error is unexpected (which should never happen) we panic. - // Expected errors are stored as strings within the snapshot. - err := workOnUrl(workerID, tx, u) - if err != nil { - logging.LogError("[%d] Unexpected GeminiError %w while visiting %s", workerID, err, u) - util.PrintStackAndPanic(err) - } - logging.LogDebug("[%d] Done %d/%d.", workerID, i+1, total) - } -} - -// workOnUrl visits a URL and stores the result. -// unexpected errors are returned. -// expected errors are stored within the snapshot. -func workOnUrl(workerID int, tx *sqlx.Tx, url string) (err error) { - if url == "" { - return fmt.Errorf("nil URL given") - } - - if IsBlacklisted(url) { - logging.LogDebug("[%d] URL matches Blacklist, ignoring %s", workerID, url) - return nil - } - - s := common.SnapshotFromURL(url) - - // If URL matches a robots.txt disallow line, - // add it as an error so next time it won't be - // crawled. - if RobotMatch(url) { - s.Error = null.StringFrom(common.ErrGeminiRobotsDisallowed.Error()) - err = _db.OverwriteSnapshot(workerID, tx, s) - if err != nil { - return fmt.Errorf("[%d] %w", workerID, err) - } - return nil - } - - // Resolve IP address via DNS - statusChan <- WorkerStatus{ - id: workerID, - status: fmt.Sprintf("Resolving %s", url), - } - IPs, err := getHostIPAddresses(s.Host) - if err != nil { - s.Error = null.StringFrom(err.Error()) - err = _db.OverwriteSnapshot(workerID, tx, s) - if err != nil { - return fmt.Errorf("[%d] %w", workerID, err) - } - return nil - } - - for { - count := 1 - if isAnotherWorkerVisitingHost(workerID, IPs) { - logging.LogDebug("[%d] Another worker is visiting this host, waiting", workerID) - statusChan <- WorkerStatus{ - id: workerID, - status: fmt.Sprintf("Waiting to grab lock for host %s", s.Host), - } - time.Sleep(2 * time.Second) // Avoid flood-retrying - count++ - if count == 3 { - return - } - } else { - break - } - } - - statusChan <- WorkerStatus{ - id: workerID, - status: fmt.Sprintf("Adding to pool %s", url), - } - AddIPsToPool(IPs) - // After finishing, remove the host IPs from - // the connections pool, with a small delay - // to avoid potentially hitting the same IP quickly. - defer func() { - go func() { - time.Sleep(1 * time.Second) - statusChan <- WorkerStatus{ - id: workerID, - status: fmt.Sprintf("Removing from pool %s", url), - } - RemoveIPsFromPool(IPs) - }() - }() - - statusChan <- WorkerStatus{ - id: workerID, - status: fmt.Sprintf("Visiting %s", url), - } - - err = Visit(s) - if err != nil { - if !common.IsKnownError(err) { - logging.LogError("[%d] Unknown error visiting %s: %w", workerID, url, err) - return err - } - s.Error = null.StringFrom(err.Error()) - // Check if error is redirection, and handle it - if errors.As(err, new(*common.GeminiError)) && - err.(*common.GeminiError).Msg == "redirect" { - err = handleRedirection(workerID, tx, s) - if err != nil { - if common.IsKnownError(err) { - s.Error = null.StringFrom(err.Error()) - } else { - return err - } - } - } - } - // If this is a gemini page, parse possible links inside - if !s.Error.Valid && s.MimeType.Valid && s.MimeType.String == "text/gemini" { - links := GetPageLinks(s.URL, s.GemText.String) - if len(links) > 0 { - logging.LogDebug("[%d] Found %d links", workerID, len(links)) - s.Links = null.ValueFrom(links) - err = storeLinks(tx, s) - if err != nil { - return err - } - } - } else { - logging.LogDebug("[%d] Not text/gemini, so not looking for page links", workerID) - } - - err = _db.OverwriteSnapshot(workerID, tx, s) - logging.LogInfo("[%3d] %2d %s", workerID, s.ResponseCode.ValueOrZero(), s.URL.String()) - if err != nil { - return err - } - - return nil -} - -func isAnotherWorkerVisitingHost(workerID int, IPs []string) bool { - IPPool.Lock.RLock() - defer func() { - IPPool.Lock.RUnlock() - }() - logging.LogDebug("[%d] Checking pool for IPs", workerID) - for _, ip := range IPs { - _, ok := IPPool.IPs[ip] - if ok { - return true - } - } - return false -} - -func storeLinks(tx *sqlx.Tx, s *common.Snapshot) error { - if s.Links.Valid { - var batchSnapshots []*common.Snapshot - for _, link := range s.Links.ValueOrZero() { - if shouldPersistURL(&link) { - newSnapshot := &common.Snapshot{ - URL: link, - Host: link.Hostname, - Timestamp: null.TimeFrom(time.Now()), - } - batchSnapshots = append(batchSnapshots, newSnapshot) - } - } - - if len(batchSnapshots) > 0 { - err := _db.SaveLinksToDBinBatches(tx, batchSnapshots) - if err != nil { - return err - } - } - } - return nil -} - -// shouldPersistURL returns true if we -// should save the URL in the _db. -// Only gemini:// urls are saved. -func shouldPersistURL(u *common.URL) bool { - return strings.HasPrefix(u.String(), "gemini://") -} - -func haveWeVisitedURL(tx *sqlx.Tx, u *common.URL) (bool, error) { - var result bool - err := tx.Select(&result, `SELECT TRUE FROM urls WHERE url=$1`, u.String()) - if err != nil { - return false, fmt.Errorf("%w: %w", common.ErrDatabase, err) - } - if result { - return result, nil - } - err = tx.Select(&result, `SELECT TRUE FROM snapshots WHERE snapshot.url=$1`, u.String()) - if err != nil { - return false, fmt.Errorf("%w: %w", common.ErrDatabase, err) - } - return result, nil -} - -// handleRedirection saves redirect URL as new snapshot -func handleRedirection(workerID int, tx *sqlx.Tx, s *common.Snapshot) error { - newURL, err := extractRedirectTarget(s.URL, s.Error.ValueOrZero()) - if err != nil { - if errors.Is(err, common.ErrGeminiRedirect) { - logging.LogDebug("[%d] %s", workerID, err) - } - return err - } - logging.LogDebug("[%d] Page redirects to %s", workerID, newURL) - // Insert fresh snapshot with new URL - if shouldPersistURL(newURL) { - snapshot := &common.Snapshot{ - // UID: uid.UID(), - URL: *newURL, - Host: newURL.Hostname, - Timestamp: null.TimeFrom(time.Now()), - } - logging.LogDebug("[%d] Saving redirection URL %s", workerID, snapshot.URL.String()) - err = _db.SaveSnapshotIfNew(tx, snapshot) - if err != nil { - return err - } - } - return nil -} - -func GetSnapshotFromURL(tx *sqlx.Tx, url string) ([]common.Snapshot, error) { - query := ` - SELECT * - FROM snapshots - WHERE url=$1 - LIMIT 1 - ` - var snapshots []common.Snapshot - err := tx.Select(&snapshots, query, url) - if err != nil { - return nil, err - } - return snapshots, nil -}