From a2bd84148ce12708a3ae1d73873234a45c9b85c7 Mon Sep 17 00:00:00 2001 From: Software Developer <7852635+dsuhinin@users.noreply.github.com> Date: Tue, 26 Nov 2024 19:08:30 +0100 Subject: [PATCH] Move `GET /mlflow/registered-models/get` endpoint. (#92) Signed-off-by: Software Developer <7852635+dsuhinin@users.noreply.github.com> --- magefiles/generate/endpoints.go | 2 +- mlflow_go/store/model_registry.py | 19 +++++++++++ pkg/contract/service/model_registry.g.go | 1 + pkg/entities/model_version.go | 31 +++++++++++++++++ pkg/entities/registered_model.go | 10 ++++++ pkg/entities/registered_model_alias.go | 18 ++++++++++ pkg/lib/model_registry.g.go | 8 +++++ pkg/model_registry/service/model_versions.go | 13 ++++++++ .../store/sql/model_versions.go | 16 ++++++--- .../store/sql/models/model_versions.go | 19 +++++++++++ .../sql/models/registered_model_aliases.go | 15 ++++++++- .../store/sql/models/registered_models.go | 33 ++++++++++++++++--- pkg/model_registry/store/store.go | 1 + pkg/server/routes/model_registry.g.go | 11 +++++++ 14 files changed, 185 insertions(+), 12 deletions(-) create mode 100644 pkg/entities/model_version.go create mode 100644 pkg/entities/registered_model_alias.go diff --git a/magefiles/generate/endpoints.go b/magefiles/generate/endpoints.go index cc761a6e..d0db3528 100644 --- a/magefiles/generate/endpoints.go +++ b/magefiles/generate/endpoints.go @@ -52,7 +52,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{ "renameRegisteredModel", "updateRegisteredModel", "deleteRegisteredModel", - // "getRegisteredModel", + "getRegisteredModel", // "searchRegisteredModels", "getLatestVersions", // "createModelVersion", diff --git a/mlflow_go/store/model_registry.py b/mlflow_go/store/model_registry.py index ead2ec66..16a8c977 100644 --- a/mlflow_go/store/model_registry.py +++ b/mlflow_go/store/model_registry.py @@ -5,6 +5,7 @@ from mlflow.protos.model_registry_pb2 import ( DeleteRegisteredModel, GetLatestVersions, + GetRegisteredModel, RenameRegisteredModel, UpdateRegisteredModel, ) @@ -60,6 +61,24 @@ def delete_registered_model(self, name): request = DeleteRegisteredModel(name=name) self.service.call_endpoint(get_lib().ModelRegistryServiceDeleteRegisteredModel, request) + def get_registered_model(self, name): + request = GetRegisteredModel(name=name) + response = self.service.call_endpoint( + get_lib().ModelRegistryServiceGetRegisteredModel, request + ) + + entity = RegisteredModel.from_proto(response.registered_model) + if entity.description == "": + entity.description = None + + # during convertion to proto, `version` value became a `string` value. + # convert it back to `int` value again to satisfy all the Python tests and related logic. + for key in entity.aliases: + if entity.aliases[key].isnumeric(): + entity.aliases[key] = int(entity.aliases[key]) + + return entity + def ModelRegistryStore(cls): return type(cls.__name__, (_ModelRegistryStore, cls), {}) diff --git a/pkg/contract/service/model_registry.g.go b/pkg/contract/service/model_registry.g.go index 42659efd..c1dc6a05 100644 --- a/pkg/contract/service/model_registry.g.go +++ b/pkg/contract/service/model_registry.g.go @@ -13,5 +13,6 @@ type ModelRegistryService interface { RenameRegisteredModel(ctx context.Context, input *protos.RenameRegisteredModel) (*protos.RenameRegisteredModel_Response, *contract.Error) UpdateRegisteredModel(ctx context.Context, input *protos.UpdateRegisteredModel) (*protos.UpdateRegisteredModel_Response, *contract.Error) DeleteRegisteredModel(ctx context.Context, input *protos.DeleteRegisteredModel) (*protos.DeleteRegisteredModel_Response, *contract.Error) + GetRegisteredModel(ctx context.Context, input *protos.GetRegisteredModel) (*protos.GetRegisteredModel_Response, *contract.Error) GetLatestVersions(ctx context.Context, input *protos.GetLatestVersions) (*protos.GetLatestVersions_Response, *contract.Error) } diff --git a/pkg/entities/model_version.go b/pkg/entities/model_version.go new file mode 100644 index 00000000..410fda3c --- /dev/null +++ b/pkg/entities/model_version.go @@ -0,0 +1,31 @@ +package entities + +import ( + "strconv" + + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +type ModelVersion struct { + Name string + Version int32 + CreationTime int64 + LastUpdatedTime int64 + Description string + UserID string + CurrentStage string + Source string + RunID string + Status string + StatusMessage string + RunLink string + StorageLocation string +} + +func (mv ModelVersion) ToProto() *protos.ModelVersion { + return &protos.ModelVersion{ + Version: utils.PtrTo(strconv.Itoa(int(mv.Version))), + CurrentStage: utils.PtrTo(mv.CurrentStage), + } +} diff --git a/pkg/entities/registered_model.go b/pkg/entities/registered_model.go index 8ac4cffa..f8ec14e4 100644 --- a/pkg/entities/registered_model.go +++ b/pkg/entities/registered_model.go @@ -8,6 +8,8 @@ import ( type RegisteredModel struct { Name string Tags []*RegisteredModelTag + Aliases []*RegisteredModelAlias + Versions []*ModelVersion Description *string CreationTime int64 LastUpdatedTime int64 @@ -26,5 +28,13 @@ func (m RegisteredModel) ToProto() *protos.RegisteredModel { registeredModel.Tags = append(registeredModel.Tags, tag.ToProto()) } + for _, alias := range m.Aliases { + registeredModel.Aliases = append(registeredModel.Aliases, alias.ToProto()) + } + + for _, version := range m.Versions { + registeredModel.LatestVersions = append(registeredModel.LatestVersions, version.ToProto()) + } + return ®isteredModel } diff --git a/pkg/entities/registered_model_alias.go b/pkg/entities/registered_model_alias.go new file mode 100644 index 00000000..dfb2680b --- /dev/null +++ b/pkg/entities/registered_model_alias.go @@ -0,0 +1,18 @@ +package entities + +import ( + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +type RegisteredModelAlias struct { + Alias string + Version string +} + +func (t RegisteredModelAlias) ToProto() *protos.RegisteredModelAlias { + return &protos.RegisteredModelAlias{ + Alias: utils.PtrTo(t.Alias), + Version: utils.PtrTo(t.Version), + } +} diff --git a/pkg/lib/model_registry.g.go b/pkg/lib/model_registry.g.go index 878f204c..9a5166e9 100644 --- a/pkg/lib/model_registry.g.go +++ b/pkg/lib/model_registry.g.go @@ -31,6 +31,14 @@ func ModelRegistryServiceDeleteRegisteredModel(serviceID int64, requestData unsa } return invokeServiceMethod(service.DeleteRegisteredModel, new(protos.DeleteRegisteredModel), requestData, requestSize, responseSize) } +//export ModelRegistryServiceGetRegisteredModel +func ModelRegistryServiceGetRegisteredModel(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { + service, err := modelRegistryServices.Get(serviceID) + if err != nil { + return makePointerFromError(err, responseSize) + } + return invokeServiceMethod(service.GetRegisteredModel, new(protos.GetRegisteredModel), requestData, requestSize, responseSize) +} //export ModelRegistryServiceGetLatestVersions func ModelRegistryServiceGetLatestVersions(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { service, err := modelRegistryServices.Get(serviceID) diff --git a/pkg/model_registry/service/model_versions.go b/pkg/model_registry/service/model_versions.go index b7e22b75..d4a89792 100644 --- a/pkg/model_registry/service/model_versions.go +++ b/pkg/model_registry/service/model_versions.go @@ -63,3 +63,16 @@ func (m *ModelRegistryService) DeleteRegisteredModel( return &protos.DeleteRegisteredModel_Response{}, nil } + +func (m *ModelRegistryService) GetRegisteredModel( + ctx context.Context, input *protos.GetRegisteredModel, +) (*protos.GetRegisteredModel_Response, *contract.Error) { + registeredModel, err := m.store.GetRegisteredModel(ctx, input.GetName()) + if err != nil { + return nil, err + } + + return &protos.GetRegisteredModel_Response{ + RegisteredModel: registeredModel.ToProto(), + }, nil +} diff --git a/pkg/model_registry/store/sql/model_versions.go b/pkg/model_registry/store/sql/model_versions.go index 2863909e..74fff363 100644 --- a/pkg/model_registry/store/sql/model_versions.go +++ b/pkg/model_registry/store/sql/model_versions.go @@ -96,7 +96,7 @@ func (m *ModelRegistrySQLStore) GetLatestVersions( return results, nil } -func (m *ModelRegistrySQLStore) GetRegisteredModelByName( +func (m *ModelRegistrySQLStore) GetRegisteredModel( ctx context.Context, name string, ) (*entities.RegisteredModel, *contract.Error) { var registeredModel models.RegisteredModel @@ -104,6 +104,12 @@ func (m *ModelRegistrySQLStore) GetRegisteredModelByName( ctx, ).Where( "name = ?", name, + ).Preload( + "Tags", + ).Preload( + "Aliases", + ).Preload( + "Versions", ).First( ®isteredModel, ).Error; err != nil { @@ -128,7 +134,7 @@ func (m *ModelRegistrySQLStore) GetRegisteredModelByName( func (m *ModelRegistrySQLStore) UpdateRegisteredModel( ctx context.Context, name, description string, ) (*entities.RegisteredModel, *contract.Error) { - registeredModel, err := m.GetRegisteredModelByName(ctx, name) + registeredModel, err := m.GetRegisteredModel(ctx, name) if err != nil { return nil, err } @@ -151,7 +157,7 @@ func (m *ModelRegistrySQLStore) UpdateRegisteredModel( func (m *ModelRegistrySQLStore) RenameRegisteredModel( ctx context.Context, name, newName string, ) (*entities.RegisteredModel, *contract.Error) { - registeredModel, err := m.GetRegisteredModelByName(ctx, name) + registeredModel, err := m.GetRegisteredModel(ctx, name) if err != nil { return nil, err } @@ -192,7 +198,7 @@ func (m *ModelRegistrySQLStore) RenameRegisteredModel( return nil, contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to rename registered model", err) } - registeredModel, err = m.GetRegisteredModelByName(ctx, newName) + registeredModel, err = m.GetRegisteredModel(ctx, newName) if err != nil { return nil, err } @@ -201,7 +207,7 @@ func (m *ModelRegistrySQLStore) RenameRegisteredModel( } func (m *ModelRegistrySQLStore) DeleteRegisteredModel(ctx context.Context, name string) *contract.Error { - registeredModel, err := m.GetRegisteredModelByName(ctx, name) + registeredModel, err := m.GetRegisteredModel(ctx, name) if err != nil { return err } diff --git a/pkg/model_registry/store/sql/models/model_versions.go b/pkg/model_registry/store/sql/models/model_versions.go index d2b373c8..49c7fda6 100644 --- a/pkg/model_registry/store/sql/models/model_versions.go +++ b/pkg/model_registry/store/sql/models/model_versions.go @@ -1,6 +1,7 @@ package models import ( + "github.com/mlflow/mlflow-go/pkg/entities" "github.com/mlflow/mlflow-go/pkg/protos" "github.com/mlflow/mlflow-go/pkg/utils" ) @@ -47,3 +48,21 @@ func (mv ModelVersion) ToProto() *protos.ModelVersion { RunLink: &mv.RunLink, } } + +func (mv ModelVersion) ToEntity() *entities.ModelVersion { + return &entities.ModelVersion{ + Name: mv.Name, + Version: mv.Version, + CreationTime: mv.CreationTime, + LastUpdatedTime: mv.LastUpdatedTime, + Description: mv.Description, + UserID: mv.UserID, + CurrentStage: mv.CurrentStage.String(), + Source: mv.Source, + RunID: mv.RunID, + Status: mv.Status, + StatusMessage: mv.StatusMessage, + RunLink: mv.RunLink, + StorageLocation: mv.StorageLocation, + } +} diff --git a/pkg/model_registry/store/sql/models/registered_model_aliases.go b/pkg/model_registry/store/sql/models/registered_model_aliases.go index 2cdf25ab..b62d335a 100644 --- a/pkg/model_registry/store/sql/models/registered_model_aliases.go +++ b/pkg/model_registry/store/sql/models/registered_model_aliases.go @@ -1,8 +1,21 @@ package models +import ( + "strconv" + + "github.com/mlflow/mlflow-go/pkg/entities" +) + // RegisteredModelAlias mapped from table . type RegisteredModelAlias struct { + Name string `db:"name" gorm:"column:name;primaryKey"` Alias string `db:"alias" gorm:"column:alias;primaryKey"` Version int32 `db:"version" gorm:"column:version;not null"` - Name string `db:"name" gorm:"column:name;primaryKey"` +} + +func (a RegisteredModelAlias) ToEntity() *entities.RegisteredModelAlias { + return &entities.RegisteredModelAlias{ + Alias: a.Alias, + Version: strconv.Itoa(int(a.Version)), + } } diff --git a/pkg/model_registry/store/sql/models/registered_models.go b/pkg/model_registry/store/sql/models/registered_models.go index 50dd3821..1895f3b8 100644 --- a/pkg/model_registry/store/sql/models/registered_models.go +++ b/pkg/model_registry/store/sql/models/registered_models.go @@ -8,17 +8,21 @@ import ( // RegisteredModel mapped from table . type RegisteredModel struct { - Name string `gorm:"column:name;primaryKey"` - Tags []RegisteredModelTag `gorm:"foreignKey:Name;references:Name"` - Description sql.NullString `gorm:"column:description"` - CreationTime int64 `gorm:"column:creation_time"` - LastUpdatedTime int64 `gorm:"column:last_updated_time"` + Name string `gorm:"column:name;primaryKey"` + Tags []RegisteredModelTag `gorm:"foreignKey:Name;references:Name"` + Aliases []RegisteredModelAlias `gorm:"foreignKey:Name;references:Name"` + Versions []ModelVersion `gorm:"foreignKey:Name;references:Name"` + Description sql.NullString `gorm:"column:description"` + CreationTime int64 `gorm:"column:creation_time"` + LastUpdatedTime int64 `gorm:"column:last_updated_time"` } func (m *RegisteredModel) ToEntity() *entities.RegisteredModel { model := entities.RegisteredModel{ Name: m.Name, Tags: make([]*entities.RegisteredModelTag, 0, len(m.Tags)), + Aliases: make([]*entities.RegisteredModelAlias, 0, len(m.Aliases)), + Versions: make([]*entities.ModelVersion, 0), CreationTime: m.CreationTime, LastUpdatedTime: m.LastUpdatedTime, } @@ -31,5 +35,24 @@ func (m *RegisteredModel) ToEntity() *entities.RegisteredModel { model.Tags = append(model.Tags, tag.ToEntity()) } + for _, alias := range m.Aliases { + model.Aliases = append(model.Aliases, alias.ToEntity()) + } + + latestVersionsByStage := map[string]*ModelVersion{} + + for _, currentVersion := range m.Versions { + stage := currentVersion.CurrentStage.String() + if stage != StageDeletedInternal { + if latestVersion, ok := latestVersionsByStage[stage]; !ok || latestVersion.Version < currentVersion.Version { + latestVersionsByStage[stage] = ¤tVersion + } + } + } + + for _, latestVersion := range latestVersionsByStage { + model.Versions = append(model.Versions, latestVersion.ToEntity()) + } + return &model } diff --git a/pkg/model_registry/store/store.go b/pkg/model_registry/store/store.go index fe19f37b..65655dcb 100644 --- a/pkg/model_registry/store/store.go +++ b/pkg/model_registry/store/store.go @@ -11,6 +11,7 @@ import ( type ModelRegistryStore interface { contract.Destroyer GetLatestVersions(ctx context.Context, name string, stages []string) ([]*protos.ModelVersion, *contract.Error) + GetRegisteredModel(ctx context.Context, name string) (*entities.RegisteredModel, *contract.Error) UpdateRegisteredModel(ctx context.Context, name, description string) (*entities.RegisteredModel, *contract.Error) RenameRegisteredModel(ctx context.Context, name, newName string) (*entities.RegisteredModel, *contract.Error) DeleteRegisteredModel(ctx context.Context, name string) *contract.Error diff --git a/pkg/server/routes/model_registry.g.go b/pkg/server/routes/model_registry.g.go index f4e6e370..1dfb0f47 100644 --- a/pkg/server/routes/model_registry.g.go +++ b/pkg/server/routes/model_registry.g.go @@ -44,6 +44,17 @@ func RegisterModelRegistryServiceRoutes(service service.ModelRegistryService, pa } return ctx.JSON(output) }) + app.Get("/mlflow/registered-models/get", func(ctx *fiber.Ctx) error { + input := &protos.GetRegisteredModel{} + if err := parser.ParseQuery(ctx, input); err != nil { + return err + } + output, err := service.GetRegisteredModel(utils.NewContextWithLoggerFromFiberContext(ctx), input) + if err != nil { + return err + } + return ctx.JSON(output) + }) app.Post("/mlflow/registered-models/get-latest-versions", func(ctx *fiber.Ctx) error { input := &protos.GetLatestVersions{} if err := parser.ParseBody(ctx, input); err != nil {