From 8631a20319251be8f72176defe19cc3531cb824d Mon Sep 17 00:00:00 2001 From: Ankur Srivastava Date: Wed, 3 Jul 2024 18:04:17 -0700 Subject: [PATCH 1/8] Added files Signed-off-by: Ankur Srivastava --- 3.test_cases/23.SMHP-esm2/README.md | 110 +++ 3.test_cases/23.SMHP-esm2/download_data.py | 225 +++++ 3.test_cases/23.SMHP-esm2/requirements.txt | 4 + .../23.SMHP-esm2/tokenize_uniref_csv.py | 228 +++++ 3.test_cases/23.SMHP-esm2/train.py | 780 ++++++++++++++++++ 5 files changed, 1347 insertions(+) create mode 100644 3.test_cases/23.SMHP-esm2/README.md create mode 100644 3.test_cases/23.SMHP-esm2/download_data.py create mode 100644 3.test_cases/23.SMHP-esm2/requirements.txt create mode 100644 3.test_cases/23.SMHP-esm2/tokenize_uniref_csv.py create mode 100644 3.test_cases/23.SMHP-esm2/train.py diff --git a/3.test_cases/23.SMHP-esm2/README.md b/3.test_cases/23.SMHP-esm2/README.md new file mode 100644 index 00000000..c8b1a413 --- /dev/null +++ b/3.test_cases/23.SMHP-esm2/README.md @@ -0,0 +1,110 @@ +# How to pretrain ESM2 with SageMaker Hyperpod using Amazon G5 instances + +## What is SageMaker Hyperpod? +[Amazon SageMaker Hyperpod](https://aws.amazon.com/sagemaker/hyperpod/) offers advanced training tools to help you accelerate scalable, reliable, and secure generative AI application development. It removes the undifferentiated heavy lifting involved in building and optimizing machine learning (ML) infrastructure for training foundation models (FMs) significantly reducing training time. SageMaker Hyperpod ensure customers can continue FM training uninterrupted by periodically saving checkpoints. When a hardware failure occurs during training, SageMaker Hyperpod automatically detects the failure, repairs, or replaces the faulty instance, and resumes the training from the last saved checkpoint, removing the need for customers to manually manage this process and helping them train for week or months in a distributed setting without disruption. + + +## What is ESM-2? +[ESM-2](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1) is a pLM trained using unsupervied masked language modelling on 250 Million protein sequences by researchers at [Facebook AI Research (FAIR)](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1). It is available in several sizes, ranging from 8 Million to 15 Billion parameters. The smaller models are suitable for various sequence and token classification tasks. The FAIR team also adapted the 3 Billion parameter version into the ESMFold protein structure prediction algorithm. They have since used ESMFold to predict the struture of [more than 700 Million metagenomic proteins](https://esmatlas.com/about). + +ESM-2 is a powerful pLM. We will demonstrate how to use QLoRA to fine-tune ESM-2 on g5.24xlarge instances. We will use ESM-2 to predict [subcellular localization](https://academic.oup.com/nar/article/50/W1/W228/6576357?login=false). Understanding where proteins appear in cells can help us understand their role in disease and find new drug targets. + +## 0. Prerequisites +You will need to set up a SageMaker Hyperpod cluster using 2 g5.24xlarge instances with a shared parallel filesystem such as [Amazon FSx for Lustre](https://docs.aws.amazon.com/fsx/latest/LustreGuide/getting-started.html). See the sagemaker-hyperpod section in the [Sagemaker Hyperpod](https://github.com/aws-samples/awsome-distributed-training/tree/main/1.architectures/5.sagemaker-hyperpod) folder for setup instructions. + +## 1. Install conda + +You can install MiniConda as follows: + +```bash +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh +chmod +x Miniconda3-latest-Linux-x86_64.sh +./Miniconda3-latest-Linux-x86_64.sh -b -f -p ./miniconda3 + +source ./miniconda3/bin/activate +``` +## 2. Create conda environment + +You can create conda environment as follows: + +```bash + conda create --name esm python=3.10 + conda activate esm + pip3 install -r requirements.txt +``` + +## 3. Prepare dataset + +Next we need to download the Uniref50 training data. You can do so by running: + +```bash +python3 download_data.py +``` +It would download the data and partitions the data in 50 .csv files in `/fsx/ubuntu/csv` folder. The whole process should take less than 30 mins. + +```bash +(esm) (CONTROLLER) ubuntu@ip-10-1-71-160:~$ python3 download_data.py +07/03/2024 21:07:01 - INFO - Parsing arguments +07/03/2024 21:07:01 - INFO - Downloading FASTA +07/03/2024 21:07:01 - INFO - Downloading https://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref50/uniref50.fasta.gz to /fsx/ubuntu/tmp9kq51ybi/fasta +https://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref50/uniref50.fasta.gz: 100%|████████████████████████████████████████████████████████████████████████████████| 12.8G/12.8G [06:11<00:00, 36.8MB/s] +07/03/2024 21:13:13 - INFO - Generating csv files +Reading FASTA file +498383it [00:12, 59276.95it/s]07/03/2024 21:13:26 - INFO - Writing 500000 records to /fsx/ubuntu/csv/x000.csv +994642it [00:47, 77930.58it/s]07/03/2024 21:14:00 - INFO - Writing 500000 records to /fsx/ubuntu/csv/x001.csv +1495773it [01:08, 88755.06it/s]07/03/2024 21:14:22 - INFO - Writing 500000 records to /fsx/ubuntu/csv/x002.csv +1993826it [01:26, 98115.08it/s]07/03/2024 21:14:40 - INFO - Writing 500000 records to /fsx/ubuntu/csv/x003.csv +... +... +65446537it [11:32, 608611.75it/s]07/03/2024 21:24:46 - INFO - Writing 500000 records to /fsx/ubuntu/csv/x130.csv +65672468it [11:33, 94696.65it/s] +07/03/2024 21:24:47 - INFO - Writing 172468 records to /fsx/ubuntu/csv/x131.csv +07/03/2024 21:24:49 - INFO - Save complete +``` + +## 4. Convert CSVs to HuggingFace Dataset and Tokenize + +Next we need to tokenize the dataset. This will split the data in training, test and validation folders, tokenize them and save the arrow files in `processed` folder. + +```bash +(esm) (CONTROLLER) ubuntu@ip-10-1-71-160:~$ python3 tokenize_uniref_csv.py +07/03/2024 23:07:49 - INFO - Parsing arguments +07/03/2024 23:07:49 - INFO - Loading csv files from /fsx/ubuntu/csv +Resolving data files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:00<00:00, 356272.93it/s] +07/03/2024 23:07:51 - INFO - DatasetDict({ + train: Dataset({ + features: ['text'], + num_rows: 65672468 + }) +}) +07/03/2024 23:07:51 - INFO - Splitting dataset +Flattening the indices: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000000/10000000 [06:02<00:00, 27582.63 examples/s] +Flattening the indices: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:00<00:00, 59268.14 examples/s] +Flattening the indices: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:00<00:00, 62442.35 examples/s] +07/03/2024 23:14:01 - INFO - Saving splits to csv +Creating CSV from Arrow format: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [01:51<00:00, 89.70ba/s] +Creating CSV from Arrow format: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 89.99ba/s] +Creating CSV from Arrow format: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 89.29ba/s] +/fsx/ubuntu/miniconda3/envs/esm/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`. + warnings.warn( +tokenizer_config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95.0/95.0 [00:00<00:00, 949kB/s] +vocab.txt: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 93.0/93.0 [00:00<00:00, 1.09MB/s] +special_tokens_map.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 1.55MB/s] +07/03/2024 23:15:56 - INFO - Processing line by line +Running tokenizer on dataset line_by_line (num_proc=8): 100%|█████████████████████████████████████████████████████████████████████████████████████████| 10000000/10000000 [23:46<00:00, 7008.67 examples/s] +Running tokenizer on dataset line_by_line (num_proc=8): 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:57<00:00, 870.72 examples/s] +Running tokenizer on dataset line_by_line (num_proc=8): 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:08<00:00, 5695.93 examples/s] +Saving the dataset (62/62 shards): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000000/10000000 [00:55<00:00, 180076.96 examples/s] +Saving the dataset (1/1 shards): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:00<00:00, 177160.38 examples/s] +Saving the dataset (1/1 shards): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:00<00:00, 182452.27 examples/s] +``` + + + + + + + + + + diff --git a/3.test_cases/23.SMHP-esm2/download_data.py b/3.test_cases/23.SMHP-esm2/download_data.py new file mode 100644 index 00000000..f9e45745 --- /dev/null +++ b/3.test_cases/23.SMHP-esm2/download_data.py @@ -0,0 +1,225 @@ +import argparse +import boto3 +import csv +import datasets +import logging +import os +import pyfastx +import random +import requests +import tempfile +import tqdm +from urllib.parse import urlparse + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) + + +def parse_args(): + """Parse the arguments.""" + logging.info("Parsing arguments") + parser = argparse.ArgumentParser() + + parser.add_argument( + "--max_records_per_partition", + type=int, + default=500000, + help="Max number of sequence records per csv partition", + ) + parser.add_argument( + "--output_dir", + type=str, + default=os.getcwd(), + help="Output dir for processed files", + ) + parser.add_argument( + "--save_arrow", + type=bool, + default=False, + help="Save Apache Arrow files to output dir?", + ) + parser.add_argument( + "--save_csv", + type=bool, + default=True, + help="Save csv files to output dir?", + ) + parser.add_argument( + "--save_fasta", + type=bool, + default=False, + help="Save FASTA file to output dir?", + ) + parser.add_argument( + "--save_parquet", + type=bool, + default=False, + help="Save Apache Parquet files to output dir?", + ) + parser.add_argument( + "--shuffle", + type=bool, + default=True, + help="Shuffle the records in each csv partition?", + ) + parser.add_argument( + "--source", + type=str, + default="https://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref50/uniref50.fasta.gz", + help="Path to input .fasta or .fasta.gz file, e.g. s3://myfasta.fa, http://myfasta.fasta.gz, ~/myfasta.fasta, etc", + ) + + args, _ = parser.parse_known_args() + return args + + +def main(args): + """Transform fasta file into dataset""" + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + tmp_dir = tempfile.TemporaryDirectory(dir=os.getcwd()) + + logging.info("Downloading FASTA") + fasta_dir = ( + os.path.join(args.output_dir, "fasta") + if args.save_fasta + else os.path.join(tmp_dir.name, "fasta") + ) + fasta_path = download(args.source, fasta_dir) + + logging.info("Generating csv files") + csv_dir = ( + os.path.join(args.output_dir, "csv") + if args.save_csv + else os.path.join(tmp_dir.name, "csv") + ) + csv_path = fasta_to_csv( + fasta_path, csv_dir, args.max_records_per_partition + ) + + if args.save_arrow or args.save_parquet: + logging.info("Loading csv files into dataset") + ds = datasets.load_dataset( + "csv", + data_dir=csv_path, + num_proc=os.cpu_count(), + cache_dir=os.path.join(tmp_dir.name, "dataset_cache"), + ) + + logging.info("Saving dataset in Arrow format") + if args.save_arrow: + ds.save_to_disk(os.path.join(args.output_dir, "arrow")) + + logging.info("Saving dataset in Parquet format") + if args.save_parquet: + for split in ds.keys(): + ds[split].to_parquet( + f"{os.path.join(args.output_dir, 'parquet')}/data.parquet" + ) + + tmp_dir.cleanup() + logging.info("Save complete") + return args.output_dir + + +def download(source: str, filename: str) -> str: + logging.info(f"Downloading {source} to {filename}") + output_dir = os.path.dirname(filename) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + if source.startswith("s3"): + s3 = boto3.client("s3") + parsed = urlparse(source, allow_fragments=False) + bucket = parsed.netloc + key = parsed.path[1:] + total = s3.head_object(Bucket=bucket, Key=key)["ContentLength"] + tqdm_params = { + "desc": source, + "total": total, + "miniters": 1, + "unit": "B", + "unit_scale": True, + "unit_divisor": 1024, + } + with tqdm.tqdm(**tqdm_params) as pb: + s3.download_file( + parsed.netloc, + parsed.path[1:], + filename, + Callback=lambda bytes_transferred: pb.update(bytes_transferred), + ) + elif source.startswith("http"): + with open(filename, "wb") as f: + with requests.get(source, stream=True) as r: + r.raise_for_status() + total = int(r.headers.get("content-length", 0)) + + tqdm_params = { + "desc": source, + "total": total, + "miniters": 1, + "unit": "B", + "unit_scale": True, + "unit_divisor": 1024, + } + with tqdm.tqdm(**tqdm_params) as pb: + for chunk in r.iter_content(chunk_size=8192): + pb.update(len(chunk)) + f.write(chunk) + elif os.path.isfile(source): + logging.info(f"{source} already exists") + else: + raise ValueError(f"Invalid source: {source}") + return filename + + +def fasta_to_csv( + fasta: str, + output_dir: str = "csv", + max_records_per_partition=2000000, + shuffle=False, +) -> list: + """Split a .fasta or .fasta.gz file into multiple .csv files.""" + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print("Reading FASTA file") + fasta_list = [] + fasta_idx = 0 + + for i, seq in tqdm.tqdm( + enumerate(pyfastx.Fasta(fasta, build_index=False, uppercase=True)) + ): + fasta_list.append(seq) + + if (i + 1) % max_records_per_partition == 0: + if shuffle: + random.shuffle(fasta_list) + fasta_idx = int(i / max_records_per_partition) + _write_seq_record_to_csv(fasta_list, output_dir, fasta_idx) + fasta_list = [] + else: + _write_seq_record_to_csv(fasta_list, output_dir, fasta_idx + 1) + return output_dir + + +def _write_seq_record_to_csv(content_list, output_dir, index): + output_path = os.path.join(output_dir, f"x{str(index).rjust(3, '0')}.csv") + logging.info(f"Writing {len(content_list)} records to {output_path}") + + with open(output_path, "w") as f: + writer = csv.writer(f) + writer.writerow(("id", "text")) + writer.writerows(content_list) + return None + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/3.test_cases/23.SMHP-esm2/requirements.txt b/3.test_cases/23.SMHP-esm2/requirements.txt new file mode 100644 index 00000000..de61eed2 --- /dev/null +++ b/3.test_cases/23.SMHP-esm2/requirements.txt @@ -0,0 +1,4 @@ +accelerate==0.25.0 +datasets==2.16.1 +pyfastx==2.0.2 +transformers==4.37.2 diff --git a/3.test_cases/23.SMHP-esm2/tokenize_uniref_csv.py b/3.test_cases/23.SMHP-esm2/tokenize_uniref_csv.py new file mode 100644 index 00000000..778a4456 --- /dev/null +++ b/3.test_cases/23.SMHP-esm2/tokenize_uniref_csv.py @@ -0,0 +1,228 @@ +import argparse +import datasets +from itertools import chain +import logging +import os +import transformers +from urllib.parse import urlparse + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) + + +def parse_args(): + """Parse the arguments.""" + logging.info("Parsing arguments") + parser = argparse.ArgumentParser() + + parser.add_argument( + "--input_dir", + type=str, + default="/fsx/ubuntu/csv", + help="Input dir for protein sequence csv", + ) + parser.add_argument( + "--output_dir", + type=str, + default="/fsx/ubuntu/processed", + help="Output dir for processed files", + ) + parser.add_argument( + "--train_size", + type=int, + default=10000000, + help="The number of samples used for a test set", + ) + parser.add_argument( + "--validation_size", + type=int, + default=50000, + help="The number of samples used for a validation set", + ) + parser.add_argument( + "--test_size", + type=int, + default=50000, + help="The number of samples used for a test set", + ) + parser.add_argument( + "--max_seq_length", + type=int, + default=512, + help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated.", + ) + parser.add_argument( + "--pad_to_max_length", + type=bool, + default=True, + help="Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when batching to the maximum length in the batch.", + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, + default=8, + help="The number of workers to use for the preprocessing.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default="facebook/esm2_t30_150M_UR50D", + help="Pretrained tokenizer name to use.", + ) + parser.add_argument( + "--line_by_line", + type=bool, + default=True, + help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.", + ) + + args, _ = parser.parse_known_args() + return args + + +def main(args): + + logging.info(f"Loading csv files from {args.input_dir}") + extension = "csv" + data_files = [ + os.path.join(args.input_dir, f) + for f in os.listdir(args.input_dir) + if f.endswith(extension) + ] + + raw_data = datasets.load_dataset( + "csv", + data_files=data_files, + num_proc=args.preprocessing_num_workers, + ) + + raw_data = raw_data.remove_columns("id") + + logging.info(raw_data) + + logging.info("Splitting dataset") + train_testvalid = raw_data["train"].train_test_split( + train_size=args.train_size, test_size=args.validation_size + args.test_size + ) + test_valid = train_testvalid["test"].train_test_split( + train_size=args.validation_size, test_size=args.test_size + ) + raw_data = datasets.DatasetDict( + { + "train": train_testvalid["train"], + "validation": test_valid["train"], + "test": test_valid["test"], + } + ) + del train_testvalid + del test_valid + + raw_data.flatten_indices() + + logging.info("Saving splits to csv") + + for dir in ["train", "val", "test"]: + path = os.path.join(args.output_dir, "csv/" + dir) + if not os.path.exists(path): + os.makedirs(path) + + raw_data["train"].to_csv(os.path.join(args.output_dir, "csv/train", "x000.csv")) + raw_data["validation"].to_csv(os.path.join(args.output_dir, "csv/val", "x001.csv")) + raw_data["test"].to_csv(os.path.join(args.output_dir, "csv/test", "x002.csv")) + + column_names = list(raw_data["train"].features) + text_column_name = "text" if "text" in column_names else column_names[0] + + tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer_name) + + if args.line_by_line == True: + logging.info("Processing line by line") + + # When using line_by_line, we just tokenize each nonempty line. + padding = "max_length" if args.pad_to_max_length else False + + def tokenize_function(examples): + # Remove empty lines + examples[text_column_name] = [ + line + for line in examples[text_column_name] + if len(line) > 0 and not line.isspace() + ] + return tokenizer( + examples[text_column_name], + padding=padding, + truncation=True, + max_length=args.max_seq_length, + return_special_tokens_mask=True, + ) + + tokenized_datasets = raw_data.map( + tokenize_function, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=[text_column_name], + desc="Running tokenizer on dataset line_by_line", + ) + else: + # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. + # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more + # efficient when it receives the `special_tokens_mask`. + def tokenize_function(examples): + return tokenizer( + examples[text_column_name], return_special_tokens_mask=True + ) + + tokenized_datasets = raw_data.map( + tokenize_function, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + desc="Running tokenizer on every text in dataset", + ) + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of + # max_seq_length. + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = { + k: list(chain(*examples[k])) for k in examples.keys() + } + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, and if the total_length < max_seq_length we exclude this batch and return an empty dict. + # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. + total_length = (total_length // args.max_seq_length) * args.max_seq_length + # Split by chunks of max_len. + result = { + k: [ + t[i : i + args.max_seq_length] + for i in range(0, total_length, args.max_seq_length) + ] + for k, t in concatenated_examples.items() + } + return result + + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a + # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value + # might be slower to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/process#map + + tokenized_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=args.preprocessing_num_workers, + desc=f"Grouping texts in chunks of {args.max_seq_length}", + ) + arrow_output_path = os.path.join(args.output_dir, "arrow") + tokenized_datasets.save_to_disk(arrow_output_path) + + return arrow_output_path + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/3.test_cases/23.SMHP-esm2/train.py b/3.test_cases/23.SMHP-esm2/train.py new file mode 100644 index 00000000..eb269ba4 --- /dev/null +++ b/3.test_cases/23.SMHP-esm2/train.py @@ -0,0 +1,780 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2020 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) on a text file or a dataset. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=fill-mask +""" +# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. +from dataclasses import dataclass, field +import datasets +from datasets import load_dataset +import evaluate +from itertools import chain +import logging +import math +import os +import sys +from typing import Optional +import transformers +from transformers import ( + CONFIG_MAPPING, + MODEL_FOR_MASKED_LM_MAPPING, + AutoConfig, + AutoModelForMaskedLM, + AutoTokenizer, + DataCollatorForLanguageModeling, + HfArgumentParser, + Trainer, + TrainingArguments, + is_torch_tpu_available, + set_seed, +) +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils.versions import require_version +import warnings + +logger = logging.getLogger(__name__) +MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." + ) + }, + ) + model_type: Optional[str] = field( + default=None, + metadata={ + "help": "If training from scratch, pass a model type from the list: " + + ", ".join(MODEL_TYPES) + }, + ) + config_overrides: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Override some existing default config settings when a model is trained from scratch. Example: " + "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" + ) + }, + ) + config_name: Optional[str] = field( + default="facebook/esm2_t30_150M_UR50D", + metadata={ + "help": "Pretrained config name or path if not the same as model_name" + }, + ) + tokenizer_name: Optional[str] = field( + default="facebook/esm2_t30_150M_UR50D", + metadata={ + "help": "Pretrained tokenizer name or path if not the same as model_name" + }, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={ + "help": "Where do you want to store the pretrained models downloaded from huggingface.co" + }, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={ + "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." + }, + ) + model_revision: str = field( + default="main", + metadata={ + "help": "The specific model version to use (can be a branch name, tag name or commit id)." + }, + ) + token: str = field( + default=None, + metadata={ + "help": ( + "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " + "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." + ) + }, + ) + use_auth_token: bool = field( + default=None, + metadata={ + "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead." + }, + ) + trust_remote_code: bool = field( + default=False, + metadata={ + "help": ( + "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option" + "should only be set to `True` for repositories you trust and in which you have read the code, as it will " + "execute code present on the Hub on your local machine." + ) + }, + ) + low_cpu_mem_usage: bool = field( + default=False, + metadata={ + "help": ( + "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. " + "set True will benefit LLM loading time and RAM consumption." + ) + }, + ) + + def __post_init__(self): + if self.config_overrides is not None and ( + self.config_name is not None or self.model_name_or_path is not None + ): + raise ValueError( + "--config_overrides can't be used in combination with --config_name or --model_name_or_path" + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, + metadata={"help": "The name of the dataset to use (via the datasets library)."}, + ) + dataset_config_name: Optional[str] = field( + default=None, + metadata={ + "help": "The configuration name of the dataset to use (via the datasets library)." + }, + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "The input training data file (a text file)."} + ) + validation_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)." + }, + ) + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets"}, + ) + validation_split_percentage: Optional[int] = field( + default=5, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) + max_seq_length: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated." + ) + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + mlm_probability: float = field( + default=0.15, + metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}, + ) + line_by_line: bool = field( + default=False, + metadata={ + "help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences." + }, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": ( + "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch." + ) + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ) + }, + ) + streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) + ##### New stuff + do_preprocess: bool = field( + default=True, metadata={"help": "Enable data preprocessing and tokenization"} + ) + dataset_dir: Optional[str] = field( + default=None, metadata={"help": "The input training data folder (a dir)."} + ) + + ##### + + def __post_init__(self): + if self.streaming: + require_version( + "datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`" + ) + + if ( + self.dataset_name is None + and self.train_file is None + and self.validation_file is None + and self.dataset_dir is None + # and self.validation_dir is None + ): + raise ValueError( + "Need either a dataset name or a training/validation file/folder." + ) + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + if extension not in ["csv", "json", "txt"]: + raise ValueError( + "`train_file` should be a csv, a json or a txt file." + ) + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + if extension not in ["csv", "json", "txt"]: + raise ValueError( + "`validation_file` should be a csv, a json or a txt file." + ) + + +def main(): + # See all possible arguments in https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser( + (ModelArguments, DataTrainingArguments, TrainingArguments) + ) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if model_args.use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.", + FutureWarning, + ) + if model_args.token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + model_args.token = model_args.use_auth_token + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + if training_args.should_log: + # The default of training_args.log_level is passive, so we set log level at info here to have that default. + transformers.utils.logging.set_verbosity_info() + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " + + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" + ) + # Set the verbosity to info of the Transformers logger (on main process only): + logger.info(f"Training/evaluation parameters {training_args}") + + # Detecting last checkpoint. + last_checkpoint = None + if ( + os.path.isdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif ( + last_checkpoint is not None and training_args.resume_from_checkpoint is None + ): + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub + # + # For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this + # behavior (see below) + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + cache_dir=model_args.cache_dir, + token=model_args.token, + streaming=data_args.streaming, + ) + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + token=model_args.token, + streaming=data_args.streaming, + ) + raw_datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + token=model_args.token, + streaming=data_args.streaming, + ) + elif data_args.dataset_dir is not None: + raw_datasets = datasets.load_from_disk(data_args.dataset_dir) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.validation_file.split(".")[-1] + if extension == "txt": + extension = "text" + raw_datasets = load_dataset( + extension, + data_files=data_files, + cache_dir=model_args.cache_dir, + token=model_args.token, + ) + + # If no validation data is there, validation_split_percentage will be used to divide the dataset. + if "validation" not in raw_datasets.keys() and training_args.do_eval: + raw_datasets = raw_datasets["train"].train_test_split( + train_test_split=data_args.validation_split_percentage + ) + raw_datasets["validation"] = raw_datasets.pop("test") + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets. + + # Load pretrained model and tokenizer + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config_kwargs = { + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "token": model_args.token, + "trust_remote_code": model_args.trust_remote_code, + } + if model_args.config_name: + config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) + elif model_args.model_name_or_path: + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, **config_kwargs + ) + else: + config = CONFIG_MAPPING[model_args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + if model_args.config_overrides is not None: + logger.info(f"Overriding config: {model_args.config_overrides}") + config.update_from_string(model_args.config_overrides) + logger.info(f"New config: {config}") + + tokenizer_kwargs = { + "cache_dir": model_args.cache_dir, + "use_fast": model_args.use_fast_tokenizer, + "revision": model_args.model_revision, + "token": model_args.token, + "trust_remote_code": model_args.trust_remote_code, + } + if model_args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name, **tokenizer_kwargs + ) + elif model_args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, **tokenizer_kwargs + ) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script. " + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + if model_args.model_name_or_path: + model = AutoModelForMaskedLM.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + low_cpu_mem_usage=model_args.low_cpu_mem_usage, + ) + else: + logger.info("Training new model from scratch") + model = AutoModelForMaskedLM.from_config( + config, trust_remote_code=model_args.trust_remote_code + ) + + # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch + # on a small vocab and want a smaller embedding size, remove this test. + embedding_size = model.get_input_embeddings().weight.shape[0] + if len(tokenizer) > embedding_size: + model.resize_token_embeddings(len(tokenizer)) + + if data_args.do_preprocess: + tokenized_datasets = load_and_tokenize_data( + raw_datasets, tokenizer, training_args, data_args + ) + else: + tokenized_datasets = raw_datasets + + if training_args.do_train: + if "train" not in tokenized_datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = tokenized_datasets["train"] + if data_args.max_train_samples is not None: + max_train_samples = min(len(train_dataset), data_args.max_train_samples) + train_dataset = train_dataset.select(range(max_train_samples)) + + if training_args.do_eval: + if "validation" not in tokenized_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = tokenized_datasets["validation"] + if data_args.max_eval_samples is not None: + max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) + eval_dataset = eval_dataset.select(range(max_eval_samples)) + + def preprocess_logits_for_metrics(logits, labels): + if isinstance(logits, tuple): + # Depending on the model and config, logits may contain extra tensors, + # like past_key_values, but logits always come first + logits = logits[0] + return logits.argmax(dim=-1) + + metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir) + + def compute_metrics(eval_preds): + preds, labels = eval_preds + # preds have the same shape as the labels, after the argmax(-1) has been calculated + # by preprocess_logits_for_metrics + labels = labels.reshape(-1) + preds = preds.reshape(-1) + mask = labels != -100 + labels = labels[mask] + preds = preds[mask] + return metric.compute(predictions=preds, references=labels) + + # Data collator + # This one will take care of randomly masking the tokens. + pad_to_multiple_of_8 = ( + data_args.line_by_line + and training_args.fp16 + and not data_args.pad_to_max_length + ) + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm_probability=data_args.mlm_probability, + pad_to_multiple_of=8 if pad_to_multiple_of_8 else None, + ) + + # Initialize our Trainer + + # Need a custom class due to these issues: + # https://github.com/huggingface/transformers/issues/21118 + # https://github.com/huggingface/transformers/issues/24714 + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=( + compute_metrics + if training_args.do_eval and not is_torch_tpu_available() + else None + ), + preprocess_logits_for_metrics=( + preprocess_logits_for_metrics + if training_args.do_eval and not is_torch_tpu_available() + else None + ), + ) + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() # Saves the tokenizer too for easy upload + metrics = train_result.metrics + + if not data_args.streaming: + max_train_samples = ( + data_args.max_train_samples + if data_args.max_train_samples is not None + else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + # Evaluation + if training_args.do_eval: + logger.info("*** Evaluate ***") + + metrics = trainer.evaluate() + logger.info(f"Metrics are {metrics}") + if not data_args.streaming: + max_eval_samples = ( + data_args.max_eval_samples + if data_args.max_eval_samples is not None + else len(eval_dataset) + ) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + logger.info("Calculating perplexity") + try: + perplexity = math.exp(metrics["eval_loss"]) + except OverflowError: + perplexity = float("inf") + logger.info(f"Perplexity: {perplexity}") + metrics["perplexity"] = perplexity + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "fill-mask"} + if data_args.dataset_name is not None: + kwargs["dataset_tags"] = data_args.dataset_name + if data_args.dataset_config_name is not None: + kwargs["dataset_args"] = data_args.dataset_config_name + kwargs["dataset"] = ( + f"{data_args.dataset_name} {data_args.dataset_config_name}" + ) + else: + kwargs["dataset"] = data_args.dataset_name + + if training_args.push_to_hub: + trainer.push_to_hub(**kwargs) + else: + trainer.create_model_card(**kwargs) + + +# def _mp_fn(index): +# # For xla_spawn (TPUs) +# main() + + +def load_and_tokenize_data(raw_datasets, tokenizer, training_args, data_args): + if training_args.do_train: + column_names = list(raw_datasets["train"].features) + else: + column_names = list(raw_datasets["validation"].features) + text_column_name = "text" if "text" in column_names else column_names[0] + + if data_args.max_seq_length is None: + max_seq_length = tokenizer.model_max_length + if max_seq_length > 1024: + logger.warning( + "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value" + " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can" + " override this default with `--block_size xxx`." + ) + max_seq_length = 1024 + else: + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the " + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + if data_args.line_by_line: + # When using line_by_line, we just tokenize each nonempty line. + padding = "max_length" if data_args.pad_to_max_length else False + + def tokenize_function(examples): + # Remove empty lines + examples[text_column_name] = [ + line + for line in examples[text_column_name] + if len(line) > 0 and not line.isspace() + ] + return tokenizer( + examples[text_column_name], + padding=padding, + truncation=True, + max_length=max_seq_length, + # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it + # receives the `special_tokens_mask`. + return_special_tokens_mask=True, + ) + + with training_args.main_process_first(desc="dataset map tokenization"): + if not data_args.streaming: + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on dataset line_by_line", + ) + else: + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + remove_columns=column_names, + ) + else: + # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. + # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more + # efficient when it receives the `special_tokens_mask`. + def tokenize_function(examples): + return tokenizer( + examples[text_column_name], return_special_tokens_mask=True + ) + + with training_args.main_process_first(desc="dataset map tokenization"): + if not data_args.streaming: + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on every text in dataset", + ) + else: + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + remove_columns=column_names, + ) + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of + # max_seq_length. + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = { + k: list(chain(*examples[k])) for k in examples.keys() + } + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, and if the total_length < max_seq_length we exclude this batch and return an empty dict. + # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. + total_length = (total_length // max_seq_length) * max_seq_length + # Split by chunks of max_len. + result = { + k: [ + t[i : i + max_seq_length] + for i in range(0, total_length, max_seq_length) + ] + for k, t in concatenated_examples.items() + } + return result + + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a + # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value + # might be slower to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/process#map + + with training_args.main_process_first(desc="grouping texts together"): + if not data_args.streaming: + tokenized_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + desc=f"Grouping texts in chunks of {max_seq_length}", + ) + else: + tokenized_datasets = tokenized_datasets.map( + group_texts, + batched=True, + ) + return tokenized_datasets + + +if __name__ == "__main__": + main() \ No newline at end of file From 251e015c278683fc62a8e624f59855da9d91d5c7 Mon Sep 17 00:00:00 2001 From: Ankur Srivastava Date: Tue, 16 Jul 2024 17:28:54 -0700 Subject: [PATCH 2/8] Updated with training example Signed-off-by: Ankur Srivastava --- 3.test_cases/23.SMHP-esm2/README.md | 37 +++++++++ 3.test_cases/23.SMHP-esm2/requirements.txt | 13 ++- 3.test_cases/23.SMHP-esm2/submit_train_g5.sh | 84 ++++++++++++++++++++ 3 files changed, 130 insertions(+), 4 deletions(-) create mode 100644 3.test_cases/23.SMHP-esm2/submit_train_g5.sh diff --git a/3.test_cases/23.SMHP-esm2/README.md b/3.test_cases/23.SMHP-esm2/README.md index c8b1a413..fd21b0ca 100644 --- a/3.test_cases/23.SMHP-esm2/README.md +++ b/3.test_cases/23.SMHP-esm2/README.md @@ -30,6 +30,7 @@ You can create conda environment as follows: ```bash conda create --name esm python=3.10 conda activate esm + conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia pip3 install -r requirements.txt ``` @@ -99,7 +100,43 @@ Saving the dataset (1/1 shards): 100%|██████████████ Saving the dataset (1/1 shards): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:00<00:00, 182452.27 examples/s] ``` +## 5. Submit training job +Once data is processed, we are ready to train the ESM2 model. + +``` +sbatch submit_train_g5.sh +``` + +``` +1: [INFO|trainer.py:2128] 2024-07-17 00:18:20,620 >> ***** Running training ***** +1: [INFO|trainer.py:2129] 2024-07-17 00:18:20,620 >> Num examples = 100,000 +1: [INFO|trainer.py:2130] 2024-07-17 00:18:20,620 >> Num Epochs = 1 +1: [INFO|trainer.py:2131] 2024-07-17 00:18:20,620 >> Instantaneous batch size per device = 8 +1: [INFO|trainer.py:2134] 2024-07-17 00:18:20,620 >> Total train batch size (w. parallel, distributed & accumulation) = 1,024 +1: [INFO|trainer.py:2135] 2024-07-17 00:18:20,620 >> Gradient Accumulation steps = 16 +1: [INFO|trainer.py:2136] 2024-07-17 00:18:20,620 >> Total optimization steps = 97 +1: [INFO|trainer.py:2137] 2024-07-17 00:18:20,622 >> Number of trainable parameters = 148,796,794 +0: [INFO|trainer.py:2128] 2024-07-17 00:18:20,685 >> ***** Running training ***** +0: [INFO|trainer.py:2129] 2024-07-17 00:18:20,685 >> Num examples = 100,000 +0: [INFO|trainer.py:2130] 2024-07-17 00:18:20,685 >> Num Epochs = 1 +0: [INFO|trainer.py:2131] 2024-07-17 00:18:20,685 >> Instantaneous batch size per device = 8 +0: [INFO|trainer.py:2134] 2024-07-17 00:18:20,685 >> Total train batch size (w. parallel, distributed & accumulation) = 1,024 +0: [INFO|trainer.py:2135] 2024-07-17 00:18:20,685 >> Gradient Accumulation steps = 16 +0: [INFO|trainer.py:2136] 2024-07-17 00:18:20,685 >> Total optimization steps = 97 +0: [INFO|trainer.py:2137] 2024-07-17 00:18:20,687 >> Number of trainable parameters = 148,796,794 +0: {'loss': 2.9859, 'grad_norm': 0.9704080820083618, 'learning_rate': 4.175257731958763e-05, 'epoch': 0.16} + 19%|█▊ | 18/97 [01:50<08:31, 6.39s/it] +0: {'loss': 2.8209, 'grad_norm': 2.9741921424865723, 'learning_rate': 3.3505154639175256e-05, 'epoch': 0.33} + 36%|███▌ | 35/97 [03:39<06:42, 6.39s/it] +0: {'loss': 2.716, 'grad_norm': 2.2170701026916504, 'learning_rate': 2.5257731958762887e-05, 'epoch': 0.49} + 53%|█████▎ | 50/97 [05:21<05:00, 6.39s/it] +0: {'loss': 2.6697, 'grad_norm': 0.8555800318717957, 'learning_rate': 1.7010309278350517e-05, 'epoch': 0.66} + 68%|██████▊ | 65/97 [06:57<03:24, 6.38s/it] +0: {'loss': 2.6591, 'grad_norm': 0.5596509575843811, 'learning_rate': 8.762886597938144e-06, 'epoch': 0.82} + 82%|████████▏ | 80/97 [08:32< + +``` diff --git a/3.test_cases/23.SMHP-esm2/requirements.txt b/3.test_cases/23.SMHP-esm2/requirements.txt index de61eed2..abd537e6 100644 --- a/3.test_cases/23.SMHP-esm2/requirements.txt +++ b/3.test_cases/23.SMHP-esm2/requirements.txt @@ -1,4 +1,9 @@ -accelerate==0.25.0 -datasets==2.16.1 -pyfastx==2.0.2 -transformers==4.37.2 +accelerate +datasets==2.20.0 +pyfastx +transformers +boto3 +huggingface_hub==0.23.4 +chardet +evaluate +scikit-learn \ No newline at end of file diff --git a/3.test_cases/23.SMHP-esm2/submit_train_g5.sh b/3.test_cases/23.SMHP-esm2/submit_train_g5.sh new file mode 100644 index 00000000..636155d5 --- /dev/null +++ b/3.test_cases/23.SMHP-esm2/submit_train_g5.sh @@ -0,0 +1,84 @@ +#!/bin/bash + +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +#SBATCH --nodes=4 # number of nodes to use +#SBATCH --job-name=FSDP # name of your job +#SBATCH --exclusive # job has exclusive use of the resource, no sharing + +set -ex; + +########################### +###### User Variables ##### +########################### + +GPUS_PER_NODE=4 # 4 for G5.12x, 8 for P4/P5 + +########################### +## Environment Variables ## +########################### + +## Plenty of EFA level variables +## Comment out for non-efa instances (G4d, P3) +## For G5.12x, Comment out RDMA and Fork safe +## For G4dn and other G5, comment out all +## export FI_EFA_USE_DEVICE_RDMA=1 # use for p4d +export FI_EFA_FORK_SAFE=1 +export FI_LOG_LEVEL=1 +export FI_PROVIDER=efa +export NCCL_DEBUG=INFO +## Switching SYNC_MEMOPS to zero can boost throughput with FSDP +## Disables CU_POINTER_ATTRIBUTE_SYNC_MEMOPS +## Reduces memory synchronizations +## https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__UNIFIED.html +## export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 + +########################### +####### Torch Dist ####### +########################### + +declare -a TORCHRUN_ARGS=( + --nproc_per_node=$GPUS_PER_NODE + --nnodes=$SLURM_JOB_NUM_NODES + --rdzv_id=$SLURM_JOB_ID + --rdzv_backend=c10d + --rdzv_endpoint=$(hostname) +) + +source /fsx/ubuntu/miniconda3/bin/activate +conda activate esm2 + +export TRAIN_SCRIPT=/fsx/ubuntu/train.py + +############################ +# Llama 2 Training Params ## +############################ + +declare -a TRAINING_ARGS=( + --config_name "facebook/esm2_t30_150M_UR50D" \ + --dataloader_num_workers 8 \ + --bf16 True \ + --do_eval True \ + --do_preprocess False \ + --do_train True \ + --gradient_accumulation_steps 16 \ + --logging_steps 16 \ + --num_train_epochs 1 \ + --output_dir "/fsx//ubuntu/output" \ + --per_device_train_batch_size 8 \ + --max_train_samples 100000 \ + --tokenizer_name "facebook/esm2_t30_150M_UR50D" \ + --dataset_dir "/fsx/ubuntu/processed/arrow/" \ + --torch_compile False \ + --pad_to_max_length True \ + --max_seq_length 512 +) + +AUTO_RESUME="" +if [ -d "/opt/sagemaker_cluster" ]; then + echo "Detected Hyperpod cluster.. enabling --auto-resume=1" + AUTO_RESUME="--auto-resume=1" +fi + +srun ${AUTO_RESUME} -l torchrun "${TORCHRUN_ARGS[@]}" $TRAIN_SCRIPT "${TRAINING_ARGS[@]}" \ No newline at end of file From 82d2e4e16c92ff5120ce7fa2ede32a06dc07cdc6 Mon Sep 17 00:00:00 2001 From: Ankur Srivastava Date: Wed, 24 Jul 2024 23:27:39 -0700 Subject: [PATCH 3/8] Added ESM2 training on SMHP Signed-off-by: Ankur Srivastava --- .../{download_data.py => 0.download_data.py} | 0 ...uniref_csv.py => 1.tokenize_uniref_csv.py} | 0 .../{submit_train_g5.sh => 2.train_ddp.sh} | 9 +-- 3.test_cases/23.SMHP-esm2/3.train_fsdp.sh | 72 +++++++++++++++++++ 3.test_cases/23.SMHP-esm2/README.md | 26 +++++-- 3.test_cases/23.SMHP-esm2/requirements.txt | 14 ++-- 6 files changed, 104 insertions(+), 17 deletions(-) rename 3.test_cases/23.SMHP-esm2/{download_data.py => 0.download_data.py} (100%) rename 3.test_cases/23.SMHP-esm2/{tokenize_uniref_csv.py => 1.tokenize_uniref_csv.py} (100%) rename 3.test_cases/23.SMHP-esm2/{submit_train_g5.sh => 2.train_ddp.sh} (87%) create mode 100644 3.test_cases/23.SMHP-esm2/3.train_fsdp.sh diff --git a/3.test_cases/23.SMHP-esm2/download_data.py b/3.test_cases/23.SMHP-esm2/0.download_data.py similarity index 100% rename from 3.test_cases/23.SMHP-esm2/download_data.py rename to 3.test_cases/23.SMHP-esm2/0.download_data.py diff --git a/3.test_cases/23.SMHP-esm2/tokenize_uniref_csv.py b/3.test_cases/23.SMHP-esm2/1.tokenize_uniref_csv.py similarity index 100% rename from 3.test_cases/23.SMHP-esm2/tokenize_uniref_csv.py rename to 3.test_cases/23.SMHP-esm2/1.tokenize_uniref_csv.py diff --git a/3.test_cases/23.SMHP-esm2/submit_train_g5.sh b/3.test_cases/23.SMHP-esm2/2.train_ddp.sh similarity index 87% rename from 3.test_cases/23.SMHP-esm2/submit_train_g5.sh rename to 3.test_cases/23.SMHP-esm2/2.train_ddp.sh index 636155d5..96713ceb 100644 --- a/3.test_cases/23.SMHP-esm2/submit_train_g5.sh +++ b/3.test_cases/23.SMHP-esm2/2.train_ddp.sh @@ -3,7 +3,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: MIT-0 -#SBATCH --nodes=4 # number of nodes to use +#SBATCH --nodes=2 # number of nodes to use #SBATCH --job-name=FSDP # name of your job #SBATCH --exclusive # job has exclusive use of the resource, no sharing @@ -23,16 +23,13 @@ GPUS_PER_NODE=4 # 4 for G5.12x, 8 for P4/P5 ## Comment out for non-efa instances (G4d, P3) ## For G5.12x, Comment out RDMA and Fork safe ## For G4dn and other G5, comment out all + ## export FI_EFA_USE_DEVICE_RDMA=1 # use for p4d export FI_EFA_FORK_SAFE=1 export FI_LOG_LEVEL=1 export FI_PROVIDER=efa export NCCL_DEBUG=INFO -## Switching SYNC_MEMOPS to zero can boost throughput with FSDP -## Disables CU_POINTER_ATTRIBUTE_SYNC_MEMOPS -## Reduces memory synchronizations -## https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__UNIFIED.html -## export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 + ########################### ####### Torch Dist ####### diff --git a/3.test_cases/23.SMHP-esm2/3.train_fsdp.sh b/3.test_cases/23.SMHP-esm2/3.train_fsdp.sh new file mode 100644 index 00000000..68347617 --- /dev/null +++ b/3.test_cases/23.SMHP-esm2/3.train_fsdp.sh @@ -0,0 +1,72 @@ +#!/bin/bash + +#SBATCH --job-name=esm2-accelerate +#SBATCH -D . +#SBATCH --output=accelerate-%x.%j.out +#SBATCH --nodes=2 # number of nodes +#SBATCH --ntasks-per-node=1 # number of MP tasks + + +###################### +### Set enviroment ### +###################### +source /fsx/ubuntu/miniconda3/bin/activate +conda activate esm2 + +export GPUS_PER_NODE=4 +###################### + +## Plenty of EFA level variables +## export FI_EFA_USE_DEVICE_RDMA=1 # use for p4d +export FI_EFA_FORK_SAFE=1 +export FI_LOG_LEVEL=1 +export FI_PROVIDER=efa +export NCCL_DEBUG=INFO + +###################### +#### Set network ##### +###################### +head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +###################### + +export LAUNCHER="accelerate launch \ + --num_processes $((SLURM_NNODES * GPUS_PER_NODE)) \ + --num_machines $SLURM_NNODES \ + --rdzv_backend c10d \ + --main_process_ip $head_node_ip \ + --main_process_port 29500 \ + --machine_rank $SLURM_PROCID \ + --use_fsdp \ + --fsdp_sharding_strategy FULL_SHARD \ + --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP \ + --fsdp_transformer_layer_cls_to_wrap EsmLayer + --fsdp_backward_prefetch BACKWARD_PRE \ + --fsdp_cpu_ram_efficient_loading True \ + --fsdp_sync_module_states True \ + --fsdp_use_orig_params True \ + " + +export TRAIN_SCRIPT="/fsx/ubuntu/train.py" +export TRAIN_SCRIPT_ARGS=" \ + --config_name "facebook/esm2_t30_150M_UR50D" \ + --dataloader_num_workers 8 \ + --bf16 True \ + --do_eval True \ + --do_preprocess False \ + --do_train True \ + --gradient_accumulation_steps 16 \ + --logging_steps 16 \ + --num_train_epochs 1 \ + --output_dir "/fsx//ubuntu/output" \ + --per_device_train_batch_size 8 \ + --max_train_samples 100000 \ + --tokenizer_name "facebook/esm2_t30_150M_UR50D" \ + --dataset_dir "/fsx/ubuntu/processed/arrow/" \ + --torch_compile False \ + --pad_to_max_length True \ + --max_seq_length 512 + " + +# This step is necessary because accelerate launch does not handle multiline arguments properly +export CMD="$LAUNCHER $TRAIN_SCRIPT $TRAIN_SCRIPT_ARGS" +srun $CMD \ No newline at end of file diff --git a/3.test_cases/23.SMHP-esm2/README.md b/3.test_cases/23.SMHP-esm2/README.md index fd21b0ca..03e97b27 100644 --- a/3.test_cases/23.SMHP-esm2/README.md +++ b/3.test_cases/23.SMHP-esm2/README.md @@ -39,7 +39,7 @@ You can create conda environment as follows: Next we need to download the Uniref50 training data. You can do so by running: ```bash -python3 download_data.py +python3 0.download_data.py ``` It would download the data and partitions the data in 50 .csv files in `/fsx/ubuntu/csv` folder. The whole process should take less than 30 mins. @@ -68,7 +68,7 @@ Reading FASTA file Next we need to tokenize the dataset. This will split the data in training, test and validation folders, tokenize them and save the arrow files in `processed` folder. ```bash -(esm) (CONTROLLER) ubuntu@ip-10-1-71-160:~$ python3 tokenize_uniref_csv.py +(esm) (CONTROLLER) ubuntu@ip-10-1-71-160:~$ python3 1.tokenize_uniref_csv.py 07/03/2024 23:07:49 - INFO - Parsing arguments 07/03/2024 23:07:49 - INFO - Loading csv files from /fsx/ubuntu/csv Resolving data files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:00<00:00, 356272.93it/s] @@ -102,10 +102,10 @@ Saving the dataset (1/1 shards): 100%|██████████████ ## 5. Submit training job -Once data is processed, we are ready to train the ESM2 model. +Once data is processed, we are ready to train the ESM2 model. To run distributed data parallel (DDP) training, we provide the `train_ddp.sh` script which you can submt as below and training should start: ``` -sbatch submit_train_g5.sh +sbatch train_ddp.sh ``` ``` @@ -138,7 +138,25 @@ sbatch submit_train_g5.sh ``` +### 5.1 Accelerate training with torch.compile +HuggingFace provides an easy to use [Trainer](https://huggingface.co/docs/transformers/en/main_classes/trainer) class that also provides an option to compile the model graph. For more details on torch.compile, follow this [blog](https://pytorch.org/blog/maximizing-training-throughput/) We notice a speedup of 43% when pre-training ESM2 with torch.compile: + +| Model | device_batch_size | num_nodes | torch.compile | Instance | Throughput | +|:------:|:-----------------:|:---------:|:-------------:| :------------: | :------------: | +| ESM2 | 8 | 2 | No | g5.12xlarge | 160 samples/s | +| ESM2 | 8 | 2 | Yes | g5.12xlarge | 229 samples/s | + + +### 5.2 Train larger models with Fully Sharded Data Parallel (FSDP) + +A disadvantage of DDP training is that it requires the entire model to fit in the GPU memory. The ESM2 150M parameter model easily fits in the A10G GPU of g5.12xlarge instances. However, as models get bigger that may not be possible. For such situations, [PyTorch FSDP](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) api is an effective way to shard model parameters which includes optimizer states, gradients and model parameters. However, as model is sharded more, the communication burden between the GPUs also increases. So the best practice is to use FSDP only when DDP is not sufficient. + +We use the [HuggingFace Accelerate](https://github.com/huggingface/accelerate) repo to setup FSDP with ESM2 on a slurm multinode cluster. To this end, we provide the `train_fsd.sh` script that you can submit as below: + +```bash +sbatch train_fsdp.sh +``` diff --git a/3.test_cases/23.SMHP-esm2/requirements.txt b/3.test_cases/23.SMHP-esm2/requirements.txt index abd537e6..d6aefff5 100644 --- a/3.test_cases/23.SMHP-esm2/requirements.txt +++ b/3.test_cases/23.SMHP-esm2/requirements.txt @@ -1,9 +1,9 @@ -accelerate +accelerate==0.32.1 datasets==2.20.0 -pyfastx -transformers -boto3 +pyfastx==2.1.0 +transformers==4.42.4 +boto3==1.34.144 huggingface_hub==0.23.4 -chardet -evaluate -scikit-learn \ No newline at end of file +chardet==5.2.0 +evaluate== 0.4.2 +scikit-learn==1.5.1 \ No newline at end of file From a2d77662e8b0988bbf7b071ea425a2273bc644ee Mon Sep 17 00:00:00 2001 From: Ankur Srivastava Date: Wed, 24 Jul 2024 23:31:20 -0700 Subject: [PATCH 4/8] Added ESM2 training on SMHP Signed-off-by: Ankur Srivastava --- 3.test_cases/23.SMHP-esm2/README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/3.test_cases/23.SMHP-esm2/README.md b/3.test_cases/23.SMHP-esm2/README.md index 03e97b27..e53815ab 100644 --- a/3.test_cases/23.SMHP-esm2/README.md +++ b/3.test_cases/23.SMHP-esm2/README.md @@ -105,7 +105,7 @@ Saving the dataset (1/1 shards): 100%|██████████████ Once data is processed, we are ready to train the ESM2 model. To run distributed data parallel (DDP) training, we provide the `train_ddp.sh` script which you can submt as below and training should start: ``` -sbatch train_ddp.sh +sbatch 2.train_ddp.sh ``` ``` @@ -155,10 +155,13 @@ A disadvantage of DDP training is that it requires the entire model to fit in th We use the [HuggingFace Accelerate](https://github.com/huggingface/accelerate) repo to setup FSDP with ESM2 on a slurm multinode cluster. To this end, we provide the `train_fsd.sh` script that you can submit as below: ```bash -sbatch train_fsdp.sh +sbatch 3.train_fsdp.sh ``` - +| Model | device_batch_size | num_nodes | Strategy | Instance | Throughput | +|:------:|:-----------------:|:---------:|:--------:| :------------: | :------------: | +| ESM2 | 14 | 2 | DDP | g5.12xlarge | 253 samples/s | +| ESM2 | 14 | 2 | FSDP | g5.12xlarge | 162 samples/s | From d87b8624b7a085975fe9708f1fd66b35ed327c3e Mon Sep 17 00:00:00 2001 From: Keita Watanabe Date: Tue, 5 Nov 2024 11:07:17 +0900 Subject: [PATCH 5/8] Update 3.test_cases/23.SMHP-esm2/2.train_ddp.sh --- 3.test_cases/23.SMHP-esm2/2.train_ddp.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3.test_cases/23.SMHP-esm2/2.train_ddp.sh b/3.test_cases/23.SMHP-esm2/2.train_ddp.sh index 96713ceb..ad353391 100644 --- a/3.test_cases/23.SMHP-esm2/2.train_ddp.sh +++ b/3.test_cases/23.SMHP-esm2/2.train_ddp.sh @@ -4,7 +4,7 @@ # SPDX-License-Identifier: MIT-0 #SBATCH --nodes=2 # number of nodes to use -#SBATCH --job-name=FSDP # name of your job +#SBATCH --job-name=DDP # name of your job #SBATCH --exclusive # job has exclusive use of the resource, no sharing set -ex; From f5b054372d7929c692756eb43528dae779c78699 Mon Sep 17 00:00:00 2001 From: Keita Watanabe Date: Tue, 5 Nov 2024 11:07:30 +0900 Subject: [PATCH 6/8] Update 3.test_cases/23.SMHP-esm2/README.md --- 3.test_cases/23.SMHP-esm2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3.test_cases/23.SMHP-esm2/README.md b/3.test_cases/23.SMHP-esm2/README.md index e53815ab..3f9f60ed 100644 --- a/3.test_cases/23.SMHP-esm2/README.md +++ b/3.test_cases/23.SMHP-esm2/README.md @@ -1,4 +1,4 @@ -# How to pretrain ESM2 with SageMaker Hyperpod using Amazon G5 instances +# How to finetune ESM2 with SageMaker Hyperpod using Amazon G5 instances ## What is SageMaker Hyperpod? [Amazon SageMaker Hyperpod](https://aws.amazon.com/sagemaker/hyperpod/) offers advanced training tools to help you accelerate scalable, reliable, and secure generative AI application development. It removes the undifferentiated heavy lifting involved in building and optimizing machine learning (ML) infrastructure for training foundation models (FMs) significantly reducing training time. SageMaker Hyperpod ensure customers can continue FM training uninterrupted by periodically saving checkpoints. When a hardware failure occurs during training, SageMaker Hyperpod automatically detects the failure, repairs, or replaces the faulty instance, and resumes the training from the last saved checkpoint, removing the need for customers to manually manage this process and helping them train for week or months in a distributed setting without disruption. From 30cb87928ac4bc653d534695c3d192dab0808c9a Mon Sep 17 00:00:00 2001 From: Keita Watanabe Date: Tue, 5 Nov 2024 11:07:37 +0900 Subject: [PATCH 7/8] Update 3.test_cases/23.SMHP-esm2/3.train_fsdp.sh --- 3.test_cases/23.SMHP-esm2/3.train_fsdp.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/3.test_cases/23.SMHP-esm2/3.train_fsdp.sh b/3.test_cases/23.SMHP-esm2/3.train_fsdp.sh index 68347617..2572be57 100644 --- a/3.test_cases/23.SMHP-esm2/3.train_fsdp.sh +++ b/3.test_cases/23.SMHP-esm2/3.train_fsdp.sh @@ -1,7 +1,6 @@ #!/bin/bash #SBATCH --job-name=esm2-accelerate -#SBATCH -D . #SBATCH --output=accelerate-%x.%j.out #SBATCH --nodes=2 # number of nodes #SBATCH --ntasks-per-node=1 # number of MP tasks From d63b2c6d4e3efff34738e3cb16ff8186d0718302 Mon Sep 17 00:00:00 2001 From: Keita Watanabe Date: Tue, 5 Nov 2024 11:07:47 +0900 Subject: [PATCH 8/8] Update 3.test_cases/23.SMHP-esm2/3.train_fsdp.sh --- 3.test_cases/23.SMHP-esm2/3.train_fsdp.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3.test_cases/23.SMHP-esm2/3.train_fsdp.sh b/3.test_cases/23.SMHP-esm2/3.train_fsdp.sh index 2572be57..65f731df 100644 --- a/3.test_cases/23.SMHP-esm2/3.train_fsdp.sh +++ b/3.test_cases/23.SMHP-esm2/3.train_fsdp.sh @@ -3,7 +3,7 @@ #SBATCH --job-name=esm2-accelerate #SBATCH --output=accelerate-%x.%j.out #SBATCH --nodes=2 # number of nodes -#SBATCH --ntasks-per-node=1 # number of MP tasks +#SBATCH --exclusive # job has exclusive use of the resource, no sharing ######################