Skip to content

Commit

Permalink
Move GET /mlflow/registered-models/get endpoint. (#92)
Browse files Browse the repository at this point in the history
Signed-off-by: Software Developer <7852635+dsuhinin@users.noreply.github.com>
  • Loading branch information
dsuhinin authored Nov 26, 2024
1 parent 0cf60de commit a2bd841
Show file tree
Hide file tree
Showing 14 changed files with 185 additions and 12 deletions.
2 changes: 1 addition & 1 deletion magefiles/generate/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{
"renameRegisteredModel",
"updateRegisteredModel",
"deleteRegisteredModel",
// "getRegisteredModel",
"getRegisteredModel",
// "searchRegisteredModels",
"getLatestVersions",
// "createModelVersion",
Expand Down
19 changes: 19 additions & 0 deletions mlflow_go/store/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from mlflow.protos.model_registry_pb2 import (
DeleteRegisteredModel,
GetLatestVersions,
GetRegisteredModel,
RenameRegisteredModel,
UpdateRegisteredModel,
)
Expand Down Expand Up @@ -60,6 +61,24 @@ def delete_registered_model(self, name):
request = DeleteRegisteredModel(name=name)
self.service.call_endpoint(get_lib().ModelRegistryServiceDeleteRegisteredModel, request)

def get_registered_model(self, name):
request = GetRegisteredModel(name=name)
response = self.service.call_endpoint(
get_lib().ModelRegistryServiceGetRegisteredModel, request
)

entity = RegisteredModel.from_proto(response.registered_model)
if entity.description == "":
entity.description = None

# during convertion to proto, `version` value became a `string` value.
# convert it back to `int` value again to satisfy all the Python tests and related logic.
for key in entity.aliases:
if entity.aliases[key].isnumeric():
entity.aliases[key] = int(entity.aliases[key])

return entity


def ModelRegistryStore(cls):
return type(cls.__name__, (_ModelRegistryStore, cls), {})
Expand Down
1 change: 1 addition & 0 deletions pkg/contract/service/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 31 additions & 0 deletions pkg/entities/model_version.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package entities

import (
"strconv"

"github.com/mlflow/mlflow-go/pkg/protos"
"github.com/mlflow/mlflow-go/pkg/utils"
)

type ModelVersion struct {
Name string
Version int32
CreationTime int64
LastUpdatedTime int64
Description string
UserID string
CurrentStage string
Source string
RunID string
Status string
StatusMessage string
RunLink string
StorageLocation string
}

func (mv ModelVersion) ToProto() *protos.ModelVersion {
return &protos.ModelVersion{
Version: utils.PtrTo(strconv.Itoa(int(mv.Version))),
CurrentStage: utils.PtrTo(mv.CurrentStage),
}
}
10 changes: 10 additions & 0 deletions pkg/entities/registered_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
type RegisteredModel struct {
Name string
Tags []*RegisteredModelTag
Aliases []*RegisteredModelAlias
Versions []*ModelVersion
Description *string
CreationTime int64
LastUpdatedTime int64
Expand All @@ -26,5 +28,13 @@ func (m RegisteredModel) ToProto() *protos.RegisteredModel {
registeredModel.Tags = append(registeredModel.Tags, tag.ToProto())
}

for _, alias := range m.Aliases {
registeredModel.Aliases = append(registeredModel.Aliases, alias.ToProto())
}

for _, version := range m.Versions {
registeredModel.LatestVersions = append(registeredModel.LatestVersions, version.ToProto())
}

return &registeredModel
}
18 changes: 18 additions & 0 deletions pkg/entities/registered_model_alias.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package entities

import (
"github.com/mlflow/mlflow-go/pkg/protos"
"github.com/mlflow/mlflow-go/pkg/utils"
)

type RegisteredModelAlias struct {
Alias string
Version string
}

func (t RegisteredModelAlias) ToProto() *protos.RegisteredModelAlias {
return &protos.RegisteredModelAlias{
Alias: utils.PtrTo(t.Alias),
Version: utils.PtrTo(t.Version),
}
}
8 changes: 8 additions & 0 deletions pkg/lib/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions pkg/model_registry/service/model_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,16 @@ func (m *ModelRegistryService) DeleteRegisteredModel(

return &protos.DeleteRegisteredModel_Response{}, nil
}

func (m *ModelRegistryService) GetRegisteredModel(
ctx context.Context, input *protos.GetRegisteredModel,
) (*protos.GetRegisteredModel_Response, *contract.Error) {
registeredModel, err := m.store.GetRegisteredModel(ctx, input.GetName())
if err != nil {
return nil, err
}

return &protos.GetRegisteredModel_Response{
RegisteredModel: registeredModel.ToProto(),
}, nil
}
16 changes: 11 additions & 5 deletions pkg/model_registry/store/sql/model_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,20 @@ func (m *ModelRegistrySQLStore) GetLatestVersions(
return results, nil
}

func (m *ModelRegistrySQLStore) GetRegisteredModelByName(
func (m *ModelRegistrySQLStore) GetRegisteredModel(
ctx context.Context, name string,
) (*entities.RegisteredModel, *contract.Error) {
var registeredModel models.RegisteredModel
if err := m.db.WithContext(
ctx,
).Where(
"name = ?", name,
).Preload(
"Tags",
).Preload(
"Aliases",
).Preload(
"Versions",
).First(
&registeredModel,
).Error; err != nil {
Expand All @@ -128,7 +134,7 @@ func (m *ModelRegistrySQLStore) GetRegisteredModelByName(
func (m *ModelRegistrySQLStore) UpdateRegisteredModel(
ctx context.Context, name, description string,
) (*entities.RegisteredModel, *contract.Error) {
registeredModel, err := m.GetRegisteredModelByName(ctx, name)
registeredModel, err := m.GetRegisteredModel(ctx, name)
if err != nil {
return nil, err
}
Expand All @@ -151,7 +157,7 @@ func (m *ModelRegistrySQLStore) UpdateRegisteredModel(
func (m *ModelRegistrySQLStore) RenameRegisteredModel(
ctx context.Context, name, newName string,
) (*entities.RegisteredModel, *contract.Error) {
registeredModel, err := m.GetRegisteredModelByName(ctx, name)
registeredModel, err := m.GetRegisteredModel(ctx, name)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -192,7 +198,7 @@ func (m *ModelRegistrySQLStore) RenameRegisteredModel(
return nil, contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to rename registered model", err)
}

registeredModel, err = m.GetRegisteredModelByName(ctx, newName)
registeredModel, err = m.GetRegisteredModel(ctx, newName)
if err != nil {
return nil, err
}
Expand All @@ -201,7 +207,7 @@ func (m *ModelRegistrySQLStore) RenameRegisteredModel(
}

func (m *ModelRegistrySQLStore) DeleteRegisteredModel(ctx context.Context, name string) *contract.Error {
registeredModel, err := m.GetRegisteredModelByName(ctx, name)
registeredModel, err := m.GetRegisteredModel(ctx, name)
if err != nil {
return err
}
Expand Down
19 changes: 19 additions & 0 deletions pkg/model_registry/store/sql/models/model_versions.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package models

import (
"github.com/mlflow/mlflow-go/pkg/entities"
"github.com/mlflow/mlflow-go/pkg/protos"
"github.com/mlflow/mlflow-go/pkg/utils"
)
Expand Down Expand Up @@ -47,3 +48,21 @@ func (mv ModelVersion) ToProto() *protos.ModelVersion {
RunLink: &mv.RunLink,
}
}

func (mv ModelVersion) ToEntity() *entities.ModelVersion {
return &entities.ModelVersion{
Name: mv.Name,
Version: mv.Version,
CreationTime: mv.CreationTime,
LastUpdatedTime: mv.LastUpdatedTime,
Description: mv.Description,
UserID: mv.UserID,
CurrentStage: mv.CurrentStage.String(),
Source: mv.Source,
RunID: mv.RunID,
Status: mv.Status,
StatusMessage: mv.StatusMessage,
RunLink: mv.RunLink,
StorageLocation: mv.StorageLocation,
}
}
15 changes: 14 additions & 1 deletion pkg/model_registry/store/sql/models/registered_model_aliases.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
package models

import (
"strconv"

"github.com/mlflow/mlflow-go/pkg/entities"
)

// RegisteredModelAlias mapped from table <registered_model_aliases>.
type RegisteredModelAlias struct {
Name string `db:"name" gorm:"column:name;primaryKey"`
Alias string `db:"alias" gorm:"column:alias;primaryKey"`
Version int32 `db:"version" gorm:"column:version;not null"`
Name string `db:"name" gorm:"column:name;primaryKey"`
}

func (a RegisteredModelAlias) ToEntity() *entities.RegisteredModelAlias {
return &entities.RegisteredModelAlias{
Alias: a.Alias,
Version: strconv.Itoa(int(a.Version)),
}
}
33 changes: 28 additions & 5 deletions pkg/model_registry/store/sql/models/registered_models.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,21 @@ import (

// RegisteredModel mapped from table <registered_models>.
type RegisteredModel struct {
Name string `gorm:"column:name;primaryKey"`
Tags []RegisteredModelTag `gorm:"foreignKey:Name;references:Name"`
Description sql.NullString `gorm:"column:description"`
CreationTime int64 `gorm:"column:creation_time"`
LastUpdatedTime int64 `gorm:"column:last_updated_time"`
Name string `gorm:"column:name;primaryKey"`
Tags []RegisteredModelTag `gorm:"foreignKey:Name;references:Name"`
Aliases []RegisteredModelAlias `gorm:"foreignKey:Name;references:Name"`
Versions []ModelVersion `gorm:"foreignKey:Name;references:Name"`
Description sql.NullString `gorm:"column:description"`
CreationTime int64 `gorm:"column:creation_time"`
LastUpdatedTime int64 `gorm:"column:last_updated_time"`
}

func (m *RegisteredModel) ToEntity() *entities.RegisteredModel {
model := entities.RegisteredModel{
Name: m.Name,
Tags: make([]*entities.RegisteredModelTag, 0, len(m.Tags)),
Aliases: make([]*entities.RegisteredModelAlias, 0, len(m.Aliases)),
Versions: make([]*entities.ModelVersion, 0),
CreationTime: m.CreationTime,
LastUpdatedTime: m.LastUpdatedTime,
}
Expand All @@ -31,5 +35,24 @@ func (m *RegisteredModel) ToEntity() *entities.RegisteredModel {
model.Tags = append(model.Tags, tag.ToEntity())
}

for _, alias := range m.Aliases {
model.Aliases = append(model.Aliases, alias.ToEntity())
}

latestVersionsByStage := map[string]*ModelVersion{}

for _, currentVersion := range m.Versions {
stage := currentVersion.CurrentStage.String()
if stage != StageDeletedInternal {
if latestVersion, ok := latestVersionsByStage[stage]; !ok || latestVersion.Version < currentVersion.Version {
latestVersionsByStage[stage] = &currentVersion
}
}
}

for _, latestVersion := range latestVersionsByStage {
model.Versions = append(model.Versions, latestVersion.ToEntity())
}

return &model
}
1 change: 1 addition & 0 deletions pkg/model_registry/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
type ModelRegistryStore interface {
contract.Destroyer
GetLatestVersions(ctx context.Context, name string, stages []string) ([]*protos.ModelVersion, *contract.Error)
GetRegisteredModel(ctx context.Context, name string) (*entities.RegisteredModel, *contract.Error)
UpdateRegisteredModel(ctx context.Context, name, description string) (*entities.RegisteredModel, *contract.Error)
RenameRegisteredModel(ctx context.Context, name, newName string) (*entities.RegisteredModel, *contract.Error)
DeleteRegisteredModel(ctx context.Context, name string) *contract.Error
Expand Down
11 changes: 11 additions & 0 deletions pkg/server/routes/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit a2bd841

Please sign in to comment.