diff --git a/cmd/server.go b/cmd/server.go index 26bc067..c4da66d 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -2,6 +2,7 @@ package cmd import ( "fmt" + "github.com/shinobistack/gokakashi/internal/assigner" "github.com/shinobistack/gokakashi/internal/db" "log" "os" @@ -103,7 +104,9 @@ func handleConfigV1() { // Populate the database db.PopulateDatabase(configDB, cfg) - // log.Println("Shutting down goKakashi gracefully...") + // ToDo: To be go routine who independently and routinely checks and assigns scans in agentTasks table + go assigner.StartAssigner(cfg.Site.Host, cfg.Site.Port, cfg.Site.APIToken, 1*time.Minute) + } func handleConfigV0() { diff --git a/internal/assigner/assigner.go b/internal/assigner/assigner.go new file mode 100644 index 0000000..cf1544a --- /dev/null +++ b/internal/assigner/assigner.go @@ -0,0 +1,202 @@ +package assigner + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/google/uuid" + "github.com/shinobistack/gokakashi/internal/restapi/v1/agents" + "github.com/shinobistack/gokakashi/internal/restapi/v1/agenttasks" + "github.com/shinobistack/gokakashi/internal/restapi/v1/scans" + "log" + "net/http" + "net/url" + "strings" + "time" +) + +func normalizeServer(server string) string { + if !strings.HasPrefix(server, "http://") && !strings.HasPrefix(server, "https://") { + server = "http://" + server // Default to HTTP + } + return server +} + +func constructURL(server string, port int, path string) string { + base := normalizeServer(server) + u, err := url.Parse(base) + if err != nil { + log.Fatalf("Invalid server URL: %s", base) + } + if u.Port() == "" { + u.Host = fmt.Sprintf("%s:%d", u.Host, port) + } + u.Path = path + return u.String() +} + +func StartAssigner(server string, port int, token string, interval time.Duration) { + log.Println("Starting the periodic task assigner...") + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for range ticker.C { + AssignTasks(server, port, token) + } + +} + +func AssignTasks(server string, port int, token string) { + log.Println("Assigner now begins assigning your scans") + // Step 1: Fetch scans needing assignment + pendingScans, err := fetchPendingScans(server, port, token, "scan_pending") + if err != nil { + log.Printf("Error fetching pending scans: %v", err) + return + } + + if len(pendingScans) == 0 { + log.Println("No pending scans to assign.") + return + } + + // Step 2: Fetch available agents + availableAgents, err := fetchAvailableAgents(server, port, token, "connected") + if err != nil { + log.Printf("Error fetching available agents: %v", err) + return + } + + if len(availableAgents) == 0 { + log.Println("No agents available for assignment.") + return + } + + // log.Printf("Agents are available: %v", availableAgents) + + // Step 3: Assign scans to agents + // ToDo: to explore task assignment for better efficiency + for i, scan := range pendingScans { + // Check if scan is already assigned + if isScanAssigned(server, port, token, scan.ID) { + log.Printf("Scan ID %s is already assigned. Skipping.", scan.ID) + continue + } + + // Select agent using round-robin + agent := availableAgents[i%len(availableAgents)] + if err := createAgentTask(server, port, token, agent.ID, scan.ID); err != nil { + log.Printf("Failed to assign scan %s to agent %d: %v", scan.ID, agent.ID, err) + } else { + log.Printf("Successfully assigned scan %s to agent %d", scan.ID, agent.ID) + } + + } +} + +func fetchPendingScans(server string, port int, token, status string) ([]scans.GetScanResponse, error) { + url := constructURL(server, port, "/api/v1/scans") + fmt.Sprintf("?status=%s", status) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request for pending scans: %w", err) + } + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("server responded with status: %d", resp.StatusCode) + } + + var scans []scans.GetScanResponse + if err := json.NewDecoder(resp.Body).Decode(&scans); err != nil { + return nil, fmt.Errorf("failed to decode scans response: %w", err) + } + + return scans, nil +} + +func fetchAvailableAgents(server string, port int, token, status string) ([]agents.GetAgentResponse, error) { + url := constructURL(server, port, "/api/v1/agents") + fmt.Sprintf("?status=%s", status) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("server responded with status: %d", resp.StatusCode) + } + + var agents []agents.GetAgentResponse + if err := json.NewDecoder(resp.Body).Decode(&agents); err != nil { + return nil, err + } + + return agents, nil +} + +func isScanAssigned(server string, port int, token string, scanID uuid.UUID) bool { + url := constructURL(server, port, "/api/v1/agents/tasks") + fmt.Sprintf("?scan_id=%s", scanID) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + log.Printf("Error checking scan assignment: %v", err) + return false + } + + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + log.Printf("Error checking scan assignment: %v", err) + return false + } + defer resp.Body.Close() + + return resp.StatusCode == http.StatusOK +} + +func createAgentTask(server string, port int, token string, agentID int, scanID uuid.UUID) error { + url := constructURL(server, port, fmt.Sprintf("/api/v1/agents/%d/tasks", agentID)) + + reqBody := agenttasks.CreateAgentTaskRequest{ + AgentID: agentID, + ScanID: scanID, + Status: "pending", + CreatedAt: time.Now(), + } + + reqBodyJSON, _ := json.Marshal(reqBody) + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(reqBodyJSON)) + if err != nil { + return err + } + + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + return fmt.Errorf("server responded with status: %d", resp.StatusCode) + } + + return nil +} diff --git a/internal/assigner/assigner_test.go b/internal/assigner/assigner_test.go new file mode 100644 index 0000000..3d26c52 --- /dev/null +++ b/internal/assigner/assigner_test.go @@ -0,0 +1,129 @@ +package assigner_test + +import ( + "encoding/json" + "github.com/google/uuid" + "github.com/shinobistack/gokakashi/internal/assigner" + "net/http" + "net/http/httptest" + "testing" +) + +type MockScan struct { + ID uuid.UUID `json:"id"` + Status string `json:"status"` +} + +type MockAgent struct { + ID int `json:"id"` + Status string `json:"status"` +} + +func TestAssignTasks(t *testing.T) { + // Mock data + mockScans := []MockScan{ + {ID: uuid.New(), Status: "scan_pending"}, + {ID: uuid.New(), Status: "scan_pending"}, + {ID: uuid.New(), Status: "scan_pending"}, + {ID: uuid.New(), Status: "scan_pending"}, + {ID: uuid.New(), Status: "scan_pending"}, + {ID: uuid.New(), Status: "scan_pending"}, + {ID: uuid.New(), Status: "scan_pending"}, + {ID: uuid.New(), Status: "scan_pending"}, + } + mockAgents := []MockAgent{ + {ID: 1, Status: "connected"}, + {ID: 2, Status: "connected"}, + } + + // Mock server + scanHandler := func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/v1/scans" && r.URL.Query().Get("status") == "scan_pending" { + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(mockScans); err != nil { + http.Error(w, "Failed to encode mock scans", http.StatusInternalServerError) + return + } + } else if r.URL.Path == "/api/v1/agents" && r.URL.Query().Get("status") == "connected" { + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(mockAgents); err != nil { + http.Error(w, "Failed to encode mock agents", http.StatusInternalServerError) + return + } + } else if r.URL.Path == "/api/v1/agents/tasks" { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusNotFound) + } + } + + mockServer := httptest.NewServer(http.HandlerFunc(scanHandler)) + defer mockServer.Close() + + // Run the assigner logic + assigner.AssignTasks(mockServer.URL, 0, "mock-token") + + t.Log("Ensure tasks are assigned to agents in round-robin fashion.") +} + +func TestAssignTasksWithNoAgents(t *testing.T) { + mockScans := []MockScan{ + {ID: uuid.New(), Status: "scan_pending"}, + } + + scanHandler := func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/v1/scans" && r.URL.Query().Get("status") == "scan_pending" { + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(mockScans); err != nil { + http.Error(w, "Failed to encode mock scans", http.StatusInternalServerError) + return + } + } else if r.URL.Path == "/api/v1/agents" { + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode([]MockAgent{}); err != nil { + http.Error(w, "Failed to encode mock agents", http.StatusInternalServerError) + return + } // No agents available + } else { + w.WriteHeader(http.StatusNotFound) + } + } + + mockServer := httptest.NewServer(http.HandlerFunc(scanHandler)) + defer mockServer.Close() + + assigner.AssignTasks(mockServer.URL, 0, "mock-token") + + t.Log("Ensure no assignments are made when no agents are available.") +} + +func TestAssignTasksWithNoScans(t *testing.T) { + mockAgents := []MockAgent{ + {ID: 1, Status: "connected"}, + } + + scanHandler := func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/v1/scans" { + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode([]MockScan{}); err != nil { + http.Error(w, "Failed to encode mock scans", http.StatusInternalServerError) + return + } // No scans available + } else if r.URL.Path == "/api/v1/agents" && r.URL.Query().Get("status") == "connected" { + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(mockAgents); err != nil { + http.Error(w, "Failed to encode mock agents", http.StatusInternalServerError) + return + } + } else { + w.WriteHeader(http.StatusNotFound) + } + } + + mockServer := httptest.NewServer(http.HandlerFunc(scanHandler)) + defer mockServer.Close() + + assigner.AssignTasks(mockServer.URL, 0, "mock-token") + + t.Log("Ensure no assignments are made when no scans are pending.") +} diff --git a/internal/restapi/v1/agents/get.go b/internal/restapi/v1/agents/get.go index abc6369..61f8516 100644 --- a/internal/restapi/v1/agents/get.go +++ b/internal/restapi/v1/agents/get.go @@ -4,6 +4,7 @@ import ( "context" "errors" "github.com/shinobistack/gokakashi/ent" + "github.com/shinobistack/gokakashi/ent/agents" "github.com/swaggest/usecase/status" "time" ) @@ -24,6 +25,7 @@ type ListAgentsResponse struct { } type PollAgentsRequest struct { + Status string `query:"status"` } type PollAgentsResponse struct { @@ -78,7 +80,13 @@ func GetAgent(client *ent.Client) func(ctx context.Context, req GetAgentRequest, func PollAgents(client *ent.Client) func(ctx context.Context, req PollAgentsRequest, res *[]PollAgentsResponse) error { return func(ctx context.Context, req PollAgentsRequest, res *[]PollAgentsResponse) error { - agentsList, err := client.Agents.Query().All(ctx) + query := client.Agents.Query() + + if req.Status != "" { + query = query.Where(agents.Status(req.Status)) + } + + agentsList, err := query.All(ctx) if err != nil { return status.Wrap(err, status.Internal) } diff --git a/internal/restapi/v1/agenttasks/get.go b/internal/restapi/v1/agenttasks/get.go index d4085aa..138ed86 100644 --- a/internal/restapi/v1/agenttasks/get.go +++ b/internal/restapi/v1/agenttasks/get.go @@ -23,8 +23,19 @@ type GetAgentTaskResponse struct { CreatedAt time.Time `json:"created_at"` } type ListAgentTasksRequest struct { - AgentID int `path:"agent_id"` - Status string `query:"status"` + AgentID *int `path:"agent_id"` + ScanID *uuid.UUID `query:"scan_id"` + Status string `query:"status"` +} + +type ListAgentTasksQueryRequest struct { + AgentID *int `query:"agent_id"` + ScanID *uuid.UUID `query:"scan_id"` + Status string `query:"status"` +} + +type ListAgentTasksByScanIDRequest struct { + ScanID uuid.UUID `query:"scan_id"` } func GetAgentTask(client *ent.Client) func(ctx context.Context, req GetAgentTaskRequest, res *GetAgentTaskResponse) error { @@ -52,13 +63,13 @@ func GetAgentTask(client *ent.Client) func(ctx context.Context, req GetAgentTask func ListAgentTasksByAgentID(client *ent.Client) func(ctx context.Context, req ListAgentTasksRequest, res *[]GetAgentTaskResponse) error { return func(ctx context.Context, req ListAgentTasksRequest, res *[]GetAgentTaskResponse) error { - if req.AgentID <= 0 { + if req.AgentID == nil || *req.AgentID <= 0 { return status.Wrap(errors.New("invalid agent ID"), status.InvalidArgument) } // Query builder query := client.AgentTasks.Query(). - Where(agenttasks.AgentID(req.AgentID)). + Where(agenttasks.AgentID(*req.AgentID)). Order(ent.Asc(agenttasks.FieldCreatedAt)) // Order by created_at ASC // Filter by status if provided @@ -84,3 +95,43 @@ func ListAgentTasksByAgentID(client *ent.Client) func(ctx context.Context, req L return nil } } + +func ListAgentTasks(client *ent.Client) func(ctx context.Context, req ListAgentTasksQueryRequest, res *[]GetAgentTaskResponse) error { + return func(ctx context.Context, req ListAgentTasksQueryRequest, res *[]GetAgentTaskResponse) error { + query := client.AgentTasks.Query() + + // Filter by agent ID if provided + if req.AgentID != nil { + query = query.Where(agenttasks.AgentID(*req.AgentID)) + } + + // Filter by scan ID if provided + if req.ScanID != nil && *req.ScanID != uuid.Nil { + query = query.Where(agenttasks.ScanID(*req.ScanID)) + } + + // Filter by status if provided + if req.Status != "" { + query = query.Where(agenttasks.Status(req.Status)) + } + + // Execute query + tasks, err := query.Order(ent.Asc(agenttasks.FieldCreatedAt)).All(ctx) + if err != nil { + return status.Wrap(err, status.Internal) + } + + // Populate response + *res = make([]GetAgentTaskResponse, len(tasks)) + for i, task := range tasks { + (*res)[i] = GetAgentTaskResponse{ + ID: task.ID, + AgentID: task.AgentID, + ScanID: task.ScanID, + Status: task.Status, + CreatedAt: task.CreatedAt, + } + } + return nil + } +} diff --git a/internal/restapi/v1/agenttasks/get_test.go b/internal/restapi/v1/agenttasks/get_test.go index d815d62..ab37574 100644 --- a/internal/restapi/v1/agenttasks/get_test.go +++ b/internal/restapi/v1/agenttasks/get_test.go @@ -111,7 +111,7 @@ func TestListAgentTasks_Valid(t *testing.T) { SetStatus("pending"). SaveX(context.Background()) - req := agenttasks.ListAgentTasksRequest{agent.ID, ""} + req := agenttasks.ListAgentTasksRequest{AgentID: intPtr(agent.ID), Status: ""} res := []agenttasks.GetAgentTaskResponse{} err := agenttasks.ListAgentTasksByAgentID(client)(context.Background(), req, &res) @@ -163,7 +163,7 @@ func TestListAgentTasks_OrderedByCreatedAt(t *testing.T) { SetCreatedAt(time.Now()). // Newer task SaveX(context.Background()) - req := agenttasks.ListAgentTasksRequest{AgentID: agent.ID} + req := agenttasks.ListAgentTasksRequest{AgentID: intPtr(agent.ID)} res := []agenttasks.GetAgentTaskResponse{} err := agenttasks.ListAgentTasksByAgentID(client)(context.Background(), req, &res) diff --git a/internal/restapi/v1/scans/get.go b/internal/restapi/v1/scans/get.go index cc8876f..af012d7 100644 --- a/internal/restapi/v1/scans/get.go +++ b/internal/restapi/v1/scans/get.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/google/uuid" "github.com/shinobistack/gokakashi/ent" + "github.com/shinobistack/gokakashi/ent/scans" "github.com/shinobistack/gokakashi/ent/schema" "github.com/swaggest/usecase/status" ) @@ -22,7 +23,9 @@ type GetScanResponse struct { Report json.RawMessage `json:"report,omitempty"` } -type ListScanRequest struct{} +type ListScanRequest struct { + Status string `query:"status"` +} type GetScanRequest struct { ID uuid.UUID `path:"id"` @@ -30,13 +33,19 @@ type GetScanRequest struct { func ListScans(client *ent.Client) func(ctx context.Context, req ListScanRequest, res *[]GetScanResponse) error { return func(ctx context.Context, req ListScanRequest, res *[]GetScanResponse) error { - scans, err := client.Scans.Query().All(ctx) + query := client.Scans.Query() + + if req.Status != "" { + query = query.Where(scans.Status(req.Status)) + } + + scanResults, err := query.All(ctx) if err != nil { return status.Wrap(errors.New("failed to fetch scan details"), status.Internal) } - *res = make([]GetScanResponse, len(scans)) - for i, scan := range scans { + *res = make([]GetScanResponse, len(scanResults)) + for i, scan := range scanResults { (*res)[i] = GetScanResponse{ ID: scan.ID, PolicyID: scan.PolicyID, diff --git a/internal/restapi/v1/scans/get_test.go b/internal/restapi/v1/scans/get_test.go index 1a3743c..1ba64f7 100644 --- a/internal/restapi/v1/scans/get_test.go +++ b/internal/restapi/v1/scans/get_test.go @@ -43,7 +43,7 @@ func TestListScans_Valid(t *testing.T) { SetIntegrationID(integrations.ID). SaveX(context.Background()) - req := scans.ListScanRequest{} + req := scans.ListScanRequest{""} res := []scans.GetScanResponse{} err := scans.ListScans(client)(context.Background(), req, &res) diff --git a/internal/restapi/v1/server.go b/internal/restapi/v1/server.go index 7ee25fb..36225b5 100644 --- a/internal/restapi/v1/server.go +++ b/internal/restapi/v1/server.go @@ -92,6 +92,7 @@ func (srv *Server) Service() *web.Service { apiV1.Delete("/agents/{id}", usecase.NewInteractor(agents1.DeleteAgent(srv.DB))) apiV1.Post("/agents/{agent_id}/tasks", usecase.NewInteractor(agenttasks1.CreateAgentTask(srv.DB))) + apiV1.Get("/agents/tasks", usecase.NewInteractor(agenttasks1.ListAgentTasks(srv.DB))) apiV1.Get("/agents/{agent_id}/tasks", usecase.NewInteractor(agenttasks1.ListAgentTasksByAgentID(srv.DB))) apiV1.Get("/agents/{agent_id}/tasks/{id}", usecase.NewInteractor(agenttasks1.GetAgentTask(srv.DB))) apiV1.Put("/agents/{agent_id}/tasks/{id}", usecase.NewInteractor(agenttasks1.UpdateAgentTask(srv.DB)))