Skip to content

Commit

Permalink
Truncate in validation
Browse files Browse the repository at this point in the history
jgiannuzzi/mlflow-go#24

Signed-off-by: nojaf <florian.verdonck@outlook.com>
  • Loading branch information
nojaf authored and jgiannuzzi committed Oct 1, 2024
1 parent acee863 commit 67c2b94
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 83 deletions.
2 changes: 1 addition & 1 deletion .mlflow.ref
Original file line number Diff line number Diff line change
@@ -1 +1 @@
https://github.com/nojaf/mlflow.git#tweak_validate_max_results_param-error
https://github.com/mlflow/mlflow.git#master
18 changes: 8 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,30 +276,28 @@ pre-commit run golangci-lint --all-files
The following Python tests are currently failing:

```
======================================================================= short test summary info ========================================================================
============================================================================================================= short test summary info ==============================================================================================================
FAILED .mlflow.repo/tests/tracking/test_rest_tracking.py::test_log_metrics_params_tags[sqlalchemy] - mlflow.exceptions.RestException: INVALID_PARAMETER_VALUE: Invalid value "NaN" for parameter 'value' supplied
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_delete_restore_experiment_with_runs - mlflow.exceptions.MlflowException: assert 1725699783457 is None
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_delete_restore_experiment_with_runs - mlflow.exceptions.MlflowException: assert 1725952832087 is None
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_create_experiments - Failed: DID NOT RAISE <class 'mlflow.exceptions.MlflowException'>
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_log_metric_concurrent_logging_succeeds - mlflow.exceptions.MlflowException: error creating metrics in batch for run_uuid "36d0d0457d35466c9eb287119d037098"
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_log_null_metric - AssertionError: Regex pattern did not match.
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_delete_restore_run - mlflow.exceptions.MlflowException: assert 1725699804046 is None
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_log_metric_concurrent_logging_succeeds - mlflow.exceptions.MlflowException: error creating metrics in batch for run_uuid "2fed3348475545e9bfb391eb784910f0"
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_delete_restore_run - mlflow.exceptions.MlflowException: assert 1725952839377 is None
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_error_logging_to_deleted_run - Failed: DID NOT RAISE <class 'mlflow.exceptions.MlflowException'>
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_order_by_metric_tag_param - mlflow.exceptions.MlflowException: error getting runs: [INTERNAL_ERROR] Failed to query search runs: no such column: x
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_order_by_attributes - AssertionError: assert ['-123', 'Non... '456', '789'] == ['-123', '123...'789', 'None']
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_search_attrs - Failed: DID NOT RAISE <class 'mlflow.exceptions.MlflowException'>
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_search_runs_pagination - AssertionError: assert '' is None
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_search_runs_datasets - AssertionError: assert {'22cd946d5d1...fc87c5bd5a3e'} == {'2f2080ed1fa...fc87c5bd5a3e'}
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_search_runs_datasets - AssertionError: assert {'3ef36935def...ceb4c329e6f0'} == {'4de45f332f7...ceb4c329e6f0'}
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_log_batch_param_overwrite_disallowed_single_req - AssertionError: Regex pattern did not match.
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_log_batch_internal_error - Failed: DID NOT RAISE <class 'mlflow.exceptions.MlflowException'>
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_log_batch_nonexistent_run - AssertionError: Regex pattern did not match.
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_log_batch_null_metrics - TypeError: must be real number, not NoneType
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_log_batch_params_max_length_value - AssertionError: Regex pattern did not match.
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_log_batch_params_max_length_value - mlflow.exceptions.MlflowException: Invalid value "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx..." for parameter 'params[0].value' supplied: length 6001 exceeded length limit of 6000
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_log_inputs_with_large_inputs_limit_check - AssertionError: assert {'digest': 'd...ema': '', ...} == {'digest': 'd...a': None, ...}
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_sqlalchemy_store_behaves_as_expected_with_inmemory_sqlite_db - mlflow.exceptions.MlflowException: failed to create experiment
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_create_experiment_appends_to_artifact_local_path_file_uri_correctly[#path/to/local/folder?-{cwd}/#path/to/local/folder?/{e}] - AssertionError: assert '/workspaces/...local/folder?' == '/workspaces/...cal/folder?/1'
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_create_run_appends_to_artifact_local_path_file_uri_correctly[#path/to/local/folder?-{cwd}/#path/to/local/folder?/{e}/{r}/artifacts] - AssertionError: assert '/workspaces/...local/folder?' == '/workspaces/...6e2/artifacts'
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_create_run_appends_to_artifact_local_path_file_uri_correctly[#path/to/local/folder?-{cwd}/#path/to/local/folder?/{e}/{r}/artifacts] - AssertionError: assert '/workspaces/...local/folder?' == '/workspaces/...98c/artifacts'
FAILED .mlflow.repo/tests/store/model_registry/test_sqlalchemy_store.py::test_get_latest_versions - AssertionError: assert {'None': '1',...Staging': '4'} == {'None': 1, '... 'Staging': 4}
========================================== 22 failed, 337 passed, 9 skipped, 128 deselected, 10 warnings in 510.33s (0:08:30) ==========================================
================================================================================ 20 failed, 339 passed, 9 skipped, 128 deselected, 10 warnings in 226.63s (0:03:46) ================================================================================
```

## Debug failing tests
Expand Down
6 changes: 5 additions & 1 deletion magefiles/dev.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ import (
func Dev() error {
mg.Deps(Generate)

return sh.RunV(
envs := make(map[string]string)
envs["MLFLOW_TRUNCATE_LONG_VALUES"] = "false"

return sh.RunWithV(
envs,
"mlflow-go",
"server",
"--backend-store-uri",
Expand Down
6 changes: 3 additions & 3 deletions magefiles/generate/validations.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ var validations = map[string]string{
"SearchRuns_MaxResults": "gt=0,max=50000",
"DeleteExperiment_ExperimentId": "required,stringAsPositiveInteger",
"LogBatch_RunId": "required,runId",
"LogBatch_Params": "omitempty,uniqueParams,max=100",
"LogBatch_Metrics": "max=1000,dip",
"LogBatch_Params": "omitempty,uniqueParams,max=100,dive",
"LogBatch_Metrics": "max=1000,dive",
"LogBatch_Tags": "max=100",
"RunTag_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique",
"RunTag_Value": "omitempty,max=5000",
"Param_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique",
"Param_Value": "omitempty,max=6000",
"Param_Value": "omitempty,truncate=6000",
"Metric_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique",
"Metric_Timestamp": "required",
"Metric_Value": "required",
Expand Down
6 changes: 3 additions & 3 deletions pkg/protos/service.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

139 changes: 86 additions & 53 deletions pkg/validation/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net/url"
"os"
"path/filepath"
"reflect"
"regexp"
Expand All @@ -18,8 +19,9 @@ import (
)

const (
QuoteLength = 2
MaxEntitiesPerBatch = 1000
QuoteLength = 2
MaxEntitiesPerBatch = 1000
MaxValidationInputLength = 100
)

// regex for valid param and metric names: may only contain slashes, alphanumerics,
Expand All @@ -29,39 +31,6 @@ var paramAndMetricNameRegex = regexp.MustCompile(`^[/\w.\- ]*$`)
// regex for valid run IDs: must be an alphanumeric string of length 1 to 256.
var runIDRegex = regexp.MustCompile(`^[a-zA-Z0-9][\w\-]{0,255}$`)

func getValue(x reflect.Value) reflect.Value {
if x.Kind() == reflect.Pointer {
return x.Elem()
}

return x
}

func validateNested(validate *validator.Validate, current reflect.Value) bool {
val := getValue(current)

//nolint:exhaustive
switch val.Kind() {
case reflect.Slice, reflect.Array:
for i := range val.Len() {
if !validateNested(validate, val.Index(i)) {
return false
}
}

case reflect.Struct:
if err := validate.Struct(val.Interface()); err != nil {
return false
}
default:
if err := validate.Var(val.Interface(), ""); err != nil {
return false
}
}

return true
}

func stringAsPositiveIntegerValidation(fl validator.FieldLevel) bool {
valueStr := fl.Field().String()

Expand Down Expand Up @@ -138,6 +107,37 @@ func validateLogBatchLimits(structLevel validator.StructLevel) {
}
}

func truncateFn(fieldLevel validator.FieldLevel) bool {
param := fieldLevel.Param() // Get the parameter from the tag

maxLength, err := strconv.Atoi(param)
if err != nil {
return false // If the parameter isn't a valid integer, fail the validation.
}

truncateLongValues, shouldTruncate := os.LookupEnv("MLFLOW_TRUNCATE_LONG_VALUES")
shouldTruncate = shouldTruncate && truncateLongValues == "true"

field := fieldLevel.Field()

if field.Kind() == reflect.String {
strValue := field.String()
if len(strValue) <= maxLength {
return true
}

if shouldTruncate {
field.SetString(strValue[:maxLength])

return true
}

return false
}

return true
}

func NewValidator() (*validator.Validate, error) {
validate := validator.New()

Expand All @@ -151,18 +151,6 @@ func NewValidator() (*validator.Validate, error) {
return name
})

// Validate nested content of a struct field while reporting a problem on the current level.
if err := validate.RegisterValidation(
"dip",
func(fl validator.FieldLevel) bool {
val := fl.Field()

return validateNested(validate, val)
},
); err != nil {
return nil, fmt.Errorf("validation registration for 'dip' failed: %w", err)
}

// Verify that the input string is a positive integer.
if err := validate.RegisterValidation(
"stringAsPositiveInteger", stringAsPositiveIntegerValidation,
Expand Down Expand Up @@ -195,6 +183,10 @@ func NewValidator() (*validator.Validate, error) {
return nil, fmt.Errorf("validation registration for 'runId' failed: %w", err)
}

if err := validate.RegisterValidation("truncate", truncateFn); err != nil {
return nil, fmt.Errorf("validation registration for 'truncateFn' failed: %w", err)
}

validate.RegisterStructValidation(validateLogBatchLimits, &protos.LogBatch{})

return validate, nil
Expand All @@ -213,13 +205,36 @@ func dereference(value interface{}) interface{} {
return value
}

func getErrorPath(err validator.FieldError) string {
path := err.Field()

if err.Namespace() != "" {
// Strip first item in struct namespace
idx := strings.Index(err.Namespace(), ".")
if idx != -1 {
path = err.Namespace()[(idx + 1):]
}
}

return path
}

func constructValidationError(field string, value any, suffix string) string {
formattedValue, err := json.Marshal(value)
if err != nil {
formattedValue = []byte(fmt.Sprintf("%v", value))
}

return fmt.Sprintf("Invalid value %s for parameter '%s' supplied%s", formattedValue, field, suffix)
}

func NewErrorFromValidationError(err error) *contract.Error {
var ve validator.ValidationErrors
if errors.As(err, &ve) {
validationErrors := make([]string, 0)

for _, err := range ve {
field := err.Field()
field := getErrorPath(err)
tag := err.Tag()
value := dereference(err.Value())

Expand All @@ -229,15 +244,33 @@ func NewErrorFromValidationError(err error) *contract.Error {
validationErrors,
fmt.Sprintf("Missing value for required parameter '%s'", field),
)
default:
formattedValue, err := json.Marshal(value)
if err != nil {
formattedValue = []byte(fmt.Sprintf("%v", value))
case "truncate":
strValue, ok := value.(string)
if ok {
expected := len(strValue)

if expected > MaxValidationInputLength {
strValue = strValue[:MaxValidationInputLength] + "..."
}

validationErrors = append(
validationErrors,
constructValidationError(
field,
strValue,
fmt.Sprintf(": length %d exceeded length limit of %s", expected, err.Param())),
)
} else {
validationErrors = append(
validationErrors,
constructValidationError(field, value, ""),
)
}

default:
validationErrors = append(
validationErrors,
fmt.Sprintf("Invalid value %s for parameter '%s' supplied", formattedValue, field),
constructValidationError(field, value, ""),
)
}
}
Expand Down
77 changes: 65 additions & 12 deletions pkg/validation/validation_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
package validation_test

import (
"errors"
"strings"
"testing"

"github.com/go-playground/validator/v10"
"github.com/stretchr/testify/require"

"github.com/mlflow/mlflow-go/pkg/protos"
Expand Down Expand Up @@ -172,20 +171,74 @@ func TestMissingTimestampInNestedMetric(t *testing.T) {

err = serverValidator.Struct(&logBatch)
if err == nil {
t.Error("Expected dip validation error, got none")
t.Error("Expected dive validation error, got none")
}

var validationErrors validator.ValidationErrors
if errors.As(err, &validationErrors) {
if len(validationErrors) != 1 {
t.Errorf("Expected 1 validation error, got %v", len(validationErrors))
msg := validation.NewErrorFromValidationError(err).Message
if !strings.Contains(msg, "metrics[0].timestamp") {
t.Errorf("Expected required validation error for nested property, got %v", msg)
}
}

type avecTruncate struct {
X *string `validate:"truncate=3"`
Y string `validate:"truncate=3"`
}

func TestTruncate(t *testing.T) {
input := &avecTruncate{
X: utils.PtrTo("123456"),
Y: "654321",
}

t.Setenv("MLFLOW_TRUNCATE_LONG_VALUES", "true")

validator, err := validation.NewValidator()
require.NoError(t, err)

err = validator.Struct(input)
require.NoError(t, err)

if len(*input.X) != 3 {
t.Errorf("Expected the length of x to be 3, was %d", len(*input.X))
}

if len(input.Y) != 3 {
t.Errorf("Expected the length of y to be 3, was %d", len(input.Y))
}
}

// This unit test is a sanity test that confirms the `dive` validation
// enters a nested slice of pointer structs.
func TestNestedErrorsInSubCollection(t *testing.T) {
t.Parallel()

value := strings.Repeat("X", 6001) + "Y"

logBatchRequest := &protos.LogBatch{
RunId: utils.PtrTo("odcppTsGTMkHeDcqfZOYDMZSf"),
Params: []*protos.Param{
{Key: utils.PtrTo("key1"), Value: utils.PtrTo(value)},
{Key: utils.PtrTo("key2"), Value: utils.PtrTo(value)},
},
}

validator, err := validation.NewValidator()
require.NoError(t, err)

err = validator.Struct(logBatchRequest)
if err != nil {
msg := validation.NewErrorFromValidationError(err).Message
// Assert the root struct name is not present in the error message
if strings.Contains(msg, "logBatch") {
t.Errorf("Validation message contained root struct name, got %s", msg)
}

validationError := validationErrors[0]
if validationError.Tag() != "dip" {
t.Errorf("Expected dip validation error, got %v", validationError.Tag())
// Assert the index is listed in the parameter path
if !strings.Contains(msg, "params[0].value") ||
!strings.Contains(msg, "params[1].value") ||
!strings.Contains(msg, "length 6002 exceeded length limit of 6000") {
t.Errorf("Unexpected validation error message, got %s", msg)
}
} else {
t.Error("Expected validation error, got none")
}
}

0 comments on commit 67c2b94

Please sign in to comment.