Skip to content

Commit 10ffe11

Browse files
populate CF Access headers if env is set (#111)
* create internal http/client package * use http client in cmd package * pass context everywhere * mod download instead of tidy * lint * Update agent.go
1 parent a066920 commit 10ffe11

File tree

5 files changed

+156
-44
lines changed

5 files changed

+156
-44
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ WORKDIR /app
2929
COPY go.mod go.sum ./
3030

3131
# Download dependencies
32-
RUN go mod tidy
32+
RUN go mod download
3333

3434
# Copy the source code
3535
COPY . .

cmd/agent.go

Lines changed: 64 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,27 @@ package cmd
66

77
import (
88
"bytes"
9+
"context"
910
"encoding/json"
1011
"fmt"
12+
"log"
13+
"net/http"
14+
"os"
15+
"time"
16+
1117
"github.com/google/uuid"
18+
"github.com/shinobistack/gokakashi/internal/http/client"
1219
"github.com/shinobistack/gokakashi/internal/restapi/v1/agents"
1320
"github.com/shinobistack/gokakashi/internal/restapi/v1/agenttasks"
1421
"github.com/shinobistack/gokakashi/internal/restapi/v1/integrations"
1522
"github.com/shinobistack/gokakashi/internal/restapi/v1/scans"
1623
"github.com/shinobistack/gokakashi/pkg/registry/v1"
1724
"github.com/shinobistack/gokakashi/pkg/scanner/v1"
1825
"github.com/spf13/cobra"
19-
"log"
20-
"net/http"
21-
"os"
22-
"time"
2326
)
2427

28+
type httpClientKey struct{}
29+
2530
var agentCmd = &cobra.Command{
2631
Use: "agent",
2732
Short: "Manage agents for GoKakashi",
@@ -30,7 +35,31 @@ var agentCmd = &cobra.Command{
3035
var agentStartCmd = &cobra.Command{
3136
Use: "start",
3237
Short: "Register an agent and start polling for tasks",
33-
Run: agentRegister,
38+
PersistentPreRun: func(cmd *cobra.Command, args []string) {
39+
if token == "" {
40+
log.Fatalf("Error: missing required flag --token")
41+
}
42+
headers := make(map[string]string)
43+
cfClientID := os.Getenv("CF_ACCESS_CLIENT_ID")
44+
cfClientSecret := os.Getenv("CF_ACCESS_CLIENT_SECRET")
45+
if cfClientID != "" && cfClientSecret != "" {
46+
headers["CF-Access-Client-Id"] = cfClientID
47+
headers["CF-Access-Client-Secret"] = cfClientSecret
48+
} else if cfClientSecret != "" {
49+
fmt.Println("Warning: ignoring CF_ACCESS_CLIENT_SECRET because CF_ACCESS_CLIENT_ID is not set")
50+
} else if cfClientID != "" {
51+
fmt.Println("Warning: ignoring CF_ACCESS_CLIENT_ID because CF_ACCESS_CLIENT_SECRET is not set")
52+
}
53+
54+
httpClient := client.New(
55+
client.WithToken(token),
56+
client.WithHeaders(headers),
57+
)
58+
59+
ctx := context.WithValue(context.Background(), httpClientKey{}, httpClient)
60+
cmd.SetContext(ctx)
61+
},
62+
Run: agentRegister,
3463
}
3564

3665
var (
@@ -51,18 +80,18 @@ func agentRegister(cmd *cobra.Command, args []string) {
5180
// log.Printf("Server: %s, Token: %s, Workspace: %s", server, token, workspace)
5281

5382
// Register the agent
54-
agentID, err := registerAgent(server, token, workspace, name)
83+
agentID, err := registerAgent(cmd.Context(), server, token, workspace, name)
5584
if err != nil {
5685
log.Fatalf("Failed to register the agent: %v", err)
5786
}
5887

5988
log.Printf("Agent registered successfully! Agent ID: %d", agentID)
6089

6190
// Start polling for tasks
62-
pollTasks(server, token, agentID, workspace)
91+
pollTasks(cmd.Context(), server, token, agentID, workspace)
6392
}
6493

65-
func registerAgent(server, token, workspace, name string) (int, error) {
94+
func registerAgent(ctx context.Context, server, token, workspace, name string) (int, error) {
6695
reqBody := agents.RegisterAgentRequest{
6796
Server: server,
6897
Token: token,
@@ -75,10 +104,9 @@ func registerAgent(server, token, workspace, name string) (int, error) {
75104
if err != nil {
76105
return 0, fmt.Errorf("failed to create registration request: %w", err)
77106
}
78-
req.Header.Set("Authorization", "Bearer "+token)
79107
req.Header.Set("Content-Type", "application/json")
80108

81-
resp, err := http.DefaultClient.Do(req)
109+
resp, err := ctx.Value(httpClientKey{}).(*client.Client).Do(req)
82110
if err != nil {
83111
return 0, fmt.Errorf("failed to send registration request: %w", err)
84112
}
@@ -95,10 +123,10 @@ func registerAgent(server, token, workspace, name string) (int, error) {
95123
return response.ID, nil
96124
}
97125

98-
func pollTasks(server, token string, agentID int, workspace string) {
126+
func pollTasks(ctx context.Context, server, token string, agentID int, workspace string) {
99127
for {
100128
// Process only tasks with status "pending" in the order returned (created_at ASC)
101-
tasks, err := fetchTasks(server, token, agentID, "pending")
129+
tasks, err := fetchTasks(ctx, server, token, agentID, "pending")
102130
if err != nil {
103131
log.Printf("Error fetching tasks: %v", err)
104132
time.Sleep(10 * time.Second)
@@ -113,12 +141,12 @@ func pollTasks(server, token string, agentID int, workspace string) {
113141

114142
for _, task := range tasks {
115143
// Update task status to "in_progress"
116-
err := updateAgentTaskStatus(server, token, task.ID, agentID, "in_progress")
144+
err := updateAgentTaskStatus(ctx, server, token, task.ID, agentID, "in_progress")
117145
if err != nil {
118146
log.Printf("Failed to update agent_task status to 'in_progress': %v", err)
119147
return
120148
}
121-
processTask(server, token, task, workspace, agentID)
149+
processTask(ctx, server, token, task, workspace, agentID)
122150
continue
123151
}
124152
// Todo: Polling interval time decide
@@ -127,15 +155,13 @@ func pollTasks(server, token string, agentID int, workspace string) {
127155
}
128156
}
129157

130-
func fetchTasks(server, token string, agentID int, status string) ([]agenttasks.GetAgentTaskResponse, error) {
158+
func fetchTasks(ctx context.Context, server, token string, agentID int, status string) ([]agenttasks.GetAgentTaskResponse, error) {
131159
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/v1/agents/%d/tasks?status=%s", server, agentID, status), nil)
132160
if err != nil {
133161
return nil, fmt.Errorf("failed to create task polling request: %w", err)
134162
}
135163

136-
req.Header.Set("Authorization", "Bearer "+token)
137-
138-
resp, err := http.DefaultClient.Do(req)
164+
resp, err := ctx.Value(httpClientKey{}).(*client.Client).Do(req)
139165
if err != nil {
140166
return nil, fmt.Errorf("failed to send task polling request: %w", err)
141167
}
@@ -153,7 +179,7 @@ func fetchTasks(server, token string, agentID int, status string) ([]agenttasks.
153179
return tasks, nil
154180
}
155181

156-
func updateAgentTaskStatus(server, token string, taskID uuid.UUID, agentID int, status string) error {
182+
func updateAgentTaskStatus(ctx context.Context, server, token string, taskID uuid.UUID, agentID int, status string) error {
157183
reqBody := agenttasks.UpdateAgentTaskRequest{
158184
ID: taskID,
159185
AgentID: intPtr(agentID),
@@ -166,10 +192,9 @@ func updateAgentTaskStatus(server, token string, taskID uuid.UUID, agentID int,
166192
if err != nil {
167193
return fmt.Errorf("failed to create task status update request: %w", err)
168194
}
169-
req.Header.Set("Authorization", "Bearer "+token)
170195
req.Header.Set("Content-Type", "application/json")
171196

172-
resp, err := http.DefaultClient.Do(req)
197+
resp, err := ctx.Value(httpClientKey{}).(*client.Client).Do(req)
173198
if err != nil {
174199
return fmt.Errorf("failed to update task status: %w", err)
175200
}
@@ -182,16 +207,16 @@ func updateAgentTaskStatus(server, token string, taskID uuid.UUID, agentID int,
182207
return nil
183208
}
184209

185-
func processTask(server, token string, task agenttasks.GetAgentTaskResponse, workspace string, agentID int) {
210+
func processTask(ctx context.Context, server, token string, task agenttasks.GetAgentTaskResponse, workspace string, agentID int) {
186211
// Step 1: Fetch scan details
187-
scan, err := fetchScan(server, token, task.ScanID)
212+
scan, err := fetchScan(ctx, server, token, task.ScanID)
188213
if err != nil {
189214
log.Printf("Failed to fetch scan details: %v", err)
190215
return
191216
}
192217

193218
// Step 2: Fetch integration details
194-
integration, err := fetchIntegration(server, token, scan.IntegrationID)
219+
integration, err := fetchIntegration(ctx, server, token, scan.IntegrationID)
195220
if err != nil {
196221
log.Printf("Failed to fetch integration details: %v", err)
197222
return
@@ -203,7 +228,7 @@ func processTask(server, token string, task agenttasks.GetAgentTaskResponse, wor
203228
return
204229
}
205230

206-
err = updateScanStatus(server, token, scan.ID, "scan_in_progress")
231+
err = updateScanStatus(ctx, server, token, scan.ID, "scan_in_progress")
207232
if err != nil {
208233
log.Printf("Failed to update scan status to 'scan_in_progress': %v", err)
209234
}
@@ -212,16 +237,16 @@ func processTask(server, token string, task agenttasks.GetAgentTaskResponse, wor
212237
reportPath, err := performScan(scan.Image, scan.Scanner)
213238
if err != nil {
214239
log.Printf("Failed to perform scan: %v", err)
215-
if err := updateScanStatus(server, token, scan.ID, "error"); err != nil {
240+
if err := updateScanStatus(ctx, server, token, scan.ID, "error"); err != nil {
216241
log.Printf("Failed to update scan status to 'error': %v", err)
217242
}
218243
return
219244
}
220245

221246
// Step 5: Upload the scan report
222-
if err := uploadReport(server, token, scan.ID, reportPath); err != nil {
247+
if err := uploadReport(ctx, server, token, scan.ID, reportPath); err != nil {
223248
log.Printf("Failed to upload scan report: %v", err)
224-
if err := updateScanStatus(server, token, scan.ID, "error"); err != nil {
249+
if err := updateScanStatus(ctx, server, token, scan.ID, "error"); err != nil {
225250
log.Printf("Failed to update scan status to 'error': %v", err)
226251
}
227252
return
@@ -231,24 +256,24 @@ func processTask(server, token string, task agenttasks.GetAgentTaskResponse, wor
231256
// Todo: if exists update the status to notify_pending else complete
232257
if scan.Notify == nil || len(*scan.Notify) == 0 {
233258
log.Printf("No notify specified for scan ID: %s", scan.ID)
234-
if err := updateScanStatus(server, token, scan.ID, "success"); err != nil {
259+
if err := updateScanStatus(ctx, server, token, scan.ID, "success"); err != nil {
235260
log.Printf("Failed to update scan status to 'success': %v", err)
236261
}
237262
} else {
238-
err = updateScanStatus(server, token, scan.ID, "notify_pending")
263+
err = updateScanStatus(ctx, server, token, scan.ID, "notify_pending")
239264
if err != nil {
240265
log.Printf("Failed to update scan status to 'scan_in_progress': %v", err)
241266
}
242267
}
243268

244-
if err := updateAgentTaskStatus(server, token, task.ID, agentID, "complete"); err != nil {
269+
if err := updateAgentTaskStatus(ctx, server, token, task.ID, agentID, "complete"); err != nil {
245270
log.Printf("Failed to update agent_task status to 'complete': %v", err)
246271
}
247272

248273
log.Printf("AgentTaskID completed successfully: %v", task.ID)
249274
}
250275

251-
func updateScanStatus(server, token string, scanID uuid.UUID, status string) error {
276+
func updateScanStatus(ctx context.Context, server, token string, scanID uuid.UUID, status string) error {
252277
reqBody := scans.UpdateScanRequest{
253278
ID: scanID,
254279
Status: strPtr(status),
@@ -259,10 +284,9 @@ func updateScanStatus(server, token string, scanID uuid.UUID, status string) err
259284
if err != nil {
260285
return fmt.Errorf("failed to create scan status update request: %w", err)
261286
}
262-
req.Header.Set("Authorization", "Bearer "+token)
263287
req.Header.Set("Content-Type", "application/json")
264288

265-
resp, err := http.DefaultClient.Do(req)
289+
resp, err := ctx.Value(httpClientKey{}).(*client.Client).Do(req)
266290
if err != nil {
267291
return fmt.Errorf("failed to update scan status: %w", err)
268292
}
@@ -275,14 +299,13 @@ func updateScanStatus(server, token string, scanID uuid.UUID, status string) err
275299
return nil
276300
}
277301

278-
func fetchScan(server, token string, scanID uuid.UUID) (*scans.GetScanResponse, error) {
302+
func fetchScan(ctx context.Context, server, token string, scanID uuid.UUID) (*scans.GetScanResponse, error) {
279303
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/v1/scans/%s", server, scanID), nil)
280304
if err != nil {
281305
return nil, fmt.Errorf("failed to create scan request: %w", err)
282306
}
283-
req.Header.Set("Authorization", "Bearer "+token)
284307

285-
resp, err := http.DefaultClient.Do(req)
308+
resp, err := ctx.Value(httpClientKey{}).(*client.Client).Do(req)
286309
if err != nil {
287310
return nil, fmt.Errorf("failed to fetch scan details: %w", err)
288311
}
@@ -300,14 +323,13 @@ func fetchScan(server, token string, scanID uuid.UUID) (*scans.GetScanResponse,
300323
return &scan, nil
301324
}
302325

303-
func fetchIntegration(server, token string, integrationID uuid.UUID) (*integrations.GetIntegrationResponse, error) {
326+
func fetchIntegration(ctx context.Context, server, token string, integrationID uuid.UUID) (*integrations.GetIntegrationResponse, error) {
304327
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/v1/integrations/%s", server, integrationID), nil)
305328
if err != nil {
306329
return nil, fmt.Errorf("failed to create integration fetch request: %w", err)
307330
}
308-
req.Header.Set("Authorization", "Bearer "+token)
309331

310-
resp, err := http.DefaultClient.Do(req)
332+
resp, err := ctx.Value(httpClientKey{}).(*client.Client).Do(req)
311333
if err != nil {
312334
return nil, fmt.Errorf("failed to fetch integration details: %w", err)
313335
}
@@ -363,7 +385,7 @@ func performScan(image, scannerType string) (string, error) {
363385
return reportPath, nil
364386
}
365387

366-
func uploadReport(server, token string, scanID uuid.UUID, reportPath string) error {
388+
func uploadReport(ctx context.Context, server, token string, scanID uuid.UUID, reportPath string) error {
367389
report, err := os.ReadFile(reportPath)
368390
if err != nil {
369391
return fmt.Errorf("failed to read report file: %w", err)
@@ -383,10 +405,9 @@ func uploadReport(server, token string, scanID uuid.UUID, reportPath string) err
383405
return fmt.Errorf("failed to create report upload request: %w", err)
384406
}
385407

386-
req.Header.Set("Authorization", "Bearer "+token)
387408
req.Header.Set("Content-Type", "application/json")
388409

389-
resp, err := http.DefaultClient.Do(req)
410+
resp, err := ctx.Value(httpClientKey{}).(*client.Client).Do(req)
390411
if err != nil {
391412
return fmt.Errorf("failed to upload scan report: %w", err)
392413
}

internal/http/client/client.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package client
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
)
7+
8+
type Client struct {
9+
client *http.Client
10+
token string
11+
headers map[string]string
12+
}
13+
14+
func New(opts ...Option) *Client {
15+
c := &Client{
16+
client: http.DefaultClient,
17+
}
18+
19+
for _, o := range opts {
20+
o(c)
21+
}
22+
23+
return c
24+
}
25+
26+
func (c *Client) Do(req *http.Request) (*http.Response, error) {
27+
if c.token != "" {
28+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token))
29+
}
30+
31+
for k, v := range c.headers {
32+
req.Header.Set(k, v)
33+
}
34+
35+
return c.client.Do(req)
36+
}

internal/http/client/client_opts.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package client
2+
3+
import "net/http"
4+
5+
type Option func(*Client)
6+
7+
func WithHTTPClient(client *http.Client) Option {
8+
return func(c *Client) {
9+
c.client = client
10+
}
11+
}
12+
13+
func WithToken(token string) Option {
14+
return func(c *Client) {
15+
c.token = token
16+
}
17+
}
18+
19+
func WithHeaders(headers map[string]string) Option {
20+
return func(c *Client) {
21+
c.headers = headers
22+
}
23+
}

0 commit comments

Comments
 (0)