Skip to content

Commit

Permalink
Add factors name arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 11, 2024
1 parent c6be439 commit 05d06a0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
4 changes: 2 additions & 2 deletions examples/openwebtext/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ We will use the pre-trained Meta-Llama-3-8B model [from HuggingFace](https://hug
To compute factors using the `ekfac` factorization strategy, run the following command which uses 4 A100 (80GB) GPUs:

```bash
torchrun --standalone --nnodes=1 --nproc-per-node=4 fit_factors.py --factor_batch_size 4
torchrun --standalone --nnodes=1 --nproc-per-node=4 fit_factors.py --factors_name jul_11_2024 --factor_batch_size 4
```

## Computing Influence Scores
Expand All @@ -25,5 +25,5 @@ I saved some prompt and completition pair to the directory `data/data.json`.
To compute influence scores on the generated prompt and compleition pair, run the following command:

```bash
torchrun --standalone --nnodes=1 --nproc-per-node=4 compute_scores.py --train_batch_size 8
torchrun --standalone --nnodes=1 --nproc-per-node=4 compute_scores.py --train_batch_size 8 --query_gradient_rank 32
```
8 changes: 7 additions & 1 deletion examples/openwebtext/fit_factors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
def parse_args():
parser = argparse.ArgumentParser(description="Influence factor computation on Openwebtext dataset.")

parser.add_argument(
"--factors_name",
type=str,
default="july_11",
help="Strategy to compute influence factors.",
)
parser.add_argument(
"--factor_strategy",
type=str,
Expand Down Expand Up @@ -77,7 +83,7 @@ def main():
dataloader_kwargs = DataLoaderKwargs(num_workers=4, collate_fn=default_data_collator, pin_memory=True)
analyzer.set_dataloader_kwargs(dataloader_kwargs)

factors_name = args.factor_strategy
factors_name = args.factors_name
factor_args = extreme_reduce_memory_factor_arguments(
strategy=args.factor_strategy, module_partitions=1, dtype=torch.bfloat16
)
Expand Down

0 comments on commit 05d06a0

Please sign in to comment.