From 378408c2905585a62a5bcbeb58f21561969e8a5e Mon Sep 17 00:00:00 2001 From: Siddharth Narayanan Date: Mon, 9 Sep 2024 09:47:02 -0700 Subject: [PATCH] Faster call ID lookups; move progress bar (#16) --- ldp/alg/optimizer/opt.py | 21 +++++++++++++++++---- ldp/alg/runners.py | 10 +++------- ldp/graph/ops.py | 23 ++++++++--------------- 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/ldp/alg/optimizer/opt.py b/ldp/alg/optimizer/opt.py index afdf84d7..bfb569dc 100644 --- a/ldp/alg/optimizer/opt.py +++ b/ldp/alg/optimizer/opt.py @@ -4,6 +4,8 @@ from abc import ABC, abstractmethod from collections.abc import Iterable +from tqdm import tqdm + from ldp.data_structures import Trajectory logger = logging.getLogger(__name__) @@ -21,9 +23,18 @@ def __init_subclass__(cls) -> None: _OPTIMIZER_REGISTRY[cls.__name__] = cls return super().__init_subclass__() - def aggregate(self, trajectories: Iterable[Trajectory]) -> None: + def aggregate( + self, trajectories: Iterable[Trajectory], show_pbar: bool = False + ) -> None: """Aggregate trajectories to construct training samples.""" - for trajectory in trajectories: + trajectories_with_pbar = tqdm( + trajectories, + desc="Aggregating trajectories", + ncols=0, + mininterval=1, + disable=not show_pbar, + ) + for trajectory in trajectories_with_pbar: self.aggregate_trajectory(trajectory) @abstractmethod @@ -41,9 +52,11 @@ class ChainedOptimizer(Optimizer): def __init__(self, *optimizers: Optimizer): self.optimizers = optimizers - def aggregate(self, trajectories: Iterable[Trajectory]) -> None: + def aggregate( + self, trajectories: Iterable[Trajectory], show_pbar: bool = False + ) -> None: for optimizer in self.optimizers: - optimizer.aggregate(trajectories) + optimizer.aggregate(trajectories, show_pbar=show_pbar) async def update(self) -> None: for optimizer in self.optimizers: diff --git a/ldp/alg/runners.py b/ldp/alg/runners.py index 827a4aa6..fd6159e3 100644 --- a/ldp/alg/runners.py +++ b/ldp/alg/runners.py @@ -300,14 +300,10 @@ async def train(self) -> None: i_batch_start : i_batch_start + self.config.batch_size ] - batch_with_pbar = tqdm( - batch, - desc="Aggregating trajectories", - ncols=0, - # Only show the progress bar if we are doing full-batch optimization - disable=len(self.train_trajectories) > self.config.batch_size, + # Only show the progress bar if we are doing full-batch optimization + self.optimizer.aggregate( + batch, show_pbar=len(self.train_trajectories) <= self.config.batch_size ) - self.optimizer.aggregate(batch_with_pbar) if (training_step + 1) % self.config.update_every == 0: await self.optimizer.update() diff --git a/ldp/graph/ops.py b/ldp/graph/ops.py index 968ccf6b..8da89d30 100644 --- a/ldp/graph/ops.py +++ b/ldp/graph/ops.py @@ -338,7 +338,7 @@ def get_or_create(cls, op_name: str) -> OpCtx: def get(self, call_id: CallID, key: str, default: Any = NOT_FOUND) -> Any: """Get an attribute with an optional default, emulating dict.get.""" - value = self.data[call_id.run_id].get((call_id.fwd_id, key), default) + value = self.data.get(call_id.run_id, {}).get((call_id.fwd_id, key), default) if value is NOT_FOUND: raise KeyError(f"call_id={call_id}, key='{key}' not found in context") return value @@ -346,14 +346,6 @@ def get(self, call_id: CallID, key: str, default: Any = NOT_FOUND) -> Any: def update(self, call_id: CallID, key: str, value: Any): self.data[call_id.run_id][(call_id.fwd_id, key)] = value - @property - def call_ids(self) -> set[CallID]: - return { - CallID(run_id, fwd_id) - for run_id, calls in self.data.items() - for fwd_id, _ in calls - } - def get_input_grads(self, call_id: CallID) -> GradInType: # TODO: this function name is confusing. Let's deprecate it. We only use it # in tests as far as I can tell. @@ -463,12 +455,13 @@ def backward( """ def get_call_ids(self, run_ids: Collection[UUID] | None = None) -> set[CallID]: - call_ids = self.ctx.call_ids - return ( - call_ids - if run_ids is None - else {c for c in call_ids if c.run_id in run_ids} - ) + ctx = self.ctx + if run_ids is None: + run_ids = ctx.data.keys() + + # de-duplicate before constructing CallIDs + ids = {(run_id, fwd_id) for run_id in run_ids for fwd_id, _ in ctx.data[run_id]} + return set(itertools.starmap(CallID, ids)) # This compute_graph() decoration will do nothing if we are already inside a compute graph. # We add it here in case we are calling a bare op(), in which case we want a graph