From dca0de62591354e2da2ad816107e9f3510f81cb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 10 Oct 2024 12:07:49 +0200 Subject: [PATCH 1/3] Add `TextGenerationWithCotReflection` task --- src/distilabel/steps/tasks/__init__.py | 7 +- src/distilabel/steps/tasks/text_generation.py | 130 ++++++++++++++++++ 2 files changed, 136 insertions(+), 1 deletion(-) diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py index 725fd065f..d3f66a28a 100644 --- a/src/distilabel/steps/tasks/__init__.py +++ b/src/distilabel/steps/tasks/__init__.py @@ -49,7 +49,11 @@ from distilabel.steps.tasks.sentence_transformers import GenerateSentencePair from distilabel.steps.tasks.structured_generation import StructuredGeneration from distilabel.steps.tasks.text_classification import TextClassification -from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration +from distilabel.steps.tasks.text_generation import ( + ChatGeneration, + TextGeneration, + TextGenerationWithCotReflection, +) from distilabel.steps.tasks.typing import ChatItem, ChatType from distilabel.steps.tasks.ultrafeedback import UltraFeedback from distilabel.steps.tasks.urial import URIAL @@ -88,6 +92,7 @@ "TextClassification", "ChatGeneration", "TextGeneration", + "TextGenerationWithCotReflection", "ChatItem", "ChatType", "CLAIR", diff --git a/src/distilabel/steps/tasks/text_generation.py b/src/distilabel/steps/tasks/text_generation.py index a8b2048e5..9454254e1 100644 --- a/src/distilabel/steps/tasks/text_generation.py +++ b/src/distilabel/steps/tasks/text_generation.py @@ -299,6 +299,136 @@ def format_output( return {"generation": output} +# Slighty modified version: https://github.com/codelion/optillm/blob/main/optillm/cot_reflection.py +COT_REFLECTION_SYSTEM_PROMPT = """ +You are an AI assistant that uses a Chain of Thought (CoT) approach with reflection to answer queries. Follow these steps: + +1. Think through the problem step by step within the tags. +2. Reflect on your thinking to check for any errors or improvements within the tags. +3. Make any necessary adjustments based on your reflection. +4. Provide your final, concise answer within the tags. + +Important: The and sections are for your internal reasoning process only. +Do not include any part of the final answer in these sections. +You can only create one and the blocks must be contained within it. +The actual response to the query must be entirely contained within the tags. +You must always include at the end of the generation. + +Use the following format for your response: + +``` + +[Your step-by-step reasoning goes here. This is your internal thought process, not the final answer.] + +[Your reflection on your reasoning, checking for errors or improvements] + +[Any adjustments to your thinking based on your reflection] + + +[Your final, concise answer to the query. This is the only part that will be shown to the user.] + +``` +""".lstrip() + +COT_REFLECTION_OUTPUT_REGEX = re.compile( + r"([\s\S]*?)\s*([\s\S]*?)" +) + + +class TextGenerationWithCotReflection(Task): + """Text generation with an `LLM` using Chain of Thought (CoT) reflection. + + `TextGenerationWithCotReflection` is a `Task` that allows generating a response for + a given instruction using a Chain of Thought (CoT) approach with reflection. The `LLM` + will first think through the problem step by step, reflect on the thinking process, make + any necessary adjustments based on the reflection, and provide a final, concise answer. + This method usually helps in generating more accurate and thoughtful responses at the + cost of generating more tokens and being slower. + + Attributes: + system_prompt: The system prompt to use in the generation and that will be appended + to the CoT Reflection system prompt. If not provided, then it will check if + the input row has a column named `system_prompt` and use it. If not, then no + system prompt will be used. Defaults to `None`. + + Input columns: + - instruction (`str`): The instruction to generate the response. + - system_prompt (`str`, optional): The system prompt to use in the generation and + that will be appended to the CoT Reflection system prompt. Defaults to `None`. + + Output columns: + - thinking (`str`): The step-by-step reasoning process. + + Categories: + - text-generation + + Examples: + Generate text from an instruction: + + ```python + from distilabel.llms import InferenceEndpointsLLM + from distilabel.steps.tasks import TextGenerationWithCotReflection + + task = TextGenerationWithCotReflection( + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + generation_kwargs={"temperature": 0.8, "max_new_tokens": 2048}, + ), + use_cache=False, + ) + + task.load() + + + result = next( + task.process_applying_mappings( + [ + { + "instruction": "If all cats have whiskers, and Fluffy is a cat, but Fluffy doesn't have whiskers, what can we conclude about this situation?" + } + ] + ) + ) + # { + # "instruction": "If all cats have whiskers, and Fluffy is a cat, but Fluffy doesn't have whiskers, what can we conclude about this situation?", + # "thinking": "Let's break down the information provided: \n- All cats have whiskers.\n- Fluffy is a cat.\n- Fluffy doesn't have whiskers...", + # "output": 'We can conclude that either the general rule "all cats have whiskers" is incorrect, ...', + # } + ``` + """ + + system_prompt: Union[str, None] = None + + @property + def inputs(self) -> "StepColumns": + return {"instruction": True, "system_prompt": False} + + def format_input(self, input: Dict[str, Any]) -> "ChatType": + system_prompt = COT_REFLECTION_SYSTEM_PROMPT + if additional_system_prompt := input.get("system_prompt", self.system_prompt): + system_prompt = f"{additional_system_prompt}\n\n{system_prompt}" + return [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": input["instruction"]}, + ] + + @property + def outputs(self) -> "StepColumns": + return ["thinking", "output"] + + def format_output( + self, output: Union[str, None], input: Union[Dict[str, Any], None] = None + ) -> Dict[str, Any]: + if output is None: + return {"thinking": None, "output": None} + + match = COT_REFLECTION_OUTPUT_REGEX.search(output) + if match is None: + return {"thinking": None, "output": output} + + return {"thinking": match.group(1).strip(), "output": match.group(2).strip()} + + class ChatGeneration(Task): """Generates text based on a conversation. From 445ef76a4b0ff1c1ad4ce5358a458e7b767af731 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 10 Oct 2024 12:58:15 +0200 Subject: [PATCH 2/3] Add unit tests for `TextGenerationWithCotReflection` --- src/distilabel/steps/tasks/text_generation.py | 5 +- .../unit/steps/tasks/test_text_generation.py | 73 ++++++++++++++++++- 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/src/distilabel/steps/tasks/text_generation.py b/src/distilabel/steps/tasks/text_generation.py index 9454254e1..5083e5c78 100644 --- a/src/distilabel/steps/tasks/text_generation.py +++ b/src/distilabel/steps/tasks/text_generation.py @@ -330,8 +330,9 @@ def format_output( ``` """.lstrip() +# Sometimes `LLM`s doesn't generate the `` that's why it's optional COT_REFLECTION_OUTPUT_REGEX = re.compile( - r"([\s\S]*?)\s*([\s\S]*?)" + r"([\s\S]*?)\s*([\s\S]*?)(?:)?" ) @@ -424,7 +425,7 @@ def format_output( match = COT_REFLECTION_OUTPUT_REGEX.search(output) if match is None: - return {"thinking": None, "output": output} + return {"thinking": None, "output": None} return {"thinking": match.group(1).strip(), "output": match.group(2).strip()} diff --git a/tests/unit/steps/tasks/test_text_generation.py b/tests/unit/steps/tasks/test_text_generation.py index 2a6abefb2..889c41a38 100644 --- a/tests/unit/steps/tasks/test_text_generation.py +++ b/tests/unit/steps/tasks/test_text_generation.py @@ -18,7 +18,12 @@ from distilabel.errors import DistilabelUserError from distilabel.pipeline.local import Pipeline -from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration +from distilabel.steps.tasks.text_generation import ( + COT_REFLECTION_SYSTEM_PROMPT, + ChatGeneration, + TextGeneration, + TextGenerationWithCotReflection, +) from tests.unit.conftest import DummyAsyncLLM @@ -175,6 +180,72 @@ def test_format_input_custom_columns_expected_errors( task.load() +class TestTextGenerationWithCotReflection: + def test_format_input(self) -> None: + llm = DummyAsyncLLM() + task = TextGenerationWithCotReflection(name="task", llm=llm) + task.load() + + assert task.format_input({"instruction": "test"}) == [ + {"role": "system", "content": COT_REFLECTION_SYSTEM_PROMPT}, + {"role": "user", "content": "test"}, + ] + + def test_format_input_with_system_prompt(self) -> None: + llm = DummyAsyncLLM() + task = TextGenerationWithCotReflection( + name="task", llm=llm, system_prompt="test" + ) + task.load() + + assert task.format_input({"instruction": "test"}) == [ + {"role": "system", "content": "test\n\n" + COT_REFLECTION_SYSTEM_PROMPT}, + {"role": "user", "content": "test"}, + ] + + def test_format_input_with_row_system_prompt(self) -> None: + llm = DummyAsyncLLM() + task = TextGenerationWithCotReflection(name="task", llm=llm) + task.load() + + assert task.format_input({"instruction": "test", "system_prompt": "test"}) == [ + {"role": "system", "content": "test\n\n" + COT_REFLECTION_SYSTEM_PROMPT}, + {"role": "user", "content": "test"}, + ] + + def test_format_input_with_row_system_prompt_and_system_prompt(self) -> None: + llm = DummyAsyncLLM() + task = TextGenerationWithCotReflection( + name="task", llm=llm, system_prompt="i won't be used" + ) + task.load() + + assert task.format_input({"instruction": "test", "system_prompt": "test"}) == [ + {"role": "system", "content": "test\n\n" + COT_REFLECTION_SYSTEM_PROMPT}, + {"role": "user", "content": "test"}, + ] + + def test_format_ouptut(self) -> None: + llm = DummyAsyncLLM() + task = TextGenerationWithCotReflection( + name="task", llm=llm, system_prompt="i won't be used" + ) + task.load() + + assert task.format_output(None) == {"thinking": None, "output": None} + assert task.format_output("i'm not following the output format") == { + "thinking": None, + "output": None, + } + + assert task.format_output( + output="\ni'm thinking\n\nI'm having a reflection\n\n\n\ni'm the output\n" + ) == { + "thinking": "i'm thinking\n\nI'm having a reflection\n", + "output": "i'm the output", + } + + class TestChatGeneration: def test_format_input(self) -> None: pipeline = Pipeline(name="unit-test-pipeline") From b04a80713c1d2405fb86d7f7c2ec92cc5a096fec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Oct 2024 11:01:16 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/arena_hard.py | 3 ++- examples/deepseek_prover.py | 7 ++++--- examples/structured_generation_with_instructor.py | 3 ++- examples/structured_generation_with_outlines.py | 5 +++-- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/arena_hard.py b/examples/arena_hard.py index 4f6f0c05c..b193bc234 100644 --- a/examples/arena_hard.py +++ b/examples/arena_hard.py @@ -15,11 +15,12 @@ import re from typing import Any, Dict, List, Optional, Union +from typing_extensions import override + from distilabel.steps import GlobalStep, StepInput from distilabel.steps.tasks.base import Task from distilabel.steps.tasks.typing import ChatType from distilabel.steps.typing import StepOutput -from typing_extensions import override class ArenaHard(Task): diff --git a/examples/deepseek_prover.py b/examples/deepseek_prover.py index b61f1c683..07b050964 100644 --- a/examples/deepseek_prover.py +++ b/examples/deepseek_prover.py @@ -17,14 +17,15 @@ from textwrap import dedent from typing import Any, Dict, List, Optional, Union +from jinja2 import Template +from pydantic import PrivateAttr +from typing_extensions import override + from distilabel.llms import InferenceEndpointsLLM from distilabel.pipeline import Pipeline from distilabel.steps import LoadDataFromHub from distilabel.steps.tasks.base import Task from distilabel.steps.tasks.typing import ChatType -from jinja2 import Template -from pydantic import PrivateAttr -from typing_extensions import override _PARSE_DEEPSEEK_PROVER_AUTOFORMAL_REGEX = r"```lean4(.*?)```" diff --git a/examples/structured_generation_with_instructor.py b/examples/structured_generation_with_instructor.py index 48082886f..0808e56ca 100644 --- a/examples/structured_generation_with_instructor.py +++ b/examples/structured_generation_with_instructor.py @@ -14,11 +14,12 @@ from typing import List +from pydantic import BaseModel, Field + from distilabel.llms import MistralLLM from distilabel.pipeline import Pipeline from distilabel.steps import LoadDataFromDicts from distilabel.steps.tasks import TextGeneration -from pydantic import BaseModel, Field class Node(BaseModel): diff --git a/examples/structured_generation_with_outlines.py b/examples/structured_generation_with_outlines.py index 98ee59ed6..b92cb6082 100644 --- a/examples/structured_generation_with_outlines.py +++ b/examples/structured_generation_with_outlines.py @@ -15,12 +15,13 @@ from enum import Enum from pathlib import Path +from pydantic import BaseModel, StringConstraints, conint +from typing_extensions import Annotated + from distilabel.llms import LlamaCppLLM from distilabel.pipeline import Pipeline from distilabel.steps import LoadDataFromDicts from distilabel.steps.tasks import TextGeneration -from pydantic import BaseModel, StringConstraints, conint -from typing_extensions import Annotated class Weapon(str, Enum):