Skip to content

Commit

Permalink
resolved #2406 #2407 #2408 #2405 (#2411)
Browse files Browse the repository at this point in the history
# Description

During `autotrain` integration, I came across some issues with the 
Closes #2406 
Closes #2407
Closes #2408 
Closes #2405

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [X] Bug fix (non-breaking change which fixes an issue)
- [X] New feature (non-breaking change which adds functionality)
- [X] Refactor (change restructuring the codebase without changing
functionality)
- [X] Improvement (change adding some improvement to an existing
functionality)
- [X] Documentation update

**How Has This Been Tested**

N.A.

**Checklist**

- [X] I added relevant documentation
- [X] I did a self-review of my code
- [X] I made corresponding changes to the documentation

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: frascuchon <francis@argilla.io>
  • Loading branch information
3 people authored Feb 24, 2023
1 parent d0f8882 commit 12aa8d6
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 17 deletions.
40 changes: 38 additions & 2 deletions docs/_source/guides/log_load_and_prepare_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -648,14 +648,15 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "907714da-5496-4a49-bc6d-2d9008d0082f",
"metadata": {},
"source": [
"### TokenClassification\n",
"\n",
"For token classification tasks, it converts the annotations of a record into integers representing BIO tags and writes them in a `ner_tags` column:\n",
"By passing the `framework` variable as `transformers` or `spacy`. "
"By passing the `framework` variable as `transformers`, `spark-nlp` or `spacy`. "
]
},
{
Expand Down Expand Up @@ -692,6 +693,41 @@
"# Output:\n",
"# <pd.DataFrame>"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "d14246b6",
"metadata": {},
"source": [
"### Text2Text\n",
"\n",
"For text generation tasks like `summarization` and translation tasks, it converts the annotations of a record `text` and `target` columns.\n",
"By passing the `framework` variable as `transformers`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f749b193",
"metadata": {},
"outputs": [],
"source": [
"import argilla as rg\n",
"\n",
"dataset_rg = rg.DatasetForText2Text()(\n",
" [\n",
" rg.Text2TextRecord()(\n",
" text=\"I live in Madrid\",\n",
" annotation=\"I live in Spain\",\n",
" )\n",
" ]\n",
")\n",
"\n",
"dataset_rg.prepare_for_training(framework=\"transformers\")[0]\n",
"# Output:\n",
"# {..., 'tokens': ['I', 'live', 'in', 'Madrid'], 'ner_tags': [0, 0, 0, 1], ...}+"
]
}
],
"metadata": {
Expand All @@ -710,7 +746,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
"version": "3.9.15"
},
"vscode": {
"interpreter": {
Expand Down
50 changes: 35 additions & 15 deletions src/argilla/client/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def to_datasets(self) -> "datasets.Dataset":
return dataset

def _to_datasets_dict(self) -> Dict:
"""Helper method to transform a argilla dataset into a dict that is compatible with `datasets.Dataset`"""
raise NotImplementedError

@classmethod
Expand Down Expand Up @@ -484,8 +483,8 @@ def prepare_for_training(
)
elif framework is Framework.SPACY and lang is None:
raise ValueError(
"Please provide a spacy language model to prepare the"
" dataset for training with the spacy framework."
"Please provide a spacy language model to prepare the dataset for"
" training with the spacy framework."
)
elif framework in [Framework.SPACY, Framework.SPARK_NLP]:
if train_size and test_size:
Expand Down Expand Up @@ -847,7 +846,7 @@ def _prepare_for_training_with_transformers(
},
features=datasets.Features(feature_dict),
)
if test_size:
if test_size is not None and test_size != 0:
ds = ds.train_test_split(
train_size=train_size, test_size=test_size, seed=seed
)
Expand Down Expand Up @@ -1007,7 +1006,7 @@ def from_datasets(
for row in dataset:
# TODO: fails with a KeyError if no tokens column is present and no mapping is indicated
if not row["tokens"]:
_LOGGER.warning(f"Ignoring row with no tokens.")
_LOGGER.warning("Ignoring row with no tokens.")
continue

if row.get("tags"):
Expand Down Expand Up @@ -1079,10 +1078,11 @@ def spans2iob(example):
.map(lambda example: {"ner_tags": spans2iob(example)})
)
new_features = ds.features.copy()
new_features["ner_tags"] = [class_tags]
new_features["ner_tags"] = datasets.Sequence(feature=class_tags)
ds = ds.cast(new_features)
ds = ds.remove_columns(set(ds.column_names) - set(["tokens", "ner_tags"]))

if train_size or test_size:
if test_size is not None and test_size != 0:
ds = ds.train_test_split(
train_size=train_size, test_size=test_size, seed=seed
)
Expand Down Expand Up @@ -1155,6 +1155,7 @@ def __only_annotations__(self, data) -> bool:

def _to_datasets_dict(self) -> Dict:
"""Helper method to put token classification records in a `datasets.Dataset`"""

# create a dict first, where we make the necessary transformations
def entities_to_dict(
entities: Optional[
Expand Down Expand Up @@ -1352,17 +1353,36 @@ def _from_pandas(cls, dataframe: pd.DataFrame) -> "DatasetForText2Text":
return cls([Text2TextRecord(**row) for row in dataframe.to_dict("records")])

@_requires_datasets
def prepare_for_training(self, **kwargs) -> "datasets.Dataset":
"""Prepares the dataset for training.
def _prepare_for_training_with_transformers(
self,
train_size: Optional[float] = None,
test_size: Optional[float] = None,
seed: Optional[int] = None,
):
import datasets

Args:
**kwargs: Specific to the task of the dataset.
ds_dict = {"text": [], "target": []}
for rec in self._records:
if rec.annotation is None:
continue
ds_dict["text"].append(rec.text)
ds_dict["target"].append(rec.annotation)

Returns:
A datasets Dataset.
"""
feature_dict = {
"text": datasets.Value("string"),
"target": datasets.Value("string"),
}

raise NotImplementedError
ds = datasets.Dataset.from_dict(
ds_dict, features=datasets.Features(feature_dict)
)

if test_size is not None and test_size != 0:
ds = ds.train_test_split(
train_size=train_size, test_size=test_size, seed=seed
)

return ds


Dataset = Union[
Expand Down
36 changes: 36 additions & 0 deletions tests/client/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,8 @@ def test_prepare_for_training(self):
]

train = rb_dataset.prepare_for_training()
assert (set(train.column_names)) == set(["tokens", "ner_tags"])

assert isinstance(train, datasets.DatasetD.Dataset) or isinstance(
train, datasets.Dataset
)
Expand Down Expand Up @@ -867,6 +869,40 @@ def test_to_from_pandas(self, text2text_records):
for rec, expected in zip(dataset, expected_dataset):
assert rec == expected

def test_prepare_for_training(self):
ds = ar.DatasetForText2Text(
[ar.Text2TextRecord(text="mock", annotation="mock")] * 10
)
train = ds.prepare_for_training(train_size=1)

assert isinstance(train, datasets.Dataset)
assert train.column_names == ["text", "target"]
assert len(train) == 10
assert train[1]["text"] == "mock"
assert train[1]["target"] == "mock"
assert train.features["text"] == datasets.Value("string")
assert train.features["target"] == datasets.Value("string")

train_test = ds.prepare_for_training(train_size=0.5)
assert len(train_test["train"]) == 5
assert len(train_test["test"]) == 5
for split in ["train", "test"]:
assert train_test[split].column_names == ["text", "target"]

def test_prepare_for_training_spacy(self):
ds = ar.DatasetForText2Text(
[ar.Text2TextRecord(text="mock", annotation="mock")] * 10
)
with pytest.raises(NotImplementedError):
ds.prepare_for_training("spacy", lang=spacy.blank("en"), train_size=1)

def test_prepare_for_training_spark_nlp(self):
ds = ar.DatasetForText2Text(
[ar.Text2TextRecord(text="mock", annotation="mock")] * 10
)
with pytest.raises(NotImplementedError):
ds.prepare_for_training("spark-nlp", train_size=1)

@pytest.mark.skipif(
_HF_HUB_ACCESS_TOKEN is None,
reason="You need a HF Hub access token to test the push_to_hub feature",
Expand Down

0 comments on commit 12aa8d6

Please sign in to comment.