diff --git a/setup.cfg b/setup.cfg index e05f599067..946c462fb0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,7 +22,7 @@ install_requires = requests ~= 2.25.0 # TODO We should use httpx instead to avoid a new dependency # Client pandas >=1.0.0,<2.0.0 # For data loading - + pydantic ~= 1.8.1 [options.packages.find] where = src @@ -36,7 +36,6 @@ where = src server = # Basic dependencies fastapi ~= 0.63.0 - pydantic ~= 1.7.0 uvicorn[standard] ~= 0.13.4 elasticsearch >= 7.1.0,<8.0.0 smart-open diff --git a/src/rubrix/client/models.py b/src/rubrix/client/models.py index c34150ece9..c804186ba3 100644 --- a/src/rubrix/client/models.py +++ b/src/rubrix/client/models.py @@ -52,7 +52,7 @@ class TextClassificationRecord(BaseModel): Attributes ---------- - inputs : `Dict[str, Any]` + inputs : `Dict[str, Union[str, List[str]]]` The inputs of the record prediction : `List[Tuple[str, float]]`, optional A list of tuples containing the predictions for the record. The first entry of the tuple is the predicted label, @@ -80,7 +80,7 @@ class TextClassificationRecord(BaseModel): The timestamp of the record. Default: None """ - inputs: Dict[str, Any] + inputs: Dict[str, Union[str, List[str]]] prediction: Optional[List[Tuple[str, float]]] = None annotation: Optional[Union[str, List[str]]] = None diff --git a/src/rubrix/server/tasks/text_classification/api/model.py b/src/rubrix/server/tasks/text_classification/api/model.py index 9d89ef37cf..4287614cff 100644 --- a/src/rubrix/server/tasks/text_classification/api/model.py +++ b/src/rubrix/server/tasks/text_classification/api/model.py @@ -81,7 +81,7 @@ class CreationTextClassificationRecord(BaseRecord[TextClassificationAnnotation]) Attributes: ----------- - inputs: Dict[str, Any] + inputs: Dict[str, Union[str, List[str]]] The input data text multi_label: bool @@ -93,7 +93,7 @@ class CreationTextClassificationRecord(BaseRecord[TextClassificationAnnotation]) The dictionary key must be aligned with provided record text. Optional """ - inputs: Dict[str, Any] + inputs: Dict[str, Union[str, List[str]]] multi_label: bool = False explanation: Dict[str, List[TokenAttributions]] = None diff --git a/tests/client/test_models.py b/tests/client/test_models.py index f6983118e9..404f517804 100644 --- a/tests/client/test_models.py +++ b/tests/client/test_models.py @@ -1,4 +1,5 @@ import pytest +from pydantic import ValidationError from rubrix.client.models import TextClassificationRecord from rubrix.client.models import TokenClassificationRecord @@ -36,3 +37,9 @@ def test_token_classification_record(annotation, status, expected_status): text="test text", tokens=["test", "text"], annotation=annotation, status=status ) assert record.status == expected_status + + +def test_text_classification_record_none_inputs(): + """Test validation error for None in inputs""" + with pytest.raises(ValidationError): + TextClassificationRecord(inputs={"text": None}) \ No newline at end of file diff --git a/tests/server/text_classification/test_model.py b/tests/server/text_classification/test_model.py index f280768412..3b7b90fb3e 100644 --- a/tests/server/text_classification/test_model.py +++ b/tests/server/text_classification/test_model.py @@ -9,16 +9,6 @@ ) -def test_flatten_inputs(): - data = { - "inputs": { - "mail": {"subject": "The mail subject", "body": "This is a large text body"} - } - } - record = TextClassificationRecord.parse_obj(data) - assert list(record.inputs.keys()) == ["mail.subject", "mail.body"] - - def test_flatten_metadata(): data = { "inputs": {"text": "bogh"},