Skip to content

Commit

Permalink
After lunch
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert authored and Negar Foroutan Eghlidi committed Sep 8, 2024
1 parent 7be77b7 commit 6cbe054
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 46 deletions.
42 changes: 28 additions & 14 deletions examples/config_multilingual_nanoset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ checkpoints:
data_stages:
- data:
dataset:
dataset_folder: /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized
training_folder: datasets/c4-es/train
validation_folder: datasets/c4-es/validation
dataset_tokens:
- 15
num_loading_workers: 1
Expand All @@ -16,24 +17,37 @@ data_stages:
start_training_step: 1
- data:
dataset:
dataset_folder:
- /mloscratch/homes/solergib/nanotrove/nanotron/datasets/SlimPajama-6B/tokenized
- /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized
training_folder:
- datasets/c4-es/train
- datasets/c4-en/train
- datasets/c4-fr/train
validation_folder:
- datasets/c4-es/validation
- datasets/c4-en/validation
- datasets/c4-fr/validation
dataset_tokens:
- 16
- 15
- 16
- 17
num_loading_workers: 1
seed: 42
name: Second purpose training (> 1 dataset)
start_training_step: 15
- data:
dataset:
dataset_folder:
/mloscratch/homes/solergib/nanotrove/nanotron/datasets/SlimPajama-6B/tokenized: 0.8
/mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized: 0.2
training_folder:
datasets/c4-es/train: 0.6
datasets/c4-en/train: 0.3
datasets/c4-fr/train: 0.1
validation_folder:
- datasets/c4-es/validation
- datasets/c4-en/validation
- datasets/c4-fr/validation
dataset_tokens:
- 16
- 15
- 16
- 17

num_loading_workers: 1
seed: 42
name: Third purpose training (Blended dataset)
Expand Down Expand Up @@ -61,12 +75,12 @@ model:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 4096
hidden_size: 512
initializer_range: 0.02
intermediate_size: 11008
intermediate_size: 512
is_llama_config: true
max_position_embeddings: 1024
num_hidden_layers: 32
num_hidden_layers: 2
num_attention_heads: 32
num_key_value_heads: 8
pad_token_id: null
Expand Down Expand Up @@ -108,13 +122,13 @@ parallelism:
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: gpt2
tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 10
micro_batch_size: 2
micro_batch_size: 4
sequence_length: 1024
train_steps: 200
val_check_interval: -1
6 changes: 2 additions & 4 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def get_dataloader_from_data_stage(

with main_rank_first(trainer.parallel_context.world_pg):
train_dataset = MultilingualNanoset(
dataset_folders=data.dataset.dataset_folder,
dataset_folders=data.dataset.training_folder,
dataset_weights=data.dataset.dataset_weights,
sequence_length=trainer.sequence_length,
token_size=token_size,
Expand Down Expand Up @@ -238,11 +238,9 @@ def get_valid_dataloader_from_data_stage(

with main_rank_first(trainer.parallel_context.world_pg):
valid_dataset = MultilingualNanoset(
dataset_folders=data.dataset.dataset_folder,
dataset_weights=data.dataset.dataset_weights,
dataset_folders=data.dataset.validation_folder,
sequence_length=trainer.sequence_length,
token_size=token_size,
train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size,
dataset_tokens=data.dataset.dataset_tokens,
is_valid=True,
random_seed=data.seed,
Expand Down
21 changes: 10 additions & 11 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,21 +110,20 @@ def __post_init__(self):
@dataclass
class MultilingualNanosetDatasetsArgs:
training_folder: Union[str, dict, List[str]]
validation_folder: Union[str, dict, List[str]]
dataset_tokens: List[
int
] # Set token for each language previously defined. We use a List and not a dict because this way we support specifyng weights (dict) or not (List[str])
validation_folder: Union[str, List[str]]
dataset_tokens: List[int] # Set token for each language previously defined

def __post_init__(self):
if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file
self.dataset_folder = [self.dataset_folder]
if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder
self.training_folder = [self.training_folder]
self.validation_folder = [self.validation_folder]
self.dataset_weights = [1]
elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file
elif isinstance(self.training_folder, List): # Case 2: > 1 Dataset folder
self.dataset_weights = None # Set to None so we consume all the samples randomly
elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights
tmp_dataset_folder = self.dataset_folder.copy()
self.dataset_folder = list(tmp_dataset_folder.keys())
self.dataset_weights = list(tmp_dataset_folder.values())
elif isinstance(self.training_folder, dict): # Case 3: dict with > 1 training_folder and weights
tmp_training_folder = self.training_folder.copy()
self.training_folder = list(tmp_training_folder.keys())
self.dataset_weights = list(tmp_training_folder.values())

assert len(self.training_folder) == len(self.validation_folder)
assert len(self.training_folder) == len(self.dataset_tokens)
Expand Down
33 changes: 17 additions & 16 deletions src/nanotron/data/multilingual_nanoset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(
dataset_folders: List[str],
sequence_length: int,
token_size: int,
train_split_num_samples: int,
dataset_tokens: List[int],
train_split_num_samples: int = None,
is_valid: bool = False,
dataset_weights: Union[List[float], None] = None,
random_seed: int = 1234,
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(
), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided."
## Build dataset index and dataset sample index
if is_valid: # Valid MultilingualNanoset
self.dataset_index, self.dataset_sample_index = self.build_valid_nanoset_index(self.dataset_lengths)
self.dataset_index, self.dataset_sample_index = build_valid_nanoset_index(self.dataset_lengths)

else: # Train MultilingualNanoset
self.dataset_index, self.dataset_sample_index = self.build_train_nanoset_index()
Expand Down Expand Up @@ -136,20 +136,6 @@ def build_train_nanoset_index(self) -> np.ndarray:

return dataset_index, dataset_sample_index

@jit(nopython=True, cache=True)
def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray:
"""
Build valid dataset index and dataset sample index
"""
dataset_index = []
dataset_sample_index = []

for i, length in enumerate(dataset_lengths):
dataset_index.extend([i] * length)
dataset_sample_index.extend(range(length))

return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long")

def print_nanoset_info(self):

log_rank(
Expand Down Expand Up @@ -211,3 +197,18 @@ def build_train_nanoset_index_helper(
current_samples[max_error_index] += 1

return dataset_index, dataset_sample_index


@jit(nopython=True, cache=True)
def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray:
"""
Build valid dataset index and dataset sample index
"""
dataset_index = []
dataset_sample_index = []

for i, length in enumerate(dataset_lengths):
dataset_index.extend([i] * length)
dataset_sample_index.extend(range(length))

return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long")
5 changes: 4 additions & 1 deletion tools/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def main(args):
dataset_options={"split": args.split},
)
elif args.readers == "parquet":
datatrove_reader = ParquetReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern)
datatrove_reader = ParquetReader(
data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern
)
else:
datatrove_reader = JsonlReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern)

Expand All @@ -107,6 +109,7 @@ def main(args):
datatrove_reader,
DocumentTokenizer(
output_folder=args.output_folder,
shuffle=False,
tokenizer_name_or_path=args.tokenizer_name_or_path,
eos_token=args.eos_token,
shuffle=False,
Expand Down

0 comments on commit 6cbe054

Please sign in to comment.