Skip to content

Commit

Permalink
add & config mypy on common package
Browse files Browse the repository at this point in the history
Signed-off-by: wiseaidev <business@wiseai.dev>
  • Loading branch information
wiseaidev committed Apr 1, 2023
1 parent 78c3235 commit 2cc1abe
Show file tree
Hide file tree
Showing 17 changed files with 218 additions and 84 deletions.
46 changes: 32 additions & 14 deletions common/batch.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,40 @@
"""Extension of torchrec.dataset.utils.Batch to cover any dataset.
"""
# flake8: noqa
from __future__ import annotations
from typing import Dict
from __future__ import (
annotations,
)

import abc
from dataclasses import dataclass
import dataclasses
from collections import (
UserDict,
)
from dataclasses import (
dataclass,
)
from typing import (
Any,
Dict,
List,
TypeVar,
)

import torch
from torchrec.streamable import Pipelineable
from torchrec.streamable import (
Pipelineable,
)

_KT = TypeVar("_KT") # key type
_VT = TypeVar("_VT") # value type


class BatchBase(Pipelineable, abc.ABC):
@abc.abstractmethod
def as_dict(self) -> Dict:
def as_dict(self) -> Dict[str, Any]:
raise NotImplementedError

def to(self, device: torch.device, non_blocking: bool = False):
def to(self, device: torch.device, non_blocking: bool = False) -> BatchBase:
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
Expand All @@ -26,14 +44,14 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
for feature_value in self.as_dict().values():
feature_value.record_stream(stream)

def pin_memory(self):
def pin_memory(self) -> BatchBase:
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.pin_memory()
return self.__class__(**args)

def __repr__(self) -> str:
def obj2str(v):
def obj2str(v: Any) -> str:
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"

return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()])
Expand All @@ -52,18 +70,18 @@ def batch_size(self) -> int:
@dataclass
class DataclassBatch(BatchBase):
@classmethod
def feature_names(cls):
def feature_names(cls) -> List[str]:
return list(cls.__dataclass_fields__.keys())

def as_dict(self):
def as_dict(self) -> Dict[str, Any]:
return {
feature_name: getattr(self, feature_name)
for feature_name in self.feature_names()
if hasattr(self, feature_name)
}

@staticmethod
def from_schema(name: str, schema):
def from_schema(name: str, schema: Any) -> type:
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
return dataclasses.make_dataclass(
cls_name=name,
Expand All @@ -72,14 +90,14 @@ def from_schema(name: str, schema):
)

@staticmethod
def from_fields(name: str, fields: dict):
def from_fields(name: str, fields: Dict[str, Any]) -> type:
return dataclasses.make_dataclass(
cls_name=name,
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
bases=(DataclassBatch,),
)


class DictionaryBatch(BatchBase, dict):
def as_dict(self) -> Dict:
class DictionaryBatch(BatchBase, UserDict[_KT, _VT]):
def as_dict(self) -> Dict[str, Any]:
return self
5 changes: 4 additions & 1 deletion common/checkpointing/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from tml.common.checkpointing.snapshot import get_checkpoint, Snapshot
from tml.common.checkpointing.snapshot import (
Snapshot,
get_checkpoint,
)
42 changes: 28 additions & 14 deletions common/checkpointing/snapshot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
import os
import time
from typing import Any, Dict, List, Optional

from tml.ml_logging.torch_logging import logging
from tml.common.filesystem import infer_fs, is_gcs_fs
from typing import (
Any,
Dict,
Generator,
List,
Optional,
)

import torchsnapshot

from tml.common.filesystem import (
infer_fs,
is_gcs_fs,
)
from tml.ml_logging.torch_logging import (
logging,
)
from torch import (
FloatTensor,
)

DONE_EVAL_SUBDIR = "evaled_by"
GCS_PREFIX = "gs://"
Expand All @@ -25,22 +37,22 @@ def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)

@property
def step(self):
def step(self) -> int:
return self.state["extra_state"]["step"]

@step.setter
def step(self, step: int) -> None:
self.state["extra_state"]["step"] = step

@property
def walltime(self):
def walltime(self) -> float:
return self.state["extra_state"]["walltime"]

@walltime.setter
def walltime(self, walltime: float) -> None:
self.state["extra_state"]["walltime"] = walltime

def save(self, global_step: int) -> "PendingSnapshot":
def save(self, global_step: int) -> "PendingSnapshot": # type: ignore
"""Saves checkpoint with given global_step."""
path = os.path.join(self.save_dir, str(global_step))
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
Expand Down Expand Up @@ -98,7 +110,7 @@ def load_snapshot_to_weight(
cls,
embedding_snapshot: torchsnapshot.Snapshot,
snapshot_emb_name: str,
weight_tensor,
weight_tensor: FloatTensor,
) -> None:
"""Loads pretrained embedding from the snapshot to the model.
Utilise partial lodaing meachanism from torchsnapshot.
Expand Down Expand Up @@ -128,19 +140,21 @@ def _eval_done_path(checkpoint_path: str, eval_partition: str) -> str:
return os.path.join(_eval_subdir(checkpoint_path), f"{eval_partition}_DONE")


def is_done_eval(checkpoint_path: str, eval_partition: str):
return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition))
def is_done_eval(checkpoint_path: str, eval_partition: str) -> bool:
return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition)) # type: ignore[attr-defined]


def mark_done_eval(checkpoint_path: str, eval_partition: str):
def mark_done_eval(checkpoint_path: str, eval_partition: str) -> Any:
infer_fs(checkpoint_path).touch(_eval_done_path(checkpoint_path, eval_partition))


def step_from_checkpoint(checkpoint: str) -> int:
return int(os.path.basename(checkpoint))


def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800):
def checkpoints_iterator(
save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800
) -> Generator[str, None, None]:
"""Simplified equivalent of tf.train.checkpoints_iterator.
Args:
Expand All @@ -149,7 +163,7 @@ def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int
"""

def _poll(last_checkpoint: Optional[str] = None):
def _poll(last_checkpoint: Optional[str] = None) -> Optional[str]:
stop_time = time.time() + timeout
while True:
_checkpoint_path = get_checkpoint(save_dir, missing_ok=True)
Expand Down
2 changes: 1 addition & 1 deletion common/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.distributed as dist


def maybe_setup_tensorflow():
def maybe_setup_tensorflow() -> None:
try:
import tensorflow as tf
except ImportError:
Expand Down
6 changes: 5 additions & 1 deletion common/filesystem/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from tml.common.filesystem.util import infer_fs, is_gcs_fs, is_local_fs
from tml.common.filesystem.util import (
infer_fs,
is_gcs_fs,
is_local_fs,
)
4 changes: 3 additions & 1 deletion common/filesystem/test_infer_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
Mostly a test that it returns an object
"""
from tml.common.filesystem import infer_fs
from tml.common.filesystem import (
infer_fs,
)


def test_infer_fs():
Expand Down
15 changes: 10 additions & 5 deletions common/filesystem/util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
"""Utilities for interacting with the file systems."""
from fsspec.implementations.local import LocalFileSystem
import gcsfs
from typing import (
Union,
)

import gcsfs
from fsspec.implementations.local import (
LocalFileSystem,
)

GCS_FS = gcsfs.GCSFileSystem(cache_timeout=-1)
LOCAL_FS = LocalFileSystem()


def infer_fs(path: str):
def infer_fs(path: str) -> Union[LocalFileSystem, gcsfs.core.GCSFileSystem, NotImplementedError]:
if path.startswith("gs://"):
return GCS_FS
elif path.startswith("hdfs://"):
Expand All @@ -17,9 +22,9 @@ def infer_fs(path: str):
return LOCAL_FS


def is_local_fs(fs):
def is_local_fs(fs: LocalFileSystem) -> bool:
return fs == LOCAL_FS


def is_gcs_fs(fs):
def is_gcs_fs(fs: gcsfs.core.GCSFileSystem) -> bool:
return fs == GCS_FS
29 changes: 20 additions & 9 deletions common/log_weights.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
"""For logging model weights."""
import itertools
from typing import Callable, Dict, List, Optional, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Union,
)

from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
import torch
import torch.distributed as dist
from torchrec.distributed.model_parallel import DistributedModelParallel
from tml.ml_logging.torch_logging import (
logging,
)
from torchrec.distributed.model_parallel import (
DistributedModelParallel,
)


def weights_to_log(
model: torch.nn.Module,
how_to_log: Optional[Union[Callable, Dict[str, Callable]]] = None,
):
how_to_log: Optional[Union[Callable[[Any], Any], Dict[str, Callable[[Any], Any]]]] = None,
) -> Optional[Dict[str, Any]]:
"""Creates dict of reduced weights to log to give sense of training.
Args:
Expand All @@ -21,7 +32,7 @@ def weights_to_log(
"""
if not how_to_log:
return
return None

to_log = dict()
named_parameters = model.named_parameters()
Expand All @@ -38,14 +49,14 @@ def weights_to_log(
how = how_to_log
else:
how = how_to_log.get(param_name) # type: ignore[assignment]
if not how:
continue # type: ignore
if how is None:
continue
to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy()
return to_log


def log_ebc_norms(
model_state_dict,
model_state_dict: Dict[str, Any],
ebc_keys: List[str],
sample_size: int = 4_000_000,
) -> Dict[str, torch.Tensor]:
Expand Down
11 changes: 7 additions & 4 deletions common/modules/embedding/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import List
from enum import Enum

import tml.core.config as base_config
from tml.optimizers.config import OptimizerConfig
from typing import (
List,
)

import pydantic
import tml.core.config as base_config
from tml.optimizers.config import (
OptimizerConfig,
)


class DataType(str, Enum):
Expand Down
28 changes: 20 additions & 8 deletions common/modules/embedding/embedding.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
from tml.common.modules.embedding.config import LargeEmbeddingsConfig, DataType
from tml.ml_logging.torch_logging import logging

import numpy as np
import torch
from torch import nn
import torchrec
from torchrec.modules import embedding_configs
from torchrec import EmbeddingBagConfig, EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
import numpy as np
from tml.common.modules.embedding.config import (
DataType,
LargeEmbeddingsConfig,
)
from tml.ml_logging.torch_logging import (
logging,
)
from torch import nn
from torchrec import (
EmbeddingBagCollection,
EmbeddingBagConfig,
)
from torchrec.modules import (
embedding_configs,
)
from torchrec.sparse.jagged_tensor import (
KeyedJaggedTensor,
KeyedTensor,
)


class LargeEmbeddings(nn.Module):
Expand Down
Loading

0 comments on commit 2cc1abe

Please sign in to comment.