@@ -6,22 +6,27 @@ package cmd
6
6
7
7
import (
8
8
"bytes"
9
+ "context"
9
10
"encoding/json"
10
11
"fmt"
12
+ "log"
13
+ "net/http"
14
+ "os"
15
+ "time"
16
+
11
17
"github.com/google/uuid"
18
+ "github.com/shinobistack/gokakashi/internal/http/client"
12
19
"github.com/shinobistack/gokakashi/internal/restapi/v1/agents"
13
20
"github.com/shinobistack/gokakashi/internal/restapi/v1/agenttasks"
14
21
"github.com/shinobistack/gokakashi/internal/restapi/v1/integrations"
15
22
"github.com/shinobistack/gokakashi/internal/restapi/v1/scans"
16
23
"github.com/shinobistack/gokakashi/pkg/registry/v1"
17
24
"github.com/shinobistack/gokakashi/pkg/scanner/v1"
18
25
"github.com/spf13/cobra"
19
- "log"
20
- "net/http"
21
- "os"
22
- "time"
23
26
)
24
27
28
+ type httpClientKey struct {}
29
+
25
30
var agentCmd = & cobra.Command {
26
31
Use : "agent" ,
27
32
Short : "Manage agents for GoKakashi" ,
@@ -30,7 +35,31 @@ var agentCmd = &cobra.Command{
30
35
var agentStartCmd = & cobra.Command {
31
36
Use : "start" ,
32
37
Short : "Register an agent and start polling for tasks" ,
33
- Run : agentRegister ,
38
+ PersistentPreRun : func (cmd * cobra.Command , args []string ) {
39
+ if token == "" {
40
+ log .Fatalf ("Error: missing required flag --token" )
41
+ }
42
+ headers := make (map [string ]string )
43
+ cfClientID := os .Getenv ("CF_ACCESS_CLIENT_ID" )
44
+ cfClientSecret := os .Getenv ("CF_ACCESS_CLIENT_SECRET" )
45
+ if cfClientID != "" && cfClientSecret != "" {
46
+ headers ["CF-Access-Client-Id" ] = cfClientID
47
+ headers ["CF-Access-Client-Secret" ] = cfClientSecret
48
+ } else if cfClientSecret != "" {
49
+ fmt .Println ("Warning: ignoring CF_ACCESS_CLIENT_SECRET because CF_ACCESS_CLIENT_ID is not set" )
50
+ } else if cfClientID != "" {
51
+ fmt .Println ("Warning: ignoring CF_ACCESS_CLIENT_ID because CF_ACCESS_CLIENT_SECRET is not set" )
52
+ }
53
+
54
+ httpClient := client .New (
55
+ client .WithToken (token ),
56
+ client .WithHeaders (headers ),
57
+ )
58
+
59
+ ctx := context .WithValue (context .Background (), httpClientKey {}, httpClient )
60
+ cmd .SetContext (ctx )
61
+ },
62
+ Run : agentRegister ,
34
63
}
35
64
36
65
var (
@@ -51,18 +80,18 @@ func agentRegister(cmd *cobra.Command, args []string) {
51
80
// log.Printf("Server: %s, Token: %s, Workspace: %s", server, token, workspace)
52
81
53
82
// Register the agent
54
- agentID , err := registerAgent (server , token , workspace , name )
83
+ agentID , err := registerAgent (cmd . Context (), server , token , workspace , name )
55
84
if err != nil {
56
85
log .Fatalf ("Failed to register the agent: %v" , err )
57
86
}
58
87
59
88
log .Printf ("Agent registered successfully! Agent ID: %d" , agentID )
60
89
61
90
// Start polling for tasks
62
- pollTasks (server , token , agentID , workspace )
91
+ pollTasks (cmd . Context (), server , token , agentID , workspace )
63
92
}
64
93
65
- func registerAgent (server , token , workspace , name string ) (int , error ) {
94
+ func registerAgent (ctx context. Context , server , token , workspace , name string ) (int , error ) {
66
95
reqBody := agents.RegisterAgentRequest {
67
96
Server : server ,
68
97
Token : token ,
@@ -75,10 +104,9 @@ func registerAgent(server, token, workspace, name string) (int, error) {
75
104
if err != nil {
76
105
return 0 , fmt .Errorf ("failed to create registration request: %w" , err )
77
106
}
78
- req .Header .Set ("Authorization" , "Bearer " + token )
79
107
req .Header .Set ("Content-Type" , "application/json" )
80
108
81
- resp , err := http . DefaultClient .Do (req )
109
+ resp , err := ctx . Value ( httpClientKey {}).( * client. Client ) .Do (req )
82
110
if err != nil {
83
111
return 0 , fmt .Errorf ("failed to send registration request: %w" , err )
84
112
}
@@ -95,10 +123,10 @@ func registerAgent(server, token, workspace, name string) (int, error) {
95
123
return response .ID , nil
96
124
}
97
125
98
- func pollTasks (server , token string , agentID int , workspace string ) {
126
+ func pollTasks (ctx context. Context , server , token string , agentID int , workspace string ) {
99
127
for {
100
128
// Process only tasks with status "pending" in the order returned (created_at ASC)
101
- tasks , err := fetchTasks (server , token , agentID , "pending" )
129
+ tasks , err := fetchTasks (ctx , server , token , agentID , "pending" )
102
130
if err != nil {
103
131
log .Printf ("Error fetching tasks: %v" , err )
104
132
time .Sleep (10 * time .Second )
@@ -113,12 +141,12 @@ func pollTasks(server, token string, agentID int, workspace string) {
113
141
114
142
for _ , task := range tasks {
115
143
// Update task status to "in_progress"
116
- err := updateAgentTaskStatus (server , token , task .ID , agentID , "in_progress" )
144
+ err := updateAgentTaskStatus (ctx , server , token , task .ID , agentID , "in_progress" )
117
145
if err != nil {
118
146
log .Printf ("Failed to update agent_task status to 'in_progress': %v" , err )
119
147
return
120
148
}
121
- processTask (server , token , task , workspace , agentID )
149
+ processTask (ctx , server , token , task , workspace , agentID )
122
150
continue
123
151
}
124
152
// Todo: Polling interval time decide
@@ -127,15 +155,13 @@ func pollTasks(server, token string, agentID int, workspace string) {
127
155
}
128
156
}
129
157
130
- func fetchTasks (server , token string , agentID int , status string ) ([]agenttasks.GetAgentTaskResponse , error ) {
158
+ func fetchTasks (ctx context. Context , server , token string , agentID int , status string ) ([]agenttasks.GetAgentTaskResponse , error ) {
131
159
req , err := http .NewRequest ("GET" , fmt .Sprintf ("%s/api/v1/agents/%d/tasks?status=%s" , server , agentID , status ), nil )
132
160
if err != nil {
133
161
return nil , fmt .Errorf ("failed to create task polling request: %w" , err )
134
162
}
135
163
136
- req .Header .Set ("Authorization" , "Bearer " + token )
137
-
138
- resp , err := http .DefaultClient .Do (req )
164
+ resp , err := ctx .Value (httpClientKey {}).(* client.Client ).Do (req )
139
165
if err != nil {
140
166
return nil , fmt .Errorf ("failed to send task polling request: %w" , err )
141
167
}
@@ -153,7 +179,7 @@ func fetchTasks(server, token string, agentID int, status string) ([]agenttasks.
153
179
return tasks , nil
154
180
}
155
181
156
- func updateAgentTaskStatus (server , token string , taskID uuid.UUID , agentID int , status string ) error {
182
+ func updateAgentTaskStatus (ctx context. Context , server , token string , taskID uuid.UUID , agentID int , status string ) error {
157
183
reqBody := agenttasks.UpdateAgentTaskRequest {
158
184
ID : taskID ,
159
185
AgentID : intPtr (agentID ),
@@ -166,10 +192,9 @@ func updateAgentTaskStatus(server, token string, taskID uuid.UUID, agentID int,
166
192
if err != nil {
167
193
return fmt .Errorf ("failed to create task status update request: %w" , err )
168
194
}
169
- req .Header .Set ("Authorization" , "Bearer " + token )
170
195
req .Header .Set ("Content-Type" , "application/json" )
171
196
172
- resp , err := http . DefaultClient .Do (req )
197
+ resp , err := ctx . Value ( httpClientKey {}).( * client. Client ) .Do (req )
173
198
if err != nil {
174
199
return fmt .Errorf ("failed to update task status: %w" , err )
175
200
}
@@ -182,16 +207,16 @@ func updateAgentTaskStatus(server, token string, taskID uuid.UUID, agentID int,
182
207
return nil
183
208
}
184
209
185
- func processTask (server , token string , task agenttasks.GetAgentTaskResponse , workspace string , agentID int ) {
210
+ func processTask (ctx context. Context , server , token string , task agenttasks.GetAgentTaskResponse , workspace string , agentID int ) {
186
211
// Step 1: Fetch scan details
187
- scan , err := fetchScan (server , token , task .ScanID )
212
+ scan , err := fetchScan (ctx , server , token , task .ScanID )
188
213
if err != nil {
189
214
log .Printf ("Failed to fetch scan details: %v" , err )
190
215
return
191
216
}
192
217
193
218
// Step 2: Fetch integration details
194
- integration , err := fetchIntegration (server , token , scan .IntegrationID )
219
+ integration , err := fetchIntegration (ctx , server , token , scan .IntegrationID )
195
220
if err != nil {
196
221
log .Printf ("Failed to fetch integration details: %v" , err )
197
222
return
@@ -203,7 +228,7 @@ func processTask(server, token string, task agenttasks.GetAgentTaskResponse, wor
203
228
return
204
229
}
205
230
206
- err = updateScanStatus (server , token , scan .ID , "scan_in_progress" )
231
+ err = updateScanStatus (ctx , server , token , scan .ID , "scan_in_progress" )
207
232
if err != nil {
208
233
log .Printf ("Failed to update scan status to 'scan_in_progress': %v" , err )
209
234
}
@@ -212,16 +237,16 @@ func processTask(server, token string, task agenttasks.GetAgentTaskResponse, wor
212
237
reportPath , err := performScan (scan .Image , scan .Scanner )
213
238
if err != nil {
214
239
log .Printf ("Failed to perform scan: %v" , err )
215
- if err := updateScanStatus (server , token , scan .ID , "error" ); err != nil {
240
+ if err := updateScanStatus (ctx , server , token , scan .ID , "error" ); err != nil {
216
241
log .Printf ("Failed to update scan status to 'error': %v" , err )
217
242
}
218
243
return
219
244
}
220
245
221
246
// Step 5: Upload the scan report
222
- if err := uploadReport (server , token , scan .ID , reportPath ); err != nil {
247
+ if err := uploadReport (ctx , server , token , scan .ID , reportPath ); err != nil {
223
248
log .Printf ("Failed to upload scan report: %v" , err )
224
- if err := updateScanStatus (server , token , scan .ID , "error" ); err != nil {
249
+ if err := updateScanStatus (ctx , server , token , scan .ID , "error" ); err != nil {
225
250
log .Printf ("Failed to update scan status to 'error': %v" , err )
226
251
}
227
252
return
@@ -231,24 +256,24 @@ func processTask(server, token string, task agenttasks.GetAgentTaskResponse, wor
231
256
// Todo: if exists update the status to notify_pending else complete
232
257
if scan .Notify == nil || len (* scan .Notify ) == 0 {
233
258
log .Printf ("No notify specified for scan ID: %s" , scan .ID )
234
- if err := updateScanStatus (server , token , scan .ID , "success" ); err != nil {
259
+ if err := updateScanStatus (ctx , server , token , scan .ID , "success" ); err != nil {
235
260
log .Printf ("Failed to update scan status to 'success': %v" , err )
236
261
}
237
262
} else {
238
- err = updateScanStatus (server , token , scan .ID , "notify_pending" )
263
+ err = updateScanStatus (ctx , server , token , scan .ID , "notify_pending" )
239
264
if err != nil {
240
265
log .Printf ("Failed to update scan status to 'scan_in_progress': %v" , err )
241
266
}
242
267
}
243
268
244
- if err := updateAgentTaskStatus (server , token , task .ID , agentID , "complete" ); err != nil {
269
+ if err := updateAgentTaskStatus (ctx , server , token , task .ID , agentID , "complete" ); err != nil {
245
270
log .Printf ("Failed to update agent_task status to 'complete': %v" , err )
246
271
}
247
272
248
273
log .Printf ("AgentTaskID completed successfully: %v" , task .ID )
249
274
}
250
275
251
- func updateScanStatus (server , token string , scanID uuid.UUID , status string ) error {
276
+ func updateScanStatus (ctx context. Context , server , token string , scanID uuid.UUID , status string ) error {
252
277
reqBody := scans.UpdateScanRequest {
253
278
ID : scanID ,
254
279
Status : strPtr (status ),
@@ -259,10 +284,9 @@ func updateScanStatus(server, token string, scanID uuid.UUID, status string) err
259
284
if err != nil {
260
285
return fmt .Errorf ("failed to create scan status update request: %w" , err )
261
286
}
262
- req .Header .Set ("Authorization" , "Bearer " + token )
263
287
req .Header .Set ("Content-Type" , "application/json" )
264
288
265
- resp , err := http . DefaultClient .Do (req )
289
+ resp , err := ctx . Value ( httpClientKey {}).( * client. Client ) .Do (req )
266
290
if err != nil {
267
291
return fmt .Errorf ("failed to update scan status: %w" , err )
268
292
}
@@ -275,14 +299,13 @@ func updateScanStatus(server, token string, scanID uuid.UUID, status string) err
275
299
return nil
276
300
}
277
301
278
- func fetchScan (server , token string , scanID uuid.UUID ) (* scans.GetScanResponse , error ) {
302
+ func fetchScan (ctx context. Context , server , token string , scanID uuid.UUID ) (* scans.GetScanResponse , error ) {
279
303
req , err := http .NewRequest ("GET" , fmt .Sprintf ("%s/api/v1/scans/%s" , server , scanID ), nil )
280
304
if err != nil {
281
305
return nil , fmt .Errorf ("failed to create scan request: %w" , err )
282
306
}
283
- req .Header .Set ("Authorization" , "Bearer " + token )
284
307
285
- resp , err := http . DefaultClient .Do (req )
308
+ resp , err := ctx . Value ( httpClientKey {}).( * client. Client ) .Do (req )
286
309
if err != nil {
287
310
return nil , fmt .Errorf ("failed to fetch scan details: %w" , err )
288
311
}
@@ -300,14 +323,13 @@ func fetchScan(server, token string, scanID uuid.UUID) (*scans.GetScanResponse,
300
323
return & scan , nil
301
324
}
302
325
303
- func fetchIntegration (server , token string , integrationID uuid.UUID ) (* integrations.GetIntegrationResponse , error ) {
326
+ func fetchIntegration (ctx context. Context , server , token string , integrationID uuid.UUID ) (* integrations.GetIntegrationResponse , error ) {
304
327
req , err := http .NewRequest ("GET" , fmt .Sprintf ("%s/api/v1/integrations/%s" , server , integrationID ), nil )
305
328
if err != nil {
306
329
return nil , fmt .Errorf ("failed to create integration fetch request: %w" , err )
307
330
}
308
- req .Header .Set ("Authorization" , "Bearer " + token )
309
331
310
- resp , err := http . DefaultClient .Do (req )
332
+ resp , err := ctx . Value ( httpClientKey {}).( * client. Client ) .Do (req )
311
333
if err != nil {
312
334
return nil , fmt .Errorf ("failed to fetch integration details: %w" , err )
313
335
}
@@ -363,7 +385,7 @@ func performScan(image, scannerType string) (string, error) {
363
385
return reportPath , nil
364
386
}
365
387
366
- func uploadReport (server , token string , scanID uuid.UUID , reportPath string ) error {
388
+ func uploadReport (ctx context. Context , server , token string , scanID uuid.UUID , reportPath string ) error {
367
389
report , err := os .ReadFile (reportPath )
368
390
if err != nil {
369
391
return fmt .Errorf ("failed to read report file: %w" , err )
@@ -383,10 +405,9 @@ func uploadReport(server, token string, scanID uuid.UUID, reportPath string) err
383
405
return fmt .Errorf ("failed to create report upload request: %w" , err )
384
406
}
385
407
386
- req .Header .Set ("Authorization" , "Bearer " + token )
387
408
req .Header .Set ("Content-Type" , "application/json" )
388
409
389
- resp , err := http . DefaultClient .Do (req )
410
+ resp , err := ctx . Value ( httpClientKey {}).( * client. Client ) .Do (req )
390
411
if err != nil {
391
412
return fmt .Errorf ("failed to upload scan report: %w" , err )
392
413
}
0 commit comments