Skip to content

Commit

Permalink
Add toggle + param group for point embed
Browse files Browse the repository at this point in the history
  • Loading branch information
andrew-healey committed Oct 4, 2023
1 parent a22b44d commit e0c234d
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 83 deletions.
2 changes: 1 addition & 1 deletion llama-recipes
92 changes: 16 additions & 76 deletions src/llama2d/modal/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def main(Llama, LlamaCfg, **kwargs):
kwargs = {
"use_2d": use_2d,
"lbd_start_value": train_config.lbd_start_value,
"use_point_embed": train_config.use_point_embed,
} # if use_2d else {}

# Load the pre-trained model and setup its configuration
Expand Down Expand Up @@ -125,6 +126,7 @@ def main(Llama, LlamaCfg, **kwargs):

llama_config.use_2d = use_2d
llama_config.lbd_start_value = train_config.lbd_start_value
llama_config.use_point_embed = train_config.use_point_embed

with torch.device("meta"):
model = Llama(llama_config)
Expand Down Expand Up @@ -269,84 +271,22 @@ def main(Llama, LlamaCfg, **kwargs):
collate_fn=default_data_collator,
)

def print_generations():
# broken right now
return

# if train_config.enable_fsdp and rank != 0:
# print(f"Skipping generation on rank {rank}")
# return

# show generations

num_samples = 5
rand_idxes = torch.randint(0, len(dataset_val), (num_samples,))
print("-----Sample generation-------")

for test_sample in range(num_samples):
# get a sample from the val dataset
test_sample = dataset_val[rand_idxes[test_sample]]

print("GT:")
print(tokenizer.decode(test_sample["input_ids"], skip_special_tokens=True))

# get last positive label idx
last_outputted_chunk_idx = (
torch.nonzero(test_sample["labels"] > 0)[-1].item() + 1
)
# get the first label idx of that chunk
first_idx_of_chunk = (
torch.nonzero(test_sample["labels"][:last_outputted_chunk_idx] <= 0)[
-1
].item()
+ 1
)

batched_sample = {
k: v[None, :first_idx_of_chunk, ...] for k, v in test_sample.items()
}

assert (
batched_sample["input_ids"][0, -1] == 0
), "input_ids for valid set must end with pad token"

# move to current device
batched_sample = {k: v.to("cuda") for k, v in batched_sample.items()}

print(batched_sample["input_ids"])

print("Generated:")
with torch.no_grad():
print(
tokenizer.decode(
model.generate(**batched_sample, max_new_tokens=100)[0],
skip_special_tokens=True,
)
)

print_generations()

# Initialize the optimizer and learning rate scheduler

# make 2 param groups: for *.lbd.* and for the rest
param_groups = [
{
"params": [
p
for n, p in model.named_parameters()
if any([k in n for k in train_config.target_modules])
],
"lr": train_config.lambda_lr,
},
{
"params": [
p
for n, p in model.named_parameters()
if not any([k in n for k in train_config.target_modules])
],
"lr": train_config.lr,
},
]
# make custom param groups
group_substrs = {
"lambda":[train_config.lambda_lr,"lbd"],
"point_embed":[train_config.point_embed_lr,"is_a_point_embed"],
}
param_groups = []
for n,p in model.named_parameters():
for group_name,(lr,substr) in group_substrs.items():
if substr in n:
param_groups.append({"params":[p],"lr":lr})
break
else:
param_groups.append({"params":[p],"lr":train_config.lr})


if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
optimizer = AnyPrecisionAdamW(
Expand Down
16 changes: 11 additions & 5 deletions src/llama2d/modal/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,18 @@ def main(
run_id: str = "",
num_epochs: int = 10,
batch_size: int = 16,
use_2d: bool = True,
peft: bool = False,
repo: str = "llama2d/llama2d-mind2web",
lbd_start_value: float = 0.0,
lr: float = 3e-5,
lambda_lr: float = 3e-4,
keep_fraction: float = 1.0,
seed: int = 0,

peft: bool = False,
use_2d: bool = True,
use_point_embed: bool = True,
lbd_start_value: float = 0.0,
lr: float = 3e-5,
lambda_lr: float = 3e-2,
point_embed_lr: float = 3e-2,

# wandb args
group: str = None,
name: str = None,
Expand Down Expand Up @@ -157,6 +161,8 @@ def main(
"repo": repo,
"lbd_start_value": lbd_start_value,
"seed": seed,
"use_point_embed": use_point_embed,
"point_embed_lr": point_embed_lr,
}
)

Expand Down

0 comments on commit e0c234d

Please sign in to comment.