diff --git a/cmd/prometheus-cve-exporter/main.go b/cmd/prometheus-cve-exporter/main.go index 2100bc8..ddbb6f9 100644 --- a/cmd/prometheus-cve-exporter/main.go +++ b/cmd/prometheus-cve-exporter/main.go @@ -1,9 +1,15 @@ package main import ( + "context" + "crypto/tls" "fmt" "log" "net/http" + "os" + "os/signal" + "syscall" + "time" "zops.top/prometheus-cve-exporter/config" "zops.top/prometheus-cve-exporter/internal/exporter" @@ -12,23 +18,106 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" ) -func main() { - cfg, err := config.Load() - if err != nil { - log.Fatalf("Failed to load configuration: %v", err) - } +type UpdateMetricsFunc func(*config.Config) - go exporter.UpdateMetrics(cfg) - startServer(cfg) +type Server struct { + cfg *config.Config + logger *log.Logger + mux *http.ServeMux + server *http.Server + updateMetrics UpdateMetricsFunc } -func startServer(cfg *config.Config) { - http.Handle("/metrics", promhttp.HandlerFor( +func NewServer(cfg *config.Config, logger *log.Logger, updateMetrics UpdateMetricsFunc) *Server { + return &Server{ + cfg: cfg, + logger: logger, + mux: http.NewServeMux(), + updateMetrics: updateMetrics, + } +} + +func (s *Server) SetupRouter() { + s.mux.Handle("/metrics", promhttp.HandlerFor( prometheus.DefaultGatherer, promhttp.HandlerOpts{ EnableOpenMetrics: true, }, )) - fmt.Printf("Starting server on :%d\n", cfg.Port) - log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", cfg.Port), nil)) + + s.mux.HandleFunc("/", s.homeHandler) +} + +func (s *Server) homeHandler(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + fmt.Fprint(w, `Go to metrics`) +} + +func (s *Server) Start() { + s.server = &http.Server{ + Addr: fmt.Sprintf(":%d", s.cfg.Port), + Handler: s.mux, + ReadTimeout: 5 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 120 * time.Second, + } + + if s.cfg.UseTLS { + s.server.TLSConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + PreferServerCipherSuites: true, + } + } + + go func() { + var err error + s.logger.Printf("Starting server on :%d\n", s.cfg.Port) + if s.cfg.UseTLS { + s.logger.Println("TLS enabled") + err = s.server.ListenAndServeTLS(s.cfg.TLSCert, s.cfg.TLSKey) + } else { + s.logger.Println("TLS disabled") + err = s.server.ListenAndServe() + } + if err != nil && err != http.ErrServerClosed { + s.logger.Fatalf("Could not listen on %d: %v\n", s.cfg.Port, err) + } + }() +} + +func (s *Server) GracefulShutdown() { + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + sig := <-quit + s.logger.Printf("Received signal: %v. Initiating shutdown...", sig) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + s.server.SetKeepAlivesEnabled(false) + if err := s.server.Shutdown(ctx); err != nil { + s.logger.Fatalf("Server forced to shutdown: %v", err) + } + + s.logger.Println("Server exiting") +} + +func main() { + logger := log.New(os.Stdout, "", log.LstdFlags) + + cfg, err := config.Load() + if err != nil { + logger.Fatalf("Failed to load configuration: %v", err) + } + + server := NewServer(cfg, logger, exporter.UpdateMetrics) + go server.updateMetrics(cfg) + + server.SetupRouter() + server.Start() + server.GracefulShutdown() } diff --git a/cmd/prometheus-cve-exporter/main_test.go b/cmd/prometheus-cve-exporter/main_test.go new file mode 100644 index 0000000..310be37 --- /dev/null +++ b/cmd/prometheus-cve-exporter/main_test.go @@ -0,0 +1,209 @@ +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "net/http/httptest" + "os" + "sync" + "syscall" + "testing" + "time" + + "zops.top/prometheus-cve-exporter/config" +) + +func mockUpdateMetrics(*config.Config) {} + +func TestNewServer(t *testing.T) { + cfg := &config.Config{} + logger := log.New(os.Stdout, "", log.LstdFlags) + server := NewServer(cfg, logger, mockUpdateMetrics) + + if server.cfg != cfg { + t.Errorf("Expected cfg to be %v, got %v", cfg, server.cfg) + } + if server.logger != logger { + t.Errorf("Expected logger to be %v, got %v", logger, server.logger) + } + if server.mux == nil { + t.Error("Expected mux to be initialized") + } + if server.updateMetrics == nil { + t.Error("Expected updateMetrics to be initialized") + } +} + +func TestSetupRouter(t *testing.T) { + server := NewServer(&config.Config{}, log.New(os.Stdout, "", log.LstdFlags), mockUpdateMetrics) + server.SetupRouter() + + testCases := []struct { + path string + expectedCode int + }{ + {"/metrics", http.StatusOK}, + {"/", http.StatusOK}, + {"/nonexistent", http.StatusNotFound}, + } + + for _, tc := range testCases { + req, err := http.NewRequest("GET", tc.path, nil) + if err != nil { + t.Fatalf("Could not create request: %v", err) + } + + rr := httptest.NewRecorder() + server.mux.ServeHTTP(rr, req) + + if rr.Code != tc.expectedCode { + t.Errorf("handler returned wrong status code for %s: got %v want %v", + tc.path, rr.Code, tc.expectedCode) + } + } +} + +func TestHomeHandler(t *testing.T) { + server := NewServer(&config.Config{}, log.New(os.Stdout, "", log.LstdFlags), mockUpdateMetrics) + + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(server.homeHandler) + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + expected := `Go to metrics` + if rr.Body.String() != expected { + t.Errorf("handler returned unexpected body: got %v want %v", + rr.Body.String(), expected) + } +} + +func TestStart(t *testing.T) { + cfg := &config.Config{Port: 20000} + logger := log.New(os.Stdout, "", log.LstdFlags) + s := NewServer(cfg, logger, mockUpdateMetrics) + s.SetupRouter() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + s.Start() + }() + + // Give some time for the server to start + time.Sleep(100 * time.Millisecond) + + resp, err := http.Get(fmt.Sprintf("http://localhost:%d", cfg.Port)) + if err != nil { + t.Fatalf("Could not send GET request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status OK; got %v", resp.Status) + } + + // Shutdown the server + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := s.server.Shutdown(ctx); err != nil { + s.logger.Fatalf("Server forced to shutdown: %v", err) + } + + wg.Wait() +} + +func TestUpdateMetricsExecution(t *testing.T) { + updateMetricsCalled := false + mockUpdateMetrics := func(*config.Config) { + updateMetricsCalled = true + } + + cfg := &config.Config{} + logger := log.New(os.Stdout, "", log.LstdFlags) + server := NewServer(cfg, logger, mockUpdateMetrics) + + server.updateMetrics(cfg) + + if !updateMetricsCalled { + t.Error("Expected UpdateMetrics to be called") + } +} + +func TestMainIntegration(t *testing.T) { + // Backup original os.Args + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + // Set up a test config file + testConfigPath := "test_config.json" + testPort := 20001 + testConfigContent := []byte(fmt.Sprintf(`{ + "nvd_feed_url": "https://test.nvd.feed.url", + "update_interval": "2h", + "port": %d, + "severity": ["HIGH", "CRITICAL"], + "package_file": "", + "use_tls": false + }`, testPort)) + err := os.WriteFile(testConfigPath, testConfigContent, 0644) + if err != nil { + t.Fatalf("Failed to create test config file: %v", err) + } + defer os.Remove(testConfigPath) + + // Set the command-line argument to use our test config + os.Args = []string{"cmd", "-config", testConfigPath} + + // Run main in a goroutine + go func() { + main() + }() + + // Give some time for the server to start + time.Sleep(100 * time.Millisecond) + + // Test if the server is running + resp, err := http.Get(fmt.Sprintf("http://localhost:%d", testPort)) + if err != nil { + t.Fatalf("Could not send GET request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status OK; got %v", resp.Status) + } + + // Send shutdown signal + err = syscall.Kill(syscall.Getpid(), syscall.SIGINT) + if err != nil { + t.Fatalf("Failed to send SIGINT signal: %v", err) + } + + // Give some time for the server to shut down + time.Sleep(500 * time.Millisecond) + + // Verify that the server has shut down + _, err = http.Get(fmt.Sprintf("http://localhost:%d", testPort)) + if err == nil { + t.Error("Expected an error when connecting to a shutdown server") + } +} + +func TestMain(m *testing.M) { + // Run tests + code := m.Run() + os.Exit(code) +} diff --git a/config.json b/config.json index f7218b9..a8f1a3a 100644 --- a/config.json +++ b/config.json @@ -8,5 +8,7 @@ "MEDIUM", "HIGH", "CRITICAL" - ] + ], + "tls_cert": "server.crt", + "tls_key": "server.key" } diff --git a/config/config.go b/config/config.go index 6f454ab..b22e9e9 100644 --- a/config/config.go +++ b/config/config.go @@ -15,6 +15,9 @@ const ( defaultPort = 10250 defaultSeverity = "CRITICAL" defaultPackageFile = "" + defaultUseTLS = false + defaultTLSCert = "" + defaultTLSKey = "" ) type Config struct { @@ -23,6 +26,9 @@ type Config struct { Port int `json:"port"` Severity []string `json:"severity"` PackageFile string `json:"package_file,omitempty"` + UseTLS bool `json:"use_tls,omitempty"` + TLSCert string `json:"tls_cert,omitempty"` + TLSKey string `json:"tls_key,omitempty"` } type configHelper struct { @@ -31,6 +37,8 @@ type configHelper struct { Port int `json:"port"` Severity []string `json:"severity"` PackageFile string `json:"package_file,omitempty"` + TLSCert string `json:"tls_cert,omitempty"` + TLSKey string `json:"tls_key,omitempty"` } func NewConfig() *Config { @@ -40,6 +48,9 @@ func NewConfig() *Config { Port: defaultPort, Severity: []string{defaultSeverity}, PackageFile: defaultPackageFile, + UseTLS: defaultUseTLS, + TLSCert: defaultTLSCert, + TLSKey: defaultTLSKey, } } @@ -59,6 +70,10 @@ func Load() (*Config, error) { return nil, err } + if cfg.TLSCert != defaultTLSCert && cfg.TLSKey != defaultTLSKey { + cfg.UseTLS = true + } + fmt.Print(prettyfyCfg(cfg)) return cfg, nil } @@ -71,6 +86,8 @@ func parseFlags(cfg *Config) string { flag.DurationVar(&cfg.UpdateInterval, "update-interval", defaultUpdateInterval, "Update interval duration") flag.IntVar(&cfg.Port, "port", defaultPort, "Port to run the server on") flag.StringVar(&cfg.PackageFile, "package-file", defaultPackageFile, "Path to file containing packages and versions") + flag.StringVar(&cfg.TLSCert, "tls-cert", defaultTLSCert, "Path to TLS certificate file") + flag.StringVar(&cfg.TLSKey, "tls-key", defaultTLSKey, "Path to TLS key file") var severity string flag.StringVar(&severity, "severity", defaultSeverity, "Comma separated list of severity levels for vulnerabilities") @@ -95,6 +112,8 @@ func loadConfigFile(cfg *Config, filename string) error { cfg.Port = helper.Port cfg.Severity = toUppercase(helper.Severity) cfg.PackageFile = helper.PackageFile + cfg.TLSCert = helper.TLSCert + cfg.TLSKey = helper.TLSKey duration, err := time.ParseDuration(helper.UpdateInterval) if err != nil { @@ -124,6 +143,8 @@ func overrideWithEnv(cfg *Config) { }, "PCE_SEVERITY": func(value string) { cfg.Severity = parseSeverity(value) }, "PCE_PACKAGE_FILE": func(value string) { cfg.PackageFile = value }, + "PCE_TLS_CERT": func(value string) { cfg.TLSCert = value }, + "PCE_TLS_KEY": func(value string) { cfg.TLSKey = value }, } for envVar, action := range envVars { @@ -139,6 +160,14 @@ func validateConfig(cfg *Config) error { return fmt.Errorf("the file %s does not exist", cfg.PackageFile) } } + if cfg.TLSCert != defaultTLSCert || cfg.TLSKey != defaultTLSKey { + if _, err := os.Stat(cfg.TLSCert); os.IsNotExist(err) { + return fmt.Errorf("the TLS certificate file %s does not exist", cfg.TLSCert) + } + if _, err := os.Stat(cfg.TLSKey); os.IsNotExist(err) { + return fmt.Errorf("the TLS key file %s does not exist", cfg.TLSKey) + } + } return nil } @@ -162,6 +191,15 @@ func parseIntEnv(value string) (int, error) { return result, err } +func parseBoolEnv(value string) bool { + switch value { + case "true": + return true + default: + return false + } +} + func prettyfyCfg(cfg *Config) string { var output strings.Builder @@ -173,6 +211,11 @@ func prettyfyCfg(cfg *Config) string { if cfg.PackageFile != "" { output.WriteString(fmt.Sprintf(" Package file: %s\n", cfg.PackageFile)) } + output.WriteString(fmt.Sprintf(" Use TLS: %v\n", cfg.UseTLS)) + if cfg.UseTLS { + output.WriteString(fmt.Sprintf(" TLS Certificate: %s\n", cfg.TLSCert)) + output.WriteString(fmt.Sprintf(" TLS Key: %s\n", cfg.TLSKey)) + } return output.String() } diff --git a/config/config_test.go b/config/config_test.go index 1ae9124..f40baca 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -34,7 +34,7 @@ func TestLoadConfigFile(t *testing.T) { "nvd_feed_url": "https://example.com/feed.json", "update_interval": "48h", "port": 9090, - "severity": ["HIGH", "MEDIUM"], + "severity": ["HIGH", "MEDIUM", "low"], "package_file": "/tmp/package.txt" }` @@ -57,7 +57,7 @@ func TestLoadConfigFile(t *testing.T) { if cfg.Port != 9090 { t.Errorf("Expected Port 9090, got %d", cfg.Port) } - expectedSeverity := []string{"HIGH", "MEDIUM"} + expectedSeverity := []string{"HIGH", "MEDIUM", "LOW"} for i, s := range expectedSeverity { if cfg.Severity[i] != s { t.Errorf("Expected Severity %v, got %v", expectedSeverity, cfg.Severity) @@ -74,14 +74,18 @@ func TestOverrideWithEnv(t *testing.T) { os.Setenv("PCE_NVD_JSON_GZ_FEED_URL", "https://env.com/feed.json") os.Setenv("PCE_UPDATE_INTERVAL", "72h") os.Setenv("PCE_PORT", "8080") - os.Setenv("PCE_SEVERITY", "LOW,INFO") + os.Setenv("PCE_SEVERITY", "LOW,INFO,critical") os.Setenv("PCE_PACKAGE_FILE", "/env/package.txt") + os.Setenv("PCE_TLS_CERT", "/env/tls.crt") + os.Setenv("PCE_TLS_KEY", "/env/tls.key") defer func() { os.Unsetenv("PCE_NVD_JSON_GZ_FEED_URL") os.Unsetenv("PCE_UPDATE_INTERVAL") os.Unsetenv("PCE_PORT") os.Unsetenv("PCE_SEVERITY") os.Unsetenv("PCE_PACKAGE_FILE") + os.Unsetenv("PCE_TLS_CERT") + os.Unsetenv("PCE_TLS_KEY") }() overrideWithEnv(cfg) @@ -95,7 +99,7 @@ func TestOverrideWithEnv(t *testing.T) { if cfg.Port != 8080 { t.Errorf("Expected Port 8080, got %d", cfg.Port) } - expectedSeverity := []string{"LOW", "INFO"} + expectedSeverity := []string{"LOW", "INFO", "CRITICAL"} for i, s := range expectedSeverity { if cfg.Severity[i] != s { t.Errorf("Expected Severity %v, got %v", expectedSeverity, cfg.Severity) @@ -104,10 +108,16 @@ func TestOverrideWithEnv(t *testing.T) { if cfg.PackageFile != "/env/package.txt" { t.Errorf("Expected PackageFile /env/package.txt, got %s", cfg.PackageFile) } + if cfg.TLSCert != "/env/tls.crt" { + t.Errorf("Expected TLSCert /env/tls.crt, got %s", cfg.TLSCert) + } + if cfg.TLSKey != "/env/tls.key" { + t.Errorf("Expected TLSKey /env/tls.key, got %s", cfg.TLSKey) + } } func TestParseSeverity(t *testing.T) { - severity := "CRITICAL,HIGH,MEDIUM,LOW" + severity := "CRITICAL,HIGH,medium,LOW" expected := []string{"CRITICAL", "HIGH", "MEDIUM", "LOW"} result := parseSeverity(severity) @@ -136,6 +146,87 @@ func TestParseIntEnv(t *testing.T) { } } +func TestParseBoolEnv(t *testing.T) { + tests := []struct { + name string + value string + expected bool + expectErr bool + }{ + { + name: "valid true value", + value: "true", + expected: true, + expectErr: false, + }, + { + name: "valid false value", + value: "false", + expected: false, + expectErr: false, + }, + { + name: "invalid value", + value: "invalid", + expected: false, + expectErr: false, + }, + { + name: "numeric value", + value: "123", + expected: false, + expectErr: false, + }, + { + name: "yes value", + value: "yes", + expected: false, + expectErr: false, + }, + { + name: "no value", + value: "no", + expected: false, + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseBoolEnv(tt.value) + + if result != tt.expected { + t.Errorf("parseBoolEnv(%q) = %v; want %v", tt.value, result, tt.expected) + } + }) + } +} + +func TestValidateConfig(t *testing.T) { + cfg := NewConfig() + + cfg.PackageFile = "/nonexistent/file" + err := validateConfig(cfg) + if err == nil { + t.Errorf("Expected error for nonexistent package file, got none") + } + + cfg.PackageFile = "" + cfg.TLSCert = "/nonexistent/tls.crt" + cfg.TLSKey = "/nonexistent/tls.key" + err = validateConfig(cfg) + if err == nil { + t.Errorf("Expected error for missing TLS cert and key, got none") + } + + cfg.TLSCert = "" + cfg.TLSKey = "" + err = validateConfig(cfg) + if err != nil { + t.Errorf("Expected no error for valid config, got %v", err) + } +} + func TestPrettyfyCfg(t *testing.T) { cfg := NewConfig() cfg.NVDFeedURL = "https://example.com/feed.json" @@ -143,6 +234,9 @@ func TestPrettyfyCfg(t *testing.T) { cfg.Port = 9090 cfg.Severity = []string{"HIGH", "MEDIUM"} cfg.PackageFile = "/tmp/package.txt" + cfg.UseTLS = true + cfg.TLSCert = "/tmp/tls.crt" + cfg.TLSKey = "/tmp/tls.key" output := prettyfyCfg(cfg) expectedStrings := []string{ @@ -152,6 +246,9 @@ func TestPrettyfyCfg(t *testing.T) { " Severity Levels: [HIGH MEDIUM]", " Port: 9090", " Package file: /tmp/package.txt", + " Use TLS: true", + " TLS Certificate: /tmp/tls.crt", + " TLS Key: /tmp/tls.key", } for _, expected := range expectedStrings { @@ -160,3 +257,28 @@ func TestPrettyfyCfg(t *testing.T) { } } } + +func TestLoad(t *testing.T) { + // Ensure no flags are set during the test + os.Args = []string{"cmd"} + cfg, err := Load() + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if cfg.NVDFeedURL != defaultNVDFeedURL { + t.Errorf("Expected NVDFeedURL %s, got %s", defaultNVDFeedURL, cfg.NVDFeedURL) + } + if cfg.UpdateInterval != defaultUpdateInterval { + t.Errorf("Expected UpdateInterval %v, got %v", defaultUpdateInterval, cfg.UpdateInterval) + } + if cfg.Port != defaultPort { + t.Errorf("Expected Port %d, got %d", defaultPort, cfg.Port) + } + if len(cfg.Severity) != 1 || cfg.Severity[0] != defaultSeverity { + t.Errorf("Expected Severity %v, got %v", []string{defaultSeverity}, cfg.Severity) + } + if cfg.PackageFile != "" { + t.Errorf("Expected PackageFile to be empty, got %s", cfg.PackageFile) + } +}