fix: Refactor ConnectAndGetData function to return GeminiPageData struct
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
|||||||
"gemini-grc/config"
|
"gemini-grc/config"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
go_url "net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -14,6 +15,14 @@ import (
|
|||||||
"github.com/guregu/null/v5"
|
"github.com/guregu/null/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type GeminiPageData struct {
|
||||||
|
ResponseCode int
|
||||||
|
MimeType string
|
||||||
|
Lang string
|
||||||
|
GemText string
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
// Resolve the URL hostname and
|
// Resolve the URL hostname and
|
||||||
// check if we already have an open
|
// check if we already have an open
|
||||||
// connection to this host.
|
// connection to this host.
|
||||||
@@ -31,51 +40,53 @@ func getHostIPAddresses(hostname string) ([]string, error) {
|
|||||||
return addrs, nil
|
return addrs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect to given URL, using the Gemini protocol.
|
func ConnectAndGetData(url string) ([]byte, error) {
|
||||||
// Return a Snapshot with the data or the error.
|
parsedUrl, err := go_url.Parse(url)
|
||||||
// Any errors are stored within the snapshot.
|
if err != nil {
|
||||||
func Visit(s *Snapshot) {
|
return nil, fmt.Errorf("Could not parse URL, error %w", err)
|
||||||
|
}
|
||||||
|
host := parsedUrl.Host
|
||||||
|
port := parsedUrl.Port()
|
||||||
|
if port == "" {
|
||||||
|
port = "1965"
|
||||||
|
host = fmt.Sprintf("%s:%s", host, port)
|
||||||
|
}
|
||||||
// Establish the underlying TCP connection.
|
// Establish the underlying TCP connection.
|
||||||
host := fmt.Sprintf("%s:%d", s.Host, s.URL.Port)
|
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
Timeout: time.Duration(config.CONFIG.ResponseTimeout) * time.Second, // Set the overall connection timeout
|
Timeout: time.Duration(config.CONFIG.ResponseTimeout) * time.Second,
|
||||||
KeepAlive: 30 * time.Second,
|
KeepAlive: 10 * time.Second,
|
||||||
}
|
}
|
||||||
conn, err := dialer.Dial("tcp", host)
|
conn, err := dialer.Dial("tcp", host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.Error = null.StringFrom(fmt.Sprintf("TCP connection failed: %v", err))
|
return nil, fmt.Errorf("TCP connection failed: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
// Make sure we always close the connection.
|
// Make sure we always close the connection.
|
||||||
defer func() {
|
defer func() {
|
||||||
err := conn.Close()
|
err := conn.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.Error = null.StringFrom(fmt.Sprintf("Error closing connection: %s", err))
|
// Do nothing! Connection will timeout eventually if still open somehow.
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Set read and write timeouts on the TCP connection.
|
// Set read and write timeouts on the TCP connection.
|
||||||
err = conn.SetReadDeadline(time.Now().Add(time.Duration(config.CONFIG.ResponseTimeout) * time.Second))
|
err = conn.SetReadDeadline(time.Now().Add(time.Duration(config.CONFIG.ResponseTimeout) * time.Second))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.Error = null.StringFrom(fmt.Sprintf("Error setting connection deadline: %s", err))
|
return nil, fmt.Errorf("Error setting connection deadline: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
err = conn.SetWriteDeadline(time.Now().Add(time.Duration(config.CONFIG.ResponseTimeout) * time.Second))
|
err = conn.SetWriteDeadline(time.Now().Add(time.Duration(config.CONFIG.ResponseTimeout) * time.Second))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.Error = null.StringFrom(fmt.Sprintf("Error setting connection deadline: %s", err))
|
return nil, fmt.Errorf("Error setting connection deadline: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform the TLS handshake
|
// Perform the TLS handshake
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
InsecureSkipVerify: true, // Accept all TLS certs, even if insecure.
|
InsecureSkipVerify: true, // Accept all TLS certs, even if insecure.
|
||||||
ServerName: s.URL.Hostname, // SNI
|
ServerName: parsedUrl.Host, // SNI
|
||||||
// MinVersion: tls.VersionTLS12, // Use a minimum TLS version. Warning breaks a lot of sites.
|
// MinVersion: tls.VersionTLS12, // Use a minimum TLS version. Warning breaks a lot of sites.
|
||||||
}
|
}
|
||||||
tlsConn := tls.Client(conn, tlsConfig)
|
tlsConn := tls.Client(conn, tlsConfig)
|
||||||
if err := tlsConn.Handshake(); err != nil {
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
s.Error = null.StringFrom(fmt.Sprintf("TLS handshake error: %v", err))
|
return nil, fmt.Errorf("TLS handshake error: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// We read `buf`-sized chunks and add data to `data`.
|
// We read `buf`-sized chunks and add data to `data`.
|
||||||
@@ -83,10 +94,9 @@ func Visit(s *Snapshot) {
|
|||||||
var data []byte
|
var data []byte
|
||||||
|
|
||||||
// Send Gemini request to trigger server response.
|
// Send Gemini request to trigger server response.
|
||||||
_, err = tlsConn.Write([]byte(fmt.Sprintf("%s\r\n", s.URL.String())))
|
_, err = tlsConn.Write([]byte(fmt.Sprintf("%s\r\n", url)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.Error = null.StringFrom(fmt.Sprintf("Error sending network request: %s", err))
|
return nil, fmt.Errorf("Error sending network request: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
// Read response bytes in len(buf) byte chunks
|
// Read response bytes in len(buf) byte chunks
|
||||||
for {
|
for {
|
||||||
@@ -96,21 +106,40 @@ func Visit(s *Snapshot) {
|
|||||||
}
|
}
|
||||||
if len(data) > config.CONFIG.MaxResponseSize {
|
if len(data) > config.CONFIG.MaxResponseSize {
|
||||||
data = []byte{}
|
data = []byte{}
|
||||||
s.Error = null.StringFrom(fmt.Sprintf("Response size exceeded maximum of %d bytes", config.CONFIG.MaxResponseSize))
|
return nil, fmt.Errorf("Response size exceeded maximum of %d bytes", config.CONFIG.MaxResponseSize)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
s.Error = null.StringFrom(fmt.Sprintf("Network error: %s", err))
|
return nil, fmt.Errorf("Network error: %s", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Great, response data received.
|
return data, nil
|
||||||
err = processResponse(s, data)
|
}
|
||||||
|
|
||||||
|
// Connect to given URL, using the Gemini protocol.
|
||||||
|
// Mutate given Snapshot with the data or the error.
|
||||||
|
func Visit(s *Snapshot) {
|
||||||
|
data, err := ConnectAndGetData(s.URL.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.Error = null.StringFrom(err.Error())
|
s.Error = null.StringFrom(err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pageData, err := processData(data)
|
||||||
|
if err != nil {
|
||||||
|
s.Error = null.StringFrom(err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.ResponseCode = null.IntFrom(int64(pageData.ResponseCode))
|
||||||
|
s.MimeType = null.StringFrom(pageData.MimeType)
|
||||||
|
s.Lang = null.StringFrom(pageData.Lang)
|
||||||
|
if pageData.GemText != "" {
|
||||||
|
s.GemText = null.StringFrom(string(pageData.GemText))
|
||||||
|
}
|
||||||
|
if pageData.Data != nil {
|
||||||
|
s.Data = null.ValueFrom(pageData.Data)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -118,31 +147,33 @@ func Visit(s *Snapshot) {
|
|||||||
// Update given snapshot with the
|
// Update given snapshot with the
|
||||||
// Gemini header data: response code,
|
// Gemini header data: response code,
|
||||||
// mime type and lang (optional)
|
// mime type and lang (optional)
|
||||||
func processResponse(snapshot *Snapshot, data []byte) error {
|
func processData(data []byte) (*GeminiPageData, error) {
|
||||||
headers, body, err := getHeadersAndData(data)
|
headers, body, err := getHeadersAndData(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
code, mimeType, lang := getMimeTypeAndLang(headers)
|
code, mimeType, lang := getMimeTypeAndLang(headers)
|
||||||
geminiError := checkGeminiStatusCode(code)
|
geminiError := checkGeminiStatusCode(code)
|
||||||
if geminiError != nil {
|
if geminiError != nil {
|
||||||
return geminiError
|
return nil, geminiError
|
||||||
|
}
|
||||||
|
pageData := GeminiPageData{
|
||||||
|
ResponseCode: code,
|
||||||
|
MimeType: mimeType,
|
||||||
|
Lang: lang,
|
||||||
}
|
}
|
||||||
snapshot.ResponseCode = null.IntFrom(int64(code))
|
|
||||||
snapshot.MimeType = null.StringFrom(mimeType)
|
|
||||||
snapshot.Lang = null.StringFrom(lang)
|
|
||||||
// If we've got a Gemini document, populate
|
// If we've got a Gemini document, populate
|
||||||
// `GemText` field, otherwise raw data goes to `Data`.
|
// `GemText` field, otherwise raw data goes to `Data`.
|
||||||
if mimeType == "text/gemini" {
|
if mimeType == "text/gemini" {
|
||||||
validBody, err := EnsureValidUTF8(body)
|
validBody, err := EnsureValidUTF8(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("UTF-8 error: %w", err)
|
return nil, fmt.Errorf("UTF-8 error: %w", err)
|
||||||
}
|
}
|
||||||
snapshot.GemText = null.StringFrom(string(validBody))
|
pageData.GemText = validBody
|
||||||
} else {
|
} else {
|
||||||
snapshot.Data = null.ValueFrom(body)
|
pageData.Data = body
|
||||||
}
|
}
|
||||||
return nil
|
return &pageData, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Checks for a Gemini header, which is
|
// Checks for a Gemini header, which is
|
||||||
|
|||||||
Reference in New Issue
Block a user