fix: Refactor ConnectAndGetData function to return GeminiPageData struct

This commit is contained in:
2024-10-23 12:46:02 +03:00
committed by antanst (aider)
parent e51d84cad8
commit 62369d90ae

View File

@@ -6,6 +6,7 @@ import (
"gemini-grc/config" "gemini-grc/config"
"io" "io"
"net" "net"
go_url "net/url"
"regexp" "regexp"
"slices" "slices"
"strconv" "strconv"
@@ -14,6 +15,14 @@ import (
"github.com/guregu/null/v5" "github.com/guregu/null/v5"
) )
type GeminiPageData struct {
ResponseCode int
MimeType string
Lang string
GemText string
Data []byte
}
// Resolve the URL hostname and // Resolve the URL hostname and
// check if we already have an open // check if we already have an open
// connection to this host. // connection to this host.
@@ -31,51 +40,53 @@ func getHostIPAddresses(hostname string) ([]string, error) {
return addrs, nil return addrs, nil
} }
// Connect to given URL, using the Gemini protocol. func ConnectAndGetData(url string) ([]byte, error) {
// Return a Snapshot with the data or the error. parsedUrl, err := go_url.Parse(url)
// Any errors are stored within the snapshot. if err != nil {
func Visit(s *Snapshot) { return nil, fmt.Errorf("Could not parse URL, error %w", err)
}
host := parsedUrl.Host
port := parsedUrl.Port()
if port == "" {
port = "1965"
host = fmt.Sprintf("%s:%s", host, port)
}
// Establish the underlying TCP connection. // Establish the underlying TCP connection.
host := fmt.Sprintf("%s:%d", s.Host, s.URL.Port)
dialer := &net.Dialer{ dialer := &net.Dialer{
Timeout: time.Duration(config.CONFIG.ResponseTimeout) * time.Second, // Set the overall connection timeout Timeout: time.Duration(config.CONFIG.ResponseTimeout) * time.Second,
KeepAlive: 30 * time.Second, KeepAlive: 10 * time.Second,
} }
conn, err := dialer.Dial("tcp", host) conn, err := dialer.Dial("tcp", host)
if err != nil { if err != nil {
s.Error = null.StringFrom(fmt.Sprintf("TCP connection failed: %v", err)) return nil, fmt.Errorf("TCP connection failed: %w", err)
return
} }
// Make sure we always close the connection. // Make sure we always close the connection.
defer func() { defer func() {
err := conn.Close() err := conn.Close()
if err != nil { if err != nil {
s.Error = null.StringFrom(fmt.Sprintf("Error closing connection: %s", err)) // Do nothing! Connection will timeout eventually if still open somehow.
} }
}() }()
// Set read and write timeouts on the TCP connection. // Set read and write timeouts on the TCP connection.
err = conn.SetReadDeadline(time.Now().Add(time.Duration(config.CONFIG.ResponseTimeout) * time.Second)) err = conn.SetReadDeadline(time.Now().Add(time.Duration(config.CONFIG.ResponseTimeout) * time.Second))
if err != nil { if err != nil {
s.Error = null.StringFrom(fmt.Sprintf("Error setting connection deadline: %s", err)) return nil, fmt.Errorf("Error setting connection deadline: %w", err)
return
} }
err = conn.SetWriteDeadline(time.Now().Add(time.Duration(config.CONFIG.ResponseTimeout) * time.Second)) err = conn.SetWriteDeadline(time.Now().Add(time.Duration(config.CONFIG.ResponseTimeout) * time.Second))
if err != nil { if err != nil {
s.Error = null.StringFrom(fmt.Sprintf("Error setting connection deadline: %s", err)) return nil, fmt.Errorf("Error setting connection deadline: %w", err)
return
} }
// Perform the TLS handshake // Perform the TLS handshake
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
InsecureSkipVerify: true, // Accept all TLS certs, even if insecure. InsecureSkipVerify: true, // Accept all TLS certs, even if insecure.
ServerName: s.URL.Hostname, // SNI ServerName: parsedUrl.Host, // SNI
// MinVersion: tls.VersionTLS12, // Use a minimum TLS version. Warning breaks a lot of sites. // MinVersion: tls.VersionTLS12, // Use a minimum TLS version. Warning breaks a lot of sites.
} }
tlsConn := tls.Client(conn, tlsConfig) tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.Handshake(); err != nil {
s.Error = null.StringFrom(fmt.Sprintf("TLS handshake error: %v", err)) return nil, fmt.Errorf("TLS handshake error: %w", err)
return
} }
// We read `buf`-sized chunks and add data to `data`. // We read `buf`-sized chunks and add data to `data`.
@@ -83,10 +94,9 @@ func Visit(s *Snapshot) {
var data []byte var data []byte
// Send Gemini request to trigger server response. // Send Gemini request to trigger server response.
_, err = tlsConn.Write([]byte(fmt.Sprintf("%s\r\n", s.URL.String()))) _, err = tlsConn.Write([]byte(fmt.Sprintf("%s\r\n", url)))
if err != nil { if err != nil {
s.Error = null.StringFrom(fmt.Sprintf("Error sending network request: %s", err)) return nil, fmt.Errorf("Error sending network request: %w", err)
return
} }
// Read response bytes in len(buf) byte chunks // Read response bytes in len(buf) byte chunks
for { for {
@@ -96,21 +106,40 @@ func Visit(s *Snapshot) {
} }
if len(data) > config.CONFIG.MaxResponseSize { if len(data) > config.CONFIG.MaxResponseSize {
data = []byte{} data = []byte{}
s.Error = null.StringFrom(fmt.Sprintf("Response size exceeded maximum of %d bytes", config.CONFIG.MaxResponseSize)) return nil, fmt.Errorf("Response size exceeded maximum of %d bytes", config.CONFIG.MaxResponseSize)
} }
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
break break
} else { } else {
s.Error = null.StringFrom(fmt.Sprintf("Network error: %s", err)) return nil, fmt.Errorf("Network error: %s", err)
return
} }
} }
} }
// Great, response data received. return data, nil
err = processResponse(s, data) }
// Connect to given URL, using the Gemini protocol.
// Mutate given Snapshot with the data or the error.
func Visit(s *Snapshot) {
data, err := ConnectAndGetData(s.URL.String())
if err != nil { if err != nil {
s.Error = null.StringFrom(err.Error()) s.Error = null.StringFrom(err.Error())
return
}
pageData, err := processData(data)
if err != nil {
s.Error = null.StringFrom(err.Error())
return
}
s.ResponseCode = null.IntFrom(int64(pageData.ResponseCode))
s.MimeType = null.StringFrom(pageData.MimeType)
s.Lang = null.StringFrom(pageData.Lang)
if pageData.GemText != "" {
s.GemText = null.StringFrom(string(pageData.GemText))
}
if pageData.Data != nil {
s.Data = null.ValueFrom(pageData.Data)
} }
return return
} }
@@ -118,31 +147,33 @@ func Visit(s *Snapshot) {
// Update given snapshot with the // Update given snapshot with the
// Gemini header data: response code, // Gemini header data: response code,
// mime type and lang (optional) // mime type and lang (optional)
func processResponse(snapshot *Snapshot, data []byte) error { func processData(data []byte) (*GeminiPageData, error) {
headers, body, err := getHeadersAndData(data) headers, body, err := getHeadersAndData(data)
if err != nil { if err != nil {
return err return nil, err
} }
code, mimeType, lang := getMimeTypeAndLang(headers) code, mimeType, lang := getMimeTypeAndLang(headers)
geminiError := checkGeminiStatusCode(code) geminiError := checkGeminiStatusCode(code)
if geminiError != nil { if geminiError != nil {
return geminiError return nil, geminiError
}
pageData := GeminiPageData{
ResponseCode: code,
MimeType: mimeType,
Lang: lang,
} }
snapshot.ResponseCode = null.IntFrom(int64(code))
snapshot.MimeType = null.StringFrom(mimeType)
snapshot.Lang = null.StringFrom(lang)
// If we've got a Gemini document, populate // If we've got a Gemini document, populate
// `GemText` field, otherwise raw data goes to `Data`. // `GemText` field, otherwise raw data goes to `Data`.
if mimeType == "text/gemini" { if mimeType == "text/gemini" {
validBody, err := EnsureValidUTF8(body) validBody, err := EnsureValidUTF8(body)
if err != nil { if err != nil {
return fmt.Errorf("UTF-8 error: %w", err) return nil, fmt.Errorf("UTF-8 error: %w", err)
} }
snapshot.GemText = null.StringFrom(string(validBody)) pageData.GemText = validBody
} else { } else {
snapshot.Data = null.ValueFrom(body) pageData.Data = body
} }
return nil return &pageData, nil
} }
// Checks for a Gemini header, which is // Checks for a Gemini header, which is