diff --git a/collie/controller/trainer.py b/collie/controller/trainer.py index 9c082ce..cb9c465 100644 --- a/collie/controller/trainer.py +++ b/collie/controller/trainer.py @@ -7,6 +7,7 @@ import json import logging import os +import random from collections import OrderedDict from functools import reduce from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union @@ -26,7 +27,7 @@ from collie.callbacks.callback import Callback from collie.callbacks.callback_manager import CallbackManager, prepare_callback from collie.config import CollieConfig -from collie.data import CollieDataLoader +from collie.data import CollieDataLoader, CollieDatasetForTraining from collie.driver.io import IODriver from collie.log import logger from collie.models.base import CollieModelForCausalLM @@ -316,6 +317,52 @@ def setup_parallel_model(self): ) deepspeed.utils.logging.logger.setLevel(deepspeed_logging_level) + def dummy_train_loop(self, max_length:Optional[int] = None): + r""" + 对根据用户设置的batchsize和max_length构造两个batch的数据试跑两个epoch测试; + 如果dataset中未设置max_length值,则需用户手动设置max_length值来构造测试用例; + """ + if not max_length: + max_length = self.train_dataset.max_length + if max_length > 0: + batch_size = self.config.train_micro_batch_size + dataset = [] + for _ in range(2): + batch = [] + for _ in range(batch_size): + tokens = torch.randint(500, 10000, (1, max_length))[0] + labels = [-100] * max_length + sample = { + "tokens": tokens, + "labels": labels + } + batch.append(sample) + dataset += batch + dataset = CollieDatasetForTraining(dataset) + dataloader = CollieDataLoader( + dataset, + batch_size, + self.config.gradient_accumulation_steps, + shuffle=True, + collate_fn=self.train_dataset_collate_fn, + drop_last=False, + num_workers=self.config.dataloader_num_workers, + ) + for epoch in range(2): + for batch_idx, batch in enumerate(dataloader): + try: + loss = self.train_fn(self, batch, self.global_batch_idx) + except RuntimeError as e: + if 'out of memory' in str(e): + print("OOM error occurred at epoch:", epoch, "batch index:", batch_idx) + print("Please reduce the batch size or max_length size") + torch.cuda.empty_cache() + else: + raise e + logger.info('Success finish dummy_train_loop.') + else: + logger.error('The max_length value is not set in the dataset and needs to be manually passed in for testing.') + def train(self, dataloader: Optional[Iterable] = None): """训练循环 @@ -691,40 +738,51 @@ def save_model( except Exception as e: logger.rank_zero_warning("Save config and tokenizer failed") logger.rank_zero_warning(str(e)) - - if isinstance(self.engine.module, CollieModelForCausalLM) or isinstance( - self.engine.module, PipelineModel + + if isinstance(self.model, PeftModel): + self.save_peft( + path=path, + protocol=protocol, + process_exclusion=process_exclusion, + **kwargs, + ) + model_to_save = self.engine.module.get_base_model() + else: + model_to_save = self.engine.module + + if isinstance(model_to_save, CollieModelForCausalLM) or isinstance( + model_to_save, PipelineModel ): if is_zero3_enabled(self.config): state_dict = {} self._checkpoint_prologue() - for name, param in self.engine.module.named_parameters(): + for name, param in model_to_save.named_parameters(): with deepspeed.zero.GatheredParameters(param): if env.dp_rank == 0: state_dict[name] = param.detach().cpu() self._checkpoint_epilogue() else: if env.dp_rank == 0: - state_dict = self.engine.module.state_dict() + state_dict = model_to_save.state_dict() else: state_dict = {} - self.engine.module.save_parallel_state_dict( + model_to_save.save_parallel_state_dict( state_dict=state_dict, path=path, config=self.config, process_exclusion=process_exclusion, protocol=protocol, ) - elif isinstance(self.engine.module, PreTrainedModel): + elif isinstance(model_to_save, PreTrainedModel): if is_zero3_enabled(self.config): self._checkpoint_prologue() with deepspeed.zero.GatheredParameters( - list(self.engine.module.parameters(recurse=True)) + list(model_to_save.parameters(recurse=True)) ): - self.engine.module.save_pretrained(save_directory=path, **kwargs) + model_to_save.save_pretrained(save_directory=path, **kwargs) self._checkpoint_epilogue() else: - self.engine.module.save_pretrained(save_directory=path, **kwargs) + model_to_save.save_pretrained(save_directory=path, **kwargs) def load_model(self, path: str, process_exclusion: bool = False, **kwargs): ...