Skip to content

Commit e161a97

Browse files
committed
Add additional validation checks
Signed-off-by: nojaf <florian.verdonck@outlook.com>
1 parent e059f04 commit e161a97

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

magefiles/generate/validations.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@ var validations = map[string]string{
3232
"SetExperimentTag_ExperimentId": "required",
3333
"SetExperimentTag_Key": "required,max=250,validMetricParamOrTagName",
3434
"SetExperimentTag_Value": "max=5000",
35-
"LogInputs_RunId": "required",
35+
"LogInputs_RunId": "required,runId",
3636
"LogInputs_Datasets": "required",
37+
"DatasetInput_Dataset": "required",
3738
"Dataset_Name": "required,max=500",
3839
"Dataset_Digest": "required,max=36",
3940
"Dataset_SourceType": "required",

pkg/tracking/service/inputs.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ import (
1111
func (ts TrackingService) LogInputs(
1212
ctx context.Context, input *protos.LogInputs,
1313
) (*protos.LogInputs_Response, *contract.Error) {
14+
if len(input.GetDatasets()) == 0 {
15+
return &protos.LogInputs_Response{}, nil
16+
}
17+
1418
datasets := make([]*entities.DatasetInput, 0, len(input.GetDatasets()))
1519
for _, d := range input.GetDatasets() {
1620
datasets = append(datasets, entities.NewDatasetInputFromProto(d))

pkg/tracking/store/sql/inputs.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,39 @@ package sql
22

33
import (
44
"context"
5+
"errors"
6+
"fmt"
7+
8+
"gorm.io/gorm"
59

610
"github.com/mlflow/mlflow-go/pkg/contract"
711
"github.com/mlflow/mlflow-go/pkg/entities"
12+
"github.com/mlflow/mlflow-go/pkg/protos"
813
)
914

1015
func (store TrackingSQLStore) LogInputs(
1116
ctx context.Context, runID string, datasets []*entities.DatasetInput,
1217
) *contract.Error {
18+
err := store.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error {
19+
contractError := checkRunIsActive(transaction, runID)
20+
if contractError != nil {
21+
return contractError
22+
}
23+
24+
return nil
25+
})
26+
if err != nil {
27+
var contractError *contract.Error
28+
if errors.As(err, &contractError) {
29+
return contractError
30+
}
31+
32+
return contract.NewErrorWith(
33+
protos.ErrorCode_INTERNAL_ERROR,
34+
fmt.Sprintf("log inputs transaction failed for %q", runID),
35+
err,
36+
)
37+
}
38+
1339
return nil
1440
}

0 commit comments

Comments
 (0)