diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 94449ad1..a011f35f 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -304,13 +304,13 @@ def train(config: Config): memory_profiler.step() if config.diloco is not None: - # if config.train.log_model_hash: - # with FSDP.summon_full_params(model): - # logger.debug("Pre diloco model: %s", get_module_signature(model)) + if config.train.log_model_hash: + logger.debug("Pre diloco model: %s", get_module_signature(model)) + diloco.step(model) - # if config.train.log_model_hash: - # with FSDP.summon_full_params(model): - # logger.debug("Post diloco model: %s", get_module_signature(model)) + + if config.train.log_model_hash: + logger.debug("Post diloco model: %s", get_module_signature(model)) training_progress.outer_step += 1 diff --git a/src/zeroband/utils/__init__.py b/src/zeroband/utils/__init__.py index 04d08ad2..32b6279d 100644 --- a/src/zeroband/utils/__init__.py +++ b/src/zeroband/utils/__init__.py @@ -3,6 +3,7 @@ from typing import Any import torch from torch.distributed.fsdp import ShardingStrategy +from torch.distributed._tensor.api import DTensor from zeroband.utils.logging import get_logger @@ -105,6 +106,10 @@ def get_tensor_signature(a: torch.Tensor | torch.nn.Parameter) -> str: """ while isinstance(a, torch.nn.Parameter): a = a.data + + if isinstance(a, DTensor): + a = a.full_tensor() + if a.numel() < TENSOR_SIG_SAMPLE_SIZE: b = a.as_strided(size=(a.numel(),), stride=(1,)) else: