Skip to content

Commit

Permalink
chore(fix): Restore the server backward compatibility code (#2327)
Browse files Browse the repository at this point in the history
# Description

The code included in PR #2248 breaks backward compatibility for older
server versions. This PR restores the codebase for providing
compatibility in those cases.

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [x] Bug fix (non-breaking change which fixes an issue)


**Checklist**

- [x] I have merged the original branch into my forked branch
- [x] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [x] I added comments to my code
- [x] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works

---------

Co-authored-by: Francisco Aranda <francisco@recogn.ai>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 9, 2023
1 parent 106060f commit e55ea3e
Show file tree
Hide file tree
Showing 6 changed files with 281 additions and 55 deletions.
4 changes: 3 additions & 1 deletion src/argilla/client/apis/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def __enter__(self):
if api_version.is_devrelease:
api_version = parse(api_version.base_version)
if not api_version >= self._min_version:
raise ApiCompatibilityError(str(self._min_version))
raise ApiCompatibilityError(
str(self._min_version), api_version=api_version
)
pass

def __exit__(
Expand Down
207 changes: 154 additions & 53 deletions src/argilla/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import re
import warnings
from asyncio import Future
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union

from tqdm.auto import tqdm

Expand Down Expand Up @@ -49,6 +49,7 @@
from argilla.client.sdk.commons.api import async_bulk
from argilla.client.sdk.commons.errors import (
AlreadyExistsApiError,
ApiCompatibilityError,
InputValueError,
NotFoundApiError,
)
Expand All @@ -67,13 +68,15 @@
LabelingRule,
LabelingRuleMetricsSummary,
TextClassificationBulkData,
TextClassificationQuery,
)
from argilla.client.sdk.text_classification.models import (
TextClassificationRecord as SdkTextClassificationRecord,
)
from argilla.client.sdk.token_classification.models import (
CreationTokenClassificationRecord,
TokenClassificationBulkData,
TokenClassificationQuery,
)
from argilla.client.sdk.token_classification.models import (
TokenClassificationRecord as SdkTokenClassificationRecord,
Expand Down Expand Up @@ -489,65 +492,37 @@ def load(
" `rg.load('my_dataset').to_pandas()`.",
)

dataset = self.datasets.find_by_name(name=name)
task = dataset.task

task_config = {
TaskType.text_classification: (
SdkTextClassificationRecord,
DatasetForTextClassification,
),
TaskType.token_classification: (
SdkTokenClassificationRecord,
DatasetForTokenClassification,
),
TaskType.text2text: (
SdkText2TextRecord,
DatasetForText2Text,
),
}

try:
sdk_record_class, dataset_class = task_config[task]
except KeyError:
raise ValueError(
f"Load method not supported for the '{task}' task. Supported Tasks: "
f"{[TaskType.text_classification, TaskType.token_classification, TaskType.text2text]}"
return self._load_records_new_fashion(
name=name,
query=query,
vector=vector,
ids=ids,
limit=limit,
id_from=id_from,
)
except ApiCompatibilityError as err: # Api backward compatibility
from argilla import __version__ as version

if vector:
vector_search = VectorSearch(
name=vector[0],
value=vector[1],
warnings.warn(
message=f"Using python client argilla=={version},"
f" however deployed server version is {err.api_version}."
" This might lead to compatibility issues.\n"
f" Preferably, update your server version to {version}"
" or downgrade your Python API at the loss"
" of functionality and robustness via\n"
f"`pip install argilla=={err.api_version}`",
category=UserWarning,
)
results = self.search.search_records(

return self._load_records_old_fashion(
name=name,
task=task,
size=limit or 100,
# query args
query_text=query,
vector=vector_search,
query=query,
ids=ids,
limit=limit,
id_from=id_from,
)

return dataset_class(results.records)

records = self.datasets.scan(
name=name,
projection={"*"},
limit=limit,
id_from=id_from,
# Query
query_text=query,
ids=ids,
)
records = [sdk_record_class.parse_obj(r).to_client() for r in records]
try:
records_sorted_by_id = sorted(records, key=lambda x: x.id)
# record ids can be a mix of int/str -> sort all as str type
except TypeError:
records_sorted_by_id = sorted(records, key=lambda x: str(x.id))
return dataset_class(records_sorted_by_id)

def dataset_metrics(self, name: str) -> List[MetricInfo]:
response = datasets_api.get_dataset(self._client, name)
response = metrics_api.get_dataset_metrics(
Expand Down Expand Up @@ -653,3 +628,129 @@ def rule_metrics_for_dataset(
)

return LabelingRuleMetricsSummary.parse_obj(response.parsed)

def _load_records_old_fashion(
self,
name: str,
query: Optional[str] = None,
ids: Optional[List[Union[str, int]]] = None,
limit: Optional[int] = None,
id_from: Optional[str] = None,
) -> Dataset:
from argilla.client.sdk.text2text import api as text2text_api
from argilla.client.sdk.text2text.models import Text2TextQuery
from argilla.client.sdk.text_classification import (
api as text_classification_api,
)
from argilla.client.sdk.token_classification import (
api as token_classification_api,
)

response = datasets_api.get_dataset(client=self._client, name=name)
task = response.parsed.task

task_config = {
TaskType.text_classification: (
text_classification_api.data,
TextClassificationQuery,
DatasetForTextClassification,
),
TaskType.token_classification: (
token_classification_api.data,
TokenClassificationQuery,
DatasetForTokenClassification,
),
TaskType.text2text: (
text2text_api.data,
Text2TextQuery,
DatasetForText2Text,
),
}

try:
get_dataset_data, request_class, dataset_class = task_config[task]
except KeyError:
raise ValueError(
f"Load method not supported for the '{task}' task. Supported tasks: "
f"{[TaskType.text_classification, TaskType.token_classification, TaskType.text2text]}"
)
response = get_dataset_data(
client=self._client,
name=name,
request=request_class(ids=ids, query_text=query),
limit=limit,
id_from=id_from,
)

records = [sdk_record.to_client() for sdk_record in response.parsed]
return dataset_class(self.__sort_records_by_id__(records))

def _load_records_new_fashion(
self,
name: str,
query: Optional[str] = None,
vector: Optional[Tuple[str, List[float]]] = None,
ids: Optional[List[Union[str, int]]] = None,
limit: Optional[int] = None,
id_from: Optional[str] = None,
) -> Dataset:
dataset = self.datasets.find_by_name(name=name)
task = dataset.task

task_config = {
TaskType.text_classification: (
SdkTextClassificationRecord,
DatasetForTextClassification,
),
TaskType.token_classification: (
SdkTokenClassificationRecord,
DatasetForTokenClassification,
),
TaskType.text2text: (
SdkText2TextRecord,
DatasetForText2Text,
),
}

try:
sdk_record_class, dataset_class = task_config[task]
except KeyError:
raise ValueError(
f"Load method not supported for the '{task}' task. Supported Tasks: "
f"{[TaskType.text_classification, TaskType.token_classification, TaskType.text2text]}"
)

if vector:
vector_search = VectorSearch(
name=vector[0],
value=vector[1],
)
results = self.search.search_records(
name=name,
task=task,
size=limit or 100,
# query args
query_text=query,
vector=vector_search,
)
return dataset_class(results.records)

records = self.datasets.scan(
name=name,
projection={"*"},
limit=limit,
id_from=id_from,
# Query
query_text=query,
ids=ids,
)
records = [sdk_record_class.parse_obj(r).to_client() for r in records]
return dataset_class(self.__sort_records_by_id__(records))

def __sort_records_by_id__(self, records: list) -> list:
try:
records_sorted_by_id = sorted(records, key=lambda x: x.id)
# record ids can be a mix of int/str -> sort all as str type
except TypeError:
records_sorted_by_id = sorted(records, key=lambda x: str(x.id))
return records_sorted_by_id
3 changes: 2 additions & 1 deletion src/argilla/client/sdk/commons/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ class InputValueError(BaseClientError):


class ApiCompatibilityError(BaseClientError):
def __init__(self, min_version: str):
def __init__(self, min_version: str, api_version: str):
self.min_version = min_version
self.api_version = api_version

def __str__(self):
return (
Expand Down
46 changes: 46 additions & 0 deletions src/argilla/client/sdk/text2text/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# coding=utf-8
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union

import httpx

from argilla.client.sdk.client import AuthenticatedClient
from argilla.client.sdk.commons.api import build_data_response, build_param_dict
from argilla.client.sdk.commons.models import (
ErrorMessage,
HTTPValidationError,
Response,
)
from argilla.client.sdk.text2text.models import Text2TextQuery, Text2TextRecord


def data(
client: AuthenticatedClient,
name: str,
request: Optional[Text2TextQuery] = None,
limit: Optional[int] = None,
id_from: Optional[str] = None,
) -> Response[Union[List[Text2TextRecord], HTTPValidationError, ErrorMessage]]:

path = f"/api/datasets/{name}/Text2Text/data"
params = build_param_dict(id_from, limit)

with client.stream(
method="POST",
path=path,
params=params if params else None,
json=request.dict() if request else {},
) as response:
return build_data_response(response=response, data_type=Text2TextRecord)
22 changes: 22 additions & 0 deletions src/argilla/client/sdk/text_classification/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,28 @@
)


def data(
client: AuthenticatedClient,
name: str,
request: Optional[TextClassificationQuery] = None,
limit: Optional[int] = None,
id_from: Optional[str] = None,
) -> Response[Union[List[TextClassificationRecord], HTTPValidationError, ErrorMessage]]:

path = f"/api/datasets/{name}/TextClassification/data"
params = build_param_dict(id_from, limit)

with client.stream(
method="POST",
path=path,
params=params if params else None,
json=request.dict() if request else {},
) as response:
return build_data_response(
response=response, data_type=TextClassificationRecord
)


def add_dataset_labeling_rule(
client: AuthenticatedClient,
name: str,
Expand Down
Loading

0 comments on commit e55ea3e

Please sign in to comment.