package main // Benchmarks a Gemini server. import ( "context" "crypto/tls" "flag" "fmt" "io" "log/slog" "net/url" "os" "os/signal" "strings" "sync" "syscall" "time" "gemserve/lib/logging" ) func main() { // Parse command-line flags insecure := flag.Bool("insecure", false, "Skip TLS certificate verification") totalConnections := flag.Int("total-connections", 250, "Total connections to make") parallelism := flag.Int("parallelism", 10, "How many connections to run in parallel") flag.Parse() // Get the URL from arguments args := flag.Args() if len(args) != 1 { fmt.Fprintf(os.Stderr, "Usage: gemget [--insecure] \n") os.Exit(1) } logging.SetupLogging(slog.LevelInfo) logger := logging.Logger ctx := logging.WithLogger(context.Background(), logger) geminiURL := args[0] host := validateUrl(geminiURL) if host == "" { logger.Error("Invalid URL.") os.Exit(1) } start := time.Now() err := benchmark(ctx, geminiURL, host, *insecure, *totalConnections, *parallelism) if err != nil { logger.Error(err.Error()) os.Exit(1) } end := time.Now() tookMs := end.Sub(start).Milliseconds() logger.Info("End.", "ms", tookMs) } var wg sync.WaitGroup type ctxKey int const ctxKeyJobIndex ctxKey = 1 func benchmark(ctx context.Context, u string, h string, insecure bool, totalConnections int, parallelism int) error { logger := logging.FromContext(ctx) signals := make(chan os.Signal, 1) signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) // Root context, used to cancel // connections and graceful shutdown. ctx, cancel := context.WithCancel(ctx) defer cancel() // Semaphore to limit concurrency. // Goroutines put value to channel (acquire slot) // and consume value from channel (release slot). semaphore := make(chan struct{}, parallelism) loop: for i := 0; i < totalConnections; i++ { select { case <-signals: logger.Warn("Received SIGINT or SIGTERM signal, shutting down gracefully") cancel() break loop case semaphore <- struct{}{}: // Acquire slot wg.Add(1) go func(jobIndex int) { defer func() { <-semaphore // Release slot wg.Done() }() ctxWithValue := context.WithValue(ctx, ctxKeyJobIndex, jobIndex) ctxWithTimeout, cancel := context.WithTimeout(ctxWithValue, 60*time.Second) defer cancel() err := connect(ctxWithTimeout, u, h, insecure) if err != nil { logger.Warn(fmt.Sprintf("%d error: %v", jobIndex, err)) } }(i) } } wg.Wait() return nil } func validateUrl(u string) string { // Parse the URL parsedURL, err := url.Parse(u) if err != nil { fmt.Fprintf(os.Stderr, "Error parsing URL: %v\n", err) os.Exit(1) } // Ensure it's a gemini URL if parsedURL.Scheme != "gemini" { fmt.Fprintf(os.Stderr, "Error: URL must use gemini:// scheme\n") os.Exit(1) } // Get host and port host := parsedURL.Host if !strings.Contains(host, ":") { host = host + ":1965" // Default Gemini port } return host } func connect(ctx context.Context, url string, host string, insecure bool) error { logger := logging.FromContext(ctx) tlsConfig := &tls.Config{ InsecureSkipVerify: insecure, MinVersion: tls.VersionTLS12, } // Context checkpoint if ctx.Err() != nil { return nil } // Connect to the server conn, err := tls.Dial("tcp", host, tlsConfig) if err != nil { return err } // Set connection deadline based on context if deadline, ok := ctx.Deadline(); ok { _ = conn.SetDeadline(deadline) } defer func() { _ = conn.Close() }() // Context checkpoint if ctx.Err() != nil { return nil } // Send the request (URL + CRLF) request := url + "\r\n" _, err = conn.Write([]byte(request)) if err != nil { return err } // Context checkpoint if ctx.Err() != nil { return nil } // Read and dump response _, err = io.Copy(io.Discard, conn) if err != nil { return err } jobIndex := ctx.Value(ctxKeyJobIndex) logger.Debug(fmt.Sprintf("%d done", jobIndex)) return nil }