diff --git a/pkg/common/filters.go b/pkg/common/filters.go index 58ca7916..bfb5f950 100644 --- a/pkg/common/filters.go +++ b/pkg/common/filters.go @@ -21,5 +21,6 @@ type ComparisonOperator int const ( Equal ComparisonOperator = iota + IsNull // Add more operators as needed, ie., gte, lte ) diff --git a/pkg/repositories/factory.go b/pkg/repositories/factory.go index 535b091d..b5e062d1 100644 --- a/pkg/repositories/factory.go +++ b/pkg/repositories/factory.go @@ -35,6 +35,7 @@ func GetRepository(repoType RepoConfig, dbConfig config.DbConfig, scope promutil if err != nil { panic(err) } + db.LogMode(true) return NewPostgresRepo( db, errors.NewPostgresErrorTransformer(), diff --git a/pkg/repositories/gormimpl/artifact_test.go b/pkg/repositories/gormimpl/artifact_test.go index e4155d57..f77ba0e9 100644 --- a/pkg/repositories/gormimpl/artifact_test.go +++ b/pkg/repositories/gormimpl/artifact_test.go @@ -181,7 +181,7 @@ func TestGetArtifact(t *testing.T) { GlobalMock.NewMock().WithQuery( `SELECT * FROM "partitions" WHERE "partitions"."deleted_at" IS NULL AND (("artifact_id" IN (123))) ORDER BY partitions.created_at ASC,"partitions"."dataset_uuid" ASC`).WithReply(expectedPartitionResponse) GlobalMock.NewMock().WithQuery( - `SELECT * FROM "tags" WHERE "tags"."deleted_at" IS NULL AND ((("artifact_id","dataset_uuid") IN ((123,test-uuid)))) ORDER BY "tags"."dataset_project" ASC`).WithReply(expectedTagResponse) + `SELECT * FROM "tags" WHERE "tags"."deleted_at" IS NULL AND ((("artifact_id","dataset_uuid") IN ((123,test-uuid)))) ORDER BY "tags"."tag_name" ASC`).WithReply(expectedTagResponse) getInput := models.ArtifactKey{ DatasetProject: artifact.DatasetProject, DatasetDomain: artifact.DatasetDomain, diff --git a/pkg/repositories/gormimpl/filter.go b/pkg/repositories/gormimpl/filter.go index f636e3d1..c5309606 100644 --- a/pkg/repositories/gormimpl/filter.go +++ b/pkg/repositories/gormimpl/filter.go @@ -10,7 +10,8 @@ import ( // String formats for various GORM expression queries const ( - equalQuery = "%s.%s = ?" + equalQuery = "%s.%s = ?" + isNullQuery = "%s.%s IS NULL" ) type gormValueFilterImpl struct { @@ -27,7 +28,13 @@ func (g *gormValueFilterImpl) GetDBQueryExpression(tableName string) (models.DBQ Query: fmt.Sprintf(equalQuery, tableName, g.field), Args: g.value, }, nil + case common.IsNull: + return models.DBQueryExpr{ + Query: fmt.Sprintf(isNullQuery, tableName, g.field), + Args: g.value, + }, nil } + return models.DBQueryExpr{}, errors.GetUnsupportedFilterExpressionErr(g.comparisonOperator) } @@ -39,3 +46,10 @@ func NewGormValueFilter(comparisonOperator common.ComparisonOperator, field stri value: value, } } + +func NewGormNullFilter(field string) models.ModelValueFilter { + return &gormValueFilterImpl{ + comparisonOperator: common.IsNull, + field: field, + } +} diff --git a/pkg/repositories/gormimpl/list.go b/pkg/repositories/gormimpl/list.go index 7e461df5..6cdcedf7 100644 --- a/pkg/repositories/gormimpl/list.go +++ b/pkg/repositories/gormimpl/list.go @@ -56,7 +56,12 @@ func applyListModelsInput(tx *gorm.DB, sourceEntity common.Entity, in models.Lis if err != nil { return nil, err } - tx = tx.Where(dbQueryExpr.Query, dbQueryExpr.Args) + + if dbQueryExpr.Args == nil { + tx = tx.Where(dbQueryExpr.Query) + } else { + tx = tx.Where(dbQueryExpr.Query, dbQueryExpr.Args) + } } } diff --git a/pkg/repositories/gormimpl/tag.go b/pkg/repositories/gormimpl/tag.go index 2dd9c920..34be53b6 100644 --- a/pkg/repositories/gormimpl/tag.go +++ b/pkg/repositories/gormimpl/tag.go @@ -4,10 +4,12 @@ import ( "context" "github.com/jinzhu/gorm" + "github.com/lyft/datacatalog/pkg/common" "github.com/lyft/datacatalog/pkg/repositories/errors" "github.com/lyft/datacatalog/pkg/repositories/interfaces" "github.com/lyft/datacatalog/pkg/repositories/models" idl_datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flytestdlib/logger" "github.com/lyft/flytestdlib/promutils" ) @@ -25,14 +27,115 @@ func NewTagRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, scope pro } } +// A tag is associated with a single artifact for each partition combination +// When creating a tag, we remove the tag from any artifacts of the same partition +// Then add the tag to the new artifact func (h *tagRepo) Create(ctx context.Context, tag models.Tag) error { timer := h.repoMetrics.CreateDuration.Start(ctx) defer timer.Stop() - db := h.db.Create(&tag) + // There are several steps that need to be done in a transaction in order for tag stealing to occur + tx := h.db.Begin() + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() - if db.Error != nil { - return h.errorTransformer.ToDataCatalogError(db.Error) + // 1. Find the set of partitions this artifact belongs to + var artifactToTag models.Artifact + tx.Preload("Partitions").Find(&artifactToTag, models.Artifact{ + ArtifactKey: models.ArtifactKey{ArtifactID: tag.ArtifactID}, + }) + + // 2. List artifacts in the partitions that are currently tagged + modelFilters := make([]models.ModelFilter, 0, len(artifactToTag.Partitions)+2) + for _, partition := range artifactToTag.Partitions { + modelFilters = append(modelFilters, models.ModelFilter{ + Entity: common.Partition, + ValueFilters: []models.ModelValueFilter{ + NewGormValueFilter(common.Equal, "key", partition.Key), + NewGormValueFilter(common.Equal, "value", partition.Value), + }, + JoinCondition: NewGormJoinCondition(common.Artifact, common.Partition), + }) + } + + modelFilters = append(modelFilters, models.ModelFilter{ + Entity: common.Tag, + ValueFilters: []models.ModelValueFilter{ + NewGormValueFilter(common.Equal, "tag_name", tag.TagName), + NewGormNullFilter("deleted_at"), + }, + JoinCondition: NewGormJoinCondition(common.Artifact, common.Tag), + }) + + listTaggedInput := models.ListModelsInput{ + ModelFilters: modelFilters, + Limit: 100, + } + + listArtifactsScope, err := applyListModelsInput(tx, common.Artifact, listTaggedInput) + if err != nil { + logger.Errorf(ctx, "Unable to construct artiact list, rolling back, tag: [%v], err [%v]", tag, tx.Error) + tx.Rollback() + return h.errorTransformer.ToDataCatalogError(err) + + } + + var artifacts []models.Artifact + if err := listArtifactsScope.Find(&artifacts).Error; err != nil { + logger.Errorf(ctx, "Unable to find previously tagged artifacts, rolling back, tag: [%v], err [%v]", tag, err) + tx.Rollback() + return h.errorTransformer.ToDataCatalogError(err) + } + + // 3. Remove the tags from the currently tagged artifacts + if len(artifacts) != 0 { + // Soft-delete the existing tags on the artifacts that are currently tagged + for _, artifact := range artifacts { + + // if the artifact to tag is already tagged, no need to remove it + if artifactToTag.ArtifactID != artifact.ArtifactID { + oldTag := models.Tag{ + TagKey: models.TagKey{TagName: tag.TagName}, + ArtifactID: artifact.ArtifactID, + DatasetUUID: artifact.DatasetUUID, + } + deleteScope := tx.NewScope(&models.Tag{}).DB().Delete(&models.Tag{}, oldTag) + if deleteScope.Error != nil { + logger.Errorf(ctx, "Unable to delete previously tagged artifacts, rolling back, tag: [%v], err [%v]", tag, deleteScope.Error) + tx.Rollback() + return h.errorTransformer.ToDataCatalogError(deleteScope.Error) + } + } + } + } + + // 4. If the artifact was ever previously tagged with this tag, we need to + // un-delete the record because we cannot tag the artifact again since + // the primary keys are the same. + undeleteScope := tx.Unscoped().Model(&tag).Update("deleted_at", gorm.Expr("NULL")) // unscope will ignore deletedAt + if undeleteScope.Error != nil { + logger.Errorf(ctx, "Unable to undelete tag tag, rolling back, tag: [%v], err [%v]", tag, tx.Error) + tx.Rollback() + return h.errorTransformer.ToDataCatalogError(tx.Error) + } + + // 5. Tag the new artifact, if it didn't previously exist + if undeleteScope.RowsAffected == 0 { + if err := tx.Create(&tag).Error; err != nil { + logger.Errorf(ctx, "Unable to create tag, rolling back, tag: [%v], err [%v]", tag, err) + tx.Rollback() + return h.errorTransformer.ToDataCatalogError(err) + } + } + + tx = tx.Commit() + if tx.Error != nil { + logger.Errorf(ctx, "Unable to commit transaction, rolling back, tag: [%v], err [%v]", tag, tx.Error) + tx.Rollback() + return h.errorTransformer.ToDataCatalogError(tx.Error) } return nil } diff --git a/pkg/repositories/gormimpl/tag_test.go b/pkg/repositories/gormimpl/tag_test.go index 41f172e6..442b4b8d 100644 --- a/pkg/repositories/gormimpl/tag_test.go +++ b/pkg/repositories/gormimpl/tag_test.go @@ -44,12 +44,23 @@ func getTestTag() models.Tag { } } -func TestCreateTag(t *testing.T) { +func TestCreateTagNew(t *testing.T) { tagCreated := false GlobalMock := mocket.Catcher.Reset() GlobalMock.Logging = true + newArtifact := getTestArtifact() + // Only match on queries that append expected filters + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "artifacts" WHERE "artifacts"."deleted_at" IS NULL AND (("artifacts"."artifact_id" = 123))`).WithReply(getDBArtifactResponse(newArtifact)) + + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "partitions" WHERE "partitions"."deleted_at" IS NULL AND (("artifact_id" IN (123)))`).WithReply(getDBPartitionResponse(newArtifact)) + + GlobalMock.NewMock().WithQuery( + `SELECT "artifacts".* FROM "artifacts" JOIN partitions partitions0 ON artifacts.artifact_id = partitions0.artifact_id JOIN tags tags1 ON artifacts.artifact_id = tags1.artifact_id WHERE "artifacts"."deleted_at" IS NULL AND ((partitions0.key = region) AND (partitions0.value = SEA) AND (tags1.tag_name = test-tagname) AND (tags1.deleted_at IS NULL)) LIMIT 100 OFFSET 0`).WithReply([]map[string]interface{}{}) + GlobalMock.NewMock().WithQuery( `INSERT INTO "tags" ("created_at","updated_at","deleted_at","dataset_project","dataset_name","dataset_domain","dataset_version","tag_name","artifact_id","dataset_uuid") VALUES (?,?,?,?,?,?,?,?,?,?)`).WithCallback( func(s string, values []driver.NamedValue) { @@ -57,8 +68,48 @@ func TestCreateTag(t *testing.T) { }, ) + newTag := getTestTag() + newTag.ArtifactID = newArtifact.ArtifactID + tagRepo := NewTagRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope()) - err := tagRepo.Create(context.Background(), getTestTag()) + err := tagRepo.Create(context.Background(), newTag) + + assert.NoError(t, err) + assert.True(t, tagCreated) +} + +func TestStealOldTag(t *testing.T) { + tagCreated := false + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + oldArtifact := getTestArtifact() + newArtifact := getTestArtifact() + newArtifact.ArtifactID = "111" + + // Only match on queries that append expected filters + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "artifacts" WHERE "artifacts"."deleted_at" IS NULL AND (("artifacts"."artifact_id" = 111))`).WithReply(getDBArtifactResponse(newArtifact)) + + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "partitions" WHERE "partitions"."deleted_at" IS NULL AND (("artifact_id" IN (111)))`).WithReply(getDBPartitionResponse(newArtifact)) + + GlobalMock.NewMock().WithQuery( + `SELECT "artifacts".* FROM "artifacts" JOIN partitions partitions0 ON artifacts.artifact_id = partitions0.artifact_id JOIN tags tags1 ON artifacts.artifact_id = tags1.artifact_id WHERE "artifacts"."deleted_at" IS NULL AND ((partitions0.key = region) AND (partitions0.value = SEA) AND (tags1.tag_name = test-tagname) AND (tags1.deleted_at IS NULL)) LIMIT 100 OFFSET 0`).WithReply(getDBArtifactResponse(oldArtifact)) + + GlobalMock.NewMock().WithQuery( + `INSERT INTO "tags" ("created_at","updated_at","deleted_at","dataset_project","dataset_name","dataset_domain","dataset_version","tag_name","artifact_id","dataset_uuid") VALUES (?,?,?,?,?,?,?,?,?,?)`).WithCallback( + func(s string, values []driver.NamedValue) { + tagCreated = true + }, + ) + + newTag := getTestTag() + newTag.ArtifactID = newArtifact.ArtifactID + + tagRepo := NewTagRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope()) + err := tagRepo.Create(context.Background(), newTag) + assert.NoError(t, err) assert.True(t, tagCreated) } @@ -71,7 +122,7 @@ func TestGetTag(t *testing.T) { // Only match on queries that append expected filters GlobalMock.NewMock().WithQuery( - `SELECT * FROM "tags" WHERE "tags"."deleted_at" IS NULL AND (("tags"."dataset_project" = testProject) AND ("tags"."dataset_name" = testName) AND ("tags"."dataset_domain" = testDomain) AND ("tags"."dataset_version" = testVersion) AND ("tags"."tag_name" = test-tag)) ORDER BY tags.created_at DESC,"tags"."dataset_project" ASC LIMIT 1`).WithReply(getDBTagResponse(artifact)) + `SELECT * FROM "tags" WHERE "tags"."deleted_at" IS NULL AND (("tags"."dataset_project" = testProject) AND ("tags"."dataset_name" = testName) AND ("tags"."dataset_domain" = testDomain) AND ("tags"."dataset_version" = testVersion) AND ("tags"."tag_name" = test-tag)) ORDER BY tags.created_at DESC,"tags"."tag_name" ASC LIMIT 1`).WithReply(getDBTagResponse(artifact)) GlobalMock.NewMock().WithQuery( `SELECT * FROM "artifacts" WHERE "artifacts"."deleted_at" IS NULL AND ((("dataset_project","dataset_name","dataset_domain","dataset_version","artifact_id") IN ((testProject,testName,testDomain,testVersion,123))))`).WithReply(getDBArtifactResponse(artifact)) GlobalMock.NewMock().WithQuery( diff --git a/pkg/repositories/models/tag.go b/pkg/repositories/models/tag.go index 057f78c3..a0ce3f2f 100644 --- a/pkg/repositories/models/tag.go +++ b/pkg/repositories/models/tag.go @@ -1,17 +1,17 @@ package models type TagKey struct { - DatasetProject string `gorm:"primary_key"` - DatasetName string `gorm:"primary_key"` - DatasetDomain string `gorm:"primary_key"` - DatasetVersion string `gorm:"primary_key"` + DatasetProject string + DatasetName string + DatasetDomain string + DatasetVersion string TagName string `gorm:"primary_key"` } type Tag struct { BaseModel TagKey - ArtifactID string + ArtifactID string `gorm:"primary_key"` DatasetUUID string `gorm:"type:uuid;index:tags_dataset_uuid_idx"` Artifact Artifact `gorm:"association_foreignkey:DatasetProject,DatasetName,DatasetDomain,DatasetVersion,ArtifactID;foreignkey:DatasetProject,DatasetName,DatasetDomain,DatasetVersion,ArtifactID"` }