diff --git a/examples/openwebtext/README.md b/examples/openwebtext/README.md index 99d0bf7..f10ece0 100644 --- a/examples/openwebtext/README.md +++ b/examples/openwebtext/README.md @@ -14,7 +14,7 @@ We will use the pre-trained Meta-Llama-3-8B model [from HuggingFace](https://hug To compute factors using the `ekfac` factorization strategy, run the following command which uses 4 A100 (80GB) GPUs: ```bash -torchrun --standalone --nnodes=1 --nproc-per-node=4 fit_factors.py --factor_batch_size 4 +torchrun --standalone --nnodes=1 --nproc-per-node=4 fit_factors.py --factors_name jul_11_2024 --factor_batch_size 4 ``` ## Computing Influence Scores @@ -25,5 +25,5 @@ I saved some prompt and completition pair to the directory `data/data.json`. To compute influence scores on the generated prompt and compleition pair, run the following command: ```bash -torchrun --standalone --nnodes=1 --nproc-per-node=4 compute_scores.py --train_batch_size 8 +torchrun --standalone --nnodes=1 --nproc-per-node=4 compute_scores.py --train_batch_size 8 --query_gradient_rank 32 ``` \ No newline at end of file diff --git a/examples/openwebtext/fit_factors.py b/examples/openwebtext/fit_factors.py index c96536f..740f349 100644 --- a/examples/openwebtext/fit_factors.py +++ b/examples/openwebtext/fit_factors.py @@ -26,6 +26,12 @@ def parse_args(): parser = argparse.ArgumentParser(description="Influence factor computation on Openwebtext dataset.") + parser.add_argument( + "--factors_name", + type=str, + default="july_11", + help="Strategy to compute influence factors.", + ) parser.add_argument( "--factor_strategy", type=str, @@ -77,7 +83,7 @@ def main(): dataloader_kwargs = DataLoaderKwargs(num_workers=4, collate_fn=default_data_collator, pin_memory=True) analyzer.set_dataloader_kwargs(dataloader_kwargs) - factors_name = args.factor_strategy + factors_name = args.factors_name factor_args = extreme_reduce_memory_factor_arguments( strategy=args.factor_strategy, module_partitions=1, dtype=torch.bfloat16 )