Skip to content

Commit

Permalink
Move /mlflow/traces/{request_id}/tags endpoint. (#11)
Browse files Browse the repository at this point in the history
Signed-off-by: Software Developer <7852635+dsuhinin@users.noreply.github.com>
  • Loading branch information
dsuhinin authored Oct 14, 2024
1 parent 8f6ef94 commit 2f25223
Show file tree
Hide file tree
Showing 18 changed files with 165 additions and 19 deletions.
2 changes: 1 addition & 1 deletion magefiles/generate/discovery/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ var routeParameterRegex = regexp.MustCompile(`<[^>]+:([^>]+)>`)
// Get the safe path to use in Fiber registration.
func (e Endpoint) GetFiberPath() string {
// e.Path cannot be trusted, it could be something like /mlflow-artifacts/artifacts/<path:artifact_path>
// Which would need to converted to /mlflow-artifacts/artifacts/:path
// which would need to be converted to /mlflow-artifacts/artifacts/:path
path := routeParameterRegex.ReplaceAllStringFunc(e.Path, func(s string) string {
parts := strings.Split(s, ":")

Expand Down
3 changes: 2 additions & 1 deletion magefiles/generate/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{
"logParam",
"setExperimentTag",
// "setTag",
// "setTraceTag",
"setTraceTag",
// "deleteTraceTag",
"deleteTraceTag",
"deleteTag",
"searchRuns",
Expand Down
2 changes: 2 additions & 0 deletions magefiles/generate/validations.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ var validations = map[string]string{
"LogMetric_Key": "required",
"LogMetric_Value": "required",
"LogMetric_Timestamp": "required",
"SetTraceTag_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique",
"SetTraceTag_Value": "omitempty,truncate=8000",
"DeleteTag_RunId": "required",
"DeleteTag_Key": "required",
"SetExperimentTag_ExperimentId": "required",
Expand Down
17 changes: 15 additions & 2 deletions mlflow_go/store/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
RunInfo,
ViewType,
)
from mlflow.environment_variables import MLFLOW_TRUNCATE_LONG_VALUES
from mlflow.exceptions import MlflowException
from mlflow.protos import databricks_pb2
from mlflow.protos.service_pb2 import (
Expand All @@ -25,6 +26,7 @@
RestoreExperiment,
RestoreRun,
SearchRuns,
SetTraceTag,
UpdateExperiment,
UpdateRun,
)
Expand All @@ -47,9 +49,12 @@ def __init__(self, *args, **kwargs):
)
config = json.dumps(
{
"default_artifact_root": resolve_uri_if_local(default_artifact_root),
"tracking_store_uri": store_uri,
"log_level": logging.getLevelName(_logger.getEffectiveLevel()),
"python_tests_env": {
"MLFLOW_TRUNCATE_LONG_VALUES": MLFLOW_TRUNCATE_LONG_VALUES.get()
},
"tracking_store_uri": store_uri,
"default_artifact_root": resolve_uri_if_local(default_artifact_root),
}
).encode("utf-8")
self.service = _ServiceProxy(get_lib().CreateTrackingService(config, len(config)))
Expand Down Expand Up @@ -176,6 +181,14 @@ def log_param(self, run_id, param):
)
self.service.call_endpoint(get_lib().TrackingServiceLogParam, request)

def set_trace_tag(self, request_id: str, key: str, value: str):
request = SetTraceTag(
key=key,
value=value,
request_id=request_id,
)
self.service.call_endpoint(get_lib().TrackingServiceSetTraceTag, request)

def delete_tag(self, run_id, key):
request = DeleteTag(run_id=run_id, key=key)
self.service.call_endpoint(get_lib().TrackingServiceDeleteTag, request)
Expand Down
23 changes: 12 additions & 11 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,18 @@ func (d *Duration) UnmarshalJSON(b []byte) error {
}

type Config struct {
Address string `json:"address"`
DefaultArtifactRoot string `json:"default_artifact_root"`
LogLevel string `json:"log_level"`
ModelRegistryStoreURI string `json:"model_registry_store_uri"`
PythonAddress string `json:"python_address"`
PythonCommand []string `json:"python_command"`
PythonEnv []string `json:"python_env"`
ShutdownTimeout Duration `json:"shutdown_timeout"`
StaticFolder string `json:"static_folder"`
TrackingStoreURI string `json:"tracking_store_uri"`
Version string `json:"version"`
Address string `json:"address"`
DefaultArtifactRoot string `json:"default_artifact_root"`
LogLevel string `json:"log_level"`
ModelRegistryStoreURI string `json:"model_registry_store_uri"`
PythonEnv []string `json:"python_env"`
PythonAddress string `json:"python_address"`
PythonCommand []string `json:"python_command"`
PythonTestsENV map[string]interface{} `json:"python_tests_env"`
ShutdownTimeout Duration `json:"shutdown_timeout"`
StaticFolder string `json:"static_folder"`
TrackingStoreURI string `json:"tracking_store_uri"`
Version string `json:"version"`
}

func NewConfigFromBytes(cfgBytes []byte) (*Config, error) {
Expand Down
1 change: 1 addition & 0 deletions pkg/contract/service/tracking.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions pkg/entities/trace_tag.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
package entities

import "github.com/mlflow/mlflow-go/pkg/protos"

type TraceTag struct {
Key string
Value string
RequestID string
}

func NewTraceTagFromProto(proto *protos.SetTraceTag) *TraceTag {
return &TraceTag{
Key: proto.GetKey(),
Value: proto.GetValue(),
RequestID: proto.GetRequestId(),
}
}
10 changes: 10 additions & 0 deletions pkg/lib/instance_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package main

import (
"context"
"fmt"
"os"
"sync"

"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -48,6 +50,14 @@ func (s *instanceMap[T]) Create(
return -1
}

for key, value := range cfg.PythonTestsENV {
if err := os.Setenv(key, fmt.Sprintf("%v", value)); err != nil {
logrus.Error("Failed to set env: ", err)

return -1
}
}

logger := utils.NewLoggerFromConfig(cfg)

logger.Debugf("Loaded config: %#v", cfg)
Expand Down
8 changes: 8 additions & 0 deletions pkg/lib/tracking.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion pkg/lib/tracking.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ var trackingServices = newInstanceMap[*service.TrackingService]()
//export CreateTrackingService
func CreateTrackingService(configData unsafe.Pointer, configSize C.int) int64 {
//nolint:nlreturn
return trackingServices.Create(service.NewTrackingService, C.GoBytes(configData, configSize))
return trackingServices.Create(
service.NewTrackingService, C.GoBytes(configData, configSize),
)
}

//export DestroyTrackingService
Expand Down
4 changes: 2 additions & 2 deletions pkg/protos/service.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions pkg/server/routes/tracking.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions pkg/tracking/service/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@ import (
"github.com/mlflow/mlflow-go/pkg/protos"
)

func (ts TrackingService) SetTraceTag(
ctx context.Context, input *protos.SetTraceTag,
) (*protos.SetTraceTag_Response, *contract.Error) {
if err := ts.Store.SetTraceTag(
ctx, input.GetRequestId(), input.GetKey(), input.GetValue(),
); err != nil {
return nil, contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to create trace_tag", err)
}

return &protos.SetTraceTag_Response{}, nil
}

func (ts TrackingService) DeleteTraceTag(
ctx context.Context, input *protos.DeleteTraceTag,
) (*protos.DeleteTraceTag_Response, *contract.Error) {
Expand Down
49 changes: 49 additions & 0 deletions pkg/tracking/store/mock_tracking_store.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions pkg/tracking/store/sql/models/trace_tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,11 @@ func (t TraceTag) ToEntity() *entities.TraceTag {
RequestID: t.RequestID,
}
}

func NewTraceTagFromEntity(entity *entities.TraceTag) TraceTag {
return TraceTag{
Key: entity.Key,
Value: entity.Value,
RequestID: entity.RequestID,
}
}
18 changes: 18 additions & 0 deletions pkg/tracking/store/sql/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,31 @@ import (
"fmt"

"gorm.io/gorm"
"gorm.io/gorm/clause"

"github.com/mlflow/mlflow-go/pkg/contract"
"github.com/mlflow/mlflow-go/pkg/entities"
"github.com/mlflow/mlflow-go/pkg/protos"
"github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models"
)

func (s TrackingSQLStore) SetTraceTag(
ctx context.Context, requestID, key, value string,
) error {
if err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}, {Name: "request_id"}},
DoUpdates: clause.AssignmentColumns([]string{"value"}),
}).Create(models.TraceTag{
Key: key,
Value: value,
RequestID: requestID,
}).Error; err != nil {
return err
}

return nil
}

func (s TrackingSQLStore) GetTraceTag(
ctx context.Context, requestID, key string,
) (*entities.TraceTag, *contract.Error) {
Expand Down
1 change: 1 addition & 0 deletions pkg/tracking/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type (
DeleteTag(ctx context.Context, runID, key string) *contract.Error
}
TraceTrackingStore interface {
SetTraceTag(ctx context.Context, requestID, key, value string) error
GetTraceTag(ctx context.Context, requestID, key string) (*entities.TraceTag, *contract.Error)
DeleteTraceTag(ctx context.Context, tag *entities.TraceTag) *contract.Error
}
Expand Down
1 change: 0 additions & 1 deletion pkg/validation/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ func truncateFn(fieldLevel validator.FieldLevel) bool {

truncateLongValues, shouldTruncate := os.LookupEnv("MLFLOW_TRUNCATE_LONG_VALUES")
shouldTruncate = shouldTruncate && truncateLongValues == "true"

field := fieldLevel.Field()

if field.Kind() == reflect.String {
Expand Down

0 comments on commit 2f25223

Please sign in to comment.