Skip to content

Commit b7e1958

Browse files
authored
Merge branch 'main' into lazy-export
2 parents 000202a + 5d97b70 commit b7e1958

File tree

24 files changed

+310
-60
lines changed

24 files changed

+310
-60
lines changed

.github/workflows/_test_template.yml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,16 @@ jobs:
6060
ARG=("--runtime=nvidia --gpus all")
6161
fi
6262
63-
docker run --rm -d --name nemo_container_${{ github.run_id }} ${ARG[@]} --shm-size=64g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --volume /mnt/datadrive/TestData:/home/TestData nemoci.azurecr.io/nemo_container:${{ github.run_id }} bash -c "sleep $(( ${{ inputs.TIMEOUT }} * 60 + 60 ))"
63+
docker run \
64+
--rm \
65+
-d \
66+
--name nemo_container_${{ github.run_id }} ${ARG[@]} \
67+
--shm-size=64g \
68+
--env TRANSFORMERS_OFFLINE=0 \
69+
--env HYDRA_FULL_ERROR=1 \
70+
--env HF_HOME=/home/TestData/HF_HOME \
71+
--volume /mnt/datadrive/TestData:/home/TestData nemoci.azurecr.io/nemo_container:${{ github.run_id }} \
72+
bash -c "sleep $(( ${{ inputs.TIMEOUT }} * 60 + 60 ))"
6473
6574
- id: main
6675
name: Run main script
@@ -95,4 +104,4 @@ jobs:
95104
if: always()
96105
run: |
97106
docker container stop nemo_container_${{ github.run_id }} || true
98-
docker container rm nemo_container_${{ github.run_id }} || true
107+
docker container rm nemo_container_${{ github.run_id }} || true

examples/llm/peft/hf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@ def formatting_prompts_func(examples):
7676
# See: https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81
7777
grad_clip = None
7878
use_dist_samp = False
79-
tokenizer = llm.HfAutoModelForCausalLM.configure_tokenizer(args.model)
79+
tokenizer = llm.HFAutoModelForCausalLM.configure_tokenizer(args.model)
8080

8181
llm.api.finetune(
82-
model=llm.HfAutoModelForCausalLM(args.model),
83-
data=llm.HfDatasetDataModule(
82+
model=llm.HFAutoModelForCausalLM(args.model),
83+
data=llm.HFDatasetDataModule(
8484
mk_hf_dataset(tokenizer.tokenizer), pad_token_id=tokenizer.tokenizer.eos_token_id
8585
),
8686
trainer=nl.Trainer(

examples/llm/sft/hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def squad(tokenizer) -> pl.LightningDataModule:
8484

8585
from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate
8686

87-
model = llm.HfAutoModelForCausalLM(model_name=args.model, model_accelerator=model_accelerator)
87+
model = llm.HFAutoModelForCausalLM(model_name=args.model, model_accelerator=model_accelerator)
8888
tokenizer = model.tokenizer
8989

9090
llm.api.finetune(

nemo/collections/llm/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
AlpacaDataModule,
2323
DollyDataModule,
2424
FineTuningDataModule,
25-
HfDatasetDataModule,
25+
HFDatasetDataModule,
2626
MockDataModule,
2727
PreTrainingDataModule,
2828
SquadDataModule,
@@ -64,7 +64,7 @@
6464
GPTConfig126M,
6565
GPTConfig175B,
6666
GPTModel,
67-
HfAutoModelForCausalLM,
67+
HFAutoModelForCausalLM,
6868
Llama2Config7B,
6969
Llama2Config13B,
7070
Llama2Config70B,
@@ -218,7 +218,7 @@
218218
"dolly",
219219
"peft",
220220
"hf_dataset",
221-
"HfAutoModelForCausalLM",
221+
"HFAutoModelForCausalLM",
222222
]
223223

224224

nemo/collections/llm/gpt/data/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from nemo.collections.llm.gpt.data.alpaca import AlpacaDataModule
1616
from nemo.collections.llm.gpt.data.dolly import DollyDataModule
1717
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule
18-
from nemo.collections.llm.gpt.data.hf_dataset import HfDatasetDataModule
18+
from nemo.collections.llm.gpt.data.hf_dataset import HFDatasetDataModule
1919
from nemo.collections.llm.gpt.data.mock import MockDataModule
2020
from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule, build_pretraining_datamodule
2121
from nemo.collections.llm.gpt.data.squad import SquadDataModule
@@ -28,5 +28,5 @@
2828
"MockDataModule",
2929
"PreTrainingDataModule",
3030
"build_pretraining_datamodule",
31-
"HfDatasetDataModule",
31+
"HFDatasetDataModule",
3232
]

nemo/collections/llm/gpt/data/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import nemo_run as run
1717

1818
from nemo.collections.llm.gpt.data.dolly import DollyDataModule
19-
from nemo.collections.llm.gpt.data.hf_dataset import HfDatasetDataModule
19+
from nemo.collections.llm.gpt.data.hf_dataset import HFDatasetDataModule
2020
from nemo.collections.llm.gpt.data.mock import MockDataModule
2121
from nemo.collections.llm.gpt.data.squad import SquadDataModule
2222

@@ -42,7 +42,7 @@ def dolly() -> pl.LightningDataModule:
4242
@run.cli.factory
4343
@run.autoconvert
4444
def hf_dataset(dataset: str) -> pl.LightningDataModule:
45-
return HfDatasetDataModule(dataset=dataset, global_batch_size=16, micro_batch_size=2)
45+
return HFDatasetDataModule(dataset=dataset, global_batch_size=16, micro_batch_size=2)
4646

4747

4848
__all__ = ["mock", "squad", "dolly", "hf_dataset"]

nemo/collections/llm/gpt/data/hf_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from nemo.lightning.pytorch.plugins import MegatronDataSampler
1919

2020

21-
class HfDatasetDataModule(pl.LightningDataModule):
21+
class HFDatasetDataModule(pl.LightningDataModule):
2222
def __init__(
2323
self,
2424
dataset,
@@ -88,7 +88,7 @@ def train_dataloader(self, collate_fn=None):
8888
from nemo.lightning.data import add_megatron_sampler
8989

9090
if collate_fn is None:
91-
collate_fn = lambda x: HfDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id)
91+
collate_fn = lambda x: HFDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id)
9292

9393
return DataLoader(
9494
self.dataset,

nemo/collections/llm/gpt/model/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
Gemma2Config27B,
4646
Gemma2Model,
4747
)
48-
from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HfAutoModelForCausalLM
48+
from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HFAutoModelForCausalLM
4949
from nemo.collections.llm.gpt.model.llama import (
5050
CodeLlamaConfig7B,
5151
CodeLlamaConfig13B,
@@ -191,5 +191,5 @@
191191
"transformer_engine_layer_spec",
192192
"transformer_engine_full_layer_spec",
193193
"local_layer_spec",
194-
"HfAutoModelForCausalLM",
194+
"HFAutoModelForCausalLM",
195195
]

nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def masked_cross_entropy(logits, targets, mask=None):
3131
return F.cross_entropy(logits, targets)
3232

3333

34-
class HfAutoModelForCausalLM(pl.LightningModule, io.IOMixin, fn.FNMixin):
34+
class HFAutoModelForCausalLM(pl.LightningModule, io.IOMixin, fn.FNMixin):
3535
def __init__(
3636
self,
3737
model_name='gpt2',
@@ -57,7 +57,7 @@ def __init__(
5757
@property
5858
def tokenizer(self):
5959
if self._tokenizer is None:
60-
self._tokenizer = HfAutoModelForCausalLM.configure_tokenizer(self.model_name, self.trust_remote_code)
60+
self._tokenizer = HFAutoModelForCausalLM.configure_tokenizer(self.model_name, self.trust_remote_code)
6161
return self._tokenizer
6262

6363
@tokenizer.setter

nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from nemo import lightning as nl
2424
from nemo.collections.llm.api import finetune, pretrain
2525
from nemo.collections.llm.gpt.data.mock import MockDataModule
26-
from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HfAutoModelForCausalLM
26+
from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HFAutoModelForCausalLM
2727
from nemo.collections.llm.peft.lora import LoRA
2828
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
2929
from nemo.collections.llm.recipes.optim.adam import pytorch_adam_with_cosine_annealing
@@ -35,23 +35,23 @@
3535
@run.cli.factory(name=NAME)
3636
def model(model_name, load_pretrained_weights) -> run.Config[pl.LightningModule]:
3737
"""
38-
Factory function to create HfAutoModelForCausalLM model configurations.
38+
Factory function to create HFAutoModelForCausalLM model configurations.
3939
4040
Args:
4141
model_name (str): Model id on HF.
4242
4343
Returns:
44-
run.Config[pl.LightningModule]: Configuration for the HfAutoModelForCausalLM.
44+
run.Config[pl.LightningModule]: Configuration for the HFAutoModelForCausalLM.
4545
4646
Examples:
4747
CLI usage:
48-
$ nemo llm pretrain --factory 'HfAutoModelForCausalLM(model_name="mistralai/Mistral-Nemo-Instruct-2407")'
48+
$ nemo llm pretrain --factory 'HFAutoModelForCausalLM(model_name="mistralai/Mistral-Nemo-Instruct-2407")'
4949
5050
Python API usage:
5151
>>> model_config = model(model_name="mistralai/Mistral-Nemo-Instruct-2407")
5252
>>> print(model_config)
5353
"""
54-
return run.Config(HfAutoModelForCausalLM, model_name=model_name, load_pretrained_weights=load_pretrained_weights)
54+
return run.Config(HFAutoModelForCausalLM, model_name=model_name, load_pretrained_weights=load_pretrained_weights)
5555

5656

5757
def trainer(
@@ -69,7 +69,7 @@ def trainer(
6969
gradient_clip_val: float = 1.0,
7070
) -> run.Config[nl.Trainer]:
7171
"""
72-
Configure the NeMo Lightning Trainer for HfAutoModelForCausalLM.
72+
Configure the NeMo Lightning Trainer for HFAutoModelForCausalLM.
7373
7474
This function sets up the distributed training strategy and other training parameters.
7575
@@ -91,7 +91,7 @@ def trainer(
9191
9292
Examples:
9393
CLI usage:
94-
$ nemo llm pretrain trainer=HfAutoModelForCausalLM ...
94+
$ nemo llm pretrain trainer=HFAutoModelForCausalLM ...
9595
9696
Python API usage:
9797
>>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8)
@@ -131,7 +131,7 @@ def pretrain_recipe(
131131
model_name: str = '',
132132
) -> run.Partial:
133133
"""
134-
Create a pre-training recipe for a HfAutoModelForCausalLM model.
134+
Create a pre-training recipe for a HFAutoModelForCausalLM model.
135135
136136
This function sets up a complete configuration for pre-training, including
137137
model, trainer, data, logging, optimization, and resumption settings.
@@ -148,7 +148,7 @@ def pretrain_recipe(
148148
149149
Examples:
150150
CLI usage:
151-
$ nemo llm pretrain --factory 'HfAutoModelForCausalLM(model_name="mistralai/Mistral-Nemo-Instruct-2407")'
151+
$ nemo llm pretrain --factory 'HFAutoModelForCausalLM(model_name="mistralai/Mistral-Nemo-Instruct-2407")'
152152
153153
Python API usage:
154154
>>> recipe = pretrain_recipe(name="auto_pretrain", num_nodes=2, model_name="mistralai/Mistral-Nemo-Instruct-2407")
@@ -179,7 +179,7 @@ def finetune_recipe(
179179
model_name: str = '',
180180
) -> run.Partial:
181181
"""
182-
Create a fine-tuning recipe for a HfAutoModelForCausalLM model.
182+
Create a fine-tuning recipe for a HFAutoModelForCausalLM model.
183183
184184
This function sets up a complete configuration for fine-tuning, including
185185
model, trainer, data, logging, optimization, and resumption settings.

nemo/export/tensorrt_llm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,10 +480,13 @@ def export(
480480

481481
tokenizer_path = os.path.join(nemo_export_dir, "tokenizer.model")
482482
tokenizer_path_nemo2 = os.path.join(nemo_export_dir, "nemo_context")
483+
vocab_path = os.path.join(nemo_export_dir, "vocab.json")
483484
if os.path.exists(tokenizer_path):
484485
shutil.copy(tokenizer_path, self.model_dir)
485486
elif os.path.exists(tokenizer_path_nemo2):
486487
shutil.copytree(tokenizer_path_nemo2, Path(self.model_dir) / "nemo_context")
488+
elif os.path.exists(vocab_path):
489+
shutil.copy(vocab_path, os.path.join(self.model_dir, "vocab.json"))
487490
else:
488491
self.tokenizer.save_pretrained(os.path.join(self.model_dir, 'huggingface_tokenizer'))
489492

nemo/export/tiktoken_tokenizer.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import base64
16+
import json
17+
from pathlib import Path
18+
from typing import Dict, Optional
19+
20+
import numpy as np
21+
import tiktoken
22+
import torch
23+
24+
PATTERN_TIKTOKEN = "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
25+
DEFAULT_TIKTOKEN_MAX_VOCAB = 2**17 # 131072
26+
SPECIAL_TOKENS = ["<unk>", "<s>", "</s>"]
27+
SPECIAL_TOKEN_TEMPLATE = "<SPECIAL_{id}>"
28+
29+
30+
def reload_mergeable_ranks(
31+
path: str,
32+
max_vocab: Optional[int] = None,
33+
) -> Dict[bytes, int]:
34+
"""
35+
Reload the tokenizer JSON file and convert it to Tiktoken format.
36+
"""
37+
assert path.endswith(".json")
38+
39+
# reload vocab
40+
with open(path, "r", encoding='utf-8') as f:
41+
vocab = json.load(f)
42+
assert isinstance(vocab, list)
43+
print(f"Vocab size: {len(vocab)}")
44+
if max_vocab is not None:
45+
vocab = vocab[:max_vocab]
46+
print(f"Cutting vocab to first {len(vocab)} tokens.")
47+
48+
# build ranks
49+
ranks: Dict[bytes, int] = {}
50+
for i, x in enumerate(vocab):
51+
assert x.keys() == {"rank", "token_bytes", "token_str"}
52+
assert x["rank"] == i
53+
merge = base64.b64decode(x["token_bytes"])
54+
assert i >= 256 or merge == bytes([i])
55+
ranks[merge] = x["rank"]
56+
57+
# sanity check
58+
assert len(ranks) == len(vocab)
59+
assert set(ranks.values()) == set(range(len(ranks)))
60+
61+
return ranks
62+
63+
64+
class TiktokenTokenizer:
65+
def __init__(self, vocab_file: str):
66+
67+
self.num_special_tokens = 1000
68+
vocab_size = DEFAULT_TIKTOKEN_MAX_VOCAB
69+
pattern = PATTERN_TIKTOKEN
70+
special_tokens = SPECIAL_TOKENS.copy()
71+
inner_vocab_size = vocab_size - self.num_special_tokens
72+
73+
token2id = reload_mergeable_ranks(vocab_file, max_vocab=inner_vocab_size)
74+
self.tokenizer = tiktoken.Encoding(
75+
name=Path(vocab_file).parent.name,
76+
pat_str=pattern,
77+
mergeable_ranks=token2id,
78+
special_tokens={}, # special tokens are handled manually
79+
)
80+
81+
# BOS / EOS / Pad token IDs
82+
self._bos_id = special_tokens.index("<s>")
83+
self._eos_id = special_tokens.index("</s>")
84+
85+
def encode(self, text):
86+
tokens = self.tokenizer.encode(text)
87+
tokens = [t + self.num_special_tokens for t in tokens]
88+
return tokens
89+
90+
def decode(self, tokens):
91+
# Filter out special tokens and adjust the remaining tokens
92+
adjusted_tokens = [
93+
t - self.num_special_tokens
94+
for t in tokens
95+
if t not in {self._bos_id, self._eos_id} and t >= self.num_special_tokens
96+
]
97+
98+
# Decode only if there are tokens left after filtering
99+
if adjusted_tokens:
100+
return self.tokenizer.decode(adjusted_tokens)
101+
else:
102+
return "" # Return an empty string if all tokens were filtered out
103+
104+
def batch_decode(self, ids):
105+
if isinstance(ids, np.ndarray) or torch.is_tensor(ids):
106+
ids = ids.tolist()
107+
108+
if isinstance(ids[0], list):
109+
ids = ids[0]
110+
111+
return self.decode(ids)
112+
113+
@property
114+
def pad_id(self):
115+
return self._eos_id
116+
117+
@property
118+
def bos_token_id(self):
119+
return self._bos_id
120+
121+
@property
122+
def eos_token_id(self):
123+
return self._eos_id

0 commit comments

Comments
 (0)