Skip to content

Commit 78e1bdd

Browse files
committed
server: allow TLS
1 parent eb97b56 commit 78e1bdd

File tree

2 files changed

+309
-11
lines changed

2 files changed

+309
-11
lines changed

cmd/prometheus-cve-exporter/main.go

Lines changed: 100 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
package main
22

33
import (
4+
"context"
5+
"crypto/tls"
46
"fmt"
57
"log"
68
"net/http"
9+
"os"
10+
"os/signal"
11+
"syscall"
12+
"time"
713

814
"zops.top/prometheus-cve-exporter/config"
915
"zops.top/prometheus-cve-exporter/internal/exporter"
@@ -12,23 +18,106 @@ import (
1218
"github.com/prometheus/client_golang/prometheus/promhttp"
1319
)
1420

15-
func main() {
16-
cfg, err := config.Load()
17-
if err != nil {
18-
log.Fatalf("Failed to load configuration: %v", err)
19-
}
21+
type UpdateMetricsFunc func(*config.Config)
2022

21-
go exporter.UpdateMetrics(cfg)
22-
startServer(cfg)
23+
type Server struct {
24+
cfg *config.Config
25+
logger *log.Logger
26+
mux *http.ServeMux
27+
server *http.Server
28+
updateMetrics UpdateMetricsFunc
2329
}
2430

25-
func startServer(cfg *config.Config) {
26-
http.Handle("/metrics", promhttp.HandlerFor(
31+
func NewServer(cfg *config.Config, logger *log.Logger, updateMetrics UpdateMetricsFunc) *Server {
32+
return &Server{
33+
cfg: cfg,
34+
logger: logger,
35+
mux: http.NewServeMux(),
36+
updateMetrics: updateMetrics,
37+
}
38+
}
39+
40+
func (s *Server) SetupRouter() {
41+
s.mux.Handle("/metrics", promhttp.HandlerFor(
2742
prometheus.DefaultGatherer,
2843
promhttp.HandlerOpts{
2944
EnableOpenMetrics: true,
3045
},
3146
))
32-
fmt.Printf("Starting server on :%d\n", cfg.Port)
33-
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", cfg.Port), nil))
47+
48+
s.mux.HandleFunc("/", s.homeHandler)
49+
}
50+
51+
func (s *Server) homeHandler(w http.ResponseWriter, r *http.Request) {
52+
if r.URL.Path != "/" {
53+
http.NotFound(w, r)
54+
return
55+
}
56+
w.Header().Set("Content-Type", "text/html; charset=utf-8")
57+
fmt.Fprint(w, `<a href="/metrics">Go to metrics</a>`)
58+
}
59+
60+
func (s *Server) Start() {
61+
s.server = &http.Server{
62+
Addr: fmt.Sprintf(":%d", s.cfg.Port),
63+
Handler: s.mux,
64+
ReadTimeout: 5 * time.Second,
65+
WriteTimeout: 10 * time.Second,
66+
IdleTimeout: 120 * time.Second,
67+
}
68+
69+
if s.cfg.UseTLS {
70+
s.server.TLSConfig = &tls.Config{
71+
MinVersion: tls.VersionTLS12,
72+
PreferServerCipherSuites: true,
73+
}
74+
}
75+
76+
go func() {
77+
var err error
78+
s.logger.Printf("Starting server on :%d\n", s.cfg.Port)
79+
if s.cfg.UseTLS {
80+
s.logger.Println("TLS enabled")
81+
err = s.server.ListenAndServeTLS(s.cfg.TLSCert, s.cfg.TLSKey)
82+
} else {
83+
s.logger.Println("TLS disabled")
84+
err = s.server.ListenAndServe()
85+
}
86+
if err != nil && err != http.ErrServerClosed {
87+
s.logger.Fatalf("Could not listen on %d: %v\n", s.cfg.Port, err)
88+
}
89+
}()
90+
}
91+
92+
func (s *Server) GracefulShutdown() {
93+
quit := make(chan os.Signal, 1)
94+
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
95+
sig := <-quit
96+
s.logger.Printf("Received signal: %v. Initiating shutdown...", sig)
97+
98+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
99+
defer cancel()
100+
101+
s.server.SetKeepAlivesEnabled(false)
102+
if err := s.server.Shutdown(ctx); err != nil {
103+
s.logger.Fatalf("Server forced to shutdown: %v", err)
104+
}
105+
106+
s.logger.Println("Server exiting")
107+
}
108+
109+
func main() {
110+
logger := log.New(os.Stdout, "", log.LstdFlags)
111+
112+
cfg, err := config.Load()
113+
if err != nil {
114+
logger.Fatalf("Failed to load configuration: %v", err)
115+
}
116+
117+
server := NewServer(cfg, logger, exporter.UpdateMetrics)
118+
go server.updateMetrics(cfg)
119+
120+
server.SetupRouter()
121+
server.Start()
122+
server.GracefulShutdown()
34123
}
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log"
7+
"net/http"
8+
"net/http/httptest"
9+
"os"
10+
"sync"
11+
"syscall"
12+
"testing"
13+
"time"
14+
15+
"zops.top/prometheus-cve-exporter/config"
16+
)
17+
18+
func mockUpdateMetrics(*config.Config) {}
19+
20+
func TestNewServer(t *testing.T) {
21+
cfg := &config.Config{}
22+
logger := log.New(os.Stdout, "", log.LstdFlags)
23+
server := NewServer(cfg, logger, mockUpdateMetrics)
24+
25+
if server.cfg != cfg {
26+
t.Errorf("Expected cfg to be %v, got %v", cfg, server.cfg)
27+
}
28+
if server.logger != logger {
29+
t.Errorf("Expected logger to be %v, got %v", logger, server.logger)
30+
}
31+
if server.mux == nil {
32+
t.Error("Expected mux to be initialized")
33+
}
34+
if server.updateMetrics == nil {
35+
t.Error("Expected updateMetrics to be initialized")
36+
}
37+
}
38+
39+
func TestSetupRouter(t *testing.T) {
40+
server := NewServer(&config.Config{}, log.New(os.Stdout, "", log.LstdFlags), mockUpdateMetrics)
41+
server.SetupRouter()
42+
43+
testCases := []struct {
44+
path string
45+
expectedCode int
46+
}{
47+
{"/metrics", http.StatusOK},
48+
{"/", http.StatusOK},
49+
{"/nonexistent", http.StatusNotFound},
50+
}
51+
52+
for _, tc := range testCases {
53+
req, err := http.NewRequest("GET", tc.path, nil)
54+
if err != nil {
55+
t.Fatalf("Could not create request: %v", err)
56+
}
57+
58+
rr := httptest.NewRecorder()
59+
server.mux.ServeHTTP(rr, req)
60+
61+
if rr.Code != tc.expectedCode {
62+
t.Errorf("handler returned wrong status code for %s: got %v want %v",
63+
tc.path, rr.Code, tc.expectedCode)
64+
}
65+
}
66+
}
67+
68+
func TestHomeHandler(t *testing.T) {
69+
server := NewServer(&config.Config{}, log.New(os.Stdout, "", log.LstdFlags), mockUpdateMetrics)
70+
71+
req, err := http.NewRequest("GET", "/", nil)
72+
if err != nil {
73+
t.Fatal(err)
74+
}
75+
76+
rr := httptest.NewRecorder()
77+
handler := http.HandlerFunc(server.homeHandler)
78+
handler.ServeHTTP(rr, req)
79+
80+
if status := rr.Code; status != http.StatusOK {
81+
t.Errorf("handler returned wrong status code: got %v want %v",
82+
status, http.StatusOK)
83+
}
84+
85+
expected := `<a href="/metrics">Go to metrics</a>`
86+
if rr.Body.String() != expected {
87+
t.Errorf("handler returned unexpected body: got %v want %v",
88+
rr.Body.String(), expected)
89+
}
90+
}
91+
92+
func TestStart(t *testing.T) {
93+
cfg := &config.Config{Port: 20000}
94+
logger := log.New(os.Stdout, "", log.LstdFlags)
95+
s := NewServer(cfg, logger, mockUpdateMetrics)
96+
s.SetupRouter()
97+
98+
var wg sync.WaitGroup
99+
wg.Add(1)
100+
go func() {
101+
defer wg.Done()
102+
s.Start()
103+
}()
104+
105+
// Give some time for the server to start
106+
time.Sleep(100 * time.Millisecond)
107+
108+
resp, err := http.Get(fmt.Sprintf("http://localhost:%d", cfg.Port))
109+
if err != nil {
110+
t.Fatalf("Could not send GET request: %v", err)
111+
}
112+
defer resp.Body.Close()
113+
114+
if resp.StatusCode != http.StatusOK {
115+
t.Errorf("Expected status OK; got %v", resp.Status)
116+
}
117+
118+
// Shutdown the server
119+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
120+
defer cancel()
121+
if err := s.server.Shutdown(ctx); err != nil {
122+
s.logger.Fatalf("Server forced to shutdown: %v", err)
123+
}
124+
125+
wg.Wait()
126+
}
127+
128+
func TestUpdateMetricsExecution(t *testing.T) {
129+
updateMetricsCalled := false
130+
mockUpdateMetrics := func(*config.Config) {
131+
updateMetricsCalled = true
132+
}
133+
134+
cfg := &config.Config{}
135+
logger := log.New(os.Stdout, "", log.LstdFlags)
136+
server := NewServer(cfg, logger, mockUpdateMetrics)
137+
138+
server.updateMetrics(cfg)
139+
140+
if !updateMetricsCalled {
141+
t.Error("Expected UpdateMetrics to be called")
142+
}
143+
}
144+
145+
func TestMainIntegration(t *testing.T) {
146+
// Backup original os.Args
147+
oldArgs := os.Args
148+
defer func() { os.Args = oldArgs }()
149+
150+
// Set up a test config file
151+
testConfigPath := "test_config.json"
152+
testPort := 20001
153+
testConfigContent := []byte(fmt.Sprintf(`{
154+
"nvd_feed_url": "https://test.nvd.feed.url",
155+
"update_interval": "2h",
156+
"port": %d,
157+
"severity": ["HIGH", "CRITICAL"],
158+
"package_file": "",
159+
"use_tls": false
160+
}`, testPort))
161+
err := os.WriteFile(testConfigPath, testConfigContent, 0644)
162+
if err != nil {
163+
t.Fatalf("Failed to create test config file: %v", err)
164+
}
165+
defer os.Remove(testConfigPath)
166+
167+
// Set the command-line argument to use our test config
168+
os.Args = []string{"cmd", "-config", testConfigPath}
169+
170+
// Run main in a goroutine
171+
go func() {
172+
main()
173+
}()
174+
175+
// Give some time for the server to start
176+
time.Sleep(100 * time.Millisecond)
177+
178+
// Test if the server is running
179+
resp, err := http.Get(fmt.Sprintf("http://localhost:%d", testPort))
180+
if err != nil {
181+
t.Fatalf("Could not send GET request: %v", err)
182+
}
183+
defer resp.Body.Close()
184+
185+
if resp.StatusCode != http.StatusOK {
186+
t.Errorf("Expected status OK; got %v", resp.Status)
187+
}
188+
189+
// Send shutdown signal
190+
err = syscall.Kill(syscall.Getpid(), syscall.SIGINT)
191+
if err != nil {
192+
t.Fatalf("Failed to send SIGINT signal: %v", err)
193+
}
194+
195+
// Give some time for the server to shut down
196+
time.Sleep(500 * time.Millisecond)
197+
198+
// Verify that the server has shut down
199+
_, err = http.Get(fmt.Sprintf("http://localhost:%d", testPort))
200+
if err == nil {
201+
t.Error("Expected an error when connecting to a shutdown server")
202+
}
203+
}
204+
205+
func TestMain(m *testing.M) {
206+
// Run tests
207+
code := m.Run()
208+
os.Exit(code)
209+
}

0 commit comments

Comments
 (0)