diff --git a/magefiles/generate/validations.go b/magefiles/generate/validations.go index a1bd28de..9a5ac7bf 100644 --- a/magefiles/generate/validations.go +++ b/magefiles/generate/validations.go @@ -48,10 +48,15 @@ var validations = map[string]string{ "Dataset_Schema": "max:1048575", "InputTag_Key": "required,max=255", "InputTag_Value": "required,max=500", + "RenameRegisteredModel_Name": "notEmpty,required", + "RenameRegisteredModel_NewName": "notEmpty,required", + "SetRegisteredModelTag_Name": "required", "SetRegisteredModelTag_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", "SetRegisteredModelTag_Value": "omitempty,max=5000", + "CreateRegisteredModel_Name": "notEmpty,required", "CreateRegisteredModel_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", "CreateRegisteredModel_Value": "omitempty,max=5000,truncate=5000", + "DeleteRegisteredModelTag_Name": "required", "DeleteRegisteredModelTag_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", "SetRegisteredModelAlias_Name": "required", "SetRegisteredModelAlias_Alias": "required,max=255,validMetricParamOrTagName,pathIsUnique", diff --git a/pkg/model_registry/service/registered_models.go b/pkg/model_registry/service/registered_models.go index 186eaf19..bfdb87b0 100644 --- a/pkg/model_registry/service/registered_models.go +++ b/pkg/model_registry/service/registered_models.go @@ -31,15 +31,7 @@ func (m *ModelRegistryService) UpdateRegisteredModel( func (m *ModelRegistryService) RenameRegisteredModel( ctx context.Context, input *protos.RenameRegisteredModel, ) (*protos.RenameRegisteredModel_Response, *contract.Error) { - newName := input.GetNewName() - if newName == "" { - return nil, contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - "Registered model name cannot be empty", - ) - } - - registeredModel, err := m.store.RenameRegisteredModel(ctx, input.GetName(), newName) + registeredModel, err := m.store.RenameRegisteredModel(ctx, input.GetName(), input.GetNewName()) if err != nil { return nil, err } @@ -75,15 +67,7 @@ func (m *ModelRegistryService) GetRegisteredModel( func (m *ModelRegistryService) SetRegisteredModelTag( ctx context.Context, input *protos.SetRegisteredModelTag, ) (*protos.SetRegisteredModelTag_Response, *contract.Error) { - name := input.GetName() - if name == "" { - return nil, contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - "Registered model name cannot be empty", - ) - } - - if err := m.store.SetRegisteredModelTag(ctx, name, input.GetKey(), input.GetValue()); err != nil { + if err := m.store.SetRegisteredModelTag(ctx, input.GetName(), input.GetKey(), input.GetValue()); err != nil { return nil, err } @@ -93,14 +77,6 @@ func (m *ModelRegistryService) SetRegisteredModelTag( func (m *ModelRegistryService) CreateRegisteredModel( ctx context.Context, input *protos.CreateRegisteredModel, ) (*protos.CreateRegisteredModel_Response, *contract.Error) { - name := input.GetName() - if name == "" { - return nil, contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - "Registered model name cannot be empty.", - ) - } - tags := make([]*entities.RegisteredModelTag, 0, len(input.GetTags())) for _, tag := range input.GetTags() { tags = append(tags, entities.NewRegisteredModelTagFromProto(tag)) @@ -119,15 +95,7 @@ func (m *ModelRegistryService) CreateRegisteredModel( func (m *ModelRegistryService) DeleteRegisteredModelTag( ctx context.Context, input *protos.DeleteRegisteredModelTag, ) (*protos.DeleteRegisteredModelTag_Response, *contract.Error) { - name := input.GetName() - if name == "" { - return nil, contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - "Registered model name cannot be empty", - ) - } - - if err := m.store.DeleteRegisteredModelTag(ctx, name, input.GetKey()); err != nil { + if err := m.store.DeleteRegisteredModelTag(ctx, input.GetName(), input.GetKey()); err != nil { return nil, err } diff --git a/pkg/protos/model_registry.pb.go b/pkg/protos/model_registry.pb.go index cacc5d3f..bebc6072 100644 --- a/pkg/protos/model_registry.pb.go +++ b/pkg/protos/model_registry.pb.go @@ -365,7 +365,7 @@ type CreateRegisteredModel struct { unknownFields protoimpl.UnknownFields // Register models under this name - Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty" query:"name" params:"name"` + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty" query:"name" params:"name" validate:"notEmpty,required"` // Additional metadata for registered model. Tags []*RegisteredModelTag `protobuf:"bytes,2,rep,name=tags" json:"tags,omitempty" query:"tags" params:"tags"` // Optional description for registered model. @@ -429,9 +429,9 @@ type RenameRegisteredModel struct { unknownFields protoimpl.UnknownFields // Registered model unique name identifier. - Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty" query:"name" params:"name"` + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty" query:"name" params:"name" validate:"notEmpty,required"` // If provided, updates the name for this “registered_model“. - NewName *string `protobuf:"bytes,2,opt,name=new_name,json=newName" json:"new_name,omitempty" query:"new_name" params:"new_name"` + NewName *string `protobuf:"bytes,2,opt,name=new_name,json=newName" json:"new_name,omitempty" query:"new_name" params:"new_name" validate:"notEmpty,required"` } func (x *RenameRegisteredModel) Reset() { @@ -1361,7 +1361,7 @@ type SetRegisteredModelTag struct { unknownFields protoimpl.UnknownFields // Unique name of the model. - Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty" query:"name" params:"name"` + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty" query:"name" params:"name" validate:"required"` // Name of the tag. Maximum size depends on storage backend. // If a tag with this name already exists, its preexisting value will be replaced by the specified `value`. // All storage backends are guaranteed to support key values up to 250 bytes in size. @@ -1502,7 +1502,7 @@ type DeleteRegisteredModelTag struct { unknownFields protoimpl.UnknownFields // Name of the registered model that the tag was logged under. - Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty" query:"name" params:"name"` + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty" query:"name" params:"name" validate:"required"` // Name of the tag. The name must be an exact match; wild-card deletion is not supported. Maximum size is 250 bytes. Key *string `protobuf:"bytes,2,opt,name=key" json:"key,omitempty" query:"key" params:"key" validate:"required,max=250,validMetricParamOrTagName,pathIsUnique"` } diff --git a/pkg/validation/validation.go b/pkg/validation/validation.go index d1611634..dd08038c 100644 --- a/pkg/validation/validation.go +++ b/pkg/validation/validation.go @@ -87,6 +87,10 @@ func pathIsClean(fl validator.FieldLevel) bool { return !(norm != valueStr || norm == "." || strings.HasPrefix(norm, "..") || strings.HasPrefix(norm, "/")) } +func notEmptyValidation(fl validator.FieldLevel) bool { + return fl.Field().String() != "" +} + func regexValidation(regex *regexp.Regexp) validator.Func { return func(fl validator.FieldLevel) bool { valueStr := fl.Field().String() @@ -146,6 +150,7 @@ func truncateFn(fieldLevel validator.FieldLevel) bool { return true } +//nolint:cyclop func NewValidator() (*validator.Validate, error) { validate := validator.New() @@ -199,6 +204,10 @@ func NewValidator() (*validator.Validate, error) { return nil, fmt.Errorf("validation registration for 'positiveNonZeroInteger' failed: %w", err) } + if err := validate.RegisterValidation("notEmpty", notEmptyValidation); err != nil { + return nil, fmt.Errorf("validation registration for 'notEmpty' failed: %w", err) + } + validate.RegisterStructValidation(validateLogBatchLimits, &protos.LogBatch{}) validate.RegisterStructValidation(validateSetTagRunIDExists, &protos.SetTag{}) @@ -299,10 +308,10 @@ func NewErrorFromValidationError(err error) *contract.Error { value := dereference(err.Value()) switch tag { - case "required": + case "notEmpty", "required": validationErrors = append( validationErrors, - fmt.Sprintf("Missing value for required parameter '%s'", field), + fmt.Sprintf("Missing value for required parameter '%s'.", field), ) case "truncate": validationErrors = append(validationErrors, mkTruncateValidationError(field, value, err))