diff --git a/llama-recipes b/llama-recipes index 09dd581..4634605 160000 --- a/llama-recipes +++ b/llama-recipes @@ -1 +1 @@ -Subproject commit 09dd581f392701e7de618a76ef2fcb2733aad43a +Subproject commit 46346059ae86f60a2af3ed1df273069d1b85c687 diff --git a/src/llama2d/modal/finetuning.py b/src/llama2d/modal/finetuning.py index b2f3a44..e2756dc 100644 --- a/src/llama2d/modal/finetuning.py +++ b/src/llama2d/modal/finetuning.py @@ -92,6 +92,7 @@ def main(Llama, LlamaCfg, **kwargs): kwargs = { "use_2d": use_2d, "lbd_start_value": train_config.lbd_start_value, + "use_point_embed": train_config.use_point_embed, } # if use_2d else {} # Load the pre-trained model and setup its configuration @@ -125,6 +126,7 @@ def main(Llama, LlamaCfg, **kwargs): llama_config.use_2d = use_2d llama_config.lbd_start_value = train_config.lbd_start_value + llama_config.use_point_embed = train_config.use_point_embed with torch.device("meta"): model = Llama(llama_config) @@ -269,84 +271,22 @@ def main(Llama, LlamaCfg, **kwargs): collate_fn=default_data_collator, ) - def print_generations(): - # broken right now - return - - # if train_config.enable_fsdp and rank != 0: - # print(f"Skipping generation on rank {rank}") - # return - - # show generations - - num_samples = 5 - rand_idxes = torch.randint(0, len(dataset_val), (num_samples,)) - print("-----Sample generation-------") - - for test_sample in range(num_samples): - # get a sample from the val dataset - test_sample = dataset_val[rand_idxes[test_sample]] - - print("GT:") - print(tokenizer.decode(test_sample["input_ids"], skip_special_tokens=True)) - - # get last positive label idx - last_outputted_chunk_idx = ( - torch.nonzero(test_sample["labels"] > 0)[-1].item() + 1 - ) - # get the first label idx of that chunk - first_idx_of_chunk = ( - torch.nonzero(test_sample["labels"][:last_outputted_chunk_idx] <= 0)[ - -1 - ].item() - + 1 - ) - - batched_sample = { - k: v[None, :first_idx_of_chunk, ...] for k, v in test_sample.items() - } - - assert ( - batched_sample["input_ids"][0, -1] == 0 - ), "input_ids for valid set must end with pad token" - - # move to current device - batched_sample = {k: v.to("cuda") for k, v in batched_sample.items()} - - print(batched_sample["input_ids"]) - - print("Generated:") - with torch.no_grad(): - print( - tokenizer.decode( - model.generate(**batched_sample, max_new_tokens=100)[0], - skip_special_tokens=True, - ) - ) - - print_generations() - # Initialize the optimizer and learning rate scheduler - # make 2 param groups: for *.lbd.* and for the rest - param_groups = [ - { - "params": [ - p - for n, p in model.named_parameters() - if any([k in n for k in train_config.target_modules]) - ], - "lr": train_config.lambda_lr, - }, - { - "params": [ - p - for n, p in model.named_parameters() - if not any([k in n for k in train_config.target_modules]) - ], - "lr": train_config.lr, - }, - ] + # make custom param groups + group_substrs = { + "lambda":[train_config.lambda_lr,"lbd"], + "point_embed":[train_config.point_embed_lr,"is_a_point_embed"], + } + param_groups = [] + for n,p in model.named_parameters(): + for group_name,(lr,substr) in group_substrs.items(): + if substr in n: + param_groups.append({"params":[p],"lr":lr}) + break + else: + param_groups.append({"params":[p],"lr":train_config.lr}) + if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision": optimizer = AnyPrecisionAdamW( diff --git a/src/llama2d/modal/train.py b/src/llama2d/modal/train.py index 86e4d7c..fec6a47 100644 --- a/src/llama2d/modal/train.py +++ b/src/llama2d/modal/train.py @@ -96,14 +96,18 @@ def main( run_id: str = "", num_epochs: int = 10, batch_size: int = 16, - use_2d: bool = True, - peft: bool = False, repo: str = "llama2d/llama2d-mind2web", - lbd_start_value: float = 0.0, - lr: float = 3e-5, - lambda_lr: float = 3e-4, keep_fraction: float = 1.0, seed: int = 0, + + peft: bool = False, + use_2d: bool = True, + use_point_embed: bool = True, + lbd_start_value: float = 0.0, + lr: float = 3e-5, + lambda_lr: float = 3e-2, + point_embed_lr: float = 3e-2, + # wandb args group: str = None, name: str = None, @@ -157,6 +161,8 @@ def main( "repo": repo, "lbd_start_value": lbd_start_value, "seed": seed, + "use_point_embed": use_point_embed, + "point_embed_lr": point_embed_lr, } ) diff --git a/transformers b/transformers index 7d6eb4c..f836244 160000 --- a/transformers +++ b/transformers @@ -1 +1 @@ -Subproject commit 7d6eb4ccf6329f353a7cf0a57660329b6f84aeaf +Subproject commit f836244a8f15639e4397d3b5df559b4a7c0aae77