From c710fe83f3c849422c3fdc04d57a85dd6395fe98 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Fri, 20 Sep 2024 21:40:10 +0000 Subject: [PATCH] add proper logging --- src/zeroband/models/llama/__init__.py | 16 ++++++++++- src/zeroband/train.py | 38 ++++++-------------------- src/zeroband/utils/logging.py | 39 +++++++++++++++++++++++++++ src/zeroband/utils/world_info.py | 14 +++++++++- 4 files changed, 75 insertions(+), 32 deletions(-) create mode 100644 src/zeroband/utils/logging.py diff --git a/src/zeroband/models/llama/__init__.py b/src/zeroband/models/llama/__init__.py index ce3a676f..6bfdae29 100644 --- a/src/zeroband/models/llama/__init__.py +++ b/src/zeroband/models/llama/__init__.py @@ -58,4 +58,18 @@ multiple_of=4096, rope_theta=500000, ), -} \ No newline at end of file +} + +def get_model(name_model: str, type_model: str, vocab_size: int) -> Transformer: + """get the transformer model""" + + if type_model == "llama2": + config = llama2_configs[name_model] + elif type_model == "llama3": + config = llama3_configs[name_model] + else: + raise ValueError(f"Model type {type_model} not supported") + + config.vocab_size = vocab_size + return Transformer(config) + diff --git a/src/zeroband/train.py b/src/zeroband/train.py index ff020ff4..b5159dc2 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -1,6 +1,5 @@ import os from contextlib import nullcontext -import logging # Added logging import from typing import Literal import torch @@ -20,23 +19,11 @@ from zeroband.utils import get_sharding_strategy from zeroband.utils.monitor import WandbMonitor, DummyMonitor from zeroband.data import TEST_VOCAB_SIZE, get_dataloader -from zeroband.models.llama import llama2_configs, llama3_configs, Transformer -from zeroband.utils.world_info import WorldInfo +from zeroband.models.llama import get_model +from zeroband.utils.world_info import get_world_info +from zeroband.utils.logging import get_logger -### TODO -# fix logger - -world_info = WorldInfo() - -if world_info.local_rank == 0: - log_level = os.getenv("ZERO_BAND_LOG_LEVEL", "INFO") - logging.basicConfig(level=getattr(logging, log_level, logging.INFO)) -else: - logging.basicConfig(level=logging.CRITICAL) # Disable logging for non-zero ranks - -logger = logging.getLogger(__name__) - # Function to initialize the distributed process group def ddp_setup(): init_process_group() @@ -87,19 +74,6 @@ class Config(BaseConfig): -def get_model(name_model: str, type_model: str, tokenizer: AutoTokenizer) -> Transformer: - """get the transformer model""" - - if type_model == "llama2": - config = llama2_configs[name_model] - elif type_model == "llama3": - config = llama3_configs[name_model] - else: - raise ValueError(f"Model type {type_model} not supported") - - config.vocab_size = tokenizer.vocab_size if name_model != "debugmodel" else TEST_VOCAB_SIZE - return Transformer(config) - def train(config: Config): sharding_strategy = get_sharding_strategy(config.train.sharding_strategy) @@ -116,7 +90,7 @@ def train(config: Config): logger.debug("tokenizer loaded") train_dataloader = get_dataloader(tokenizer.pad_token_id, world_info.world_size, world_info.rank, config.data.seq_length, config.train.micro_bs, config.data.num_workers) - model = get_model(config.name_model, config.type_model, tokenizer=tokenizer) + model = get_model(config.name_model, config.type_model, vocab_size=tokenizer.vocab_size if config.name_model != "debugmodel" else TEST_VOCAB_SIZE) model = model.to(world_info.local_rank) logger.debug("model loaded") @@ -213,6 +187,10 @@ def train(config: Config): # However, in development, we want to know that we broke torch compile torch._dynamo.config.suppress_errors = "ZERO_BAND_DEV" not in os.environ torch.set_float32_matmul_precision("high") + + world_info = get_world_info() + logger = get_logger() + ddp_setup() config = Config(**parse_argv()) diff --git a/src/zeroband/utils/logging.py b/src/zeroband/utils/logging.py new file mode 100644 index 00000000..de6a4ff8 --- /dev/null +++ b/src/zeroband/utils/logging.py @@ -0,0 +1,39 @@ +import logging +import os + +from zeroband.utils.world_info import get_world_info + +logger = None + +class CustomFormatter(logging.Formatter): + def __init__(self, local_rank: int): + super().__init__() + self.local_rank = local_rank + + def format(self, record): + log_format = "{asctime} [{levelname}] [Rank {local_rank}] {message}" + formatter = logging.Formatter(log_format, style='{', datefmt="%H:%M:%S") + record.local_rank = self.local_rank # Add this line to set the local rank in the record + return formatter.format(record) + +def get_logger(): + global logger # Add this line to modify the global logger variable + if logger is not None: + return logger + + world_info = get_world_info() + logger = logging.getLogger(__name__) + + if world_info.local_rank == 0: + log_level = os.getenv("ZERO_BAND_LOG_LEVEL", "INFO") + logging.basicConfig(level=getattr(logging, log_level, logging.INFO)) + else: + logging.basicConfig(level=logging.CRITICAL) # Disable logging for non-zero ranks + + handler = logging.StreamHandler() + handler.setFormatter(CustomFormatter(world_info.local_rank)) + logger.addHandler(handler) + logger.propagate = False # Prevent the log messages from being propagated to the root logger + + return logger + diff --git a/src/zeroband/utils/world_info.py b/src/zeroband/utils/world_info.py index c3bfe22d..c21c74d9 100644 --- a/src/zeroband/utils/world_info.py +++ b/src/zeroband/utils/world_info.py @@ -1,5 +1,7 @@ import os +world_info = None + class WorldInfo: """This class parse env var about torch world into class variables.""" world_size: int @@ -11,4 +13,14 @@ def __init__(self): self.world_size = int(os.environ["WORLD_SIZE"]) self.rank = int(os.environ["RANK"]) self.local_rank = int(os.environ["LOCAL_RANK"]) - self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) \ No newline at end of file + self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + +def get_world_info() -> WorldInfo: + """ + Return a WorldInfo singleton. + """ + global world_info + if world_info is None: + world_info = WorldInfo() + return world_info +