diff --git a/config/config.go b/config/config.go index bc651eb..24879e9 100644 --- a/config/config.go +++ b/config/config.go @@ -39,13 +39,21 @@ func (c *Config) String() string { var CONFIG Config // parsePositiveInt parses and validates positive integer values -func parsePositiveInt(value string) (int, error) { +func parsePositiveInt(param, value string) (int, error) { val, err := strconv.Atoi(value) if err != nil { - return 0, err + return 0, &ValidationError{ + Param: param, + Value: value, + Reason: "must be a valid integer", + } } if val <= 0 { - return 0, fmt.Errorf("value must be positive") + return 0, &ValidationError{ + Param: param, + Value: value, + Reason: "must be positive", + } } return val, nil } @@ -59,46 +67,53 @@ func GetConfig() *Config { EnvLogLevel: func(v string) error { level, err := zerolog.ParseLevel(v) if err != nil { - return fmt.Errorf("invalid log level: %w", err) + return &ValidationError{ + Param: EnvLogLevel, + Value: v, + Reason: "must be one of: debug, info, warn, error", + } } config.LogLevel = level return nil }, EnvRootPath: func(v string) error { if _, err := os.Stat(v); err != nil { - return fmt.Errorf("invalid root path: %w", err) + return &ConfigError{ + Param: EnvRootPath, + Err: err, + } } config.RootPath = v return nil }, EnvNumWorkers: func(v string) error { - val, err := parsePositiveInt(v) + val, err := parsePositiveInt(EnvNumWorkers, v) if err != nil { - return fmt.Errorf("invalid number of workers: %w", err) + return err } config.NumOfWorkers = val return nil }, EnvWorkerBatchSize: func(v string) error { - val, err := parsePositiveInt(v) + val, err := parsePositiveInt(EnvWorkerBatchSize, v) if err != nil { - return fmt.Errorf("invalid worker batch size: %w", err) + return err } config.WorkerBatchSize = val return nil }, EnvMaxResponseSize: func(v string) error { - val, err := parsePositiveInt(v) + val, err := parsePositiveInt(EnvMaxResponseSize, v) if err != nil { - return fmt.Errorf("invalid max response size: %w", err) + return err } config.MaxResponseSize = val return nil }, EnvResponseTimeout: func(v string) error { - val, err := parsePositiveInt(v) + val, err := parsePositiveInt(EnvResponseTimeout, v) if err != nil { - return fmt.Errorf("invalid response timeout: %w", err) + return err } config.ResponseTimeout = val return nil diff --git a/config/config_test.go b/config/config_test.go index ace8aca..541f1c3 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -38,22 +38,56 @@ func TestGetConfig(t *testing.T) { func TestParsePositiveInt(t *testing.T) { tests := []struct { - name string - input string - want int - wantErr bool + name string + param string + input string + want int + wantErr bool + errType interface{} + errMessage string }{ - {"valid positive", "42", 42, false}, - {"zero", "0", 0, true}, - {"negative", "-1", 0, true}, - {"invalid", "abc", 0, true}, + { + name: "valid positive", + param: "TEST_PARAM", + input: "42", + want: 42, + wantErr: false, + }, + { + name: "zero", + param: "TEST_PARAM", + input: "0", + wantErr: true, + errType: &ValidationError{}, + errMessage: "invalid value '0' for TEST_PARAM: must be positive", + }, + { + name: "negative", + param: "TEST_PARAM", + input: "-1", + wantErr: true, + errType: &ValidationError{}, + errMessage: "invalid value '-1' for TEST_PARAM: must be positive", + }, + { + name: "invalid", + param: "TEST_PARAM", + input: "abc", + wantErr: true, + errType: &ValidationError{}, + errMessage: "invalid value 'abc' for TEST_PARAM: must be a valid integer", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := parsePositiveInt(tt.input) + got, err := parsePositiveInt(tt.param, tt.input) if tt.wantErr { assert.Error(t, err) + assert.IsType(t, tt.errType, err) + if tt.errMessage != "" { + assert.Equal(t, tt.errMessage, err.Error()) + } } else { assert.NoError(t, err) assert.Equal(t, tt.want, got) @@ -61,3 +95,52 @@ func TestParsePositiveInt(t *testing.T) { }) } } + +func TestConfigValidation(t *testing.T) { + tests := []struct { + name string + envVars map[string]string + wantErr bool + errMessage string + }{ + { + name: "invalid log level", + envVars: map[string]string{ + EnvLogLevel: "invalid", + }, + wantErr: true, + errMessage: "invalid value 'invalid' for LOG_LEVEL: must be one of: debug, info, warn, error", + }, + { + name: "invalid worker count", + envVars: map[string]string{ + EnvLogLevel: "debug", + EnvRootPath: ".", + EnvNumWorkers: "-1", + }, + wantErr: true, + errMessage: "invalid value '-1' for NUM_OF_WORKERS: must be positive", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Clear environment + os.Clearenv() + + // Set required environment variables + for k, v := range tt.envVars { + os.Setenv(k, v) + } + + // Defer cleanup + defer os.Clearenv() + + if tt.wantErr { + assert.PanicsWithError(t, tt.errMessage, func() { + GetConfig() + }) + } + }) + } +} diff --git a/config/errors.go b/config/errors.go new file mode 100644 index 0000000..9a1d6dd --- /dev/null +++ b/config/errors.go @@ -0,0 +1,28 @@ +package config + +import "fmt" + +// ConfigError represents a configuration error +type ConfigError struct { + Param string + Err error +} + +func (e *ConfigError) Error() string { + return fmt.Sprintf("configuration error for %s: %v", e.Param, e.Err) +} + +func (e *ConfigError) Unwrap() error { + return e.Err +} + +// ValidationError represents a validation error +type ValidationError struct { + Param string + Value string + Reason string +} + +func (e *ValidationError) Error() string { + return fmt.Sprintf("invalid value '%s' for %s: %s", e.Value, e.Param, e.Reason) +}