diff --git a/ent/agents_create.go b/ent/agents_create.go index 1f528e0..c7df18a 100644 --- a/ent/agents_create.go +++ b/ent/agents_create.go @@ -9,6 +9,7 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" + "github.com/google/uuid" "github.com/shinobistack/gokakashi/ent/agents" "github.com/shinobistack/gokakashi/ent/agenttasks" ) @@ -41,14 +42,14 @@ func (ac *AgentsCreate) SetID(i int) *AgentsCreate { } // AddAgentTaskIDs adds the "agent_tasks" edge to the AgentTasks entity by IDs. -func (ac *AgentsCreate) AddAgentTaskIDs(ids ...int) *AgentsCreate { +func (ac *AgentsCreate) AddAgentTaskIDs(ids ...uuid.UUID) *AgentsCreate { ac.mutation.AddAgentTaskIDs(ids...) return ac } // AddAgentTasks adds the "agent_tasks" edges to the AgentTasks entity. func (ac *AgentsCreate) AddAgentTasks(a ...*AgentTasks) *AgentsCreate { - ids := make([]int, len(a)) + ids := make([]uuid.UUID, len(a)) for i := range a { ids[i] = a[i].ID } @@ -145,7 +146,7 @@ func (ac *AgentsCreate) createSpec() (*Agents, *sqlgraph.CreateSpec) { Columns: []string{agents.AgentTasksColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt), + IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID), }, } for _, k := range nodes { diff --git a/ent/agents_update.go b/ent/agents_update.go index 2094b8b..9499f26 100644 --- a/ent/agents_update.go +++ b/ent/agents_update.go @@ -10,6 +10,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" + "github.com/google/uuid" "github.com/shinobistack/gokakashi/ent/agents" "github.com/shinobistack/gokakashi/ent/agenttasks" "github.com/shinobistack/gokakashi/ent/predicate" @@ -43,14 +44,14 @@ func (au *AgentsUpdate) SetNillableStatus(s *string) *AgentsUpdate { } // AddAgentTaskIDs adds the "agent_tasks" edge to the AgentTasks entity by IDs. -func (au *AgentsUpdate) AddAgentTaskIDs(ids ...int) *AgentsUpdate { +func (au *AgentsUpdate) AddAgentTaskIDs(ids ...uuid.UUID) *AgentsUpdate { au.mutation.AddAgentTaskIDs(ids...) return au } // AddAgentTasks adds the "agent_tasks" edges to the AgentTasks entity. func (au *AgentsUpdate) AddAgentTasks(a ...*AgentTasks) *AgentsUpdate { - ids := make([]int, len(a)) + ids := make([]uuid.UUID, len(a)) for i := range a { ids[i] = a[i].ID } @@ -69,14 +70,14 @@ func (au *AgentsUpdate) ClearAgentTasks() *AgentsUpdate { } // RemoveAgentTaskIDs removes the "agent_tasks" edge to AgentTasks entities by IDs. -func (au *AgentsUpdate) RemoveAgentTaskIDs(ids ...int) *AgentsUpdate { +func (au *AgentsUpdate) RemoveAgentTaskIDs(ids ...uuid.UUID) *AgentsUpdate { au.mutation.RemoveAgentTaskIDs(ids...) return au } // RemoveAgentTasks removes "agent_tasks" edges to AgentTasks entities. func (au *AgentsUpdate) RemoveAgentTasks(a ...*AgentTasks) *AgentsUpdate { - ids := make([]int, len(a)) + ids := make([]uuid.UUID, len(a)) for i := range a { ids[i] = a[i].ID } @@ -130,7 +131,7 @@ func (au *AgentsUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{agents.AgentTasksColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt), + IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -143,7 +144,7 @@ func (au *AgentsUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{agents.AgentTasksColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt), + IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID), }, } for _, k := range nodes { @@ -159,7 +160,7 @@ func (au *AgentsUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{agents.AgentTasksColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt), + IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID), }, } for _, k := range nodes { @@ -202,14 +203,14 @@ func (auo *AgentsUpdateOne) SetNillableStatus(s *string) *AgentsUpdateOne { } // AddAgentTaskIDs adds the "agent_tasks" edge to the AgentTasks entity by IDs. -func (auo *AgentsUpdateOne) AddAgentTaskIDs(ids ...int) *AgentsUpdateOne { +func (auo *AgentsUpdateOne) AddAgentTaskIDs(ids ...uuid.UUID) *AgentsUpdateOne { auo.mutation.AddAgentTaskIDs(ids...) return auo } // AddAgentTasks adds the "agent_tasks" edges to the AgentTasks entity. func (auo *AgentsUpdateOne) AddAgentTasks(a ...*AgentTasks) *AgentsUpdateOne { - ids := make([]int, len(a)) + ids := make([]uuid.UUID, len(a)) for i := range a { ids[i] = a[i].ID } @@ -228,14 +229,14 @@ func (auo *AgentsUpdateOne) ClearAgentTasks() *AgentsUpdateOne { } // RemoveAgentTaskIDs removes the "agent_tasks" edge to AgentTasks entities by IDs. -func (auo *AgentsUpdateOne) RemoveAgentTaskIDs(ids ...int) *AgentsUpdateOne { +func (auo *AgentsUpdateOne) RemoveAgentTaskIDs(ids ...uuid.UUID) *AgentsUpdateOne { auo.mutation.RemoveAgentTaskIDs(ids...) return auo } // RemoveAgentTasks removes "agent_tasks" edges to AgentTasks entities. func (auo *AgentsUpdateOne) RemoveAgentTasks(a ...*AgentTasks) *AgentsUpdateOne { - ids := make([]int, len(a)) + ids := make([]uuid.UUID, len(a)) for i := range a { ids[i] = a[i].ID } @@ -319,7 +320,7 @@ func (auo *AgentsUpdateOne) sqlSave(ctx context.Context) (_node *Agents, err err Columns: []string{agents.AgentTasksColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt), + IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -332,7 +333,7 @@ func (auo *AgentsUpdateOne) sqlSave(ctx context.Context) (_node *Agents, err err Columns: []string{agents.AgentTasksColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt), + IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID), }, } for _, k := range nodes { @@ -348,7 +349,7 @@ func (auo *AgentsUpdateOne) sqlSave(ctx context.Context) (_node *Agents, err err Columns: []string{agents.AgentTasksColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt), + IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID), }, } for _, k := range nodes { diff --git a/ent/agenttasks.go b/ent/agenttasks.go index c2d7c5b..0a61247 100644 --- a/ent/agenttasks.go +++ b/ent/agenttasks.go @@ -20,7 +20,7 @@ type AgentTasks struct { config `json:"-"` // ID of the ent. // Primary key, unique identifier. - ID int `json:"id,omitempty"` + ID uuid.UUID `json:"id,omitempty"` // Foreign key to Agents.ID. AgentID int `json:"agent_id,omitempty"` // Foreign key to Scans.ID. @@ -73,13 +73,13 @@ func (*AgentTasks) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case agenttasks.FieldID, agenttasks.FieldAgentID: + case agenttasks.FieldAgentID: values[i] = new(sql.NullInt64) case agenttasks.FieldStatus: values[i] = new(sql.NullString) case agenttasks.FieldCreatedAt: values[i] = new(sql.NullTime) - case agenttasks.FieldScanID: + case agenttasks.FieldID, agenttasks.FieldScanID: values[i] = new(uuid.UUID) default: values[i] = new(sql.UnknownType) @@ -97,11 +97,11 @@ func (at *AgentTasks) assignValues(columns []string, values []any) error { for i := range columns { switch columns[i] { case agenttasks.FieldID: - value, ok := values[i].(*sql.NullInt64) - if !ok { - return fmt.Errorf("unexpected type %T for field id", value) + if value, ok := values[i].(*uuid.UUID); !ok { + return fmt.Errorf("unexpected type %T for field id", values[i]) + } else if value != nil { + at.ID = *value } - at.ID = int(value.Int64) case agenttasks.FieldAgentID: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for field agent_id", values[i]) diff --git a/ent/agenttasks/agenttasks.go b/ent/agenttasks/agenttasks.go index 4c1b4fe..c2d3551 100644 --- a/ent/agenttasks/agenttasks.go +++ b/ent/agenttasks/agenttasks.go @@ -7,6 +7,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/google/uuid" ) const ( @@ -68,6 +69,8 @@ var ( DefaultStatus string // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time + // DefaultID holds the default value on creation for the "id" field. + DefaultID func() uuid.UUID ) // OrderOption defines the ordering options for the AgentTasks queries. diff --git a/ent/agenttasks/where.go b/ent/agenttasks/where.go index 847f4e0..65aa2c7 100644 --- a/ent/agenttasks/where.go +++ b/ent/agenttasks/where.go @@ -12,47 +12,47 @@ import ( ) // ID filters vertices based on their ID field. -func ID(id int) predicate.AgentTasks { +func ID(id uuid.UUID) predicate.AgentTasks { return predicate.AgentTasks(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. -func IDEQ(id int) predicate.AgentTasks { +func IDEQ(id uuid.UUID) predicate.AgentTasks { return predicate.AgentTasks(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. -func IDNEQ(id int) predicate.AgentTasks { +func IDNEQ(id uuid.UUID) predicate.AgentTasks { return predicate.AgentTasks(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. -func IDIn(ids ...int) predicate.AgentTasks { +func IDIn(ids ...uuid.UUID) predicate.AgentTasks { return predicate.AgentTasks(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. -func IDNotIn(ids ...int) predicate.AgentTasks { +func IDNotIn(ids ...uuid.UUID) predicate.AgentTasks { return predicate.AgentTasks(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. -func IDGT(id int) predicate.AgentTasks { +func IDGT(id uuid.UUID) predicate.AgentTasks { return predicate.AgentTasks(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. -func IDGTE(id int) predicate.AgentTasks { +func IDGTE(id uuid.UUID) predicate.AgentTasks { return predicate.AgentTasks(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. -func IDLT(id int) predicate.AgentTasks { +func IDLT(id uuid.UUID) predicate.AgentTasks { return predicate.AgentTasks(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. -func IDLTE(id int) predicate.AgentTasks { +func IDLTE(id uuid.UUID) predicate.AgentTasks { return predicate.AgentTasks(sql.FieldLTE(FieldID, id)) } diff --git a/ent/agenttasks_create.go b/ent/agenttasks_create.go index a35f118..e5c5f00 100644 --- a/ent/agenttasks_create.go +++ b/ent/agenttasks_create.go @@ -64,8 +64,16 @@ func (atc *AgentTasksCreate) SetNillableCreatedAt(t *time.Time) *AgentTasksCreat } // SetID sets the "id" field. -func (atc *AgentTasksCreate) SetID(i int) *AgentTasksCreate { - atc.mutation.SetID(i) +func (atc *AgentTasksCreate) SetID(u uuid.UUID) *AgentTasksCreate { + atc.mutation.SetID(u) + return atc +} + +// SetNillableID sets the "id" field if the given value is not nil. +func (atc *AgentTasksCreate) SetNillableID(u *uuid.UUID) *AgentTasksCreate { + if u != nil { + atc.SetID(*u) + } return atc } @@ -122,6 +130,10 @@ func (atc *AgentTasksCreate) defaults() { v := agenttasks.DefaultCreatedAt() atc.mutation.SetCreatedAt(v) } + if _, ok := atc.mutation.ID(); !ok { + v := agenttasks.DefaultID() + atc.mutation.SetID(v) + } } // check runs all checks and user-defined validators on the builder. @@ -158,9 +170,12 @@ func (atc *AgentTasksCreate) sqlSave(ctx context.Context) (*AgentTasks, error) { } return nil, err } - if _spec.ID.Value != _node.ID { - id := _spec.ID.Value.(int64) - _node.ID = int(id) + if _spec.ID.Value != nil { + if id, ok := _spec.ID.Value.(*uuid.UUID); ok { + _node.ID = *id + } else if err := _node.ID.Scan(_spec.ID.Value); err != nil { + return nil, err + } } atc.mutation.id = &_node.ID atc.mutation.done = true @@ -170,11 +185,11 @@ func (atc *AgentTasksCreate) sqlSave(ctx context.Context) (*AgentTasks, error) { func (atc *AgentTasksCreate) createSpec() (*AgentTasks, *sqlgraph.CreateSpec) { var ( _node = &AgentTasks{config: atc.config} - _spec = sqlgraph.NewCreateSpec(agenttasks.Table, sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt)) + _spec = sqlgraph.NewCreateSpec(agenttasks.Table, sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID)) ) if id, ok := atc.mutation.ID(); ok { _node.ID = id - _spec.ID.Value = id + _spec.ID.Value = &id } if value, ok := atc.mutation.Status(); ok { _spec.SetField(agenttasks.FieldStatus, field.TypeString, value) @@ -266,10 +281,6 @@ func (atcb *AgentTasksCreateBulk) Save(ctx context.Context) ([]*AgentTasks, erro return nil, err } mutation.id = &nodes[i].ID - if specs[i].ID.Value != nil && nodes[i].ID == 0 { - id := specs[i].ID.Value.(int64) - nodes[i].ID = int(id) - } mutation.done = true return nodes[i], nil }) diff --git a/ent/agenttasks_delete.go b/ent/agenttasks_delete.go index 740ff0a..f925752 100644 --- a/ent/agenttasks_delete.go +++ b/ent/agenttasks_delete.go @@ -40,7 +40,7 @@ func (atd *AgentTasksDelete) ExecX(ctx context.Context) int { } func (atd *AgentTasksDelete) sqlExec(ctx context.Context) (int, error) { - _spec := sqlgraph.NewDeleteSpec(agenttasks.Table, sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt)) + _spec := sqlgraph.NewDeleteSpec(agenttasks.Table, sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID)) if ps := atd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { diff --git a/ent/agenttasks_query.go b/ent/agenttasks_query.go index 8ecdecd..2e6f57f 100644 --- a/ent/agenttasks_query.go +++ b/ent/agenttasks_query.go @@ -131,8 +131,8 @@ func (atq *AgentTasksQuery) FirstX(ctx context.Context) *AgentTasks { // FirstID returns the first AgentTasks ID from the query. // Returns a *NotFoundError when no AgentTasks ID was found. -func (atq *AgentTasksQuery) FirstID(ctx context.Context) (id int, err error) { - var ids []int +func (atq *AgentTasksQuery) FirstID(ctx context.Context) (id uuid.UUID, err error) { + var ids []uuid.UUID if ids, err = atq.Limit(1).IDs(setContextOp(ctx, atq.ctx, ent.OpQueryFirstID)); err != nil { return } @@ -144,7 +144,7 @@ func (atq *AgentTasksQuery) FirstID(ctx context.Context) (id int, err error) { } // FirstIDX is like FirstID, but panics if an error occurs. -func (atq *AgentTasksQuery) FirstIDX(ctx context.Context) int { +func (atq *AgentTasksQuery) FirstIDX(ctx context.Context) uuid.UUID { id, err := atq.FirstID(ctx) if err != nil && !IsNotFound(err) { panic(err) @@ -182,8 +182,8 @@ func (atq *AgentTasksQuery) OnlyX(ctx context.Context) *AgentTasks { // OnlyID is like Only, but returns the only AgentTasks ID in the query. // Returns a *NotSingularError when more than one AgentTasks ID is found. // Returns a *NotFoundError when no entities are found. -func (atq *AgentTasksQuery) OnlyID(ctx context.Context) (id int, err error) { - var ids []int +func (atq *AgentTasksQuery) OnlyID(ctx context.Context) (id uuid.UUID, err error) { + var ids []uuid.UUID if ids, err = atq.Limit(2).IDs(setContextOp(ctx, atq.ctx, ent.OpQueryOnlyID)); err != nil { return } @@ -199,7 +199,7 @@ func (atq *AgentTasksQuery) OnlyID(ctx context.Context) (id int, err error) { } // OnlyIDX is like OnlyID, but panics if an error occurs. -func (atq *AgentTasksQuery) OnlyIDX(ctx context.Context) int { +func (atq *AgentTasksQuery) OnlyIDX(ctx context.Context) uuid.UUID { id, err := atq.OnlyID(ctx) if err != nil { panic(err) @@ -227,7 +227,7 @@ func (atq *AgentTasksQuery) AllX(ctx context.Context) []*AgentTasks { } // IDs executes the query and returns a list of AgentTasks IDs. -func (atq *AgentTasksQuery) IDs(ctx context.Context) (ids []int, err error) { +func (atq *AgentTasksQuery) IDs(ctx context.Context) (ids []uuid.UUID, err error) { if atq.ctx.Unique == nil && atq.path != nil { atq.Unique(true) } @@ -239,7 +239,7 @@ func (atq *AgentTasksQuery) IDs(ctx context.Context) (ids []int, err error) { } // IDsX is like IDs, but panics if an error occurs. -func (atq *AgentTasksQuery) IDsX(ctx context.Context) []int { +func (atq *AgentTasksQuery) IDsX(ctx context.Context) []uuid.UUID { ids, err := atq.IDs(ctx) if err != nil { panic(err) @@ -514,7 +514,7 @@ func (atq *AgentTasksQuery) sqlCount(ctx context.Context) (int, error) { } func (atq *AgentTasksQuery) querySpec() *sqlgraph.QuerySpec { - _spec := sqlgraph.NewQuerySpec(agenttasks.Table, agenttasks.Columns, sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt)) + _spec := sqlgraph.NewQuerySpec(agenttasks.Table, agenttasks.Columns, sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID)) _spec.From = atq.sql if unique := atq.ctx.Unique; unique != nil { _spec.Unique = *unique diff --git a/ent/agenttasks_update.go b/ent/agenttasks_update.go index c2637f2..912a47a 100644 --- a/ent/agenttasks_update.go +++ b/ent/agenttasks_update.go @@ -141,7 +141,7 @@ func (atu *AgentTasksUpdate) sqlSave(ctx context.Context) (n int, err error) { if err := atu.check(); err != nil { return n, err } - _spec := sqlgraph.NewUpdateSpec(agenttasks.Table, agenttasks.Columns, sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt)) + _spec := sqlgraph.NewUpdateSpec(agenttasks.Table, agenttasks.Columns, sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID)) if ps := atu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -354,7 +354,7 @@ func (atuo *AgentTasksUpdateOne) sqlSave(ctx context.Context) (_node *AgentTasks if err := atuo.check(); err != nil { return _node, err } - _spec := sqlgraph.NewUpdateSpec(agenttasks.Table, agenttasks.Columns, sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt)) + _spec := sqlgraph.NewUpdateSpec(agenttasks.Table, agenttasks.Columns, sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID)) id, ok := atuo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AgentTasks.id" for update`)} diff --git a/ent/client.go b/ent/client.go index 8e3a044..48fb702 100644 --- a/ent/client.go +++ b/ent/client.go @@ -325,7 +325,7 @@ func (c *AgentTasksClient) UpdateOne(at *AgentTasks) *AgentTasksUpdateOne { } // UpdateOneID returns an update builder for the given id. -func (c *AgentTasksClient) UpdateOneID(id int) *AgentTasksUpdateOne { +func (c *AgentTasksClient) UpdateOneID(id uuid.UUID) *AgentTasksUpdateOne { mutation := newAgentTasksMutation(c.config, OpUpdateOne, withAgentTasksID(id)) return &AgentTasksUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} } @@ -342,7 +342,7 @@ func (c *AgentTasksClient) DeleteOne(at *AgentTasks) *AgentTasksDeleteOne { } // DeleteOneID returns a builder for deleting the given entity by its id. -func (c *AgentTasksClient) DeleteOneID(id int) *AgentTasksDeleteOne { +func (c *AgentTasksClient) DeleteOneID(id uuid.UUID) *AgentTasksDeleteOne { builder := c.Delete().Where(agenttasks.ID(id)) builder.mutation.id = &id builder.mutation.op = OpDeleteOne @@ -359,12 +359,12 @@ func (c *AgentTasksClient) Query() *AgentTasksQuery { } // Get returns a AgentTasks entity by its id. -func (c *AgentTasksClient) Get(ctx context.Context, id int) (*AgentTasks, error) { +func (c *AgentTasksClient) Get(ctx context.Context, id uuid.UUID) (*AgentTasks, error) { return c.Query().Where(agenttasks.ID(id)).Only(ctx) } // GetX is like Get, but panics if an error occurs. -func (c *AgentTasksClient) GetX(ctx context.Context, id int) *AgentTasks { +func (c *AgentTasksClient) GetX(ctx context.Context, id uuid.UUID) *AgentTasks { obj, err := c.Get(ctx, id) if err != nil { panic(err) diff --git a/ent/migrate/schema.go b/ent/migrate/schema.go index f76cf35..da1c44a 100644 --- a/ent/migrate/schema.go +++ b/ent/migrate/schema.go @@ -10,7 +10,7 @@ import ( var ( // AgentTasksColumns holds the columns for the "agent_tasks" table. AgentTasksColumns = []*schema.Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "id", Type: field.TypeUUID, Unique: true}, {Name: "status", Type: field.TypeString, Default: "pending"}, {Name: "created_at", Type: field.TypeTime}, {Name: "agent_id", Type: field.TypeInt}, diff --git a/ent/mutation.go b/ent/mutation.go index 69331b6..7910624 100644 --- a/ent/mutation.go +++ b/ent/mutation.go @@ -48,7 +48,7 @@ type AgentTasksMutation struct { config op Op typ string - id *int + id *uuid.UUID status *string created_at *time.Time clearedFields map[string]struct{} @@ -81,7 +81,7 @@ func newAgentTasksMutation(c config, op Op, opts ...agenttasksOption) *AgentTask } // withAgentTasksID sets the ID field of the mutation. -func withAgentTasksID(id int) agenttasksOption { +func withAgentTasksID(id uuid.UUID) agenttasksOption { return func(m *AgentTasksMutation) { var ( err error @@ -133,13 +133,13 @@ func (m AgentTasksMutation) Tx() (*Tx, error) { // SetID sets the value of the id field. Note that this // operation is only accepted on creation of AgentTasks entities. -func (m *AgentTasksMutation) SetID(id int) { +func (m *AgentTasksMutation) SetID(id uuid.UUID) { m.id = &id } // ID returns the ID value in the mutation. Note that the ID is only available // if it was provided to the builder or after it was returned from the database. -func (m *AgentTasksMutation) ID() (id int, exists bool) { +func (m *AgentTasksMutation) ID() (id uuid.UUID, exists bool) { if m.id == nil { return } @@ -150,12 +150,12 @@ func (m *AgentTasksMutation) ID() (id int, exists bool) { // That means, if the mutation is applied within a transaction with an isolation level such // as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated // or updated by the mutation. -func (m *AgentTasksMutation) IDs(ctx context.Context) ([]int, error) { +func (m *AgentTasksMutation) IDs(ctx context.Context) ([]uuid.UUID, error) { switch { case m.op.Is(OpUpdateOne | OpDeleteOne): id, exists := m.ID() if exists { - return []int{id}, nil + return []uuid.UUID{id}, nil } fallthrough case m.op.Is(OpUpdate | OpDelete): @@ -648,8 +648,8 @@ type AgentsMutation struct { id *int status *string clearedFields map[string]struct{} - agent_tasks map[int]struct{} - removedagent_tasks map[int]struct{} + agent_tasks map[uuid.UUID]struct{} + removedagent_tasks map[uuid.UUID]struct{} clearedagent_tasks bool done bool oldValue func(context.Context) (*Agents, error) @@ -797,9 +797,9 @@ func (m *AgentsMutation) ResetStatus() { } // AddAgentTaskIDs adds the "agent_tasks" edge to the AgentTasks entity by ids. -func (m *AgentsMutation) AddAgentTaskIDs(ids ...int) { +func (m *AgentsMutation) AddAgentTaskIDs(ids ...uuid.UUID) { if m.agent_tasks == nil { - m.agent_tasks = make(map[int]struct{}) + m.agent_tasks = make(map[uuid.UUID]struct{}) } for i := range ids { m.agent_tasks[ids[i]] = struct{}{} @@ -817,9 +817,9 @@ func (m *AgentsMutation) AgentTasksCleared() bool { } // RemoveAgentTaskIDs removes the "agent_tasks" edge to the AgentTasks entity by IDs. -func (m *AgentsMutation) RemoveAgentTaskIDs(ids ...int) { +func (m *AgentsMutation) RemoveAgentTaskIDs(ids ...uuid.UUID) { if m.removedagent_tasks == nil { - m.removedagent_tasks = make(map[int]struct{}) + m.removedagent_tasks = make(map[uuid.UUID]struct{}) } for i := range ids { delete(m.agent_tasks, ids[i]) @@ -828,7 +828,7 @@ func (m *AgentsMutation) RemoveAgentTaskIDs(ids ...int) { } // RemovedAgentTasks returns the removed IDs of the "agent_tasks" edge to the AgentTasks entity. -func (m *AgentsMutation) RemovedAgentTasksIDs() (ids []int) { +func (m *AgentsMutation) RemovedAgentTasksIDs() (ids []uuid.UUID) { for id := range m.removedagent_tasks { ids = append(ids, id) } @@ -836,7 +836,7 @@ func (m *AgentsMutation) RemovedAgentTasksIDs() (ids []int) { } // AgentTasksIDs returns the "agent_tasks" edge IDs in the mutation. -func (m *AgentsMutation) AgentTasksIDs() (ids []int) { +func (m *AgentsMutation) AgentTasksIDs() (ids []uuid.UUID) { for id := range m.agent_tasks { ids = append(ids, id) } @@ -3706,8 +3706,8 @@ type ScansMutation struct { scan_labels map[int]struct{} removedscan_labels map[int]struct{} clearedscan_labels bool - agent_tasks map[int]struct{} - removedagent_tasks map[int]struct{} + agent_tasks map[uuid.UUID]struct{} + removedagent_tasks map[uuid.UUID]struct{} clearedagent_tasks bool done bool oldValue func(context.Context) (*Scans, error) @@ -4106,9 +4106,9 @@ func (m *ScansMutation) ResetScanLabels() { } // AddAgentTaskIDs adds the "agent_tasks" edge to the AgentTasks entity by ids. -func (m *ScansMutation) AddAgentTaskIDs(ids ...int) { +func (m *ScansMutation) AddAgentTaskIDs(ids ...uuid.UUID) { if m.agent_tasks == nil { - m.agent_tasks = make(map[int]struct{}) + m.agent_tasks = make(map[uuid.UUID]struct{}) } for i := range ids { m.agent_tasks[ids[i]] = struct{}{} @@ -4126,9 +4126,9 @@ func (m *ScansMutation) AgentTasksCleared() bool { } // RemoveAgentTaskIDs removes the "agent_tasks" edge to the AgentTasks entity by IDs. -func (m *ScansMutation) RemoveAgentTaskIDs(ids ...int) { +func (m *ScansMutation) RemoveAgentTaskIDs(ids ...uuid.UUID) { if m.removedagent_tasks == nil { - m.removedagent_tasks = make(map[int]struct{}) + m.removedagent_tasks = make(map[uuid.UUID]struct{}) } for i := range ids { delete(m.agent_tasks, ids[i]) @@ -4137,7 +4137,7 @@ func (m *ScansMutation) RemoveAgentTaskIDs(ids ...int) { } // RemovedAgentTasks returns the removed IDs of the "agent_tasks" edge to the AgentTasks entity. -func (m *ScansMutation) RemovedAgentTasksIDs() (ids []int) { +func (m *ScansMutation) RemovedAgentTasksIDs() (ids []uuid.UUID) { for id := range m.removedagent_tasks { ids = append(ids, id) } @@ -4145,7 +4145,7 @@ func (m *ScansMutation) RemovedAgentTasksIDs() (ids []int) { } // AgentTasksIDs returns the "agent_tasks" edge IDs in the mutation. -func (m *ScansMutation) AgentTasksIDs() (ids []int) { +func (m *ScansMutation) AgentTasksIDs() (ids []uuid.UUID) { for id := range m.agent_tasks { ids = append(ids, id) } diff --git a/ent/runtime.go b/ent/runtime.go index f1daf13..62a7064 100644 --- a/ent/runtime.go +++ b/ent/runtime.go @@ -31,6 +31,10 @@ func init() { agenttasksDescCreatedAt := agenttasksFields[4].Descriptor() // agenttasks.DefaultCreatedAt holds the default value on creation for the created_at field. agenttasks.DefaultCreatedAt = agenttasksDescCreatedAt.Default.(func() time.Time) + // agenttasksDescID is the schema descriptor for id field. + agenttasksDescID := agenttasksFields[0].Descriptor() + // agenttasks.DefaultID holds the default value on creation for the id field. + agenttasks.DefaultID = agenttasksDescID.Default.(func() uuid.UUID) agentsFields := schema.Agents{}.Fields() _ = agentsFields // agentsDescStatus is the schema descriptor for status field. diff --git a/ent/scans_create.go b/ent/scans_create.go index a030898..cbd2fd5 100644 --- a/ent/scans_create.go +++ b/ent/scans_create.go @@ -113,14 +113,14 @@ func (sc *ScansCreate) AddScanLabels(s ...*ScanLabels) *ScansCreate { } // AddAgentTaskIDs adds the "agent_tasks" edge to the AgentTasks entity by IDs. -func (sc *ScansCreate) AddAgentTaskIDs(ids ...int) *ScansCreate { +func (sc *ScansCreate) AddAgentTaskIDs(ids ...uuid.UUID) *ScansCreate { sc.mutation.AddAgentTaskIDs(ids...) return sc } // AddAgentTasks adds the "agent_tasks" edges to the AgentTasks entity. func (sc *ScansCreate) AddAgentTasks(a ...*AgentTasks) *ScansCreate { - ids := make([]int, len(a)) + ids := make([]uuid.UUID, len(a)) for i := range a { ids[i] = a[i].ID } @@ -278,7 +278,7 @@ func (sc *ScansCreate) createSpec() (*Scans, *sqlgraph.CreateSpec) { Columns: []string{scans.AgentTasksColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt), + IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID), }, } for _, k := range nodes { diff --git a/ent/scans_update.go b/ent/scans_update.go index bdc2a95..3e339b8 100644 --- a/ent/scans_update.go +++ b/ent/scans_update.go @@ -135,14 +135,14 @@ func (su *ScansUpdate) AddScanLabels(s ...*ScanLabels) *ScansUpdate { } // AddAgentTaskIDs adds the "agent_tasks" edge to the AgentTasks entity by IDs. -func (su *ScansUpdate) AddAgentTaskIDs(ids ...int) *ScansUpdate { +func (su *ScansUpdate) AddAgentTaskIDs(ids ...uuid.UUID) *ScansUpdate { su.mutation.AddAgentTaskIDs(ids...) return su } // AddAgentTasks adds the "agent_tasks" edges to the AgentTasks entity. func (su *ScansUpdate) AddAgentTasks(a ...*AgentTasks) *ScansUpdate { - ids := make([]int, len(a)) + ids := make([]uuid.UUID, len(a)) for i := range a { ids[i] = a[i].ID } @@ -188,14 +188,14 @@ func (su *ScansUpdate) ClearAgentTasks() *ScansUpdate { } // RemoveAgentTaskIDs removes the "agent_tasks" edge to AgentTasks entities by IDs. -func (su *ScansUpdate) RemoveAgentTaskIDs(ids ...int) *ScansUpdate { +func (su *ScansUpdate) RemoveAgentTaskIDs(ids ...uuid.UUID) *ScansUpdate { su.mutation.RemoveAgentTaskIDs(ids...) return su } // RemoveAgentTasks removes "agent_tasks" edges to AgentTasks entities. func (su *ScansUpdate) RemoveAgentTasks(a ...*AgentTasks) *ScansUpdate { - ids := make([]int, len(a)) + ids := make([]uuid.UUID, len(a)) for i := range a { ids[i] = a[i].ID } @@ -349,7 +349,7 @@ func (su *ScansUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{scans.AgentTasksColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt), + IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -362,7 +362,7 @@ func (su *ScansUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{scans.AgentTasksColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt), + IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID), }, } for _, k := range nodes { @@ -378,7 +378,7 @@ func (su *ScansUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{scans.AgentTasksColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt), + IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID), }, } for _, k := range nodes { @@ -509,14 +509,14 @@ func (suo *ScansUpdateOne) AddScanLabels(s ...*ScanLabels) *ScansUpdateOne { } // AddAgentTaskIDs adds the "agent_tasks" edge to the AgentTasks entity by IDs. -func (suo *ScansUpdateOne) AddAgentTaskIDs(ids ...int) *ScansUpdateOne { +func (suo *ScansUpdateOne) AddAgentTaskIDs(ids ...uuid.UUID) *ScansUpdateOne { suo.mutation.AddAgentTaskIDs(ids...) return suo } // AddAgentTasks adds the "agent_tasks" edges to the AgentTasks entity. func (suo *ScansUpdateOne) AddAgentTasks(a ...*AgentTasks) *ScansUpdateOne { - ids := make([]int, len(a)) + ids := make([]uuid.UUID, len(a)) for i := range a { ids[i] = a[i].ID } @@ -562,14 +562,14 @@ func (suo *ScansUpdateOne) ClearAgentTasks() *ScansUpdateOne { } // RemoveAgentTaskIDs removes the "agent_tasks" edge to AgentTasks entities by IDs. -func (suo *ScansUpdateOne) RemoveAgentTaskIDs(ids ...int) *ScansUpdateOne { +func (suo *ScansUpdateOne) RemoveAgentTaskIDs(ids ...uuid.UUID) *ScansUpdateOne { suo.mutation.RemoveAgentTaskIDs(ids...) return suo } // RemoveAgentTasks removes "agent_tasks" edges to AgentTasks entities. func (suo *ScansUpdateOne) RemoveAgentTasks(a ...*AgentTasks) *ScansUpdateOne { - ids := make([]int, len(a)) + ids := make([]uuid.UUID, len(a)) for i := range a { ids[i] = a[i].ID } @@ -753,7 +753,7 @@ func (suo *ScansUpdateOne) sqlSave(ctx context.Context) (_node *Scans, err error Columns: []string{scans.AgentTasksColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt), + IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -766,7 +766,7 @@ func (suo *ScansUpdateOne) sqlSave(ctx context.Context) (_node *Scans, err error Columns: []string{scans.AgentTasksColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt), + IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID), }, } for _, k := range nodes { @@ -782,7 +782,7 @@ func (suo *ScansUpdateOne) sqlSave(ctx context.Context) (_node *Scans, err error Columns: []string{scans.AgentTasksColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeInt), + IDSpec: sqlgraph.NewFieldSpec(agenttasks.FieldID, field.TypeUUID), }, } for _, k := range nodes { diff --git a/ent/schema/agenttasks.go b/ent/schema/agenttasks.go index 606239d..2b59ca6 100644 --- a/ent/schema/agenttasks.go +++ b/ent/schema/agenttasks.go @@ -16,7 +16,8 @@ type AgentTasks struct { // Fields of the AgentTasks. func (AgentTasks) Fields() []ent.Field { return []ent.Field{ - field.Int("id"). + field.UUID("id", uuid.UUID{}). + Default(uuid.New). Unique(). Comment("Primary key, unique identifier."), field.Int("agent_id"). diff --git a/internal/restapi/v1/agents/delete_test.go b/internal/restapi/v1/agents/delete_test.go new file mode 100644 index 0000000..e199daf --- /dev/null +++ b/internal/restapi/v1/agents/delete_test.go @@ -0,0 +1,42 @@ +package agents_test + +import ( + "context" + "github.com/shinobistack/gokakashi/internal/restapi/v1/agents" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/shinobistack/gokakashi/ent/enttest" + "github.com/stretchr/testify/assert" +) + +func TestDeleteAgent_Valid(t *testing.T) { + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + agent := client.Agents.Create(). + SetStatus("connected"). + SaveX(context.Background()) + + req := agents.DeleteAgentRequest{ID: agent.ID} + res := &agents.DeleteAgentResponse{} + + err := agents.DeleteAgent(client)(context.Background(), req, res) + + assert.NoError(t, err) + assert.Equal(t, agent.ID, res.ID) + assert.Equal(t, "deleted", res.Status) +} + +func TestDeleteAgent_NotFound(t *testing.T) { + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + req := agents.DeleteAgentRequest{ID: 9999} // Non-existent ID + res := &agents.DeleteAgentResponse{} + + err := agents.DeleteAgent(client)(context.Background(), req, res) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "agent not found") +} diff --git a/internal/restapi/v1/agents/get.go b/internal/restapi/v1/agents/get.go index 2752544..84fc858 100644 --- a/internal/restapi/v1/agents/get.go +++ b/internal/restapi/v1/agents/get.go @@ -11,6 +11,8 @@ type GetAgentRequest struct { ID int `path:"id"` } +type ListAgentsRequest struct{} + type GetAgentResponse struct { ID int `json:"id"` Status string `json:"status"` @@ -20,8 +22,8 @@ type ListAgentsResponse struct { Agents []GetAgentResponse `json:"agents"` } -func ListAgents(client *ent.Client) func(ctx context.Context, req interface{}, res *[]GetAgentResponse) error { - return func(ctx context.Context, req interface{}, res *[]GetAgentResponse) error { +func ListAgents(client *ent.Client) func(ctx context.Context, req ListAgentsRequest, res *[]GetAgentResponse) error { + return func(ctx context.Context, req ListAgentsRequest, res *[]GetAgentResponse) error { agentsList, err := client.Agents.Query().All(ctx) if err != nil { return status.Wrap(err, status.Internal) diff --git a/internal/restapi/v1/agents/get_test.go b/internal/restapi/v1/agents/get_test.go new file mode 100644 index 0000000..22a9990 --- /dev/null +++ b/internal/restapi/v1/agents/get_test.go @@ -0,0 +1,64 @@ +package agents_test + +import ( + "context" + "github.com/shinobistack/gokakashi/internal/restapi/v1/agents" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/shinobistack/gokakashi/ent/enttest" + "github.com/stretchr/testify/assert" +) + +func TestGetAgent_Valid(t *testing.T) { + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + agent := client.Agents.Create(). + SetStatus("connected"). + SaveX(context.Background()) + + req := agents.GetAgentRequest{ID: agent.ID} + res := &agents.GetAgentResponse{} + + err := agents.GetAgent(client)(context.Background(), req, res) + + assert.NoError(t, err) + assert.Equal(t, agent.Status, res.Status) + assert.Equal(t, agent.ID, res.ID) +} + +func TestGetAgent_NotFound(t *testing.T) { + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + req := agents.GetAgentRequest{ID: 9999} // Non-existent ID + res := &agents.GetAgentResponse{} + + err := agents.GetAgent(client)(context.Background(), req, res) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "agent not found") +} + +func TestListAgents_Valid(t *testing.T) { + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + client.Agents.Create(). + SetStatus("connected"). + SaveX(context.Background()) + + client.Agents.Create(). + SetStatus("in_progress"). + SaveX(context.Background()) + + req := agents.ListAgentsRequest{} + res := []agents.GetAgentResponse{} + + err := agents.ListAgents(client)(context.Background(), req, &res) + + assert.NoError(t, err) + assert.Equal(t, 2, len(res)) + +} diff --git a/internal/restapi/v1/agents/post_test.go b/internal/restapi/v1/agents/post_test.go new file mode 100644 index 0000000..3b2c676 --- /dev/null +++ b/internal/restapi/v1/agents/post_test.go @@ -0,0 +1,43 @@ +package agents_test + +import ( + "context" + "github.com/shinobistack/gokakashi/internal/restapi/v1/agents" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/shinobistack/gokakashi/ent/enttest" + + "github.com/stretchr/testify/assert" +) + +func TestCreateAgent_Valid(t *testing.T) { + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + req := agents.CreateAgentRequest{ + Status: "connected", + } + res := &agents.CreateAgentResponse{} + + err := agents.CreateAgent(client)(context.Background(), req, res) + + assert.NoError(t, err) + assert.Equal(t, req.Status, res.Status) + assert.NotZero(t, res.ID) +} + +//func TestCreateAgent_MissingFields(t *testing.T) { +// client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") +// defer client.Close() +// +// req := agents.CreateAgentRequest{ +// Status: "", // Missing Status +// } +// res := &agents.CreateAgentResponse{} +// +// err := agents.CreateAgent(client)(context.Background(), req, res) +// +// assert.Error(t, err) +// assert.Contains(t, err.Error(), "missing required fields") +//} diff --git a/internal/restapi/v1/agents/put_test.go b/internal/restapi/v1/agents/put_test.go new file mode 100644 index 0000000..15194e4 --- /dev/null +++ b/internal/restapi/v1/agents/put_test.go @@ -0,0 +1,47 @@ +package agents_test + +import ( + "context" + "github.com/shinobistack/gokakashi/internal/restapi/v1/agents" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/shinobistack/gokakashi/ent/enttest" + "github.com/stretchr/testify/assert" +) + +func TestUpdateAgent_Valid(t *testing.T) { + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + agent := client.Agents.Create(). + SetStatus("connected"). + SaveX(context.Background()) + + req := agents.UpdateAgentRequest{ + ID: agent.ID, + Status: "in_progress", + } + res := &agents.UpdateAgentResponse{} + + err := agents.UpdateAgent(client)(context.Background(), req, res) + + assert.NoError(t, err) + assert.Equal(t, req.Status, res.Status) +} + +func TestUpdateAgent_NotFound(t *testing.T) { + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + req := agents.UpdateAgentRequest{ + ID: 9999, // Non-existent ID + Status: "in_progress", + } + res := &agents.UpdateAgentResponse{} + + err := agents.UpdateAgent(client)(context.Background(), req, res) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "agent not found") +} diff --git a/internal/restapi/v1/agenttasks/delete.go b/internal/restapi/v1/agenttasks/delete.go index e9fb8be..8170328 100644 --- a/internal/restapi/v1/agenttasks/delete.go +++ b/internal/restapi/v1/agenttasks/delete.go @@ -3,22 +3,24 @@ package agenttasks import ( "context" "errors" + "github.com/google/uuid" "github.com/shinobistack/gokakashi/ent" "github.com/swaggest/usecase/status" ) type DeleteAgentTaskRequest struct { - ID int `path:"id"` + ID uuid.UUID `path:"id"` + AgentID int `path:"agent_id"` } type DeleteAgentTaskResponse struct { - ID int `json:"id"` - Status string `json:"status"` + ID uuid.UUID `json:"id"` + Status string `json:"status"` } func DeleteAgentTask(client *ent.Client) func(ctx context.Context, req DeleteAgentTaskRequest, res *DeleteAgentTaskResponse) error { return func(ctx context.Context, req DeleteAgentTaskRequest, res *DeleteAgentTaskResponse) error { - if req.ID <= 0 { + if req.ID == uuid.Nil { return status.Wrap(errors.New("invalid ID"), status.InvalidArgument) } diff --git a/internal/restapi/v1/agenttasks/delete_test.go b/internal/restapi/v1/agenttasks/delete_test.go new file mode 100644 index 0000000..02ebb32 --- /dev/null +++ b/internal/restapi/v1/agenttasks/delete_test.go @@ -0,0 +1,61 @@ +package agenttasks_test + +import ( + "context" + "github.com/shinobistack/gokakashi/ent/schema" + "github.com/shinobistack/gokakashi/internal/restapi/v1/agenttasks" + "testing" + + "github.com/google/uuid" + _ "github.com/mattn/go-sqlite3" + "github.com/shinobistack/gokakashi/ent/enttest" + "github.com/stretchr/testify/assert" +) + +func TestDeleteAgentTask_Valid(t *testing.T) { + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + policy := client.Policies.Create(). + SetName("to-be-deleted-test-policy"). + SetImage(schema.Image{Registry: "example-registry", Name: "example-name", Tags: []string{"v1.0"}}). + SaveX(context.Background()) + + scan := client.Scans.Create(). + SetPolicyID(policy.ID). + SetImage("example-image:latest"). + SetStatus("scan_pending"). + SaveX(context.Background()) + + agent := client.Agents.Create(). + SetStatus("connected"). + SaveX(context.Background()) + + task := client.AgentTasks.Create(). + SetAgentID(agent.ID). + SetScanID(scan.ID). + SetStatus("pending"). + SaveX(context.Background()) + + req := agenttasks.DeleteAgentTaskRequest{ID: task.ID} + res := &agenttasks.DeleteAgentTaskResponse{} + + err := agenttasks.DeleteAgentTask(client)(context.Background(), req, res) + + assert.NoError(t, err) + assert.Equal(t, task.ID, res.ID) + assert.Equal(t, "deleted", res.Status) +} + +func TestDeleteAgentTask_NotFound(t *testing.T) { + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + req := agenttasks.DeleteAgentTaskRequest{ID: uuid.New()} // Non-existent ID + res := &agenttasks.DeleteAgentTaskResponse{} + + err := agenttasks.DeleteAgentTask(client)(context.Background(), req, res) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "agent task not found") +} diff --git a/internal/restapi/v1/agenttasks/get.go b/internal/restapi/v1/agenttasks/get.go index 0d79a1c..fb1c83a 100644 --- a/internal/restapi/v1/agenttasks/get.go +++ b/internal/restapi/v1/agenttasks/get.go @@ -10,23 +10,26 @@ import ( ) type GetAgentTaskRequest struct { - ID int `path:"id"` + ID uuid.UUID `path:"id"` + AgentID int `path:"agent_id"` } type GetAgentTaskResponse struct { - ID int `json:"id"` + ID uuid.UUID `json:"id"` AgentID int `json:"agent_id"` ScanID uuid.UUID `json:"scan_id"` Status string `json:"status"` CreatedAt time.Time `json:"created_at"` } +type ListAgentTasksRequest struct { +} type ListAgentTasksResponse struct { AgentTasks []GetAgentTaskResponse `json:"agent_tasks"` } -func ListAgentTasks(client *ent.Client) func(ctx context.Context, req interface{}, res *[]GetAgentTaskResponse) error { - return func(ctx context.Context, req interface{}, res *[]GetAgentTaskResponse) error { +func ListAgentTasks(client *ent.Client) func(ctx context.Context, req ListAgentTasksRequest, res *[]GetAgentTaskResponse) error { + return func(ctx context.Context, req ListAgentTasksRequest, res *[]GetAgentTaskResponse) error { tasks, err := client.AgentTasks.Query().All(ctx) if err != nil { return status.Wrap(err, status.Internal) @@ -48,7 +51,7 @@ func ListAgentTasks(client *ent.Client) func(ctx context.Context, req interface{ func GetAgentTask(client *ent.Client) func(ctx context.Context, req GetAgentTaskRequest, res *GetAgentTaskResponse) error { return func(ctx context.Context, req GetAgentTaskRequest, res *GetAgentTaskResponse) error { - if req.ID <= 0 { + if req.ID == uuid.Nil { return status.Wrap(errors.New("invalid ID"), status.InvalidArgument) } diff --git a/internal/restapi/v1/agenttasks/get_test.go b/internal/restapi/v1/agenttasks/get_test.go new file mode 100644 index 0000000..5ebffa1 --- /dev/null +++ b/internal/restapi/v1/agenttasks/get_test.go @@ -0,0 +1,103 @@ +package agenttasks_test + +import ( + "context" + "github.com/shinobistack/gokakashi/ent/schema" + "github.com/shinobistack/gokakashi/internal/restapi/v1/agenttasks" + "testing" + + "github.com/google/uuid" + _ "github.com/mattn/go-sqlite3" + "github.com/shinobistack/gokakashi/ent/enttest" + "github.com/stretchr/testify/assert" +) + +func TestGetAgentTask_Valid(t *testing.T) { + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + policy := client.Policies.Create(). + SetName("to-be-deleted-test-policy"). + SetImage(schema.Image{Registry: "example-registry", Name: "example-name", Tags: []string{"v1.0"}}). + SaveX(context.Background()) + + scan := client.Scans.Create(). + SetPolicyID(policy.ID). + SetImage("example-image:latest"). + SetStatus("scan_pending"). + SaveX(context.Background()) + + agent := client.Agents.Create(). + SetStatus("connected"). + SaveX(context.Background()) + + task := client.AgentTasks.Create(). + SetAgentID(agent.ID). + SetScanID(scan.ID). + SetStatus("pending"). + SaveX(context.Background()) + + req := agenttasks.GetAgentTaskRequest{ID: task.ID} + res := &agenttasks.GetAgentTaskResponse{} + + err := agenttasks.GetAgentTask(client)(context.Background(), req, res) + + assert.NoError(t, err) + assert.Equal(t, task.Status, res.Status) + assert.Equal(t, task.ID, res.ID) + assert.Equal(t, task.AgentID, res.AgentID) + assert.Equal(t, task.ScanID, res.ScanID) +} + +func TestGetAgentTask_NotFound(t *testing.T) { + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + req := agenttasks.GetAgentTaskRequest{ID: uuid.New()} // Non-existent ID + res := &agenttasks.GetAgentTaskResponse{} + + err := agenttasks.GetAgentTask(client)(context.Background(), req, res) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "agent task not found") +} + +func TestListAgentTasks_Valid(t *testing.T) { + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + policy := client.Policies.Create(). + SetName("to-be-deleted-test-policy"). + SetImage(schema.Image{Registry: "example-registry", Name: "example-name", Tags: []string{"v1.0"}}). + SaveX(context.Background()) + + scan := client.Scans.Create(). + SetPolicyID(policy.ID). + SetImage("example-image:latest"). + SetStatus("scan_pending"). + SaveX(context.Background()) + + agent := client.Agents.Create(). + SetStatus("connected"). + SaveX(context.Background()) + + client.AgentTasks.Create(). + SetAgentID(agent.ID). + SetScanID(scan.ID). + SetStatus("pending"). + SaveX(context.Background()) + client.AgentTasks.Create(). + SetAgentID(agent.ID). + SetScanID(scan.ID). + SetStatus("pending"). + SaveX(context.Background()) + + req := agenttasks.ListAgentTasksRequest{} + res := []agenttasks.GetAgentTaskResponse{} + + err := agenttasks.ListAgentTasks(client)(context.Background(), req, &res) + + assert.NoError(t, err) + assert.Equal(t, 2, len(res)) + +} diff --git a/internal/restapi/v1/agenttasks/post.go b/internal/restapi/v1/agenttasks/post.go index 784ad1e..9900108 100644 --- a/internal/restapi/v1/agenttasks/post.go +++ b/internal/restapi/v1/agenttasks/post.go @@ -3,22 +3,26 @@ package agenttasks import ( "context" "errors" + "fmt" "github.com/google/uuid" "github.com/shinobistack/gokakashi/ent" + "github.com/shinobistack/gokakashi/ent/agents" + "github.com/shinobistack/gokakashi/ent/agenttasks" + "github.com/shinobistack/gokakashi/ent/scans" "github.com/swaggest/usecase/status" "time" ) type CreateAgentTaskRequest struct { - AgentID int `json:"agent_id"` + AgentID int `path:"agent_id"` ScanID uuid.UUID `json:"scan_id"` Status string `json:"status"` CreatedAt time.Time `json:"created_at,omitempty"` } type CreateAgentTaskResponse struct { - ID int `json:"id"` - Status string `json:"status"` + ID uuid.UUID `json:"id"` + Status string `json:"status"` } func CreateAgentTask(client *ent.Client) func(ctx context.Context, req CreateAgentTaskRequest, res *CreateAgentTaskResponse) error { @@ -27,6 +31,36 @@ func CreateAgentTask(client *ent.Client) func(ctx context.Context, req CreateAge return status.Wrap(errors.New("missing required fields"), status.InvalidArgument) } + // Check if the agent exists + agentExists, err := client.Agents.Query().Where(agents.ID(req.AgentID)).Exist(ctx) + if err != nil { + return status.Wrap(fmt.Errorf("failed to check agent existence: %w", err), status.Internal) + } + if !agentExists { + return status.Wrap(errors.New("agent not found"), status.NotFound) + } + + // Check if the scan exists + scanExists, err := client.Scans.Query().Where(scans.ID(req.ScanID)).Exist(ctx) + if err != nil { + return status.Wrap(fmt.Errorf("failed to check scan existence: %w", err), status.Internal) + } + if !scanExists { + return status.Wrap(errors.New("scan not found"), status.NotFound) + } + + // Ensure the same scan ID isn't already assigned to another task + existingTask, err := client.AgentTasks.Query(). + Where(agenttasks.ScanID(req.ScanID)). + First(ctx) + if err != nil && !ent.IsNotFound(err) { + return status.Wrap(fmt.Errorf("failed to check for existing tasks: %w", err), status.Internal) + } + if existingTask != nil { + return status.Wrap(errors.New("scan ID is already assigned to another agent"), status.InvalidArgument) + } + + // Create the agent task task, err := client.AgentTasks.Create(). SetAgentID(req.AgentID). SetScanID(req.ScanID). @@ -35,9 +69,14 @@ func CreateAgentTask(client *ent.Client) func(ctx context.Context, req CreateAge Save(ctx) if err != nil { - return status.Wrap(err, status.Internal) + // Handle foreign key constraint errors + if ent.IsConstraintError(err) { + return status.Wrap(errors.New("constraint violation: ensure valid agent ID and scan ID"), status.InvalidArgument) + } + return status.Wrap(fmt.Errorf("failed to create agent task: %w", err), status.Internal) } + // Populate the response res.ID = task.ID res.Status = task.Status return nil diff --git a/internal/restapi/v1/agenttasks/post_test.go b/internal/restapi/v1/agenttasks/post_test.go new file mode 100644 index 0000000..194ac88 --- /dev/null +++ b/internal/restapi/v1/agenttasks/post_test.go @@ -0,0 +1,64 @@ +package agenttasks_test + +import ( + "context" + "github.com/shinobistack/gokakashi/ent/schema" + "github.com/shinobistack/gokakashi/internal/restapi/v1/agenttasks" + "testing" + + "github.com/google/uuid" + _ "github.com/mattn/go-sqlite3" + "github.com/shinobistack/gokakashi/ent/enttest" + "github.com/stretchr/testify/assert" +) + +func TestCreateAgentTask_Valid(t *testing.T) { + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + policy := client.Policies.Create(). + SetName("to-be-deleted-test-policy"). + SetImage(schema.Image{Registry: "example-registry", Name: "example-name", Tags: []string{"v1.0"}}). + SaveX(context.Background()) + + scan := client.Scans.Create(). + SetPolicyID(policy.ID). + SetImage("example-image:latest"). + SetStatus("scan_pending"). + SaveX(context.Background()) + + agent := client.Agents.Create(). + SetStatus("connected"). + SaveX(context.Background()) + + // Create test request + req := agenttasks.CreateAgentTaskRequest{ + AgentID: agent.ID, + ScanID: scan.ID, + Status: "pending", + } + res := &agenttasks.CreateAgentTaskResponse{} + + err := agenttasks.CreateAgentTask(client)(context.Background(), req, res) + + assert.NoError(t, err) + assert.Equal(t, req.Status, res.Status) + assert.NotZero(t, res.ID) +} + +func TestCreateAgentTask_MissingFields(t *testing.T) { + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + req := agenttasks.CreateAgentTaskRequest{ + AgentID: 0, // Missing AgentID + ScanID: uuid.Nil, + Status: "", + } + res := &agenttasks.CreateAgentTaskResponse{} + + err := agenttasks.CreateAgentTask(client)(context.Background(), req, res) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing required fields") +} diff --git a/internal/restapi/v1/agenttasks/put.go b/internal/restapi/v1/agenttasks/put.go index 082598d..3d249fe 100644 --- a/internal/restapi/v1/agenttasks/put.go +++ b/internal/restapi/v1/agenttasks/put.go @@ -3,29 +3,34 @@ package agenttasks import ( "context" "errors" + "github.com/google/uuid" "github.com/shinobistack/gokakashi/ent" + "github.com/shinobistack/gokakashi/ent/agenttasks" "github.com/swaggest/usecase/status" ) type UpdateAgentTaskRequest struct { - ID int `path:"id"` - Status string `json:"status"` + ID uuid.UUID `path:"id"` + AgentID int `path:"agent_id"` + Status string `json:"status"` + // Todo: Should the created_AT be updated to time.now whenever an update call is made? } type UpdateAgentTaskResponse struct { - ID int `json:"id"` - Status string `json:"status"` + ID uuid.UUID `json:"id"` + Status string `json:"status"` } func UpdateAgentTask(client *ent.Client) func(ctx context.Context, req UpdateAgentTaskRequest, res *UpdateAgentTaskResponse) error { return func(ctx context.Context, req UpdateAgentTaskRequest, res *UpdateAgentTaskResponse) error { - if req.ID <= 0 || req.Status == "" { + if req.ID == uuid.Nil || req.Status == "" { return status.Wrap(errors.New("invalid ID or Status"), status.InvalidArgument) } - task, err := client.AgentTasks.UpdateOneID(req.ID). - SetStatus(req.Status). - Save(ctx) + // Fetch the task and validate the AgentID + task, err := client.AgentTasks.Query(). + Where(agenttasks.ID(req.ID)). + Only(ctx) if err != nil { if ent.IsNotFound(err) { return status.Wrap(errors.New("agent task not found"), status.NotFound) @@ -33,8 +38,20 @@ func UpdateAgentTask(client *ent.Client) func(ctx context.Context, req UpdateAge return status.Wrap(err, status.Internal) } + if task.AgentID != req.AgentID { + return status.Wrap(errors.New("agent ID mismatch"), status.InvalidArgument) + } + + task, err = client.AgentTasks.UpdateOneID(req.ID). + SetStatus(req.Status). + Save(ctx) + if err != nil { + return status.Wrap(err, status.Internal) + } + res.ID = task.ID res.Status = task.Status return nil + } } diff --git a/internal/restapi/v1/agenttasks/put_test.go b/internal/restapi/v1/agenttasks/put_test.go new file mode 100644 index 0000000..3eb9bac --- /dev/null +++ b/internal/restapi/v1/agenttasks/put_test.go @@ -0,0 +1,67 @@ +package agenttasks_test + +import ( + "context" + "github.com/shinobistack/gokakashi/ent/schema" + "github.com/shinobistack/gokakashi/internal/restapi/v1/agenttasks" + "testing" + + "github.com/google/uuid" + _ "github.com/mattn/go-sqlite3" + "github.com/shinobistack/gokakashi/ent/enttest" + "github.com/stretchr/testify/assert" +) + +func TestUpdateAgentTask_Valid(t *testing.T) { + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + policy := client.Policies.Create(). + SetName("to-be-deleted-test-policy"). + SetImage(schema.Image{Registry: "example-registry", Name: "example-name", Tags: []string{"v1.0"}}). + SaveX(context.Background()) + + scan := client.Scans.Create(). + SetPolicyID(policy.ID). + SetImage("example-image:latest"). + SetStatus("scan_pending"). + SaveX(context.Background()) + + agent := client.Agents.Create(). + SetStatus("connected"). + SaveX(context.Background()) + + task := client.AgentTasks.Create(). + SetAgentID(agent.ID). + SetScanID(scan.ID). + SetStatus("pending"). + SaveX(context.Background()) + + req := agenttasks.UpdateAgentTaskRequest{ + ID: task.ID, + AgentID: agent.ID, + Status: "in_progress", + } + res := &agenttasks.UpdateAgentTaskResponse{} + + err := agenttasks.UpdateAgentTask(client)(context.Background(), req, res) + + assert.NoError(t, err) + assert.Equal(t, req.Status, res.Status) +} + +func TestUpdateAgentTask_NotFound(t *testing.T) { + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + req := agenttasks.UpdateAgentTaskRequest{ + ID: uuid.New(), // Non-existent ID + Status: "in_progress", + } + res := &agenttasks.UpdateAgentTaskResponse{} + + err := agenttasks.UpdateAgentTask(client)(context.Background(), req, res) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "agent task not found") +} diff --git a/internal/restapi/v1/server.go b/internal/restapi/v1/server.go index d271b4a..d4cb80a 100644 --- a/internal/restapi/v1/server.go +++ b/internal/restapi/v1/server.go @@ -94,7 +94,7 @@ func (srv *Server) Service() *web.Service { apiV1.Delete("/agents/{id}", usecase.NewInteractor(agents1.DeleteAgent(srv.DB))) apiV1.Post("/agents/{agent_id}/tasks", usecase.NewInteractor(agenttasks1.CreateAgentTask(srv.DB))) - apiV1.Get("/agents/{agent_id}/tasks", usecase.NewInteractor(agenttasks1.ListAgentTasks(srv.DB))) + apiV1.Get("/agents/tasks", usecase.NewInteractor(agenttasks1.ListAgentTasks(srv.DB))) apiV1.Get("/agents/{agent_id}/tasks/{id}", usecase.NewInteractor(agenttasks1.GetAgentTask(srv.DB))) apiV1.Put("/agents/{agent_id}/tasks/{id}", usecase.NewInteractor(agenttasks1.UpdateAgentTask(srv.DB))) apiV1.Delete("/agents/{agent_id}/tasks/{id}", usecase.NewInteractor(agenttasks1.DeleteAgentTask(srv.DB)))