diff --git a/cmd/prometheus-cve-exporter/main.go b/cmd/prometheus-cve-exporter/main.go
index 2100bc8..ddbb6f9 100644
--- a/cmd/prometheus-cve-exporter/main.go
+++ b/cmd/prometheus-cve-exporter/main.go
@@ -1,9 +1,15 @@
package main
import (
+ "context"
+ "crypto/tls"
"fmt"
"log"
"net/http"
+ "os"
+ "os/signal"
+ "syscall"
+ "time"
"zops.top/prometheus-cve-exporter/config"
"zops.top/prometheus-cve-exporter/internal/exporter"
@@ -12,23 +18,106 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
)
-func main() {
- cfg, err := config.Load()
- if err != nil {
- log.Fatalf("Failed to load configuration: %v", err)
- }
+type UpdateMetricsFunc func(*config.Config)
- go exporter.UpdateMetrics(cfg)
- startServer(cfg)
+type Server struct {
+ cfg *config.Config
+ logger *log.Logger
+ mux *http.ServeMux
+ server *http.Server
+ updateMetrics UpdateMetricsFunc
}
-func startServer(cfg *config.Config) {
- http.Handle("/metrics", promhttp.HandlerFor(
+func NewServer(cfg *config.Config, logger *log.Logger, updateMetrics UpdateMetricsFunc) *Server {
+ return &Server{
+ cfg: cfg,
+ logger: logger,
+ mux: http.NewServeMux(),
+ updateMetrics: updateMetrics,
+ }
+}
+
+func (s *Server) SetupRouter() {
+ s.mux.Handle("/metrics", promhttp.HandlerFor(
prometheus.DefaultGatherer,
promhttp.HandlerOpts{
EnableOpenMetrics: true,
},
))
- fmt.Printf("Starting server on :%d\n", cfg.Port)
- log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", cfg.Port), nil))
+
+ s.mux.HandleFunc("/", s.homeHandler)
+}
+
+func (s *Server) homeHandler(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/" {
+ http.NotFound(w, r)
+ return
+ }
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ fmt.Fprint(w, `Go to metrics`)
+}
+
+func (s *Server) Start() {
+ s.server = &http.Server{
+ Addr: fmt.Sprintf(":%d", s.cfg.Port),
+ Handler: s.mux,
+ ReadTimeout: 5 * time.Second,
+ WriteTimeout: 10 * time.Second,
+ IdleTimeout: 120 * time.Second,
+ }
+
+ if s.cfg.UseTLS {
+ s.server.TLSConfig = &tls.Config{
+ MinVersion: tls.VersionTLS12,
+ PreferServerCipherSuites: true,
+ }
+ }
+
+ go func() {
+ var err error
+ s.logger.Printf("Starting server on :%d\n", s.cfg.Port)
+ if s.cfg.UseTLS {
+ s.logger.Println("TLS enabled")
+ err = s.server.ListenAndServeTLS(s.cfg.TLSCert, s.cfg.TLSKey)
+ } else {
+ s.logger.Println("TLS disabled")
+ err = s.server.ListenAndServe()
+ }
+ if err != nil && err != http.ErrServerClosed {
+ s.logger.Fatalf("Could not listen on %d: %v\n", s.cfg.Port, err)
+ }
+ }()
+}
+
+func (s *Server) GracefulShutdown() {
+ quit := make(chan os.Signal, 1)
+ signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
+ sig := <-quit
+ s.logger.Printf("Received signal: %v. Initiating shutdown...", sig)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ s.server.SetKeepAlivesEnabled(false)
+ if err := s.server.Shutdown(ctx); err != nil {
+ s.logger.Fatalf("Server forced to shutdown: %v", err)
+ }
+
+ s.logger.Println("Server exiting")
+}
+
+func main() {
+ logger := log.New(os.Stdout, "", log.LstdFlags)
+
+ cfg, err := config.Load()
+ if err != nil {
+ logger.Fatalf("Failed to load configuration: %v", err)
+ }
+
+ server := NewServer(cfg, logger, exporter.UpdateMetrics)
+ go server.updateMetrics(cfg)
+
+ server.SetupRouter()
+ server.Start()
+ server.GracefulShutdown()
}
diff --git a/cmd/prometheus-cve-exporter/main_test.go b/cmd/prometheus-cve-exporter/main_test.go
new file mode 100644
index 0000000..310be37
--- /dev/null
+++ b/cmd/prometheus-cve-exporter/main_test.go
@@ -0,0 +1,209 @@
+package main
+
+import (
+ "context"
+ "fmt"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "sync"
+ "syscall"
+ "testing"
+ "time"
+
+ "zops.top/prometheus-cve-exporter/config"
+)
+
+func mockUpdateMetrics(*config.Config) {}
+
+func TestNewServer(t *testing.T) {
+ cfg := &config.Config{}
+ logger := log.New(os.Stdout, "", log.LstdFlags)
+ server := NewServer(cfg, logger, mockUpdateMetrics)
+
+ if server.cfg != cfg {
+ t.Errorf("Expected cfg to be %v, got %v", cfg, server.cfg)
+ }
+ if server.logger != logger {
+ t.Errorf("Expected logger to be %v, got %v", logger, server.logger)
+ }
+ if server.mux == nil {
+ t.Error("Expected mux to be initialized")
+ }
+ if server.updateMetrics == nil {
+ t.Error("Expected updateMetrics to be initialized")
+ }
+}
+
+func TestSetupRouter(t *testing.T) {
+ server := NewServer(&config.Config{}, log.New(os.Stdout, "", log.LstdFlags), mockUpdateMetrics)
+ server.SetupRouter()
+
+ testCases := []struct {
+ path string
+ expectedCode int
+ }{
+ {"/metrics", http.StatusOK},
+ {"/", http.StatusOK},
+ {"/nonexistent", http.StatusNotFound},
+ }
+
+ for _, tc := range testCases {
+ req, err := http.NewRequest("GET", tc.path, nil)
+ if err != nil {
+ t.Fatalf("Could not create request: %v", err)
+ }
+
+ rr := httptest.NewRecorder()
+ server.mux.ServeHTTP(rr, req)
+
+ if rr.Code != tc.expectedCode {
+ t.Errorf("handler returned wrong status code for %s: got %v want %v",
+ tc.path, rr.Code, tc.expectedCode)
+ }
+ }
+}
+
+func TestHomeHandler(t *testing.T) {
+ server := NewServer(&config.Config{}, log.New(os.Stdout, "", log.LstdFlags), mockUpdateMetrics)
+
+ req, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ rr := httptest.NewRecorder()
+ handler := http.HandlerFunc(server.homeHandler)
+ handler.ServeHTTP(rr, req)
+
+ if status := rr.Code; status != http.StatusOK {
+ t.Errorf("handler returned wrong status code: got %v want %v",
+ status, http.StatusOK)
+ }
+
+ expected := `Go to metrics`
+ if rr.Body.String() != expected {
+ t.Errorf("handler returned unexpected body: got %v want %v",
+ rr.Body.String(), expected)
+ }
+}
+
+func TestStart(t *testing.T) {
+ cfg := &config.Config{Port: 20000}
+ logger := log.New(os.Stdout, "", log.LstdFlags)
+ s := NewServer(cfg, logger, mockUpdateMetrics)
+ s.SetupRouter()
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ s.Start()
+ }()
+
+ // Give some time for the server to start
+ time.Sleep(100 * time.Millisecond)
+
+ resp, err := http.Get(fmt.Sprintf("http://localhost:%d", cfg.Port))
+ if err != nil {
+ t.Fatalf("Could not send GET request: %v", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ t.Errorf("Expected status OK; got %v", resp.Status)
+ }
+
+ // Shutdown the server
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ if err := s.server.Shutdown(ctx); err != nil {
+ s.logger.Fatalf("Server forced to shutdown: %v", err)
+ }
+
+ wg.Wait()
+}
+
+func TestUpdateMetricsExecution(t *testing.T) {
+ updateMetricsCalled := false
+ mockUpdateMetrics := func(*config.Config) {
+ updateMetricsCalled = true
+ }
+
+ cfg := &config.Config{}
+ logger := log.New(os.Stdout, "", log.LstdFlags)
+ server := NewServer(cfg, logger, mockUpdateMetrics)
+
+ server.updateMetrics(cfg)
+
+ if !updateMetricsCalled {
+ t.Error("Expected UpdateMetrics to be called")
+ }
+}
+
+func TestMainIntegration(t *testing.T) {
+ // Backup original os.Args
+ oldArgs := os.Args
+ defer func() { os.Args = oldArgs }()
+
+ // Set up a test config file
+ testConfigPath := "test_config.json"
+ testPort := 20001
+ testConfigContent := []byte(fmt.Sprintf(`{
+ "nvd_feed_url": "https://test.nvd.feed.url",
+ "update_interval": "2h",
+ "port": %d,
+ "severity": ["HIGH", "CRITICAL"],
+ "package_file": "",
+ "use_tls": false
+ }`, testPort))
+ err := os.WriteFile(testConfigPath, testConfigContent, 0644)
+ if err != nil {
+ t.Fatalf("Failed to create test config file: %v", err)
+ }
+ defer os.Remove(testConfigPath)
+
+ // Set the command-line argument to use our test config
+ os.Args = []string{"cmd", "-config", testConfigPath}
+
+ // Run main in a goroutine
+ go func() {
+ main()
+ }()
+
+ // Give some time for the server to start
+ time.Sleep(100 * time.Millisecond)
+
+ // Test if the server is running
+ resp, err := http.Get(fmt.Sprintf("http://localhost:%d", testPort))
+ if err != nil {
+ t.Fatalf("Could not send GET request: %v", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ t.Errorf("Expected status OK; got %v", resp.Status)
+ }
+
+ // Send shutdown signal
+ err = syscall.Kill(syscall.Getpid(), syscall.SIGINT)
+ if err != nil {
+ t.Fatalf("Failed to send SIGINT signal: %v", err)
+ }
+
+ // Give some time for the server to shut down
+ time.Sleep(500 * time.Millisecond)
+
+ // Verify that the server has shut down
+ _, err = http.Get(fmt.Sprintf("http://localhost:%d", testPort))
+ if err == nil {
+ t.Error("Expected an error when connecting to a shutdown server")
+ }
+}
+
+func TestMain(m *testing.M) {
+ // Run tests
+ code := m.Run()
+ os.Exit(code)
+}
diff --git a/config.json b/config.json
index f7218b9..a8f1a3a 100644
--- a/config.json
+++ b/config.json
@@ -8,5 +8,7 @@
"MEDIUM",
"HIGH",
"CRITICAL"
- ]
+ ],
+ "tls_cert": "server.crt",
+ "tls_key": "server.key"
}
diff --git a/config/config.go b/config/config.go
index 6f454ab..b22e9e9 100644
--- a/config/config.go
+++ b/config/config.go
@@ -15,6 +15,9 @@ const (
defaultPort = 10250
defaultSeverity = "CRITICAL"
defaultPackageFile = ""
+ defaultUseTLS = false
+ defaultTLSCert = ""
+ defaultTLSKey = ""
)
type Config struct {
@@ -23,6 +26,9 @@ type Config struct {
Port int `json:"port"`
Severity []string `json:"severity"`
PackageFile string `json:"package_file,omitempty"`
+ UseTLS bool `json:"use_tls,omitempty"`
+ TLSCert string `json:"tls_cert,omitempty"`
+ TLSKey string `json:"tls_key,omitempty"`
}
type configHelper struct {
@@ -31,6 +37,8 @@ type configHelper struct {
Port int `json:"port"`
Severity []string `json:"severity"`
PackageFile string `json:"package_file,omitempty"`
+ TLSCert string `json:"tls_cert,omitempty"`
+ TLSKey string `json:"tls_key,omitempty"`
}
func NewConfig() *Config {
@@ -40,6 +48,9 @@ func NewConfig() *Config {
Port: defaultPort,
Severity: []string{defaultSeverity},
PackageFile: defaultPackageFile,
+ UseTLS: defaultUseTLS,
+ TLSCert: defaultTLSCert,
+ TLSKey: defaultTLSKey,
}
}
@@ -59,6 +70,10 @@ func Load() (*Config, error) {
return nil, err
}
+ if cfg.TLSCert != defaultTLSCert && cfg.TLSKey != defaultTLSKey {
+ cfg.UseTLS = true
+ }
+
fmt.Print(prettyfyCfg(cfg))
return cfg, nil
}
@@ -71,6 +86,8 @@ func parseFlags(cfg *Config) string {
flag.DurationVar(&cfg.UpdateInterval, "update-interval", defaultUpdateInterval, "Update interval duration")
flag.IntVar(&cfg.Port, "port", defaultPort, "Port to run the server on")
flag.StringVar(&cfg.PackageFile, "package-file", defaultPackageFile, "Path to file containing packages and versions")
+ flag.StringVar(&cfg.TLSCert, "tls-cert", defaultTLSCert, "Path to TLS certificate file")
+ flag.StringVar(&cfg.TLSKey, "tls-key", defaultTLSKey, "Path to TLS key file")
var severity string
flag.StringVar(&severity, "severity", defaultSeverity, "Comma separated list of severity levels for vulnerabilities")
@@ -95,6 +112,8 @@ func loadConfigFile(cfg *Config, filename string) error {
cfg.Port = helper.Port
cfg.Severity = toUppercase(helper.Severity)
cfg.PackageFile = helper.PackageFile
+ cfg.TLSCert = helper.TLSCert
+ cfg.TLSKey = helper.TLSKey
duration, err := time.ParseDuration(helper.UpdateInterval)
if err != nil {
@@ -124,6 +143,8 @@ func overrideWithEnv(cfg *Config) {
},
"PCE_SEVERITY": func(value string) { cfg.Severity = parseSeverity(value) },
"PCE_PACKAGE_FILE": func(value string) { cfg.PackageFile = value },
+ "PCE_TLS_CERT": func(value string) { cfg.TLSCert = value },
+ "PCE_TLS_KEY": func(value string) { cfg.TLSKey = value },
}
for envVar, action := range envVars {
@@ -139,6 +160,14 @@ func validateConfig(cfg *Config) error {
return fmt.Errorf("the file %s does not exist", cfg.PackageFile)
}
}
+ if cfg.TLSCert != defaultTLSCert || cfg.TLSKey != defaultTLSKey {
+ if _, err := os.Stat(cfg.TLSCert); os.IsNotExist(err) {
+ return fmt.Errorf("the TLS certificate file %s does not exist", cfg.TLSCert)
+ }
+ if _, err := os.Stat(cfg.TLSKey); os.IsNotExist(err) {
+ return fmt.Errorf("the TLS key file %s does not exist", cfg.TLSKey)
+ }
+ }
return nil
}
@@ -162,6 +191,15 @@ func parseIntEnv(value string) (int, error) {
return result, err
}
+func parseBoolEnv(value string) bool {
+ switch value {
+ case "true":
+ return true
+ default:
+ return false
+ }
+}
+
func prettyfyCfg(cfg *Config) string {
var output strings.Builder
@@ -173,6 +211,11 @@ func prettyfyCfg(cfg *Config) string {
if cfg.PackageFile != "" {
output.WriteString(fmt.Sprintf(" Package file: %s\n", cfg.PackageFile))
}
+ output.WriteString(fmt.Sprintf(" Use TLS: %v\n", cfg.UseTLS))
+ if cfg.UseTLS {
+ output.WriteString(fmt.Sprintf(" TLS Certificate: %s\n", cfg.TLSCert))
+ output.WriteString(fmt.Sprintf(" TLS Key: %s\n", cfg.TLSKey))
+ }
return output.String()
}
diff --git a/config/config_test.go b/config/config_test.go
index 1ae9124..f40baca 100644
--- a/config/config_test.go
+++ b/config/config_test.go
@@ -34,7 +34,7 @@ func TestLoadConfigFile(t *testing.T) {
"nvd_feed_url": "https://example.com/feed.json",
"update_interval": "48h",
"port": 9090,
- "severity": ["HIGH", "MEDIUM"],
+ "severity": ["HIGH", "MEDIUM", "low"],
"package_file": "/tmp/package.txt"
}`
@@ -57,7 +57,7 @@ func TestLoadConfigFile(t *testing.T) {
if cfg.Port != 9090 {
t.Errorf("Expected Port 9090, got %d", cfg.Port)
}
- expectedSeverity := []string{"HIGH", "MEDIUM"}
+ expectedSeverity := []string{"HIGH", "MEDIUM", "LOW"}
for i, s := range expectedSeverity {
if cfg.Severity[i] != s {
t.Errorf("Expected Severity %v, got %v", expectedSeverity, cfg.Severity)
@@ -74,14 +74,18 @@ func TestOverrideWithEnv(t *testing.T) {
os.Setenv("PCE_NVD_JSON_GZ_FEED_URL", "https://env.com/feed.json")
os.Setenv("PCE_UPDATE_INTERVAL", "72h")
os.Setenv("PCE_PORT", "8080")
- os.Setenv("PCE_SEVERITY", "LOW,INFO")
+ os.Setenv("PCE_SEVERITY", "LOW,INFO,critical")
os.Setenv("PCE_PACKAGE_FILE", "/env/package.txt")
+ os.Setenv("PCE_TLS_CERT", "/env/tls.crt")
+ os.Setenv("PCE_TLS_KEY", "/env/tls.key")
defer func() {
os.Unsetenv("PCE_NVD_JSON_GZ_FEED_URL")
os.Unsetenv("PCE_UPDATE_INTERVAL")
os.Unsetenv("PCE_PORT")
os.Unsetenv("PCE_SEVERITY")
os.Unsetenv("PCE_PACKAGE_FILE")
+ os.Unsetenv("PCE_TLS_CERT")
+ os.Unsetenv("PCE_TLS_KEY")
}()
overrideWithEnv(cfg)
@@ -95,7 +99,7 @@ func TestOverrideWithEnv(t *testing.T) {
if cfg.Port != 8080 {
t.Errorf("Expected Port 8080, got %d", cfg.Port)
}
- expectedSeverity := []string{"LOW", "INFO"}
+ expectedSeverity := []string{"LOW", "INFO", "CRITICAL"}
for i, s := range expectedSeverity {
if cfg.Severity[i] != s {
t.Errorf("Expected Severity %v, got %v", expectedSeverity, cfg.Severity)
@@ -104,10 +108,16 @@ func TestOverrideWithEnv(t *testing.T) {
if cfg.PackageFile != "/env/package.txt" {
t.Errorf("Expected PackageFile /env/package.txt, got %s", cfg.PackageFile)
}
+ if cfg.TLSCert != "/env/tls.crt" {
+ t.Errorf("Expected TLSCert /env/tls.crt, got %s", cfg.TLSCert)
+ }
+ if cfg.TLSKey != "/env/tls.key" {
+ t.Errorf("Expected TLSKey /env/tls.key, got %s", cfg.TLSKey)
+ }
}
func TestParseSeverity(t *testing.T) {
- severity := "CRITICAL,HIGH,MEDIUM,LOW"
+ severity := "CRITICAL,HIGH,medium,LOW"
expected := []string{"CRITICAL", "HIGH", "MEDIUM", "LOW"}
result := parseSeverity(severity)
@@ -136,6 +146,87 @@ func TestParseIntEnv(t *testing.T) {
}
}
+func TestParseBoolEnv(t *testing.T) {
+ tests := []struct {
+ name string
+ value string
+ expected bool
+ expectErr bool
+ }{
+ {
+ name: "valid true value",
+ value: "true",
+ expected: true,
+ expectErr: false,
+ },
+ {
+ name: "valid false value",
+ value: "false",
+ expected: false,
+ expectErr: false,
+ },
+ {
+ name: "invalid value",
+ value: "invalid",
+ expected: false,
+ expectErr: false,
+ },
+ {
+ name: "numeric value",
+ value: "123",
+ expected: false,
+ expectErr: false,
+ },
+ {
+ name: "yes value",
+ value: "yes",
+ expected: false,
+ expectErr: false,
+ },
+ {
+ name: "no value",
+ value: "no",
+ expected: false,
+ expectErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := parseBoolEnv(tt.value)
+
+ if result != tt.expected {
+ t.Errorf("parseBoolEnv(%q) = %v; want %v", tt.value, result, tt.expected)
+ }
+ })
+ }
+}
+
+func TestValidateConfig(t *testing.T) {
+ cfg := NewConfig()
+
+ cfg.PackageFile = "/nonexistent/file"
+ err := validateConfig(cfg)
+ if err == nil {
+ t.Errorf("Expected error for nonexistent package file, got none")
+ }
+
+ cfg.PackageFile = ""
+ cfg.TLSCert = "/nonexistent/tls.crt"
+ cfg.TLSKey = "/nonexistent/tls.key"
+ err = validateConfig(cfg)
+ if err == nil {
+ t.Errorf("Expected error for missing TLS cert and key, got none")
+ }
+
+ cfg.TLSCert = ""
+ cfg.TLSKey = ""
+ err = validateConfig(cfg)
+ if err != nil {
+ t.Errorf("Expected no error for valid config, got %v", err)
+ }
+}
+
func TestPrettyfyCfg(t *testing.T) {
cfg := NewConfig()
cfg.NVDFeedURL = "https://example.com/feed.json"
@@ -143,6 +234,9 @@ func TestPrettyfyCfg(t *testing.T) {
cfg.Port = 9090
cfg.Severity = []string{"HIGH", "MEDIUM"}
cfg.PackageFile = "/tmp/package.txt"
+ cfg.UseTLS = true
+ cfg.TLSCert = "/tmp/tls.crt"
+ cfg.TLSKey = "/tmp/tls.key"
output := prettyfyCfg(cfg)
expectedStrings := []string{
@@ -152,6 +246,9 @@ func TestPrettyfyCfg(t *testing.T) {
" Severity Levels: [HIGH MEDIUM]",
" Port: 9090",
" Package file: /tmp/package.txt",
+ " Use TLS: true",
+ " TLS Certificate: /tmp/tls.crt",
+ " TLS Key: /tmp/tls.key",
}
for _, expected := range expectedStrings {
@@ -160,3 +257,28 @@ func TestPrettyfyCfg(t *testing.T) {
}
}
}
+
+func TestLoad(t *testing.T) {
+ // Ensure no flags are set during the test
+ os.Args = []string{"cmd"}
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
+
+ if cfg.NVDFeedURL != defaultNVDFeedURL {
+ t.Errorf("Expected NVDFeedURL %s, got %s", defaultNVDFeedURL, cfg.NVDFeedURL)
+ }
+ if cfg.UpdateInterval != defaultUpdateInterval {
+ t.Errorf("Expected UpdateInterval %v, got %v", defaultUpdateInterval, cfg.UpdateInterval)
+ }
+ if cfg.Port != defaultPort {
+ t.Errorf("Expected Port %d, got %d", defaultPort, cfg.Port)
+ }
+ if len(cfg.Severity) != 1 || cfg.Severity[0] != defaultSeverity {
+ t.Errorf("Expected Severity %v, got %v", []string{defaultSeverity}, cfg.Severity)
+ }
+ if cfg.PackageFile != "" {
+ t.Errorf("Expected PackageFile to be empty, got %s", cfg.PackageFile)
+ }
+}