diff --git a/assetdb.go b/assetdb.go index 5daaa8b..689a450 100644 --- a/assetdb.go +++ b/assetdb.go @@ -19,11 +19,13 @@ type AssetDB struct { // New creates a new assetDB instance. // It initializes the asset database with the specified database type and DSN. -func New(dbType repository.DBType, dsn string) *AssetDB { - database := repository.New(dbType, dsn) - return &AssetDB{ - repository: database, +func New(dbtype, dsn string) *AssetDB { + if db, err := repository.New(dbtype, dsn); err == nil && db != nil { + return &AssetDB{ + repository: db, + } } + return nil } // Close will close the assetdb and return any errors. @@ -75,9 +77,9 @@ func (as *AssetDB) FindByContent(asset oam.Asset, since time.Time) ([]*types.Ent return as.repository.FindEntityByContent(asset, since) } -// FindById finds an entity in the database by the ID. +// FindEntityById finds an entity in the database by the ID. // It returns the matching entity and an error, if any. -func (as *AssetDB) FindById(id string) (*types.Entity, error) { +func (as *AssetDB) FindEntityById(id string) (*types.Entity, error) { return as.repository.FindEntityById(id) } @@ -89,10 +91,10 @@ func (as *AssetDB) FindByScope(constraints []oam.Asset, since time.Time) ([]*typ return as.repository.FindEntitiesByScope(constraints, since) } -// FindByType finds all entities in the database of the provided asset type and last seen after the since parameter. +// FindEntitiesByType finds all entities in the database of the provided asset type and last seen after the since parameter. // If since.IsZero(), the parameter will be ignored. // It returns the matching entities and an error, if any. -func (as *AssetDB) FindByType(atype oam.AssetType, since time.Time) ([]*types.Entity, error) { +func (as *AssetDB) FindEntitiesByType(atype oam.AssetType, since time.Time) ([]*types.Entity, error) { return as.repository.FindEntitiesByType(atype, since) } @@ -116,3 +118,47 @@ func (as *AssetDB) IncomingEdges(entity *types.Entity, since time.Time, labels . func (as *AssetDB) OutgoingEdges(entity *types.Entity, since time.Time, labels ...string) ([]*types.Edge, error) { return as.repository.OutgoingEdges(entity, since, labels...) } + +// CreateEntityTag creates a new entity tag in the database. +// It takes an oam.Property as input and persists it in the database. +// The entity tag is serialized to JSON and stored in the Content field of the EntityTag struct. +// Returns the created entity tag as a types.EntityTag or an error if the creation fails. +func (as *AssetDB) CreateEntityTag(entity *types.Entity, property oam.Property) (*types.EntityTag, error) { + return as.repository.CreateEntityTag(entity, property) +} + +// GetEntityTags finds all tags for the entity with the specified names and last seen after the since parameter. +// If since.IsZero(), the parameter will be ignored. +// If no names are specified, all tags for the specified entity are returned. +func (as *AssetDB) GetEntityTags(entity *types.Entity, since time.Time, names ...string) ([]*types.EntityTag, error) { + return as.repository.GetEntityTags(entity, since, names...) +} + +// DeleteEntityTag removes an entity tag in the database by its ID. +// It takes a string representing the entity tag ID and removes the corresponding tag from the database. +// Returns an error if the tag is not found. +func (as *AssetDB) DeleteEntityTag(id string) error { + return as.repository.DeleteEntityTag(id) +} + +// CreateEdgeTag creates a new edge tag in the database. +// It takes an oam.Property as input and persists it in the database. +// The edge tag is serialized to JSON and stored in the Content field of the EdgeTag struct. +// Returns the created edge tag as a types.EdgeTag or an error if the creation fails. +func (as *AssetDB) CreateEdgeTag(edge *types.Edge, property oam.Property) (*types.EdgeTag, error) { + return as.repository.CreateEdgeTag(edge, property) +} + +// GetEdgeTags finds all tags for the edge with the specified names and last seen after the since parameter. +// If since.IsZero(), the parameter will be ignored. +// If no names are specified, all tags for the specified edge are returned. +func (as *AssetDB) GetEdgeTags(edge *types.Edge, since time.Time, names ...string) ([]*types.EdgeTag, error) { + return as.repository.GetEdgeTags(edge, since, names...) +} + +// DeleteEdgeTag removes an edge tag in the database by its ID. +// It takes a string representing the edge tag ID and removes the corresponding tag from the database. +// Returns an error if the tag is not found. +func (as *AssetDB) DeleteEdgeTag(id string) error { + return as.repository.DeleteEdgeTag(id) +} diff --git a/assetdb_test.go b/assetdb_test.go index b8c4ddc..dd34d63 100644 --- a/assetdb_test.go +++ b/assetdb_test.go @@ -87,7 +87,7 @@ func TestAssetDB(t *testing.T) { } }) - t.Run("FindById", func(t *testing.T) { + t.Run("FindEntityById", func(t *testing.T) { testCases := []struct { description string id string @@ -107,7 +107,7 @@ func TestAssetDB(t *testing.T) { mockAssetDB.On("FindEntityById", tc.id).Return(tc.expected, tc.expectedError) - result, err := adb.FindById(tc.id) + result, err := adb.FindEntityById(tc.id) assert.Equal(t, tc.expected, result) assert.Equal(t, tc.expectedError, err) @@ -180,7 +180,7 @@ func TestAssetDB(t *testing.T) { } }) - t.Run("FindByType", func(t *testing.T) { + t.Run("FindEntitiesByType", func(t *testing.T) { testCases := []struct { description string atype oam.AssetType @@ -200,7 +200,7 @@ func TestAssetDB(t *testing.T) { mockAssetDB.On("FindEntitiesByType", tc.atype, start).Return(tc.expected, tc.expectedError) - result, err := adb.FindByType(tc.atype, start) + result, err := adb.FindEntitiesByType(tc.atype, start) assert.Equal(t, tc.expected, result) assert.Equal(t, tc.expectedError, err) diff --git a/go.mod b/go.mod index cf2ec55..5d089a8 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.23.1 require ( github.com/caffix/stringset v0.2.0 github.com/glebarez/sqlite v1.11.0 - github.com/owasp-amass/open-asset-model v0.11.0-dev + github.com/owasp-amass/open-asset-model v0.12.0 github.com/rubenv/sql-migrate v1.7.0 github.com/stretchr/testify v1.9.0 gorm.io/datatypes v1.2.4 diff --git a/go.sum b/go.sum index 99ed7dd..53d266a 100644 --- a/go.sum +++ b/go.sum @@ -50,8 +50,8 @@ github.com/microsoft/go-mssqldb v0.17.0 h1:Fto83dMZPnYv1Zwx5vHHxpNraeEaUlQ/hhHLg github.com/microsoft/go-mssqldb v0.17.0/go.mod h1:OkoNGhGEs8EZqchVTtochlXruEhEOaO4S0d2sB5aeGQ= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/owasp-amass/open-asset-model v0.11.0-dev h1:lPN//kGiiEkjpL8sZgiwZPJisz7Oa6x9xgYM+rWINjA= -github.com/owasp-amass/open-asset-model v0.11.0-dev/go.mod h1:DOX+SiD6PZBroSMnsILAmpf0SHi6TVpqjV4uNfBeg7g= +github.com/owasp-amass/open-asset-model v0.12.0 h1:WBf0P82ONVJErGjIdZ9jkAzXssOT87dFfY6dKM93csc= +github.com/owasp-amass/open-asset-model v0.12.0/go.mod h1:DOX+SiD6PZBroSMnsILAmpf0SHi6TVpqjV4uNfBeg7g= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/poy/onpar v1.1.2 h1:QaNrNiZx0+Nar5dLgTVp5mXkyoVFIbepjyEoGSnhbAY= diff --git a/repository/repository.go b/repository/repository.go index ce90b32..cdd104e 100644 --- a/repository/repository.go +++ b/repository/repository.go @@ -5,14 +5,17 @@ package repository import ( + "errors" + "strings" "time" + "github.com/owasp-amass/asset-db/repository/sqlrepo" "github.com/owasp-amass/asset-db/types" oam "github.com/owasp-amass/open-asset-model" ) // Repository defines the methods for interacting with the asset database. -// It provides operations for creating, retrieving, and linking assets. +// It provides operations for creating, retrieving, tagging, and linking assets. type Repository interface { GetDBType() string CreateEntity(asset oam.Asset) (*types.Entity, error) @@ -34,3 +37,11 @@ type Repository interface { DeleteEdgeTag(id string) error Close() error } + +// New creates a new instance of the asset database repository. +func New(dbtype, dsn string) (Repository, error) { + if strings.EqualFold(dbtype, sqlrepo.Postgres) || strings.EqualFold(dbtype, sqlrepo.SQLite) { + return sqlrepo.New(dbtype, dsn) + } + return nil, errors.New("unknown DB type") +} diff --git a/repository/edge.go b/repository/sqlrepo/edge.go similarity index 99% rename from repository/edge.go rename to repository/sqlrepo/edge.go index 08c74c1..c7672aa 100644 --- a/repository/edge.go +++ b/repository/sqlrepo/edge.go @@ -2,7 +2,7 @@ // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. // SPDX-License-Identifier: Apache-2.0 -package repository +package sqlrepo import ( "errors" diff --git a/repository/edge_test.go b/repository/sqlrepo/edge_test.go similarity index 99% rename from repository/edge_test.go rename to repository/sqlrepo/edge_test.go index 5cc28c0..a0e0f10 100644 --- a/repository/edge_test.go +++ b/repository/sqlrepo/edge_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. // SPDX-License-Identifier: Apache-2.0 -package repository +package sqlrepo import ( "net/netip" diff --git a/repository/entity.go b/repository/sqlrepo/entity.go similarity index 93% rename from repository/entity.go rename to repository/sqlrepo/entity.go index 5d01609..8f58516 100644 --- a/repository/entity.go +++ b/repository/sqlrepo/entity.go @@ -2,7 +2,7 @@ // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. // SPDX-License-Identifier: Apache-2.0 -package repository +package sqlrepo import ( "errors" @@ -17,45 +17,39 @@ import ( "gorm.io/gorm/logger" ) -// DBType represents the type of the database. -type DBType string - const ( - // Postgres represents the PostgreSQL database type. - Postgres DBType = "postgres" - // SQLite represents the SQLite database type. - SQLite DBType = "sqlite" + Postgres string = "postgres" + SQLite string = "sqlite" ) // sqlRepository is a repository implementation using GORM as the underlying ORM. type sqlRepository struct { db *gorm.DB - dbType DBType + dbtype string } // New creates a new instance of the asset database repository. -func New(dbType DBType, dsn string) *sqlRepository { - db, err := newDatabase(dbType, dsn) +func New(dbtype, dsn string) (*sqlRepository, error) { + db, err := newDatabase(dbtype, dsn) if err != nil { - panic(err) + return nil, err } return &sqlRepository{ db: db, - dbType: dbType, - } + dbtype: dbtype, + }, nil } // newDatabase creates a new GORM database connection based on the provided database type and data source name (dsn). -func newDatabase(dbType DBType, dsn string) (*gorm.DB, error) { - switch dbType { +func newDatabase(dbtype, dsn string) (*gorm.DB, error) { + switch dbtype { case Postgres: return postgresDatabase(dsn) case SQLite: return sqliteDatabase(dsn) - default: - panic("Unknown db type") } + return nil, errors.New("unknown DB type") } // postgresDatabase creates a new PostgreSQL database connection using the provided data source name (dsn). @@ -78,7 +72,7 @@ func (sql *sqlRepository) Close() error { // GetDBType returns the type of the database. func (sql *sqlRepository) GetDBType() string { - return string(sql.dbType) + return string(sql.dbtype) } // CreateEntity creates a new entity in the database. diff --git a/repository/entity_test.go b/repository/sqlrepo/entity_test.go similarity index 99% rename from repository/entity_test.go rename to repository/sqlrepo/entity_test.go index b992227..9ef8737 100644 --- a/repository/entity_test.go +++ b/repository/sqlrepo/entity_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. // SPDX-License-Identifier: Apache-2.0 -package repository +package sqlrepo import ( "fmt" @@ -30,7 +30,7 @@ import ( var store *sqlRepository type testSetup struct { - name DBType + name string dsn string setup func(string) (*gorm.DB, error) teardown func(string) @@ -157,7 +157,7 @@ func TestMain(m *testing.M) { panic(err) } - store = New(w.name, w.dsn) + store, _ = New(w.name, w.dsn) exitCodes[i] = m.Run() if w.teardown != nil { w.teardown(w.dsn) @@ -359,7 +359,7 @@ func TestRepository(t *testing.T) { func TestGetDBType(t *testing.T) { sql := &sqlRepository{ - dbType: "postgres", + dbtype: "postgres", } expected := "postgres" diff --git a/repository/models.go b/repository/sqlrepo/models.go similarity index 99% rename from repository/models.go rename to repository/sqlrepo/models.go index 02f974b..380d631 100644 --- a/repository/models.go +++ b/repository/sqlrepo/models.go @@ -2,7 +2,7 @@ // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. // SPDX-License-Identifier: Apache-2.0 -package repository +package sqlrepo import ( "encoding/json" diff --git a/repository/models_test.go b/repository/sqlrepo/models_test.go similarity index 99% rename from repository/models_test.go rename to repository/sqlrepo/models_test.go index 97359cd..193de45 100644 --- a/repository/models_test.go +++ b/repository/sqlrepo/models_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. // SPDX-License-Identifier: Apache-2.0 -package repository +package sqlrepo import ( "net/netip" diff --git a/repository/sql_scope.go b/repository/sqlrepo/sql_scope.go similarity index 99% rename from repository/sql_scope.go rename to repository/sqlrepo/sql_scope.go index 6320923..06d64d2 100644 --- a/repository/sql_scope.go +++ b/repository/sqlrepo/sql_scope.go @@ -2,7 +2,7 @@ // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. // SPDX-License-Identifier: Apache-2.0 -package repository +package sqlrepo import ( "errors" diff --git a/repository/tag.go b/repository/sqlrepo/tag.go similarity index 99% rename from repository/tag.go rename to repository/sqlrepo/tag.go index e06a954..fc4807b 100644 --- a/repository/tag.go +++ b/repository/sqlrepo/tag.go @@ -2,7 +2,7 @@ // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. // SPDX-License-Identifier: Apache-2.0 -package repository +package sqlrepo import ( "strconv" diff --git a/repository/tag_test.go b/repository/sqlrepo/tag_test.go similarity index 99% rename from repository/tag_test.go rename to repository/sqlrepo/tag_test.go index 4691966..d19f78f 100644 --- a/repository/tag_test.go +++ b/repository/sqlrepo/tag_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. // SPDX-License-Identifier: Apache-2.0 -package repository +package sqlrepo import ( "testing"