From d02a5749a6f6e1dc558fff4e99b1bc795ba75c15 Mon Sep 17 00:00:00 2001 From: "pierre.delaunay" Date: Wed, 22 Jan 2025 09:37:33 -0500 Subject: [PATCH] Update llm-lora-ddp-gpus --- .../configs/llama3_8B_lora_single_device.yaml | 58 ++- benchmarks/llm/requirements.cuda.txt | 479 ++++++++++++++++++ config/base.yaml | 4 +- 3 files changed, 518 insertions(+), 23 deletions(-) create mode 100644 benchmarks/llm/requirements.cuda.txt diff --git a/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml b/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml index 4dad835e..a6c7b070 100644 --- a/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml +++ b/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml @@ -1,49 +1,58 @@ # Config for single device LoRA finetuning in lora_finetune_single_device.py -# using a Llama3 8B Instruct model +# using a Llama3.1 8B Instruct model # # This config assumes that you've run the following command before launching # this run: -# tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir /tmp/Meta-Llama-3-8B-Instruct --hf-token +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" # # To launch on a single device, run the following command from root: -# tune run lora_finetune_single_device --config llama3/8B_lora_single_device +# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device # # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training # you can run: -# tune run lora_finetune_single_device --config llama3/8B_lora_single_device checkpointer.checkpoint_dir= +# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device checkpointer.checkpoint_dir= # # This config works only for training on single device. +output_dir: /tmp/torchtune/llama3_1_8B/lora_single_device # /tmp may be deleted by your system. Change it to your preference. + # Model Arguments model: _component_: torchtune.models.llama3_1.lora_llama3_1_8b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank + lora_dropout: 0.0 # Tokenizer tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer - path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + max_seq_len: null checkpointer: - _component_: torchtune.training.FullModelMetaCheckpointer - checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/ + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ checkpoint_files: [ - consolidated.00.pth + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors ] recipe_checkpoint: null - output_dir: /tmp/Meta-Llama-3-8B-Instruct/ + output_dir: ${output_dir} model_type: LLAMA3 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -51,33 +60,38 @@ batch_size: 2 # Optimizer and Scheduler optimizer: _component_: torch.optim.AdamW + fused: True weight_decay: 0.01 lr: 3e-4 lr_scheduler: - _component_: torchtune.modules.get_cosine_schedule_with_warmup + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 64 -compile: False +gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null +compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging -output_dir: /tmp/lora_finetune_output metric_logger: _component_: torchtune.training.metric_logging.DiskLogger - log_dir: ${output_dir} + log_dir: ${output_dir}/logs log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: True + +# Activations Memory +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory + # Profiler (disabled) profiler: @@ -100,6 +114,6 @@ profiler: # `torch.profiler.schedule` options: # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat wait_steps: 5 - warmup_steps: 5 + warmup_steps: 3 active_steps: 2 num_cycles: 1 diff --git a/benchmarks/llm/requirements.cuda.txt b/benchmarks/llm/requirements.cuda.txt new file mode 100644 index 00000000..5bf56ae8 --- /dev/null +++ b/benchmarks/llm/requirements.cuda.txt @@ -0,0 +1,479 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile --output-file=benchmarks/llm/requirements.cuda.txt .pin/tmp-constraints-cuda-llm-full-mp-nodes.txt benchmarks/llm/requirements.in +# +--extra-index-url https://pypi.ngc.nvidia.com +--extra-index-url https://download.pytorch.org/whl/cu121 +--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +--find-links https://data.pyg.org/whl/torch-2.5.1+cu121.html +--trusted-host pypi.ngc.nvidia.com + +accelerate==1.3.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # -r benchmarks/llm/requirements.in +aiohappyeyeballs==2.4.4 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # aiohttp +aiohttp==3.11.11 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # datasets + # fsspec +aiosignal==1.3.2 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # aiohttp +antlr4-python3-runtime==4.9.3 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # omegaconf +argklass==1.4.4 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # -r benchmarks/llm/requirements.in +asttokens==3.0.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # giving +async-timeout==5.0.1 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # aiohttp +attrs==24.3.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # aiohttp +blobfile==3.0.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # -r benchmarks/llm/requirements.txt + # torchtune +certifi==2024.12.14 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # requests +charset-normalizer==3.4.1 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # requests +codefind==0.1.7 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # ptera +datasets==3.2.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # torchtune +dill==0.3.8 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # datasets + # multiprocess +executing==2.1.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # varname +fairscale==0.4.13 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # -r benchmarks/llm/requirements.in + # -r benchmarks/llm/requirements.txt +filelock==3.17.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # blobfile + # datasets + # huggingface-hub + # torch + # transformers + # triton +fire==0.7.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # -r benchmarks/llm/requirements.txt +frozenlist==1.5.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # aiohttp + # aiosignal +fsspec[http]==2024.9.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # datasets + # huggingface-hub + # torch +giving==0.4.3 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # ptera + # voir +hjson==3.1.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # argklass +huggingface-hub[hf-transfer,hf_transfer]==0.27.1 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # accelerate + # datasets + # tokenizers + # torchtune + # transformers +idna==3.10 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # requests + # yarl +importlib-resources==6.5.2 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # argklass +jax[cuda12]==0.5.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # -r .pin/../constraints/extra/torch.cuda.txt +jax-cuda12-pjrt==0.5.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # jax-cuda12-plugin +jax-cuda12-plugin[with-cuda]==0.5.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # jax +jaxlib==0.5.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # jax +jinja2==3.1.5 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # torch +kagglehub==0.3.6 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # torchtune +lxml==5.3.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # blobfile +markdown-it-py==3.0.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # rich +markupsafe==3.0.2 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # jinja2 +mdurl==0.1.2 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # markdown-it-py +ml-dtypes==0.5.1 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # jax + # jaxlib +mpmath==1.3.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # sympy +multidict==6.1.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # aiohttp + # yarl +multiprocess==0.70.16 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # datasets +networkx==3.4.2 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # torch +numpy==1.26.4 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # accelerate + # datasets + # fairscale + # jax + # jaxlib + # ml-dtypes + # pandas + # scipy + # torchtune + # transformers + # xformers +nvidia-cublas-cu12==12.1.3.1 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # jax-cuda12-plugin + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.1.105 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # jax-cuda12-plugin + # torch +nvidia-cuda-nvcc-cu12==12.6.85 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # jax-cuda12-plugin +nvidia-cuda-nvrtc-cu12==12.1.105 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # torch +nvidia-cuda-runtime-cu12==12.1.105 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # jax-cuda12-plugin + # torch +nvidia-cudnn-cu12==9.1.0.70 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # jax-cuda12-plugin + # torch +nvidia-cufft-cu12==11.0.2.54 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # jax-cuda12-plugin + # torch +nvidia-curand-cu12==10.3.2.106 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # torch +nvidia-cusolver-cu12==11.4.5.107 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # jax-cuda12-plugin + # torch +nvidia-cusparse-cu12==12.1.0.106 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # jax-cuda12-plugin + # nvidia-cusolver-cu12 + # torch +nvidia-ml-py==12.560.30 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # voir +nvidia-nccl-cu12==2.21.5 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # jax-cuda12-plugin + # torch +nvidia-nvjitlink-cu12==12.6.85 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # jax-cuda12-plugin + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # torch +omegaconf==2.3.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # torchtune + # voir +opt-einsum==3.4.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # jax +ovld==0.3.9 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # voir +packaging==24.2 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # accelerate + # datasets + # huggingface-hub + # kagglehub + # transformers +pandas==2.2.3 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # datasets +pillow==11.1.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # torchtune +propcache==0.2.1 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # aiohttp + # yarl +psutil==5.9.8 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # accelerate + # torchtune + # voir +ptera==1.4.1 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # voir +pyarrow==19.0.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # datasets +pycryptodomex==3.21.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # blobfile +pygments==2.19.1 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # rich +python-dateutil==2.9.0.post0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # pandas +pytz==2024.2 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # pandas +pyyaml==6.0.2 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # -r benchmarks/llm/requirements.in + # accelerate + # datasets + # huggingface-hub + # omegaconf + # transformers +reactivex==4.0.4 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # giving +regex==2024.11.6 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # tiktoken + # transformers +requests==2.32.3 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # datasets + # huggingface-hub + # kagglehub + # tiktoken + # transformers +rich==13.9.4 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # voir +safetensors==0.5.2 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # accelerate + # torchtune + # transformers +scipy==1.15.1 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # jax + # jaxlib +sentencepiece==0.2.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # torchtune +six==1.17.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # python-dateutil +sympy==1.13.1 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # torch +termcolor==2.5.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # fire +tiktoken==0.8.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # torchtune +tokenizers==0.21.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # transformers +torch==2.5.1+cu121 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # -c .pin/../constraints/cuda.txt + # -r benchmarks/llm/requirements.in + # -r benchmarks/llm/requirements.txt + # accelerate + # fairscale + # xformers +torchao==0.8.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # -r benchmarks/llm/requirements.in +torchtune==0.5.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # -r benchmarks/llm/requirements.in +tqdm==4.67.1 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # datasets + # huggingface-hub + # kagglehub + # torchtune + # transformers +transformers==4.48.1 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # -r benchmarks/llm/requirements.in +triton==3.1.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # torch +typing-extensions==4.12.2 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # huggingface-hub + # multidict + # reactivex + # rich + # torch +tzdata==2025.1 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # pandas +urllib3==2.3.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # blobfile + # requests +varname==0.14.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # giving +voir==0.2.19 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # -c .pin/../constraints/cuda.txt + # -r benchmarks/llm/requirements.in +xformers==0.0.29.post1 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # -r .pin/../constraints/extra/torch.cuda.txt +xxhash==3.5.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # datasets +yarl==1.18.3 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # aiohttp diff --git a/config/base.yaml b/config/base.yaml index b96dfdd9..3b8b5048 100644 --- a/config/base.yaml +++ b/config/base.yaml @@ -591,8 +591,9 @@ llm-lora-ddp-gpus: epochs=1: true output_dir={milabench_extra}/output: true tokenizer.path={milabench_data}/llama3_8B/original/tokenizer.model: true - checkpointer.checkpoint_dir={milabench_data}/llama3_8B/original: true + checkpointer.checkpoint_dir={milabench_data}/llama3_8B/: true checkpointer.output_dir={milabench_data}/llama3_8B/: true + safetensors=true: true metric_logger.log_dir={milabench_extra}/metrics: true repo_id="meta-llama/Meta-Llama-3.1-8B": true batch_size=8: true @@ -619,6 +620,7 @@ llm-lora-ddp-nodes: tokenizer.path={milabench_data}/llama3_8B/original/tokenizer.model: true checkpointer.checkpoint_dir={milabench_data}/llama3_8B/original: true checkpointer.output_dir={milabench_data}/llama3_8B/: true + safetensors=true: true metric_logger.log_dir={milabench_extra}/metrics: true repo_id="meta-llama/Meta-Llama-3.1-8B": true batch_size=8: true