Skip to content
This repository was archived by the owner on Sep 13, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion mlem/contrib/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def _abstract_shape(shape):
@classmethod
def process(cls, obj, **kwargs) -> DataType:
return NumpyNdarrayType(
shape=cls._abstract_shape(obj.shape), dtype=obj.dtype.name
shape=cls._abstract_shape(obj.shape)
if not kwargs.get("is_dynamic")
else tuple(None for _ in obj.shape),
dtype=obj.dtype.name,
)

@classmethod
Expand Down Expand Up @@ -258,3 +261,11 @@ def read_batch(
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DataType]:
raise NotImplementedError


def apply_shape_pattern(
abs_shape: Tuple[Optional[int], ...], shape: Tuple[int, ...]
):
return tuple(
s if s is not None else shape[i] for i, s in enumerate(abs_shape)
)
13 changes: 10 additions & 3 deletions mlem/contrib/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from pydantic import conlist, create_model
from tensorflow.python.keras.saving.saved_model_experimental import sequential

from mlem.contrib.numpy import python_type_from_np_string_repr
from mlem.contrib.numpy import (
apply_shape_pattern,
python_type_from_np_string_repr,
)
from mlem.core.artifacts import Artifacts, Storage
from mlem.core.data_type import (
DataHook,
Expand Down Expand Up @@ -60,7 +63,9 @@ def tf_type(self):
return getattr(tf, self.dtype)

def check_shape(self, tensor, exc_type):
if tuple(tensor.shape)[1:] != self.shape[1:]:
if tuple(tensor.shape) != apply_shape_pattern(
self.shape, tensor.shape
):
raise exc_type(
f"given tensor is of shape: {(None,) + tuple(tensor.shape)[1:]}, expected: {self.shape}"
)
Expand All @@ -85,7 +90,9 @@ def subtype(self, subshape: Tuple[Optional[int], ...]):
@classmethod
def process(cls, obj: tf.Tensor, **kwargs) -> DataType:
return TFTensorDataType(
shape=(None,) + tuple(obj.shape)[1:],
shape=(None,) + tuple(obj.shape)[1:]
if not kwargs.get("is_dynamic")
else tuple(None for _ in obj.shape),
dtype=obj.dtype.name,
)

Expand Down
13 changes: 10 additions & 3 deletions mlem/contrib/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from pydantic import conlist, create_model

from mlem.config import MlemConfigBase
from mlem.contrib.numpy import python_type_from_np_string_repr
from mlem.contrib.numpy import (
apply_shape_pattern,
python_type_from_np_string_repr,
)
from mlem.core.artifacts import Artifacts, FSSpecArtifact, Storage
from mlem.core.data_type import (
DataHook,
Expand Down Expand Up @@ -66,7 +69,9 @@ class TorchTensorDataType(
"""Type name of `torch.Tensor` elements"""

def check_shape(self, tensor, exc_type):
if tuple(tensor.shape)[1:] != self.shape[1:]:
if tuple(tensor.shape) != apply_shape_pattern(
self.shape, tensor.shape
):
raise exc_type(
f"given tensor is of shape: {(None,) + tuple(tensor.shape)[1:]}, expected: {self.shape}"
)
Expand All @@ -91,7 +96,9 @@ def subtype(self, subshape: Tuple[Optional[int], ...]):
@classmethod
def process(cls, obj: torch.Tensor, **kwargs) -> DataType:
return TorchTensorDataType(
shape=(None,) + obj.shape[1:],
shape=(None,) + obj.shape[1:]
if not kwargs.get("is_dynamic")
else tuple(None for _ in obj.shape),
dtype=str(obj.dtype)[len("torch") + 1 :],
)

Expand Down
232 changes: 232 additions & 0 deletions mlem/contrib/transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import os
import tempfile
from enum import Enum
from importlib import import_module
from typing import Any, ClassVar, Dict, Iterator, Optional, Tuple

from transformers import (
AutoModel,
AutoTokenizer,
BatchEncoding,
PreTrainedTokenizer,
TensorType,
)
from transformers.modeling_utils import PreTrainedModel

from mlem.core.artifacts import Artifacts, Storage
from mlem.core.data_type import (
DataAnalyzer,
DataHook,
DataType,
DataWriter,
DictReader,
DictSerializer,
DictType,
DictWriter,
)
from mlem.core.hooks import IsInstanceHookMixin
from mlem.core.model import BufferModelIO, ModelHook, ModelType, Signature
from mlem.core.requirements import InstallableRequirement, Requirements


class ObjectType(str, Enum):
MODEL = "model"
TOKENIZER = "tokenizer"


_loaders = {ObjectType.MODEL: AutoModel, ObjectType.TOKENIZER: AutoTokenizer}

_bases = {
PreTrainedModel: ObjectType.MODEL,
PreTrainedTokenizer: ObjectType.TOKENIZER,
}


def get_object_type(obj) -> ObjectType:
for base, obj_type in _bases.items():
if isinstance(obj, base):
return obj_type
raise ValueError(f"Cannot determine object type for {obj}")


class TransformersIO(BufferModelIO):
type: ClassVar = "transformers"

class Config:
use_enum_values = True

obj_type: ObjectType

def save_model(self, model: PreTrainedModel, path: str):
model.save_pretrained(path)

@property
def load_class(self):
return _loaders[self.obj_type]

def load(self, artifacts: Artifacts):
with tempfile.TemporaryDirectory() as tmpdir:
for name, art in artifacts.items():
art.materialize(os.path.join(tmpdir, name))
return self.load_class.from_pretrained(tmpdir)


class TokenizerModelType(ModelType, ModelHook, IsInstanceHookMixin):
type: ClassVar = "transformers"
valid_types: ClassVar = (PreTrainedModel, PreTrainedTokenizer)

class Config:
use_enum_values = True

return_tensors: Optional[TensorType] = None
io: TransformersIO

@classmethod
def process(
cls,
obj: Any,
sample_data: Optional[Any] = None,
methods_sample_data: Optional[Dict[str, Any]] = None,
**kwargs,
) -> ModelType:
call_kwargs = {}
return_tensors = kwargs.get("return_tensors")
if return_tensors:
call_kwargs["return_tensors"] = return_tensors
sample_data = (methods_sample_data or {}).get("__call__", sample_data)
signature = Signature.from_method(
obj.__call__,
sample_data,
auto_infer=sample_data is not None,
**call_kwargs,
)
[a for a in signature.args if a.name == "return_tensors"][
0
].default = return_tensors
return TokenizerModelType(
methods={"__call__": signature},
io=TransformersIO(obj_type=get_object_type(obj)),
)

def get_requirements(self) -> Requirements:
reqs = super().get_requirements()
if self.io.obj_type == ObjectType.TOKENIZER:
try:
reqs += InstallableRequirement.from_module(
import_module("sentencepiece")
)
reqs += InstallableRequirement.from_module(
import_module("google.protobuf"), package_name="protobuf"
)
except ImportError:
pass
return reqs


ADDITIONAL_DEPS = {
TensorType.NUMPY: "numpy",
TensorType.PYTORCH: "torch",
TensorType.TENSORFLOW: "tensorflow",
}


class BatchEncodingType(DictType, DataHook, IsInstanceHookMixin):
class Config:
use_enum_values = True

type: ClassVar = "batch_encoding"
valid_types: ClassVar = BatchEncoding
return_tensors: Optional[TensorType] = None

@staticmethod
def get_tensors_type(obj: BatchEncoding) -> Optional[TensorType]:
types = {type(v) for v in obj.values()}
if len(types) > 1:
raise ValueError(f"Mixed tensor types in {obj}")
type_ = next(iter(types))
if type_.__module__ == "torch":
return TensorType.PYTORCH
if type_.__module__.startswith("tensorflow"):
return TensorType.TENSORFLOW
if type_.__module__.startswith("numpy"):
return TensorType.NUMPY
if type_ is list:
return None
raise ValueError(f"Unknown tensor type {type_}")

@property
def return_tensors_enum(self) -> Optional[TensorType]:
if self.return_tensors is not None and not isinstance(
self.return_tensors, TensorType
):
return TensorType(self.return_tensors)
return self.return_tensors

@classmethod
def process(cls, obj: BatchEncoding, **kwargs) -> DataType:
return BatchEncodingType(
return_tensors=cls.get_tensors_type(obj),
item_types={
k: DataAnalyzer.analyze(v, is_dynamic=True, **kwargs)
for (k, v) in obj.items()
},
)

def get_requirements(self) -> Requirements:
new = Requirements.new("transformers")
if self.return_tensors_enum in ADDITIONAL_DEPS:
new += Requirements.new(ADDITIONAL_DEPS[self.return_tensors_enum])
return new

def get_writer(
self, project: str = None, filename: str = None, **kwargs
) -> DataWriter:
return BatchEncodingWriter(**kwargs)


class BatchEncodingSerializer(DictSerializer):
data_class: ClassVar = BatchEncodingType
is_default: ClassVar = True

@staticmethod
def _check_type_and_keys(data_type, obj, exc_type):
data_type.check_type(obj, (dict, BatchEncoding), exc_type)
if set(obj.keys()) != set(data_type.item_types.keys()):
raise exc_type(
f"given dict has keys: {set(obj.keys())}, expected: {set(data_type.item_types.keys())}"
)

def deserialize(self, data_type: DictType, obj):
assert isinstance(data_type, BatchEncodingType)
return BatchEncoding(
super().deserialize(data_type, obj),
tensor_type=data_type.return_tensors_enum,
)


class BatchEncodingReader(DictReader):
type: ClassVar = "batch_encoding"

def read(self, artifacts: Artifacts) -> DictType:
res = super().read(artifacts)
return res.bind(BatchEncoding(res.data))

def read_batch(
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DictType]:
raise NotImplementedError


class BatchEncodingWriter(DictWriter):
type: ClassVar = "batch_encoding"

def write(
self, data: DataType, storage: Storage, path: str
) -> Tuple[DictReader, Artifacts]:
res, art = super().write(data, storage, path)
return (
BatchEncodingReader(
data_type=res.data_type, item_readers=res.item_readers
),
art,
)
2 changes: 1 addition & 1 deletion mlem/core/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ class DictWriter(DataWriter):

def write(
self, data: DataType, storage: Storage, path: str
) -> Tuple[DataReader, Artifacts]:
) -> Tuple["DictReader", Artifacts]:
if not isinstance(data, DictType):
raise ValueError(
f"expected data to be of DictType, got {type(data)} instead"
Expand Down
9 changes: 5 additions & 4 deletions mlem/core/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,12 @@ def get_object_metadata(
params: Dict[str, str] = None,
preprocess: Union[Any, Dict[str, Any]] = None,
postprocess: Union[Any, Dict[str, Any]] = None,
**kwargs,
) -> Union[MlemData, MlemModel]:
"""Convert given object to appropriate MlemObject subclass"""
if preprocess is None and postprocess is None:
try:
return MlemData.from_data(
obj,
params=params,
)
return MlemData.from_data(obj, params=params, **kwargs)
except HookNotFound:
pass

Expand All @@ -56,6 +54,7 @@ def get_object_metadata(
params=params,
preprocess=preprocess,
postprocess=postprocess,
**kwargs,
)


Expand Down Expand Up @@ -100,6 +99,7 @@ def save(
params: Dict[str, str] = None,
preprocess: Union[Any, Dict[str, Any]] = None,
postprocess: Union[Any, Dict[str, Any]] = None,
**kwargs,
) -> MlemObject:
"""Saves given object to a given path

Expand All @@ -125,6 +125,7 @@ def save(
params=params,
preprocess=preprocess,
postprocess=postprocess,
**kwargs,
)
log_meta_params(meta, add_object_type=True)
path = os.fspath(path)
Expand Down
2 changes: 1 addition & 1 deletion mlem/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def from_argspec(
f"auto_infer=True, but no value for {name} argument"
)
type_ = DataAnalyzer.analyze(
defaults.get(name, call_kwargs.get(name))
call_kwargs.get(name, defaults.get(name))
)
else:
type_ = UnspecifiedDataType()
Expand Down
Loading