diff --git a/mlem/core/metadata.py b/mlem/core/metadata.py index c3490bef..78d1a9a6 100644 --- a/mlem/core/metadata.py +++ b/mlem/core/metadata.py @@ -36,7 +36,7 @@ def get_object_metadata( obj: Any, sample_data=None, - params: Dict[str, str] = None, + params: Dict[str, Any] = None, preprocess: Union[Any, Dict[str, Any]] = None, postprocess: Union[Any, Dict[str, Any]] = None, ) -> Union[MlemData, MlemModel]: @@ -97,7 +97,7 @@ def save( project: Optional[str] = None, sample_data=None, fs: Optional[AbstractFileSystem] = None, - params: Dict[str, str] = None, + params: Dict[str, Any] = None, preprocess: Union[Any, Dict[str, Any]] = None, postprocess: Union[Any, Dict[str, Any]] = None, ) -> MlemObject: diff --git a/mlem/core/objects.py b/mlem/core/objects.py index bffbec37..e39a0d7a 100644 --- a/mlem/core/objects.py +++ b/mlem/core/objects.py @@ -5,6 +5,7 @@ import dataclasses import hashlib import itertools +import json import os import posixpath import time @@ -50,6 +51,7 @@ MlemError, MlemObjectNotFound, MlemObjectNotSavedError, + SerializationError, WrongABCType, WrongMetaSubType, WrongMetaType, @@ -91,9 +93,19 @@ class Config: object_type: ClassVar[str] location: Optional[Location] = None """MlemObject location [transient]""" - params: Dict[str, str] = {} + params: Dict[str, Any] = {} """Arbitrary map of additional parameters""" + @validator("params") + def params_are_serializable( # pylint: disable=no-self-argument + cls, value # noqa: B902 + ): + try: + json.dumps(value) + except TypeError as e: + raise SerializationError(f"Can't serialize object: {value}") from e + return value + @property def loc(self) -> Location: if self.location is None: @@ -751,7 +763,7 @@ def from_obj( model: Any, sample_data: Any = None, methods_sample_data: Dict[str, Any] = None, - params: Dict[str, str] = None, + params: Dict[str, Any] = None, preprocess: Union[Any, Dict[str, Any]] = None, postprocess: Union[Any, Dict[str, Any]] = None, ) -> "MlemModel": @@ -931,7 +943,7 @@ def data(self): def from_data( cls, data: Any, - params: Dict[str, str] = None, + params: Dict[str, Any] = None, ) -> "MlemData": data_type = DataType.create( data, diff --git a/mlem/runtime/interface.py b/mlem/runtime/interface.py index 8db9c797..1c47bcae 100644 --- a/mlem/runtime/interface.py +++ b/mlem/runtime/interface.py @@ -99,6 +99,7 @@ class VersionedInterfaceDescriptor(BaseModel): methods: InterfaceDescriptor version: str = mlem.version.__version__ """mlem version""" + meta: Any class Interface(ABC, MlemABC): @@ -201,9 +202,14 @@ def get_descriptor(self) -> InterfaceDescriptor: } ) + def get_model_meta(self): + return None + def get_versioned_descriptor(self) -> VersionedInterfaceDescriptor: return VersionedInterfaceDescriptor( - version=mlem.__version__, methods=self.get_descriptor() + version=mlem.__version__, + methods=self.get_descriptor(), + meta=self.get_model_meta(), ) @@ -267,7 +273,7 @@ def _check_no_signature(data): class ModelInterface(Interface): - """Interface that descibes model methods""" + """Interface that describes model methods""" type: ClassVar[str] = "model" model: MlemModel @@ -352,3 +358,6 @@ def get_method_args( a.name: a.type_ for a in self.model.model_type.methods[method_name].args } + + def get_model_meta(self): + return self.model.params diff --git a/mlem/runtime/server.py b/mlem/runtime/server.py index fcaca84a..97f1265e 100644 --- a/mlem/runtime/server.py +++ b/mlem/runtime/server.py @@ -335,3 +335,6 @@ def get_method_signature(self, method_name: str) -> InterfaceMethod: ], returns=self._get_response(method_name, signature.returns), ) + + def get_model_meta(self): + return getattr(getattr(self.interface, "model", None), "params", None) diff --git a/tests/contrib/test_fastapi.py b/tests/contrib/test_fastapi.py index 202411c5..ed8978db 100644 --- a/tests/contrib/test_fastapi.py +++ b/tests/contrib/test_fastapi.py @@ -145,6 +145,20 @@ def test_endpoint(f_client, f_interface: Interface, create_mlem_client, train): assert response.json() == [0] * 50 + [1] * 50 + [2] * 50 +def test_params_exposed_to_interface(): + model = MlemModel.from_obj( + lambda x: x, sample_data="sample", params={"a": "b"} + ) + interface = ModelInterface.from_model(model) + + app = FastAPIServer().app_init(interface) + client = TestClient(app) + + docs = client.get("/interface.json") + assert docs.status_code == 200, docs.json() + assert docs.json()["meta"] == {"a": "b"} + + @pytest.mark.parametrize( "data", [ diff --git a/tests/core/test_metadata.py b/tests/core/test_metadata.py index d2afc8c1..37bf0dfc 100644 --- a/tests/core/test_metadata.py +++ b/tests/core/test_metadata.py @@ -3,6 +3,7 @@ import shutil import sys import tempfile +from datetime import datetime from pathlib import Path from urllib.parse import quote_plus @@ -16,7 +17,7 @@ from mlem.api import init from mlem.contrib.heroku.meta import HerokuEnv -from mlem.core.errors import InvalidArgumentError +from mlem.core.errors import InvalidArgumentError, SerializationError from mlem.core.meta_io import MLEM_EXT from mlem.core.metadata import ( list_objects, @@ -42,9 +43,16 @@ @pytest.mark.parametrize("obj", [lazy_fixture("model"), lazy_fixture("train")]) def test_save_with_meta_fields(obj, tmpdir): path = str(tmpdir / "obj") - save(obj, path, params={"a": "b"}) + save(obj, path, params={"a": {"b": ["c", "d", 1]}}) new = load_meta(path) - assert new.params == {"a": "b"} + assert new.params == {"a": {"b": ["c", "d", 1]}} + + +@pytest.mark.parametrize("obj", [lazy_fixture("model")]) +def test_save_with_meta_fields_fails(obj, tmpdir): + path = str(tmpdir / "obj") + with pytest.raises(SerializationError): + save(obj, path, params={"a": datetime.now()}) def test_saving_with_project(model, tmpdir): diff --git a/tests/runtime/test_interface.py b/tests/runtime/test_interface.py index aaaf01ab..aa624782 100644 --- a/tests/runtime/test_interface.py +++ b/tests/runtime/test_interface.py @@ -70,6 +70,7 @@ def test_interface_descriptor__to_dict(interface: Interface): assert d.dict() == { "version": mlem.__version__, + "meta": None, "methods": { "method1": { "args": [