Skip to content

Commit

Permalink
Faster call ID lookups; move progress bar (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
sidnarayanan authored Sep 9, 2024
1 parent 642e128 commit 378408c
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 26 deletions.
21 changes: 17 additions & 4 deletions ldp/alg/optimizer/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable

from tqdm import tqdm

from ldp.data_structures import Trajectory

logger = logging.getLogger(__name__)
Expand All @@ -21,9 +23,18 @@ def __init_subclass__(cls) -> None:
_OPTIMIZER_REGISTRY[cls.__name__] = cls
return super().__init_subclass__()

def aggregate(self, trajectories: Iterable[Trajectory]) -> None:
def aggregate(
self, trajectories: Iterable[Trajectory], show_pbar: bool = False
) -> None:
"""Aggregate trajectories to construct training samples."""
for trajectory in trajectories:
trajectories_with_pbar = tqdm(
trajectories,
desc="Aggregating trajectories",
ncols=0,
mininterval=1,
disable=not show_pbar,
)
for trajectory in trajectories_with_pbar:
self.aggregate_trajectory(trajectory)

@abstractmethod
Expand All @@ -41,9 +52,11 @@ class ChainedOptimizer(Optimizer):
def __init__(self, *optimizers: Optimizer):
self.optimizers = optimizers

def aggregate(self, trajectories: Iterable[Trajectory]) -> None:
def aggregate(
self, trajectories: Iterable[Trajectory], show_pbar: bool = False
) -> None:
for optimizer in self.optimizers:
optimizer.aggregate(trajectories)
optimizer.aggregate(trajectories, show_pbar=show_pbar)

async def update(self) -> None:
for optimizer in self.optimizers:
Expand Down
10 changes: 3 additions & 7 deletions ldp/alg/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,14 +300,10 @@ async def train(self) -> None:
i_batch_start : i_batch_start + self.config.batch_size
]

batch_with_pbar = tqdm(
batch,
desc="Aggregating trajectories",
ncols=0,
# Only show the progress bar if we are doing full-batch optimization
disable=len(self.train_trajectories) > self.config.batch_size,
# Only show the progress bar if we are doing full-batch optimization
self.optimizer.aggregate(
batch, show_pbar=len(self.train_trajectories) <= self.config.batch_size
)
self.optimizer.aggregate(batch_with_pbar)

if (training_step + 1) % self.config.update_every == 0:
await self.optimizer.update()
Expand Down
23 changes: 8 additions & 15 deletions ldp/graph/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,22 +338,14 @@ def get_or_create(cls, op_name: str) -> OpCtx:

def get(self, call_id: CallID, key: str, default: Any = NOT_FOUND) -> Any:
"""Get an attribute with an optional default, emulating dict.get."""
value = self.data[call_id.run_id].get((call_id.fwd_id, key), default)
value = self.data.get(call_id.run_id, {}).get((call_id.fwd_id, key), default)
if value is NOT_FOUND:
raise KeyError(f"call_id={call_id}, key='{key}' not found in context")
return value

def update(self, call_id: CallID, key: str, value: Any):
self.data[call_id.run_id][(call_id.fwd_id, key)] = value

@property
def call_ids(self) -> set[CallID]:
return {
CallID(run_id, fwd_id)
for run_id, calls in self.data.items()
for fwd_id, _ in calls
}

def get_input_grads(self, call_id: CallID) -> GradInType:
# TODO: this function name is confusing. Let's deprecate it. We only use it
# in tests as far as I can tell.
Expand Down Expand Up @@ -463,12 +455,13 @@ def backward(
"""

def get_call_ids(self, run_ids: Collection[UUID] | None = None) -> set[CallID]:
call_ids = self.ctx.call_ids
return (
call_ids
if run_ids is None
else {c for c in call_ids if c.run_id in run_ids}
)
ctx = self.ctx
if run_ids is None:
run_ids = ctx.data.keys()

# de-duplicate before constructing CallIDs
ids = {(run_id, fwd_id) for run_id in run_ids for fwd_id, _ in ctx.data[run_id]}
return set(itertools.starmap(CallID, ids))

# This compute_graph() decoration will do nothing if we are already inside a compute graph.
# We add it here in case we are calling a bare op(), in which case we want a graph
Expand Down

0 comments on commit 378408c

Please sign in to comment.