Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
Expose params passed to mlem.api.save to fastapi's `interface.jso…
Browse files Browse the repository at this point in the history
…n` (#670)

close #664
  • Loading branch information
aguschin authored May 11, 2023
1 parent 6ca9d0e commit 5936765
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 10 deletions.
4 changes: 2 additions & 2 deletions mlem/core/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
def get_object_metadata(
obj: Any,
sample_data=None,
params: Dict[str, str] = None,
params: Dict[str, Any] = None,
preprocess: Union[Any, Dict[str, Any]] = None,
postprocess: Union[Any, Dict[str, Any]] = None,
) -> Union[MlemData, MlemModel]:
Expand Down Expand Up @@ -97,7 +97,7 @@ def save(
project: Optional[str] = None,
sample_data=None,
fs: Optional[AbstractFileSystem] = None,
params: Dict[str, str] = None,
params: Dict[str, Any] = None,
preprocess: Union[Any, Dict[str, Any]] = None,
postprocess: Union[Any, Dict[str, Any]] = None,
) -> MlemObject:
Expand Down
18 changes: 15 additions & 3 deletions mlem/core/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dataclasses
import hashlib
import itertools
import json
import os
import posixpath
import time
Expand Down Expand Up @@ -50,6 +51,7 @@
MlemError,
MlemObjectNotFound,
MlemObjectNotSavedError,
SerializationError,
WrongABCType,
WrongMetaSubType,
WrongMetaType,
Expand Down Expand Up @@ -91,9 +93,19 @@ class Config:
object_type: ClassVar[str]
location: Optional[Location] = None
"""MlemObject location [transient]"""
params: Dict[str, str] = {}
params: Dict[str, Any] = {}
"""Arbitrary map of additional parameters"""

@validator("params")
def params_are_serializable( # pylint: disable=no-self-argument
cls, value # noqa: B902
):
try:
json.dumps(value)
except TypeError as e:
raise SerializationError(f"Can't serialize object: {value}") from e
return value

@property
def loc(self) -> Location:
if self.location is None:
Expand Down Expand Up @@ -751,7 +763,7 @@ def from_obj(
model: Any,
sample_data: Any = None,
methods_sample_data: Dict[str, Any] = None,
params: Dict[str, str] = None,
params: Dict[str, Any] = None,
preprocess: Union[Any, Dict[str, Any]] = None,
postprocess: Union[Any, Dict[str, Any]] = None,
) -> "MlemModel":
Expand Down Expand Up @@ -931,7 +943,7 @@ def data(self):
def from_data(
cls,
data: Any,
params: Dict[str, str] = None,
params: Dict[str, Any] = None,
) -> "MlemData":
data_type = DataType.create(
data,
Expand Down
13 changes: 11 additions & 2 deletions mlem/runtime/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class VersionedInterfaceDescriptor(BaseModel):
methods: InterfaceDescriptor
version: str = mlem.version.__version__
"""mlem version"""
meta: Any


class Interface(ABC, MlemABC):
Expand Down Expand Up @@ -201,9 +202,14 @@ def get_descriptor(self) -> InterfaceDescriptor:
}
)

def get_model_meta(self):
return None

def get_versioned_descriptor(self) -> VersionedInterfaceDescriptor:
return VersionedInterfaceDescriptor(
version=mlem.__version__, methods=self.get_descriptor()
version=mlem.__version__,
methods=self.get_descriptor(),
meta=self.get_model_meta(),
)


Expand Down Expand Up @@ -267,7 +273,7 @@ def _check_no_signature(data):


class ModelInterface(Interface):
"""Interface that descibes model methods"""
"""Interface that describes model methods"""

type: ClassVar[str] = "model"
model: MlemModel
Expand Down Expand Up @@ -352,3 +358,6 @@ def get_method_args(
a.name: a.type_
for a in self.model.model_type.methods[method_name].args
}

def get_model_meta(self):
return self.model.params
3 changes: 3 additions & 0 deletions mlem/runtime/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,6 @@ def get_method_signature(self, method_name: str) -> InterfaceMethod:
],
returns=self._get_response(method_name, signature.returns),
)

def get_model_meta(self):
return getattr(getattr(self.interface, "model", None), "params", None)
14 changes: 14 additions & 0 deletions tests/contrib/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,20 @@ def test_endpoint(f_client, f_interface: Interface, create_mlem_client, train):
assert response.json() == [0] * 50 + [1] * 50 + [2] * 50


def test_params_exposed_to_interface():
model = MlemModel.from_obj(
lambda x: x, sample_data="sample", params={"a": "b"}
)
interface = ModelInterface.from_model(model)

app = FastAPIServer().app_init(interface)
client = TestClient(app)

docs = client.get("/interface.json")
assert docs.status_code == 200, docs.json()
assert docs.json()["meta"] == {"a": "b"}


@pytest.mark.parametrize(
"data",
[
Expand Down
14 changes: 11 additions & 3 deletions tests/core/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import shutil
import sys
import tempfile
from datetime import datetime
from pathlib import Path
from urllib.parse import quote_plus

Expand All @@ -16,7 +17,7 @@

from mlem.api import init
from mlem.contrib.heroku.meta import HerokuEnv
from mlem.core.errors import InvalidArgumentError
from mlem.core.errors import InvalidArgumentError, SerializationError
from mlem.core.meta_io import MLEM_EXT
from mlem.core.metadata import (
list_objects,
Expand All @@ -42,9 +43,16 @@
@pytest.mark.parametrize("obj", [lazy_fixture("model"), lazy_fixture("train")])
def test_save_with_meta_fields(obj, tmpdir):
path = str(tmpdir / "obj")
save(obj, path, params={"a": "b"})
save(obj, path, params={"a": {"b": ["c", "d", 1]}})
new = load_meta(path)
assert new.params == {"a": "b"}
assert new.params == {"a": {"b": ["c", "d", 1]}}


@pytest.mark.parametrize("obj", [lazy_fixture("model")])
def test_save_with_meta_fields_fails(obj, tmpdir):
path = str(tmpdir / "obj")
with pytest.raises(SerializationError):
save(obj, path, params={"a": datetime.now()})


def test_saving_with_project(model, tmpdir):
Expand Down
1 change: 1 addition & 0 deletions tests/runtime/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def test_interface_descriptor__to_dict(interface: Interface):

assert d.dict() == {
"version": mlem.__version__,
"meta": None,
"methods": {
"method1": {
"args": [
Expand Down

0 comments on commit 5936765

Please sign in to comment.