Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
add support for sklearn transformers (#538)
Browse files Browse the repository at this point in the history
close #514

Co-authored-by: Yury <yury@iterative.ai>
Co-authored-by: Alexander Guschin <1aguschin@gmail.com>
  • Loading branch information
3 people authored Jan 19, 2023
1 parent 892a039 commit c7f9161
Show file tree
Hide file tree
Showing 9 changed files with 391 additions and 26 deletions.
13 changes: 8 additions & 5 deletions mlem/api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
MlemLink,
MlemModel,
MlemObject,
_ModelMethodCall,
)
from mlem.runtime.client import Client
from mlem.runtime.interface import ModelInterface
Expand Down Expand Up @@ -85,18 +86,20 @@ def apply(
except WrongMethodError:
resolved_method = PREDICT_METHOD_NAME
echo(EMOJI_APPLY + f"Applying `{resolved_method}` method...")
method_call = _ModelMethodCall(
name=resolved_method,
order=model.call_orders[resolved_method],
model=model,
)
if batch_size:
res: Any = []
for part in data:
batch_data = get_data_value(part, batch_size)
for batch in batch_data:
preds = w.call_method(resolved_method, batch.data)
preds = method_call(batch.data)
res += [*preds] # TODO: merge results
else:
res = [
w.call_method(resolved_method, get_data_value(part))
for part in data
]
res = [method_call(get_data_value(part)) for part in data]
if output is None:
if len(res) == 1:
return res[0]
Expand Down
1 change: 1 addition & 0 deletions mlem/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
PREDICT_METHOD_NAME = "predict"
PREDICT_PROBA_METHOD_NAME = "predict_proba"
PREDICT_ARG_NAME = "data"
TRANSFORM_METHOD_NAME = "transform"

MLEM_CONFIG_FILE_NAME = ".mlem.yaml"
37 changes: 19 additions & 18 deletions mlem/contrib/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ def np_type_from_string(string_repr) -> np.dtype:
raise ValueError(f"Unknown numpy type {string_repr}") from e


def check_shape(shape, array, exc_type):
if shape is not None:
if len(array.shape) != len(shape):
raise exc_type(
f"given array is of rank: {len(array.shape)}, expected: {len(shape)}"
)

array_shape = tuple(
None if expected_dim is None else array_dim
for array_dim, expected_dim in zip(array.shape, shape)
)
if tuple(array_shape) != shape:
raise exc_type(
f"given array is of shape: {array_shape}, expected: {shape}"
)


class NumpyNumberType(
WithDefaultSerializer, LibRequirementsMixin, DataType, DataHook
):
Expand Down Expand Up @@ -123,22 +140,6 @@ def subtype(self, subshape: Tuple[Optional[int], ...]):
max_items=subshape[0],
)

def check_shape(self, array, exc_type):
if self.shape is not None:
if len(array.shape) != len(self.shape):
raise exc_type(
f"given array is of rank: {len(array.shape)}, expected: {len(self.shape)}"
)

array_shape = tuple(
None if expected_dim is None else array_dim
for array_dim, expected_dim in zip(array.shape, self.shape)
)
if tuple(array_shape) != self.shape:
raise exc_type(
f"given array is of shape: {array_shape}, expected: {self.shape}"
)

def get_writer(self, project: str = None, filename: str = None, **kwargs):
return NumpyArrayWriter()

Expand Down Expand Up @@ -171,7 +172,7 @@ def deserialize(self, data_type, obj):
f"given object: {obj} could not be converted to array "
f"of type: {np_type_from_string(data_type.dtype)}"
) from e
data_type.check_shape(ret, DeserializationError)
check_shape(data_type.shape, ret, DeserializationError)
return ret

def serialize(self, data_type, instance: np.ndarray):
Expand All @@ -181,7 +182,7 @@ def serialize(self, data_type, instance: np.ndarray):
raise SerializationError(
f"given array is of type: {instance.dtype}, expected: {exp_type}"
)
data_type.check_shape(instance, SerializationError)
check_shape(data_type.shape, instance, SerializationError)
return instance.tolist()


Expand Down
150 changes: 150 additions & 0 deletions mlem/contrib/scipy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""Scipy Sparse matrices support
Extension type: data
DataType, Reader and Writer implementations for `scipy.sparse`
"""
from typing import ClassVar, Iterator, List, Optional, Tuple, Type, Union

import scipy
from pydantic import BaseModel
from pydantic.main import create_model
from pydantic.types import conlist
from scipy import sparse
from scipy.sparse import spmatrix

from mlem.contrib.numpy import (
check_shape,
np_type_from_string,
python_type_from_np_string_repr,
)
from mlem.core.artifacts import Artifacts, Storage
from mlem.core.data_type import (
DataHook,
DataReader,
DataSerializer,
DataType,
DataWriter,
WithDefaultSerializer,
)
from mlem.core.errors import DeserializationError, SerializationError
from mlem.core.hooks import IsInstanceHookMixin
from mlem.core.requirements import InstallableRequirement, Requirements


class ScipySparseMatrix(
WithDefaultSerializer, DataType, DataHook, IsInstanceHookMixin
):
"""
DataType implementation for scipy sparse matrix
"""

type: ClassVar[str] = "csr_matrix"
valid_types: ClassVar = (spmatrix,)
shape: Optional[Tuple]
"""Shape of `sparse.csr_matrix` object in data"""
dtype: str
"""Dtype of `sparse.csr_matrix` object in data"""

def get_requirements(self) -> Requirements:
return Requirements.new([InstallableRequirement.from_module(scipy)])

@classmethod
def process(cls, obj: sparse.csr_matrix, **kwargs) -> DataType:
return ScipySparseMatrix(dtype=obj.dtype.name, shape=obj.shape)

def get_writer(
self, project: str = None, filename: str = None, **kwargs
) -> DataWriter:
return ScipyWriter(**kwargs)

def subtype(self, subshape: Tuple[Optional[int], ...]):
if len(subshape) == 0:
return python_type_from_np_string_repr(self.dtype)
return conlist(
self.subtype(subshape[1:]),
min_items=subshape[0],
max_items=subshape[0],
)


class ScipyWriter(DataWriter[ScipySparseMatrix]):
"""
Write scipy matrix to npz format
"""

type: ClassVar[str] = "csr_matrix"

def write(
self, data: DataType, storage: Storage, path: str
) -> Tuple[DataReader, Artifacts]:
with storage.open(path) as (f, art):
sparse.save_npz(f, data.data)
return ScipyReader(data_type=data), {self.art_name: art}


class ScipyReader(DataReader):
"""
Read scipy matrix from npz format
"""

type: ClassVar[str] = "csr_matrix"

def read_batch(
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DataType]:
raise NotImplementedError

def read(self, artifacts: Artifacts) -> Iterator[DataType]:
if DataWriter.art_name not in artifacts:
raise ValueError(
f"Wrong artifacts {artifacts}: should be one {DataWriter.art_name} file"
)
with artifacts[DataWriter.art_name].open() as f:
data = sparse.load_npz(f)
return self.data_type.copy().bind(data)


class ScipySparseMatrixSerializer(DataSerializer[ScipySparseMatrix]):
"""
Serializer for scipy sparse matrices
"""

is_default: ClassVar = True
data_class: ClassVar = ScipySparseMatrix

def get_model(
self, data_type: ScipySparseMatrix, prefix: str = ""
) -> Union[Type[BaseModel], type]:
item_type = List[data_type.subtype(data_type.shape[1:])] # type: ignore[index]
return create_model(
prefix + "ScipySparse",
__root__=(item_type, ...),
)

def serialize(self, data_type: ScipySparseMatrix, instance: spmatrix):
data_type.check_type(instance, sparse.csr_matrix, SerializationError)
if instance.dtype != np_type_from_string(data_type.dtype):
raise SerializationError(
f"given matrix is of dtype: {instance.dtype}, "
f"expected: {data_type.dtype}"
)
check_shape(data_type.shape, instance, SerializationError)
coordinate_matrix = instance.tocoo()
data = coordinate_matrix.data
row = coordinate_matrix.row
col = coordinate_matrix.col
return data, (row, col)

def deserialize(self, data_type, obj) -> sparse.csr_matrix:

try:
mat = sparse.csr_matrix(
obj, dtype=data_type.dtype, shape=data_type.shape
)
except ValueError as e:
raise DeserializationError(
f"Given object {obj} could not be converted"
f"to sparse matrix of type: {data_type.type}"
) from e
check_shape(data_type.shape, mat, DeserializationError)
return mat
36 changes: 36 additions & 0 deletions mlem/contrib/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@

import sklearn
from sklearn.base import ClassifierMixin, RegressorMixin
from sklearn.feature_extraction.text import TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.preprocessing._encoders import _BaseEncoder

from mlem.constants import TRANSFORM_METHOD_NAME
from mlem.core.hooks import IsInstanceHookMixin
from mlem.core.model import (
ModelHook,
Expand Down Expand Up @@ -132,3 +135,36 @@ def process(
**predict_proba_args
)
return mt


class SklearnTransformer(SklearnModel):
"""
Model Type implementation for sklearn transformers
"""

valid_types: ClassVar = (TransformerMixin, _BaseEncoder)
type: ClassVar = "sklearn_transformer"

@classmethod
def process(
cls,
obj: Any,
sample_data: Optional[Any] = None,
methods_sample_data: Optional[Dict[str, Any]] = None,
**kwargs
) -> ModelType:
methods_sample_data = methods_sample_data or {}
sample_data = methods_sample_data.get(
TRANSFORM_METHOD_NAME, sample_data
)
methods = {
TRANSFORM_METHOD_NAME: Signature.from_method(
obj.transform,
auto_infer=sample_data is not None,
X=sample_data,
),
}

return SklearnTransformer(io=SimplePickleIO(), methods=methods).bind(
obj
)
1 change: 1 addition & 0 deletions mlem/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class ExtensionLoader:
False,
),
Extension("mlem.contrib.git", ["pygit2"], True),
Extension("mlem.contrib.scipy", ["scipy"], False),
Extension(
"mlem.contrib.flyio", ["docker", "fastapi", "uvicorn"], False
),
Expand Down
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"pandas": ["pandas"],
"numpy": ["numpy"],
"sklearn": ["scikit-learn"],
"scipy": ["scipy"],
"onnx": ["onnx"],
"onnxruntime": [
"protobuf==3.20.1",
Expand Down Expand Up @@ -246,6 +247,11 @@
"serializer.xgboost_dmatrix = mlem.contrib.xgboost:DMatrixSerializer",
"model_type.xgboost = mlem.contrib.xgboost:XGBoostModel",
"model_io.xgboost_io = mlem.contrib.xgboost:XGBoostModelIO",
"data_type.csr_matrix = mlem.contrib.scipy:ScipySparseMatrix",
"data_writer.csr_matrix = mlem.contrib.scipy:ScipyWriter",
"data_reader.csr_matrix = mlem.contrib.scipy:ScipyReader",
"model_type.sklearn_transformer = mlem.contrib.sklearn:SklearnTransformer",
"serializer.csr_matrix = mlem.contrib.scipy:ScipySparseMatrixSerializer",
],
"mlem.config": [
"core = mlem.config:MlemConfig",
Expand Down
Loading

0 comments on commit c7f9161

Please sign in to comment.