Skip to content

Commit

Permalink
add proper logging
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 20, 2024
1 parent 2feb040 commit c710fe8
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 32 deletions.
16 changes: 15 additions & 1 deletion src/zeroband/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,18 @@
multiple_of=4096,
rope_theta=500000,
),
}
}

def get_model(name_model: str, type_model: str, vocab_size: int) -> Transformer:
"""get the transformer model"""

if type_model == "llama2":
config = llama2_configs[name_model]
elif type_model == "llama3":
config = llama3_configs[name_model]
else:
raise ValueError(f"Model type {type_model} not supported")

config.vocab_size = vocab_size
return Transformer(config)

38 changes: 8 additions & 30 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from contextlib import nullcontext
import logging # Added logging import
from typing import Literal

import torch
Expand All @@ -20,23 +19,11 @@
from zeroband.utils import get_sharding_strategy
from zeroband.utils.monitor import WandbMonitor, DummyMonitor
from zeroband.data import TEST_VOCAB_SIZE, get_dataloader
from zeroband.models.llama import llama2_configs, llama3_configs, Transformer
from zeroband.utils.world_info import WorldInfo
from zeroband.models.llama import get_model
from zeroband.utils.world_info import get_world_info
from zeroband.utils.logging import get_logger


### TODO
# fix logger

world_info = WorldInfo()

if world_info.local_rank == 0:
log_level = os.getenv("ZERO_BAND_LOG_LEVEL", "INFO")
logging.basicConfig(level=getattr(logging, log_level, logging.INFO))
else:
logging.basicConfig(level=logging.CRITICAL) # Disable logging for non-zero ranks

logger = logging.getLogger(__name__)

# Function to initialize the distributed process group
def ddp_setup():
init_process_group()
Expand Down Expand Up @@ -87,19 +74,6 @@ class Config(BaseConfig):



def get_model(name_model: str, type_model: str, tokenizer: AutoTokenizer) -> Transformer:
"""get the transformer model"""

if type_model == "llama2":
config = llama2_configs[name_model]
elif type_model == "llama3":
config = llama3_configs[name_model]
else:
raise ValueError(f"Model type {type_model} not supported")

config.vocab_size = tokenizer.vocab_size if name_model != "debugmodel" else TEST_VOCAB_SIZE
return Transformer(config)

def train(config: Config):
sharding_strategy = get_sharding_strategy(config.train.sharding_strategy)

Expand All @@ -116,7 +90,7 @@ def train(config: Config):
logger.debug("tokenizer loaded")
train_dataloader = get_dataloader(tokenizer.pad_token_id, world_info.world_size, world_info.rank, config.data.seq_length, config.train.micro_bs, config.data.num_workers)

model = get_model(config.name_model, config.type_model, tokenizer=tokenizer)
model = get_model(config.name_model, config.type_model, vocab_size=tokenizer.vocab_size if config.name_model != "debugmodel" else TEST_VOCAB_SIZE)
model = model.to(world_info.local_rank)
logger.debug("model loaded")

Expand Down Expand Up @@ -213,6 +187,10 @@ def train(config: Config):
# However, in development, we want to know that we broke torch compile
torch._dynamo.config.suppress_errors = "ZERO_BAND_DEV" not in os.environ
torch.set_float32_matmul_precision("high")

world_info = get_world_info()
logger = get_logger()

ddp_setup()

config = Config(**parse_argv())
Expand Down
39 changes: 39 additions & 0 deletions src/zeroband/utils/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import logging
import os

from zeroband.utils.world_info import get_world_info

logger = None

class CustomFormatter(logging.Formatter):
def __init__(self, local_rank: int):
super().__init__()
self.local_rank = local_rank

def format(self, record):
log_format = "{asctime} [{levelname}] [Rank {local_rank}] {message}"
formatter = logging.Formatter(log_format, style='{', datefmt="%H:%M:%S")
record.local_rank = self.local_rank # Add this line to set the local rank in the record
return formatter.format(record)

def get_logger():
global logger # Add this line to modify the global logger variable
if logger is not None:
return logger

world_info = get_world_info()
logger = logging.getLogger(__name__)

if world_info.local_rank == 0:
log_level = os.getenv("ZERO_BAND_LOG_LEVEL", "INFO")
logging.basicConfig(level=getattr(logging, log_level, logging.INFO))
else:
logging.basicConfig(level=logging.CRITICAL) # Disable logging for non-zero ranks

handler = logging.StreamHandler()
handler.setFormatter(CustomFormatter(world_info.local_rank))
logger.addHandler(handler)
logger.propagate = False # Prevent the log messages from being propagated to the root logger

return logger

14 changes: 13 additions & 1 deletion src/zeroband/utils/world_info.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

world_info = None

class WorldInfo:
"""This class parse env var about torch world into class variables."""
world_size: int
Expand All @@ -11,4 +13,14 @@ def __init__(self):
self.world_size = int(os.environ["WORLD_SIZE"])
self.rank = int(os.environ["RANK"])
self.local_rank = int(os.environ["LOCAL_RANK"])
self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])

def get_world_info() -> WorldInfo:
"""
Return a WorldInfo singleton.
"""
global world_info
if world_info is None:
world_info = WorldInfo()
return world_info

0 comments on commit c710fe8

Please sign in to comment.