From 0ba1ce07202a045cb31feb121f2e11bc6ab55372 Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Fri, 19 Jul 2024 01:34:29 +0100 Subject: [PATCH 01/14] Add dataset config class --- huggingface_pipelines/dataset.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 huggingface_pipelines/dataset.py diff --git a/huggingface_pipelines/dataset.py b/huggingface_pipelines/dataset.py new file mode 100644 index 0000000..e69de29 From 347489f739875835892180a1c9baced70cb53e7d Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Fri, 19 Jul 2024 01:37:05 +0100 Subject: [PATCH 02/14] Add dataset config class --- huggingface_pipelines/dataset.py | 65 ++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/huggingface_pipelines/dataset.py b/huggingface_pipelines/dataset.py index e69de29..afb3bd8 100644 --- a/huggingface_pipelines/dataset.py +++ b/huggingface_pipelines/dataset.py @@ -0,0 +1,65 @@ +from typing import TypedDict +from dataclasses import dataclass, field, replace +import uuid + + +class DatasetOverwrites(TypedDict, total=False): + dataset_name: str + dataset_split: str + world_size: int + rank: int + cache_dir: str + + +@dataclass +class DatasetConfig(): + """ + Configuration class for loading and sharding datasets. + + Attributes: + dataset_name (str): The name of the dataset to load. + dataset_split (str): The split of the dataset to use (e.g., 'train', 'test', 'validation'). + world_size (int): The number of shards to split the dataset into. Defaults to 1. + rank (int): The ID of the shard to retrieve. Defaults to 0. + cache_dir (str): The directory to cache the loaded dataset. Defaults to None. + trust_remote_code (bool): Whether to trust remote code when loading the dataset. Defaults to False. + """ + dataset_name: str + dataset_split: str + config: str = None + world_size: int = 1 + rank: int = 0 + cache_dir: str = None + trust_remote_code: bool = False + uuid: str = field(default_factory=lambda: str(uuid.uuid4())) + + def load_dataset(self): + """ + Loads and shards the dataset based on the configuration settings. + + Returns: + datasets.Dataset: The loaded and sharded dataset. + """ + from datasets import load_dataset + + dataset_kwargs = { + "path": self.dataset_name, + "name": self.config, + "split": self.dataset_split, + "cache_dir": self.cache_dir, + "trust_remote_code": self.trust_remote_code, + } + dataset_kwargs = {k: v for k, + v in dataset_kwargs.items() if v is not None} + + dataset = load_dataset(**dataset_kwargs) + + # Shard the dataset + if self.world_size > 1: + dataset = dataset.shard( + num_shards=self.world_size, index=self.rank) + + return dataset + + def with_overwrites(self, overwrites: DatasetOverwrites): + return replace(self, **overwrites) From 87cceaf764fcaaa2125df09bb547516f9686d129 Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Mon, 29 Jul 2024 23:34:07 +0100 Subject: [PATCH 03/14] Make dataset now handle writing dataset values to disk --- huggingface_pipelines/dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/huggingface_pipelines/dataset.py b/huggingface_pipelines/dataset.py index afb3bd8..b35223f 100644 --- a/huggingface_pipelines/dataset.py +++ b/huggingface_pipelines/dataset.py @@ -32,6 +32,7 @@ class DatasetConfig(): cache_dir: str = None trust_remote_code: bool = False uuid: str = field(default_factory=lambda: str(uuid.uuid4())) + output_dir: str = "results" def load_dataset(self): """ @@ -63,3 +64,4 @@ def load_dataset(self): def with_overwrites(self, overwrites: DatasetOverwrites): return replace(self, **overwrites) + From b55793804494d09ac9a03216a7545933ae8e6a11 Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Tue, 30 Jul 2024 23:31:44 +0100 Subject: [PATCH 04/14] Implement option to stream datasets --- huggingface_pipelines/dataset.py | 61 ++++++++++++++++++++++++++------ 1 file changed, 50 insertions(+), 11 deletions(-) diff --git a/huggingface_pipelines/dataset.py b/huggingface_pipelines/dataset.py index b35223f..d56aa2b 100644 --- a/huggingface_pipelines/dataset.py +++ b/huggingface_pipelines/dataset.py @@ -12,34 +12,57 @@ class DatasetOverwrites(TypedDict, total=False): @dataclass -class DatasetConfig(): +class DatasetConfig: """ Configuration class for loading and sharding datasets. + This class provides a structured way to configure dataset loading parameters + and includes methods for loading and sharding datasets. + Attributes: - dataset_name (str): The name of the dataset to load. + dataset_name (str): The name or path of the dataset to load. dataset_split (str): The split of the dataset to use (e.g., 'train', 'test', 'validation'). - world_size (int): The number of shards to split the dataset into. Defaults to 1. - rank (int): The ID of the shard to retrieve. Defaults to 0. - cache_dir (str): The directory to cache the loaded dataset. Defaults to None. - trust_remote_code (bool): Whether to trust remote code when loading the dataset. Defaults to False. + output_dir (str): The directory to store the output datasets. + streaming (bool): Whether to stream the dataset or load it entirely into memory. Defaults to False. + config (str): The specific configuration of the dataset to load, if applicable. Defaults to None. + world_size (int): The total number of shards to split the dataset into. Defaults to 1. + rank (int): The index of the shard to retrieve (0-based). Defaults to 0. + cache_dir (str): The directory to cache the loaded dataset. If None, uses the default cache. Defaults to None. + trust_remote_code (bool): Whether to trust remote code when loading the dataset. Use with caution. Defaults to False. + uuid (str): A unique identifier for this configuration instance. Automatically generated. + + Note: + The `world_size` and `rank` attributes are particularly useful for distributed data processing, + allowing the dataset to be split across multiple processes or machines. """ + dataset_name: str dataset_split: str + output_dir: str + streaming: bool = False config: str = None world_size: int = 1 rank: int = 0 cache_dir: str = None trust_remote_code: bool = False uuid: str = field(default_factory=lambda: str(uuid.uuid4())) - output_dir: str = "results" def load_dataset(self): """ - Loads and shards the dataset based on the configuration settings. + Loads and optionally shards the dataset based on the configuration settings. + + This method uses the Hugging Face datasets library to load the dataset. + If `world_size` is greater than 1, it also shards the dataset. Returns: - datasets.Dataset: The loaded and sharded dataset. + datasets.Dataset: The loaded and potentially sharded dataset. + + Raises: + ValueError: If the dataset cannot be loaded with the given configuration. + ImportError: If the 'datasets' library is not installed. + + Note: + Ensure that the 'datasets' library is installed before calling this method. """ from datasets import load_dataset @@ -49,19 +72,35 @@ def load_dataset(self): "split": self.dataset_split, "cache_dir": self.cache_dir, "trust_remote_code": self.trust_remote_code, + "streaming": self.streaming } dataset_kwargs = {k: v for k, v in dataset_kwargs.items() if v is not None} dataset = load_dataset(**dataset_kwargs) - # Shard the dataset - if self.world_size > 1: + # Shard the dataset if world_size > 1 + if not self.streaming and self.world_size > 1: dataset = dataset.shard( num_shards=self.world_size, index=self.rank) return dataset def with_overwrites(self, overwrites: DatasetOverwrites): + """ + Creates a new DatasetConfig instance with specified overwrites. + + This method allows for the creation of a new configuration object + with some attributes overwritten, without modifying the original instance. + + Args: + overwrites (DatasetOverwrites): A dictionary of attributes to overwrite. + + Returns: + DatasetConfig: A new instance of DatasetConfig with the specified overwrites applied. + + Example: + new_config = config.with_overwrites({"dataset_split": "test", "world_size": 4}) + """ return replace(self, **overwrites) From 2340d79b8d6da229c185a9f1501ba96a8c2bb7ba Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Wed, 31 Jul 2024 15:43:09 +0100 Subject: [PATCH 05/14] Allow different types of configs based on type of dataset --- huggingface_pipelines/dataset.py | 136 +++++++++++++++++++++++++------ 1 file changed, 112 insertions(+), 24 deletions(-) diff --git a/huggingface_pipelines/dataset.py b/huggingface_pipelines/dataset.py index d56aa2b..9ee0a22 100644 --- a/huggingface_pipelines/dataset.py +++ b/huggingface_pipelines/dataset.py @@ -1,35 +1,51 @@ -from typing import TypedDict -from dataclasses import dataclass, field, replace +from abc import ABC import uuid +from dataclasses import dataclass, field, replace +from typing import TypedDict, Any class DatasetOverwrites(TypedDict, total=False): + """ + TypedDict for dataset configuration overwrites. + + Attributes: + dataset_name (str): Name of the dataset. + dataset_split (str): Split of the dataset (e.g., 'train', 'test'). + output_dir (str): Directory for output. + streaming (bool): Whether to use streaming mode. + config (str): Specific dataset configuration. + trust_remote_code (bool): Whether to trust remote code. + world_size (int): Number of shards for distributed processing. + rank (int): Rank of the current process. + """ dataset_name: str dataset_split: str + output_dir: str + streaming: bool + config: str + trust_remote_code: bool world_size: int rank: int - cache_dir: str @dataclass -class DatasetConfig: +class DatasetConfig(ABC): """ - Configuration class for loading and sharding datasets. + Abstract base configuration class for loading and sharding datasets. This class provides a structured way to configure dataset loading parameters - and includes methods for loading and sharding datasets. + and includes methods for loading and sharding datasets from Hugging Face. Attributes: dataset_name (str): The name or path of the dataset to load. dataset_split (str): The split of the dataset to use (e.g., 'train', 'test', 'validation'). output_dir (str): The directory to store the output datasets. - streaming (bool): Whether to stream the dataset or load it entirely into memory. Defaults to False. - config (str): The specific configuration of the dataset to load, if applicable. Defaults to None. - world_size (int): The total number of shards to split the dataset into. Defaults to 1. - rank (int): The index of the shard to retrieve (0-based). Defaults to 0. - cache_dir (str): The directory to cache the loaded dataset. If None, uses the default cache. Defaults to None. - trust_remote_code (bool): Whether to trust remote code when loading the dataset. Use with caution. Defaults to False. - uuid (str): A unique identifier for this configuration instance. Automatically generated. + streaming (bool): Whether to stream the dataset or load it entirely into memory. + config (str): The specific configuration of the dataset to load, if applicable. + world_size (int): The total number of shards to split the dataset into. + rank (int): The index of the shard to retrieve (0-based). + trust_remote_code (bool): Whether to trust remote code when loading the dataset. + uuid (str): A unique identifier for this configuration instance. Note: The `world_size` and `rank` attributes are particularly useful for distributed data processing, @@ -43,7 +59,6 @@ class DatasetConfig: config: str = None world_size: int = 1 rank: int = 0 - cache_dir: str = None trust_remote_code: bool = False uuid: str = field(default_factory=lambda: str(uuid.uuid4())) @@ -66,29 +81,63 @@ def load_dataset(self): """ from datasets import load_dataset + dataset_kwargs = self.get_dataset_kwargs() + dataset = load_dataset(**dataset_kwargs) + + self.validate_world_size_and_rank() + + if not self.streaming and self.world_size > 1: + dataset = dataset.shard( + num_shards=self.world_size, index=self.rank) + + return self.post_process_dataset(dataset) + + def get_dataset_kwargs(self) -> dict[str, Any]: + """ + Returns the kwargs for load_dataset function. + + This method prepares the keyword arguments used in the load_dataset function. + + Returns: + dict[str, Any]: A dictionary of keyword arguments for load_dataset. + """ dataset_kwargs = { "path": self.dataset_name, "name": self.config, "split": self.dataset_split, - "cache_dir": self.cache_dir, "trust_remote_code": self.trust_remote_code, "streaming": self.streaming } - dataset_kwargs = {k: v for k, - v in dataset_kwargs.items() if v is not None} + return {k: v for k, v in dataset_kwargs.items() if v is not None} - dataset = load_dataset(**dataset_kwargs) + def validate_world_size_and_rank(self): + """ + Validates world_size and rank. - # Shard the dataset if world_size > 1 - if not self.streaming and self.world_size > 1: - dataset = dataset.shard( - num_shards=self.world_size, index=self.rank) + Raises: + AssertionError: If world_size or rank are invalid. + """ + assert self.world_size >= 1, f"Invalid world_size: {self.world_size}. It should be >= 1." + assert 0 <= self.rank < self.world_size, f"Invalid rank: {self.rank}. It should be between 0 and {self.world_size - 1}." + + def post_process_dataset(self, dataset): + """ + Performs any post-processing on the dataset. + This method can be overridden in subclasses to implement + dataset-specific post-processing. + + Args: + dataset (datasets.Dataset): The loaded dataset. + + Returns: + datasets.Dataset: The post-processed dataset. + """ return dataset def with_overwrites(self, overwrites: DatasetOverwrites): """ - Creates a new DatasetConfig instance with specified overwrites. + Creates a new instance with specified overwrites. This method allows for the creation of a new configuration object with some attributes overwritten, without modifying the original instance. @@ -97,10 +146,49 @@ def with_overwrites(self, overwrites: DatasetOverwrites): overwrites (DatasetOverwrites): A dictionary of attributes to overwrite. Returns: - DatasetConfig: A new instance of DatasetConfig with the specified overwrites applied. + BaseDatasetConfig: A new instance with the specified overwrites applied. Example: new_config = config.with_overwrites({"dataset_split": "test", "world_size": 4}) """ return replace(self, **overwrites) + +@dataclass +class TextDatasetConfig(DatasetConfig): + """ + Configuration for text datasets. + + This class inherits from BaseDatasetConfig and can be used for + text-specific dataset configurations. + """ + + +@dataclass +class AudioDatasetConfig(DatasetConfig): + """ + Configuration for audio datasets. + + This class inherits from BaseDatasetConfig and includes + audio-specific attributes and processing. + + Attributes: + sampling_rate (int): The target sampling rate for audio data. + """ + sampling_rate: int = 16000 + + def post_process_dataset(self, dataset): + """ + Performs audio-specific post-processing on the dataset. + + This method can be used to implement audio-specific processing + such as resampling or feature extraction. + + Args: + dataset (datasets.Dataset): The loaded audio dataset. + + Returns: + datasets.Dataset: The post-processed audio dataset. + """ + return dataset + From f726fc5eab3dc36f8c0fae783797e8e0eb8f6c38 Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Wed, 7 Aug 2024 14:08:12 +0100 Subject: [PATCH 06/14] Add transformers and datasets as dependencies to pyproject.toml --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 6385767..3d2ef3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,8 @@ dependencies = [ "tqdm", "overrides", "typing_extensions", + "transformers", + "datasets" ] [project.optional-dependencies] From b58f2ade8b929b4c92bdf634912d64decfbd7381 Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Mon, 12 Aug 2024 11:03:04 +0100 Subject: [PATCH 07/14] Remove post process dataset as that should be done in the preprocessing pipelines --- huggingface_pipelines/dataset.py | 61 ++------------------------------ 1 file changed, 2 insertions(+), 59 deletions(-) diff --git a/huggingface_pipelines/dataset.py b/huggingface_pipelines/dataset.py index 9ee0a22..44553eb 100644 --- a/huggingface_pipelines/dataset.py +++ b/huggingface_pipelines/dataset.py @@ -1,6 +1,5 @@ from abc import ABC -import uuid -from dataclasses import dataclass, field, replace +from dataclasses import dataclass, replace from typing import TypedDict, Any @@ -45,7 +44,6 @@ class DatasetConfig(ABC): world_size (int): The total number of shards to split the dataset into. rank (int): The index of the shard to retrieve (0-based). trust_remote_code (bool): Whether to trust remote code when loading the dataset. - uuid (str): A unique identifier for this configuration instance. Note: The `world_size` and `rank` attributes are particularly useful for distributed data processing, @@ -60,7 +58,6 @@ class DatasetConfig(ABC): world_size: int = 1 rank: int = 0 trust_remote_code: bool = False - uuid: str = field(default_factory=lambda: str(uuid.uuid4())) def load_dataset(self): """ @@ -90,7 +87,7 @@ def load_dataset(self): dataset = dataset.shard( num_shards=self.world_size, index=self.rank) - return self.post_process_dataset(dataset) + return dataset def get_dataset_kwargs(self) -> dict[str, Any]: """ @@ -120,21 +117,6 @@ def validate_world_size_and_rank(self): assert self.world_size >= 1, f"Invalid world_size: {self.world_size}. It should be >= 1." assert 0 <= self.rank < self.world_size, f"Invalid rank: {self.rank}. It should be between 0 and {self.world_size - 1}." - def post_process_dataset(self, dataset): - """ - Performs any post-processing on the dataset. - - This method can be overridden in subclasses to implement - dataset-specific post-processing. - - Args: - dataset (datasets.Dataset): The loaded dataset. - - Returns: - datasets.Dataset: The post-processed dataset. - """ - return dataset - def with_overwrites(self, overwrites: DatasetOverwrites): """ Creates a new instance with specified overwrites. @@ -153,42 +135,3 @@ def with_overwrites(self, overwrites: DatasetOverwrites): """ return replace(self, **overwrites) - -@dataclass -class TextDatasetConfig(DatasetConfig): - """ - Configuration for text datasets. - - This class inherits from BaseDatasetConfig and can be used for - text-specific dataset configurations. - """ - - -@dataclass -class AudioDatasetConfig(DatasetConfig): - """ - Configuration for audio datasets. - - This class inherits from BaseDatasetConfig and includes - audio-specific attributes and processing. - - Attributes: - sampling_rate (int): The target sampling rate for audio data. - """ - sampling_rate: int = 16000 - - def post_process_dataset(self, dataset): - """ - Performs audio-specific post-processing on the dataset. - - This method can be used to implement audio-specific processing - such as resampling or feature extraction. - - Args: - dataset (datasets.Dataset): The loaded audio dataset. - - Returns: - datasets.Dataset: The post-processed audio dataset. - """ - return dataset - From 6532ebbc57c02569247ce3adbf6cff5317409301 Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Mon, 12 Aug 2024 13:33:21 +0100 Subject: [PATCH 08/14] Add optional dependencies for huggingface using tag --- pyproject.toml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 3d2ef3e..453489b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,13 @@ dependencies = [ "pylint>=2.8.0", ] + hg = [ + "transformers>=4.44.0", + "datasets>=2.20.0" + ] + + + [project.urls] Source = "https://github.com/facebookresearch/SONAR" Tracker = "https://github.com/facebookresearch/SONAR/issues" From c1fdd02d17c54d434f0594d697372fa05b7ab5f5 Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Mon, 12 Aug 2024 13:34:58 +0100 Subject: [PATCH 09/14] Remove commas between configs --- pyproject.toml | 75 ++++++++++++++++++++++++-------------------------- 1 file changed, 36 insertions(+), 39 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 453489b..5b4758d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,18 +7,27 @@ build-backend = "flit_core.buildapi" [project] name = "sonar-space" readme = "README.md" -authors = [{name = "Meta AI Research"}] +authors = [{ name = "Meta AI Research" }] requires-python = ">=3.8" dynamic = ["version", "description"] -keywords = ["sentence embeddings", "sentence representation", "sentence encoder", - "sonar models", "speech2speech", "text2text", "speech2text", - "text2speech", "multi-modal models", "multi-language models"] +keywords = [ + "sentence embeddings", + "sentence representation", + "sentence encoder", + "sonar models", + "speech2speech", + "text2text", + "speech2text", + "text2speech", + "multi-modal models", + "multi-language models", +] # zip_safe = false -classifiers=[ - "Topic :: Scientific/Engineering", - "Development Status :: 4 - Beta", +classifiers = [ + "Topic :: Scientific/Engineering", + "Development Status :: 4 - Beta", ] dependencies = [ @@ -31,35 +40,29 @@ dependencies = [ "tqdm", "overrides", "typing_extensions", - "transformers", - "datasets" ] [project.optional-dependencies] - dev = [ - # Test - "pytest>=4.3.0", - "pytest-asyncio>=0.15.0", - "pytest-cov>=2.6.1", - "coverage[toml]>=5.1", - # Format - "black==24.3.0", - "isort>=5.10.1", - # Linters - "mypy>=0.782", - "pylint>=2.8.0", - ] - - hg = [ - "transformers>=4.44.0", - "datasets>=2.20.0" - ] +dev = [ + # Test + "pytest>=4.3.0", + "pytest-asyncio>=0.15.0", + "pytest-cov>=2.6.1", + "coverage[toml]>=5.1", + # Format + "black==24.3.0", + "isort>=5.10.1", + # Linters + "mypy>=0.782", + "pylint>=2.8.0", +] +hg = ["transformers>=4.44.0", "datasets>=2.20.0"] [project.urls] - Source = "https://github.com/facebookresearch/SONAR" - Tracker = "https://github.com/facebookresearch/SONAR/issues" +Source = "https://github.com/facebookresearch/SONAR" +Tracker = "https://github.com/facebookresearch/SONAR/issues" [tool.flit.module] name = "sonar" @@ -78,10 +81,8 @@ extend-exclude = ''' ''' [tool.flake8] -extend_ignore = ["E", "Y"] # Black -per-file-ignores = [ - "__init__.py:F401", -] +extend_ignore = ["E", "Y"] # Black +per-file-ignores = ["__init__.py:F401"] [tool.isort] profile = "black" @@ -93,13 +94,9 @@ python_version = "3.8" show_error_codes = true check_untyped_defs = true -files = [ - "sonar/" -] +files = ["sonar/"] [tool.pytest.ini_options] minversion = "6.0" testpaths = ["tests/"] -python_files = [ - "test_*.py" -] +python_files = ["test_*.py"] From 42482df557c858b134fc38e038c85d69f54e1cbb Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Wed, 14 Aug 2024 09:12:12 +0100 Subject: [PATCH 10/14] Fix import sorts and dict subscript issues --- huggingface_pipelines/dataset.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/huggingface_pipelines/dataset.py b/huggingface_pipelines/dataset.py index 44553eb..a81b906 100644 --- a/huggingface_pipelines/dataset.py +++ b/huggingface_pipelines/dataset.py @@ -1,6 +1,6 @@ from abc import ABC from dataclasses import dataclass, replace -from typing import TypedDict, Any +from typing import Any, TypedDict, Dict class DatasetOverwrites(TypedDict, total=False): @@ -89,7 +89,7 @@ def load_dataset(self): return dataset - def get_dataset_kwargs(self) -> dict[str, Any]: + def get_dataset_kwargs(self) -> Dict[str, Any]: """ Returns the kwargs for load_dataset function. @@ -134,4 +134,3 @@ def with_overwrites(self, overwrites: DatasetOverwrites): new_config = config.with_overwrites({"dataset_split": "test", "world_size": 4}) """ return replace(self, **overwrites) - From 441315dfa4299c69b5c1fc6f3a6f695a9d7df075 Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Wed, 14 Aug 2024 11:07:46 +0100 Subject: [PATCH 11/14] Fix linting issues in dataset file --- huggingface_pipelines/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/huggingface_pipelines/dataset.py b/huggingface_pipelines/dataset.py index a81b906..4a211a5 100644 --- a/huggingface_pipelines/dataset.py +++ b/huggingface_pipelines/dataset.py @@ -1,6 +1,6 @@ from abc import ABC from dataclasses import dataclass, replace -from typing import Any, TypedDict, Dict +from typing import Any, Dict, TypedDict class DatasetOverwrites(TypedDict, total=False): From c0e2c3739c75bb79fe4b3ae4e6962c331ad21a73 Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Wed, 14 Aug 2024 11:09:31 +0100 Subject: [PATCH 12/14] Make config optional to allow setting to None --- huggingface_pipelines/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/huggingface_pipelines/dataset.py b/huggingface_pipelines/dataset.py index 4a211a5..bfaf48f 100644 --- a/huggingface_pipelines/dataset.py +++ b/huggingface_pipelines/dataset.py @@ -1,6 +1,6 @@ from abc import ABC from dataclasses import dataclass, replace -from typing import Any, Dict, TypedDict +from typing import Any, Dict, Optional, TypedDict class DatasetOverwrites(TypedDict, total=False): @@ -54,7 +54,7 @@ class DatasetConfig(ABC): dataset_split: str output_dir: str streaming: bool = False - config: str = None + config: Optional[str] = None world_size: int = 1 rank: int = 0 trust_remote_code: bool = False From cd384abab5948aee62b6120c5025acd0aed0c293 Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Wed, 14 Aug 2024 11:11:31 +0100 Subject: [PATCH 13/14] Ignore checking HF stubs --- huggingface_pipelines/dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/huggingface_pipelines/dataset.py b/huggingface_pipelines/dataset.py index bfaf48f..6014166 100644 --- a/huggingface_pipelines/dataset.py +++ b/huggingface_pipelines/dataset.py @@ -76,7 +76,9 @@ def load_dataset(self): Note: Ensure that the 'datasets' library is installed before calling this method. """ - from datasets import load_dataset + # We ignore because no offical HF stubs available + + from datasets import load_dataset # type: ignore dataset_kwargs = self.get_dataset_kwargs() dataset = load_dataset(**dataset_kwargs) From d3701c1bc073170d2ab257488e2a350e830f1a21 Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Wed, 14 Aug 2024 11:14:33 +0100 Subject: [PATCH 14/14] Reformat file with black --- huggingface_pipelines/dataset.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/huggingface_pipelines/dataset.py b/huggingface_pipelines/dataset.py index 6014166..009e774 100644 --- a/huggingface_pipelines/dataset.py +++ b/huggingface_pipelines/dataset.py @@ -17,6 +17,7 @@ class DatasetOverwrites(TypedDict, total=False): world_size (int): Number of shards for distributed processing. rank (int): Rank of the current process. """ + dataset_name: str dataset_split: str output_dir: str @@ -86,8 +87,7 @@ def load_dataset(self): self.validate_world_size_and_rank() if not self.streaming and self.world_size > 1: - dataset = dataset.shard( - num_shards=self.world_size, index=self.rank) + dataset = dataset.shard(num_shards=self.world_size, index=self.rank) return dataset @@ -105,7 +105,7 @@ def get_dataset_kwargs(self) -> Dict[str, Any]: "name": self.config, "split": self.dataset_split, "trust_remote_code": self.trust_remote_code, - "streaming": self.streaming + "streaming": self.streaming, } return {k: v for k, v in dataset_kwargs.items() if v is not None} @@ -116,8 +116,12 @@ def validate_world_size_and_rank(self): Raises: AssertionError: If world_size or rank are invalid. """ - assert self.world_size >= 1, f"Invalid world_size: {self.world_size}. It should be >= 1." - assert 0 <= self.rank < self.world_size, f"Invalid rank: {self.rank}. It should be between 0 and {self.world_size - 1}." + assert ( + self.world_size >= 1 + ), f"Invalid world_size: {self.world_size}. It should be >= 1." + assert ( + 0 <= self.rank < self.world_size + ), f"Invalid rank: {self.rank}. It should be between 0 and {self.world_size - 1}." def with_overwrites(self, overwrites: DatasetOverwrites): """