From e55ea3e0d22b8cad213c84a363482ce4fa223a33 Mon Sep 17 00:00:00 2001 From: Francisco Aranda <2518789+frascuchon@users.noreply.github.com> Date: Thu, 9 Feb 2023 17:41:07 +0100 Subject: [PATCH] chore(fix): Restore the server backward compatibility code (#2327) # Description The code included in PR #2248 breaks backward compatibility for older server versions. This PR restores the codebase for providing compatibility in those cases. **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [x] Bug fix (non-breaking change which fixes an issue) **Checklist** - [x] I have merged the original branch into my forked branch - [x] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [x] I added comments to my code - [x] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works --------- Co-authored-by: Francisco Aranda Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/argilla/client/apis/status.py | 4 +- src/argilla/client/client.py | 207 +++++++++++++----- src/argilla/client/sdk/commons/errors.py | 3 +- src/argilla/client/sdk/text2text/api.py | 46 ++++ .../client/sdk/text_classification/api.py | 22 ++ .../client/sdk/token_classification/api.py | 54 +++++ 6 files changed, 281 insertions(+), 55 deletions(-) create mode 100644 src/argilla/client/sdk/text2text/api.py create mode 100644 src/argilla/client/sdk/token_classification/api.py diff --git a/src/argilla/client/apis/status.py b/src/argilla/client/apis/status.py index 4d29c03afd..fade8cf6ad 100644 --- a/src/argilla/client/apis/status.py +++ b/src/argilla/client/apis/status.py @@ -64,7 +64,9 @@ def __enter__(self): if api_version.is_devrelease: api_version = parse(api_version.base_version) if not api_version >= self._min_version: - raise ApiCompatibilityError(str(self._min_version)) + raise ApiCompatibilityError( + str(self._min_version), api_version=api_version + ) pass def __exit__( diff --git a/src/argilla/client/client.py b/src/argilla/client/client.py index db1d4f89d4..5a3974ede3 100644 --- a/src/argilla/client/client.py +++ b/src/argilla/client/client.py @@ -18,7 +18,7 @@ import re import warnings from asyncio import Future -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union from tqdm.auto import tqdm @@ -49,6 +49,7 @@ from argilla.client.sdk.commons.api import async_bulk from argilla.client.sdk.commons.errors import ( AlreadyExistsApiError, + ApiCompatibilityError, InputValueError, NotFoundApiError, ) @@ -67,6 +68,7 @@ LabelingRule, LabelingRuleMetricsSummary, TextClassificationBulkData, + TextClassificationQuery, ) from argilla.client.sdk.text_classification.models import ( TextClassificationRecord as SdkTextClassificationRecord, @@ -74,6 +76,7 @@ from argilla.client.sdk.token_classification.models import ( CreationTokenClassificationRecord, TokenClassificationBulkData, + TokenClassificationQuery, ) from argilla.client.sdk.token_classification.models import ( TokenClassificationRecord as SdkTokenClassificationRecord, @@ -489,65 +492,37 @@ def load( " `rg.load('my_dataset').to_pandas()`.", ) - dataset = self.datasets.find_by_name(name=name) - task = dataset.task - - task_config = { - TaskType.text_classification: ( - SdkTextClassificationRecord, - DatasetForTextClassification, - ), - TaskType.token_classification: ( - SdkTokenClassificationRecord, - DatasetForTokenClassification, - ), - TaskType.text2text: ( - SdkText2TextRecord, - DatasetForText2Text, - ), - } - try: - sdk_record_class, dataset_class = task_config[task] - except KeyError: - raise ValueError( - f"Load method not supported for the '{task}' task. Supported Tasks: " - f"{[TaskType.text_classification, TaskType.token_classification, TaskType.text2text]}" + return self._load_records_new_fashion( + name=name, + query=query, + vector=vector, + ids=ids, + limit=limit, + id_from=id_from, ) + except ApiCompatibilityError as err: # Api backward compatibility + from argilla import __version__ as version - if vector: - vector_search = VectorSearch( - name=vector[0], - value=vector[1], + warnings.warn( + message=f"Using python client argilla=={version}," + f" however deployed server version is {err.api_version}." + " This might lead to compatibility issues.\n" + f" Preferably, update your server version to {version}" + " or downgrade your Python API at the loss" + " of functionality and robustness via\n" + f"`pip install argilla=={err.api_version}`", + category=UserWarning, ) - results = self.search.search_records( + + return self._load_records_old_fashion( name=name, - task=task, - size=limit or 100, - # query args - query_text=query, - vector=vector_search, + query=query, + ids=ids, + limit=limit, + id_from=id_from, ) - return dataset_class(results.records) - - records = self.datasets.scan( - name=name, - projection={"*"}, - limit=limit, - id_from=id_from, - # Query - query_text=query, - ids=ids, - ) - records = [sdk_record_class.parse_obj(r).to_client() for r in records] - try: - records_sorted_by_id = sorted(records, key=lambda x: x.id) - # record ids can be a mix of int/str -> sort all as str type - except TypeError: - records_sorted_by_id = sorted(records, key=lambda x: str(x.id)) - return dataset_class(records_sorted_by_id) - def dataset_metrics(self, name: str) -> List[MetricInfo]: response = datasets_api.get_dataset(self._client, name) response = metrics_api.get_dataset_metrics( @@ -653,3 +628,129 @@ def rule_metrics_for_dataset( ) return LabelingRuleMetricsSummary.parse_obj(response.parsed) + + def _load_records_old_fashion( + self, + name: str, + query: Optional[str] = None, + ids: Optional[List[Union[str, int]]] = None, + limit: Optional[int] = None, + id_from: Optional[str] = None, + ) -> Dataset: + from argilla.client.sdk.text2text import api as text2text_api + from argilla.client.sdk.text2text.models import Text2TextQuery + from argilla.client.sdk.text_classification import ( + api as text_classification_api, + ) + from argilla.client.sdk.token_classification import ( + api as token_classification_api, + ) + + response = datasets_api.get_dataset(client=self._client, name=name) + task = response.parsed.task + + task_config = { + TaskType.text_classification: ( + text_classification_api.data, + TextClassificationQuery, + DatasetForTextClassification, + ), + TaskType.token_classification: ( + token_classification_api.data, + TokenClassificationQuery, + DatasetForTokenClassification, + ), + TaskType.text2text: ( + text2text_api.data, + Text2TextQuery, + DatasetForText2Text, + ), + } + + try: + get_dataset_data, request_class, dataset_class = task_config[task] + except KeyError: + raise ValueError( + f"Load method not supported for the '{task}' task. Supported tasks: " + f"{[TaskType.text_classification, TaskType.token_classification, TaskType.text2text]}" + ) + response = get_dataset_data( + client=self._client, + name=name, + request=request_class(ids=ids, query_text=query), + limit=limit, + id_from=id_from, + ) + + records = [sdk_record.to_client() for sdk_record in response.parsed] + return dataset_class(self.__sort_records_by_id__(records)) + + def _load_records_new_fashion( + self, + name: str, + query: Optional[str] = None, + vector: Optional[Tuple[str, List[float]]] = None, + ids: Optional[List[Union[str, int]]] = None, + limit: Optional[int] = None, + id_from: Optional[str] = None, + ) -> Dataset: + dataset = self.datasets.find_by_name(name=name) + task = dataset.task + + task_config = { + TaskType.text_classification: ( + SdkTextClassificationRecord, + DatasetForTextClassification, + ), + TaskType.token_classification: ( + SdkTokenClassificationRecord, + DatasetForTokenClassification, + ), + TaskType.text2text: ( + SdkText2TextRecord, + DatasetForText2Text, + ), + } + + try: + sdk_record_class, dataset_class = task_config[task] + except KeyError: + raise ValueError( + f"Load method not supported for the '{task}' task. Supported Tasks: " + f"{[TaskType.text_classification, TaskType.token_classification, TaskType.text2text]}" + ) + + if vector: + vector_search = VectorSearch( + name=vector[0], + value=vector[1], + ) + results = self.search.search_records( + name=name, + task=task, + size=limit or 100, + # query args + query_text=query, + vector=vector_search, + ) + return dataset_class(results.records) + + records = self.datasets.scan( + name=name, + projection={"*"}, + limit=limit, + id_from=id_from, + # Query + query_text=query, + ids=ids, + ) + records = [sdk_record_class.parse_obj(r).to_client() for r in records] + return dataset_class(self.__sort_records_by_id__(records)) + + def __sort_records_by_id__(self, records: list) -> list: + try: + records_sorted_by_id = sorted(records, key=lambda x: x.id) + # record ids can be a mix of int/str -> sort all as str type + except TypeError: + records_sorted_by_id = sorted(records, key=lambda x: str(x.id)) + return records_sorted_by_id diff --git a/src/argilla/client/sdk/commons/errors.py b/src/argilla/client/sdk/commons/errors.py index bb38fa0432..355f5fd754 100644 --- a/src/argilla/client/sdk/commons/errors.py +++ b/src/argilla/client/sdk/commons/errors.py @@ -38,8 +38,9 @@ class InputValueError(BaseClientError): class ApiCompatibilityError(BaseClientError): - def __init__(self, min_version: str): + def __init__(self, min_version: str, api_version: str): self.min_version = min_version + self.api_version = api_version def __str__(self): return ( diff --git a/src/argilla/client/sdk/text2text/api.py b/src/argilla/client/sdk/text2text/api.py new file mode 100644 index 0000000000..2baab48725 --- /dev/null +++ b/src/argilla/client/sdk/text2text/api.py @@ -0,0 +1,46 @@ +# coding=utf-8 +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Union + +import httpx + +from argilla.client.sdk.client import AuthenticatedClient +from argilla.client.sdk.commons.api import build_data_response, build_param_dict +from argilla.client.sdk.commons.models import ( + ErrorMessage, + HTTPValidationError, + Response, +) +from argilla.client.sdk.text2text.models import Text2TextQuery, Text2TextRecord + + +def data( + client: AuthenticatedClient, + name: str, + request: Optional[Text2TextQuery] = None, + limit: Optional[int] = None, + id_from: Optional[str] = None, +) -> Response[Union[List[Text2TextRecord], HTTPValidationError, ErrorMessage]]: + + path = f"/api/datasets/{name}/Text2Text/data" + params = build_param_dict(id_from, limit) + + with client.stream( + method="POST", + path=path, + params=params if params else None, + json=request.dict() if request else {}, + ) as response: + return build_data_response(response=response, data_type=Text2TextRecord) diff --git a/src/argilla/client/sdk/text_classification/api.py b/src/argilla/client/sdk/text_classification/api.py index 5699566eba..024997cf69 100644 --- a/src/argilla/client/sdk/text_classification/api.py +++ b/src/argilla/client/sdk/text_classification/api.py @@ -36,6 +36,28 @@ ) +def data( + client: AuthenticatedClient, + name: str, + request: Optional[TextClassificationQuery] = None, + limit: Optional[int] = None, + id_from: Optional[str] = None, +) -> Response[Union[List[TextClassificationRecord], HTTPValidationError, ErrorMessage]]: + + path = f"/api/datasets/{name}/TextClassification/data" + params = build_param_dict(id_from, limit) + + with client.stream( + method="POST", + path=path, + params=params if params else None, + json=request.dict() if request else {}, + ) as response: + return build_data_response( + response=response, data_type=TextClassificationRecord + ) + + def add_dataset_labeling_rule( client: AuthenticatedClient, name: str, diff --git a/src/argilla/client/sdk/token_classification/api.py b/src/argilla/client/sdk/token_classification/api.py new file mode 100644 index 0000000000..f5a95ba36d --- /dev/null +++ b/src/argilla/client/sdk/token_classification/api.py @@ -0,0 +1,54 @@ +# coding=utf-8 +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import httpx + +from argilla.client.sdk.client import AuthenticatedClient +from argilla.client.sdk.commons.api import build_data_response, build_param_dict +from argilla.client.sdk.commons.models import ( + ErrorMessage, + HTTPValidationError, + Response, +) +from argilla.client.sdk.token_classification.models import ( + TokenClassificationQuery, + TokenClassificationRecord, +) + + +def data( + client: AuthenticatedClient, + name: str, + request: Optional[TokenClassificationQuery] = None, + limit: Optional[int] = None, + id_from: Optional[str] = None, +) -> Response[ + Union[List[TokenClassificationRecord], HTTPValidationError, ErrorMessage] +]: + + path = f"/api/datasets/{name}/TokenClassification/data" + params = build_param_dict(id_from, limit) + + with client.stream( + path=path, + method="POST", + params=params if params else None, + json=request.dict() if request else {}, + ) as response: + return build_data_response( + response=response, data_type=TokenClassificationRecord + )