Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add entry in dataset card to help fine-tuning using TRL with the generated dataset #1079

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion src/distilabel/distiset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from collections import defaultdict
from os import PathLike
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Generator, Iterator, List, Optional, Union

import fsspec
import yaml
Expand Down Expand Up @@ -50,6 +50,7 @@

if TYPE_CHECKING:
from distilabel.pipeline._dag import DAG
from distilabel.steps.typing import DatasetUse


class Distiset(dict):
Expand All @@ -73,6 +74,7 @@ class Distiset(dict):
_artifacts_path: Optional[Path] = None
_log_filename_path: Optional[Path] = None
_citations: Optional[List[str]] = None
_dataset_uses: Optional[list["DatasetUse"]] = None

def push_to_hub(
self,
Expand Down Expand Up @@ -199,6 +201,8 @@ def _get_card(
),
"tags": ["synthetic", "distilabel", "rlaif"],
}
# The variables must be passed by name here to be rendered in the template.
uses = self._get_dataset_uses(dataset_name=repo_id)

card = DistilabelDatasetCard.from_template(
card_data=DatasetCardData(**metadata),
Expand All @@ -208,6 +212,7 @@ def _get_card(
filename_py=filename_py,
artifacts=self._get_artifacts_metadata(),
references=self.citations,
dataset_uses=list(uses) if uses else [],
)

return card
Expand Down Expand Up @@ -238,6 +243,30 @@ def iterdir_ignore_hidden(path: Path) -> Generator[Path, None, None]:

return dict(artifacts_metadata)

def _get_dataset_uses(self, **kwargs: Any) -> Union[Iterator[dict[str, str]], None]:
"""Gets the dataset uses from the pipeline steps.
To determine automatically the variables that will be rendered in the template, the name in the
`kwargs` dictionary must match the name of the variable in the template.
"""
if not self._dataset_uses:
# The variable hasn't been set (this is done when calling `create_distiset`).
return

from jinja2 import Template

for dataset_use in self._dataset_uses:
template = Template(dataset_use["template"])
variables = dataset_use["variables"]
to_render = {}
for var_name, variable in kwargs.items():
if var_name in variables:
to_render[var_name] = variable

yield {
"title": dataset_use["title"],
"content": template.render(**to_render),
}

def _extract_readme_metadata(
self, repo_id: str, token: Optional[str]
) -> Dict[str, Any]:
Expand Down Expand Up @@ -669,6 +698,7 @@ def create_distiset( # noqa: C901

if dag:
distiset._citations = _grab_citations(dag)
distiset._dataset_uses = _get_dataset_uses(dag)

return distiset

Expand Down Expand Up @@ -706,3 +736,21 @@ def _grab_citations(dag: "DAG") -> List[str]:
print(f"Untracked error: {e}")
citations.extend(bibtex_refs)
return citations


def _get_dataset_uses(dag: "DAG") -> list["DatasetUse"]:
"""Extracts the uses of the dataset, by calling the method in the steps that
define it.

Args:
dag: `DAG` contained in the pipeline that created the `Distiset`.

Returns:
List of uses to add to the `Distiset`.
"""
uses = []
for step_name in dag:
if dataset_use := dag.get_step(step_name)[STEP_ATTR_NAME]._dataset_use():
uses.append(dataset_use)

return uses
2 changes: 2 additions & 0 deletions src/distilabel/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from distilabel.steps.embeddings.nearest_neighbour import FaissNearestNeighbour
from distilabel.steps.filtering.embedding import EmbeddingDedup
from distilabel.steps.filtering.minhash import MinHashDedup
from distilabel.steps.formatting.apo import FormatAPO
from distilabel.steps.formatting.conversation import ConversationTemplate
from distilabel.steps.formatting.dpo import (
FormatChatGenerationDPO,
Expand Down Expand Up @@ -79,6 +80,7 @@
"EmbeddingGeneration",
"FaissNearestNeighbour",
"ConversationTemplate",
"FormatAPO",
"FormatChatGenerationDPO",
"FormatTextGenerationDPO",
"FormatChatGenerationSFT",
Expand Down
19 changes: 18 additions & 1 deletion src/distilabel/steps/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@
DownstreamConnectableSteps,
UpstreamConnectableSteps,
)
from distilabel.steps.typing import GeneratorStepOutput, StepColumns, StepOutput
from distilabel.steps.typing import (
DatasetUse,
GeneratorStepOutput,
StepColumns,
StepOutput,
)


DEFAULT_INPUT_BATCH_SIZE = 50
Expand Down Expand Up @@ -610,6 +615,18 @@ def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
dump["runtime_parameters_info"] = self.get_runtime_parameters_info()
return dump

def _dataset_use(self) -> "DatasetUse":
"""This method can be used to include additional information in the final dataset
card. A given step can override this method to include information about the uses,
like for example, how to fine-tune on the final dataset.
If overwritten, it must return a dictionary with 2 keys:
"template": that will contain a string that can be converted to a Jinja2 Template.
"variables": that will contain a list of strings that will be used to fill the template.

This info will be grabbed when the Distiset is created and pushed to the hub.
"""
pass


class Step(_Step, ABC):
"""Base class for the steps that can be included in a `Pipeline`.
Expand Down
148 changes: 148 additions & 0 deletions src/distilabel/steps/formatting/apo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
from typing import TYPE_CHECKING

from typing_extensions import override

from distilabel.steps import Step
from distilabel.steps.base import StepInput
from distilabel.utils.card.dataset_card import get_dataset_use_template

if TYPE_CHECKING:
from distilabel.steps.typing import DatasetUse, StepColumns, StepOutput


class FormatAPO(Step):
"""Format the output of `CLAIR` task for Anchored Preference Optimization (APO).

`FormatAPO` is a `Step` that formats the output of a `CLAIR` task for
Anchored Preference Optimization (APO) following the standard formatting from `TRL`.

Input columns:
- prompt (`str`): The instruction used to generate the `generation` with the `LLM`.
- response (`str`): The generation produced by the `LLM`.
- revision (`str`): The revised text.

Output columns:
- prompt (`str`): The instruction used to generate the `generation` with the `LLM`.
- chosen (`List[Dict[str, str]]`): The `chosen` generation based on the `ratings`.
- rejected (`List[Dict[str, str]]`): The `rejected` generation based on the `ratings`.
- prompt_id (`str`): The `SHA256` hash of the `prompt`.

Categories:
- format
- preference
- instruction
- generation

Examples:
Format your dataset for APO fine tuning:

```python
from distilabel.steps import FormatAPO

formatter = FormatAPO()
formatter.load()

result = next(
formatter.process(
[
{
"prompt": "How many gaps are there between the earth and the moon?",
"response": '''There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon's orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.''',
"revision": '''There are no physical gaps or empty spaces between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a significant separation or gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range. This variation in distance is a result of the Moon\'s orbital path, not the presence of any gaps.\n\nIn summary, the Moon\'s orbit is continuous, with no intervening gaps, and its distance from the Earth varies due to the elliptical shape of its orbit.''',
}
]
)
)
# >>> result
# [{'prompt': 'How many gaps are there between the earth and the moon?',
# 'response': 'There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.',
# 'revision': 'There are no physical gaps or empty spaces between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a significant separation or gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range. This variation in distance is a result of the Moon\'s orbital path, not the presence of any gaps.\n\nIn summary, the Moon\'s orbit is continuous, with no intervening gaps, and its distance from the Earth varies due to the elliptical shape of its orbit.',
# 'prompt_id': 'd5e8924f2856fe7180c0aef3ec186f7a421b2ba11551b9ebfffeb7638ec5b021',
# 'chosen': [{'role': 'user',
# 'content': 'How many gaps are there between the earth and the moon?'},
# {'role': 'assistant',
# 'content': 'There are no physical gaps or empty spaces between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a significant separation or gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range. This variation in distance is a result of the Moon\'s orbital path, not the presence of any gaps.\n\nIn summary, the Moon\'s orbit is continuous, with no intervening gaps, and its distance from the Earth varies due to the elliptical shape of its orbit.'}],
# 'rejected': [{'role': 'user',
# 'content': 'How many gaps are there between the earth and the moon?'},
# {'role': 'assistant',
# 'content': 'There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.'}]}]
```
"""

@property
def inputs(self) -> "StepColumns":
return ["prompt", "response", "revision"]

@property
def outputs(self) -> "StepColumns":
return ["prompt", "chosen", "rejected", "prompt_id"]

def process(self, *inputs: StepInput) -> "StepOutput": # type: ignore
"""The `process` method formats the received `StepInput` or list of `StepInput`
according to the APO formatting standard (DPO with loss_type equal to apo_zero
or apo_down in TRL).

Args:
*inputs: A list of `StepInput` to be combined.

Yields:
A `StepOutput` with batches of formatted `StepInput` following the APO standard.
"""
for input in inputs:
for item in input:
messages = [
{"role": "user", "content": item["prompt"]}, # type: ignore
]
if (
"system_prompt" in item
and isinstance(item["system_prompt"], str) # type: ignore
and len(item["system_prompt"]) > 0 # type: ignore
):
messages.insert(
0,
{"role": "system", "content": item["system_prompt"]}, # type: ignore
)

item["prompt_id"] = hashlib.sha256(
item["prompt"].encode("utf-8") # type: ignore
).hexdigest()

item["chosen"] = messages + [
{
"role": "assistant",
"content": item["revision"],
}
]
item["rejected"] = messages + [
{
"role": "assistant",
"content": item["response"],
}
]
yield input

@override
def _dataset_use(self) -> "DatasetUse":
with open(get_dataset_use_template("sft")) as f:
template = f.read()

return {
"title": "Anchored Preference Optimization (APO)",
"template": template,
"variables": ["dataset_name"],
}
27 changes: 26 additions & 1 deletion src/distilabel/steps/formatting/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
import hashlib
from typing import TYPE_CHECKING, List

from typing_extensions import override

from distilabel.steps.base import Step, StepInput
from distilabel.utils.card.dataset_card import get_dataset_use_template

if TYPE_CHECKING:
from distilabel.steps.typing import StepColumns, StepOutput
from distilabel.steps.typing import DatasetUse, StepColumns, StepOutput


class FormatTextGenerationDPO(Step):
Expand Down Expand Up @@ -194,6 +197,17 @@ def process(self, *inputs: StepInput) -> "StepOutput": # type: ignore

yield input

@override
def _dataset_use(self) -> "DatasetUse":
with open(get_dataset_use_template("dpo")) as f:
template = f.read()

return {
"title": "Direct Preference Optimization (DPO)",
"template": template,
"variables": ["dataset_name"],
}


class FormatChatGenerationDPO(Step):
"""Format the output of a combination of a `ChatGeneration` + a preference task for Direct Preference Optimization (DPO).
Expand Down Expand Up @@ -354,3 +368,14 @@ def process(self, *inputs: StepInput) -> "StepOutput": # type: ignore
item["rejected_rating"] = item["ratings"][rejected_idx]

yield input

@override
def _dataset_use(self) -> "DatasetUse":
with open(get_dataset_use_template("dpo")) as f:
template = f.read()

return {
"title": "Direct Preference Optimization (DPO)",
"template": template,
"variables": ["dataset_name"],
}
27 changes: 26 additions & 1 deletion src/distilabel/steps/formatting/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
import hashlib
from typing import TYPE_CHECKING, List

from typing_extensions import override

from distilabel.steps.base import Step, StepInput
from distilabel.utils.card.dataset_card import get_dataset_use_template

if TYPE_CHECKING:
from distilabel.steps.typing import StepColumns, StepOutput
from distilabel.steps.typing import DatasetUse, StepColumns, StepOutput


class FormatTextGenerationSFT(Step):
Expand Down Expand Up @@ -140,6 +143,17 @@ def process(self, *inputs: StepInput) -> "StepOutput": # type: ignore

yield input

@override
def _dataset_use(self) -> "DatasetUse":
with open(get_dataset_use_template("sft")) as f:
template = f.read()

return {
"title": "Supervised Fine-Tuning (SFT)",
"template": template,
"variables": ["dataset_name"],
}


class FormatChatGenerationSFT(Step):
"""Format the output of a `ChatGeneration` task for Supervised Fine-Tuning (SFT).
Expand Down Expand Up @@ -244,3 +258,14 @@ def process(self, *inputs: StepInput) -> "StepOutput": # type: ignore
{"role": "assistant", "content": item["generation"]}, # type: ignore
]
yield input

@override
def _dataset_use(self) -> "DatasetUse":
with open(get_dataset_use_template("sft")) as f:
template = f.read()

return {
"title": "Supervised Fine-Tuning (SFT)",
"template": template,
"variables": ["dataset_name"],
}
Loading
Loading