Files
gemini-grc/gemini/db.go

177 lines
4.7 KiB
Go

package gemini
import (
"encoding/json"
"errors"
"fmt"
"os"
"strconv"
"gemini-grc/config"
"gemini-grc/logging"
_ "github.com/jackc/pgx/v5/stdlib" // PGX driver for PostgreSQL
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
)
func ConnectToDB() *sqlx.DB {
connStr := fmt.Sprintf("postgres://%s:%s@%s:%s/%s", //nolint:nosprintfhostport
os.Getenv("PG_USER"),
os.Getenv("PG_PASSWORD"),
os.Getenv("PG_HOST"),
os.Getenv("PG_PORT"),
os.Getenv("PG_DATABASE"),
)
// Create a connection pool
db, err := sqlx.Open("pgx", connStr)
if err != nil {
panic(fmt.Sprintf("Unable to connect to database with URL %s: %v\n", connStr, err))
}
// TODO move PG_MAX_OPEN_CONNECTIONS to config env variables
maxConnections, err := strconv.Atoi(os.Getenv("PG_MAX_OPEN_CONNECTIONS"))
if err != nil {
panic(fmt.Sprintf("Unable to set max DB connections: %s\n", err))
}
db.SetMaxOpenConns(maxConnections)
err = db.Ping()
if err != nil {
panic(fmt.Sprintf("Unable to ping database: %v\n", err))
}
logging.LogDebug("Connected to database")
return db
}
// isDeadlockError checks if the error is a PostgreSQL deadlock error
func isDeadlockError(err error) bool {
var pqErr *pq.Error
if errors.As(err, &pqErr) {
return pqErr.Code == "40P01" // PostgreSQL deadlock error code
}
return false
}
func GetSnapshotsToVisit(tx *sqlx.Tx) ([]Snapshot, error) {
var snapshots []Snapshot
err := tx.Select(&snapshots, SQL_SELECT_UNVISITED_SNAPSHOTS_UNIQUE_HOSTS, config.CONFIG.WorkerBatchSize)
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrDatabase, err)
}
return snapshots, nil
}
func SaveSnapshotIfNew(tx *sqlx.Tx, s *Snapshot) error {
if config.CONFIG.DryRun {
marshalled, err := json.MarshalIndent(s, "", " ")
if err != nil {
panic(fmt.Sprintf("JSON serialization error for %v", s))
}
logging.LogDebug("Would insert (if new) snapshot %s", marshalled)
return nil
}
query := SQL_INSERT_SNAPSHOT_IF_NEW
_, err := tx.NamedExec(query, s)
if err != nil {
return fmt.Errorf("[%s] GeminiError inserting snapshot: %w", s.URL, err)
}
return nil
}
func UpsertSnapshot(workedID int, tx *sqlx.Tx, s *Snapshot) (err error) {
// if config.CONFIG.DryRun {
//marshalled, err := json.MarshalIndent(s, "", " ")
//if err != nil {
// panic(fmt.Sprintf("JSON serialization error for %v", s))
//}
//logging.LogDebug("[%d] Would upsert snapshot %s", workedID, marshalled)
// return nil
// }
query := SQL_UPSERT_SNAPSHOT
rows, err := tx.NamedQuery(query, s)
if err != nil {
return fmt.Errorf("[%d] %w while upserting snapshot: %w", workedID, ErrDatabase, err)
}
defer func() {
_err := rows.Close()
if _err != nil {
err = fmt.Errorf("[%d] %w error closing rows: %w", workedID, ErrDatabase, _err)
}
}()
if rows.Next() {
var returnedID int
err = rows.Scan(&returnedID)
if err != nil {
return fmt.Errorf("[%d] %w error scanning returned id: %w", workedID, ErrDatabase, err)
}
s.ID = returnedID
// logging.LogDebug("[%d] Upserted snapshot with ID %d", workedID, returnedID)
}
return nil
}
func UpdateSnapshot(workedID int, tx *sqlx.Tx, s *Snapshot) (err error) {
// if config.CONFIG.DryRun {
//marshalled, err := json.MarshalIndent(s, "", " ")
//if err != nil {
// panic(fmt.Sprintf("JSON serialization error for %v", s))
//}
//logging.LogDebug("[%d] Would upsert snapshot %s", workedID, marshalled)
// return nil
// }
query := SQL_UPDATE_SNAPSHOT
rows, err := tx.NamedQuery(query, s)
if err != nil {
return fmt.Errorf("[%d] %w while updating snapshot: %w", workedID, ErrDatabase, err)
}
defer func() {
_err := rows.Close()
if _err != nil {
err = fmt.Errorf("[%d] %w error closing rows: %w", workedID, ErrDatabase, _err)
}
}()
if rows.Next() {
var returnedID int
err = rows.Scan(&returnedID)
if err != nil {
return fmt.Errorf("[%d] %w error scanning returned id: %w", workedID, ErrDatabase, err)
}
s.ID = returnedID
// logging.LogDebug("[%d] Updated snapshot with ID %d", workedID, returnedID)
}
return nil
}
func SaveLinksToDBinBatches(tx *sqlx.Tx, snapshots []*Snapshot) error {
if config.CONFIG.DryRun {
return nil
}
const batchSize = 5000
query := SQL_INSERT_SNAPSHOT_IF_NEW
for i := 0; i < len(snapshots); i += batchSize {
end := i + batchSize
if end > len(snapshots) {
end = len(snapshots)
}
batch := snapshots[i:end]
_, err := tx.NamedExec(query, batch)
if err != nil {
return fmt.Errorf("%w: While saving links in batches: %w", ErrDatabase, err)
}
}
return nil
}
func SaveLinksToDB(tx *sqlx.Tx, snapshots []*Snapshot) error {
if config.CONFIG.DryRun {
return nil
}
query := SQL_INSERT_SNAPSHOT_IF_NEW
_, err := tx.NamedExec(query, snapshots)
if err != nil {
logging.LogError("GeminiError batch inserting snapshots: %w", err)
return fmt.Errorf("DB error: %w", err)
}
return nil
}