diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 238f8269..599bff6c 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -9,8 +9,8 @@ data_stages: dataset: training_folder: datasets/c4-es/train validation_folder: datasets/c4-es/validation - dataset_tokens: - - 15 + lang_to_ids: + es: 128002 num_loading_workers: 1 seed: 42 name: General purpose training (Single dataset) @@ -25,10 +25,10 @@ data_stages: - datasets/c4-es/validation - datasets/c4-en/validation - datasets/c4-fr/validation - dataset_tokens: - - 15 - - 16 - - 17 + lang_to_ids: + es: 128002 + en: 128003 + fr: 128004 num_loading_workers: 1 seed: 42 name: Second purpose training (> 1 dataset) @@ -43,10 +43,10 @@ data_stages: - datasets/c4-es/validation - datasets/c4-en/validation - datasets/c4-fr/validation - dataset_tokens: - - 15 - - 16 - - 17 + lang_to_ids: + es: 128002 + en: 128003 + fr: 128004 num_loading_workers: 1 seed: 42 diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index f1881faa..d90f13fb 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -111,7 +111,7 @@ def __post_init__(self): class MultilingualNanosetDatasetsArgs: training_folder: Union[str, dict, List[str]] validation_folder: Union[str, List[str]] - dataset_tokens: List[int] # Set token for each language previously defined + lang_to_ids: dict # Mapping from the previously defined folders to tokens. Respect the order def __post_init__(self): if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder @@ -125,8 +125,13 @@ def __post_init__(self): self.training_folder = list(tmp_training_folder.keys()) self.dataset_weights = list(tmp_training_folder.values()) - assert len(self.training_folder) == len(self.validation_folder) - assert len(self.training_folder) == len(self.dataset_tokens) + self.dataset_tokens = list(self.lang_to_ids.values()) + assert len(self.training_folder) == len( + self.validation_folder + ), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})" + assert len(self.training_folder) == len( + self.dataset_tokens + ), f"The sizes of training_folder and lang_to_ids mismatch ({len(self.training_folder)} vs {len(self.dataset_tokens)})" @dataclass