Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(trainer):added the dummy_train_loop function to check if the current configuration can run properly #189

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 69 additions & 11 deletions collie/controller/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import logging
import os
import random
from collections import OrderedDict
from functools import reduce
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
Expand All @@ -26,7 +27,7 @@
from collie.callbacks.callback import Callback
from collie.callbacks.callback_manager import CallbackManager, prepare_callback
from collie.config import CollieConfig
from collie.data import CollieDataLoader
from collie.data import CollieDataLoader, CollieDatasetForTraining
from collie.driver.io import IODriver
from collie.log import logger
from collie.models.base import CollieModelForCausalLM
Expand Down Expand Up @@ -316,6 +317,52 @@ def setup_parallel_model(self):
)
deepspeed.utils.logging.logger.setLevel(deepspeed_logging_level)

def dummy_train_loop(self, max_length:Optional[int] = None):
r"""
对根据用户设置的batchsize和max_length构造两个batch的数据试跑两个epoch测试;
如果dataset中未设置max_length值,则需用户手动设置max_length值来构造测试用例;
"""
if not max_length:
max_length = self.train_dataset.max_length
if max_length > 0:
batch_size = self.config.train_micro_batch_size
dataset = []
for _ in range(2):
batch = []
for _ in range(batch_size):
tokens = torch.randint(500, 10000, (1, max_length))[0]
labels = [-100] * max_length
sample = {
"tokens": tokens,
"labels": labels
}
batch.append(sample)
dataset += batch
dataset = CollieDatasetForTraining(dataset)
dataloader = CollieDataLoader(
dataset,
batch_size,
self.config.gradient_accumulation_steps,
shuffle=True,
collate_fn=self.train_dataset_collate_fn,
drop_last=False,
num_workers=self.config.dataloader_num_workers,
)
for epoch in range(2):
for batch_idx, batch in enumerate(dataloader):
try:
loss = self.train_fn(self, batch, self.global_batch_idx)
except RuntimeError as e:
if 'out of memory' in str(e):
print("OOM error occurred at epoch:", epoch, "batch index:", batch_idx)
print("Please reduce the batch size or max_length size")
torch.cuda.empty_cache()
else:
raise e
logger.info('Success finish dummy_train_loop.')
else:
logger.error('The max_length value is not set in the dataset and needs to be manually passed in for testing.')

def train(self, dataloader: Optional[Iterable] = None):
"""训练循环

Expand Down Expand Up @@ -691,40 +738,51 @@ def save_model(
except Exception as e:
logger.rank_zero_warning("Save config and tokenizer failed")
logger.rank_zero_warning(str(e))

if isinstance(self.engine.module, CollieModelForCausalLM) or isinstance(
self.engine.module, PipelineModel

if isinstance(self.model, PeftModel):
self.save_peft(
path=path,
protocol=protocol,
process_exclusion=process_exclusion,
**kwargs,
)
model_to_save = self.engine.module.get_base_model()
else:
model_to_save = self.engine.module

if isinstance(model_to_save, CollieModelForCausalLM) or isinstance(
model_to_save, PipelineModel
):
if is_zero3_enabled(self.config):
state_dict = {}
self._checkpoint_prologue()
for name, param in self.engine.module.named_parameters():
for name, param in model_to_save.named_parameters():
with deepspeed.zero.GatheredParameters(param):
if env.dp_rank == 0:
state_dict[name] = param.detach().cpu()
self._checkpoint_epilogue()
else:
if env.dp_rank == 0:
state_dict = self.engine.module.state_dict()
state_dict = model_to_save.state_dict()
else:
state_dict = {}
self.engine.module.save_parallel_state_dict(
model_to_save.save_parallel_state_dict(
state_dict=state_dict,
path=path,
config=self.config,
process_exclusion=process_exclusion,
protocol=protocol,
)
elif isinstance(self.engine.module, PreTrainedModel):
elif isinstance(model_to_save, PreTrainedModel):
if is_zero3_enabled(self.config):
self._checkpoint_prologue()
with deepspeed.zero.GatheredParameters(
list(self.engine.module.parameters(recurse=True))
list(model_to_save.parameters(recurse=True))
):
self.engine.module.save_pretrained(save_directory=path, **kwargs)
model_to_save.save_pretrained(save_directory=path, **kwargs)
self._checkpoint_epilogue()
else:
self.engine.module.save_pretrained(save_directory=path, **kwargs)
model_to_save.save_pretrained(save_directory=path, **kwargs)

def load_model(self, path: str, process_exclusion: bool = False, **kwargs):
...
Expand Down