diff --git a/ciclo/callbacks.py b/ciclo/callbacks.py index 30e0355..0e4b477 100644 --- a/ciclo/callbacks.py +++ b/ciclo/callbacks.py @@ -4,12 +4,23 @@ from dataclasses import dataclass, replace from datetime import datetime from enum import Enum, auto -from typing import Any, Callable, Dict, Optional, Tuple, Union, overload +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Mapping, + Optional, + Tuple, + Union, + overload, +) from pkbar import Kbar from tqdm import tqdm -from ciclo.logging import Logs +from ciclo.logging import Collection, Entry, History, Logs from ciclo.loops.loop import ( CallbackOutput, LoopCallbackBase, @@ -22,6 +33,11 @@ from ciclo.types import Batch, S from ciclo.utils import get_batch_size, is_scalar +AggregationFn = Callable[[List[Any]], Any] +InnerLoopAggregation = Union[ + Literal["last", "mean", "sum", "min", "max", "first"], AggregationFn +] + def unavailable_dependency(msg: str) -> Any: class DependencyNotAvailable(LoopCallbackBase[S]): @@ -39,6 +55,22 @@ class OptimizationMode(str, Enum): max = auto() +def _transpose_history( + log_history: History, +) -> Mapping[Collection, Mapping[Entry, List[Any]]]: + """Convert a list of (nested) log dictionaries into a (nested) dictionary of lists.""" + result = {} + for log_dict in log_history: + for collection, entries in log_dict.items(): + if collection not in result: + result[collection] = {} + for entry, value in entries.items(): + if entry not in result[collection]: + result[collection][entry] = [] + result[collection][entry].append(value) + return result + + class inner_loop(LoopCallbackBase[S]): @overload def __init__( @@ -65,6 +97,9 @@ def __init__( maybe_loop_fn: Optional[Callable[[S], LoopOutput[S]]] = None, *, output_state: bool = False, + aggregation: Union[ + InnerLoopAggregation, Mapping[Collection, InnerLoopAggregation] + ] = "last", ): if isinstance(name_or_loop_fn, str): assert maybe_loop_fn is not None @@ -75,17 +110,20 @@ def __init__( self.name = None self.loop_fn = name_or_loop_fn self.output_state = output_state + self.aggregation = aggregation def __call__(self, state: S) -> Tuple[Logs, S]: inner_state, log_history, _ = self.loop_fn(state) - logs = log_history[-1] if len(log_history) > 0 else Logs() + logs = _transpose_history(log_history) logs = Logs( { collection: { - k + f"_{self.name}" if self.name else k: v - for k, v in values.items() + entry + f"_{self.name}" + if self.name + else entry: self.__get_aggregation_fn(collection)(values) + for entry, values in entries.items() } - for collection, values in logs.items() + for collection, entries in logs.items() if collection != "elapsed" } ) @@ -94,6 +132,36 @@ def __call__(self, state: S) -> Tuple[Logs, S]: def __loop_callback__(self, loop_state: LoopState[S]) -> CallbackOutput[S]: return self(loop_state.state) + def __get_aggregation_fn(self, collection: Collection) -> AggregationFn: + if isinstance(self.aggregation, Mapping): + aggregation = self.aggregation.get(collection, "last") + error_message = f"The aggregation ({aggregation}) for collection {collection} must be a str or Callable." + else: + aggregation = self.aggregation + error_message = ( + f"The aggregation ({aggregation}) must be a str or Callable." + ) + + if not (isinstance(aggregation, str) or isinstance(aggregation, Callable)): + raise ValueError(error_message) + + if aggregation == "last": + return lambda x: x[-1] + elif aggregation == "mean": + return lambda x: sum(x) / len(x) + elif aggregation == "sum": + return sum + elif aggregation == "min": + return min + elif aggregation == "max": + return max + elif aggregation == "first": + return lambda x: x[0] + elif isinstance(aggregation, Callable): + return aggregation + else: + raise ValueError(f"Invalid aggregation: {aggregation}") + if importlib.util.find_spec("tensorflow") is not None: from flax.training import checkpoints as flax_checkpoints diff --git a/ciclo/loops/common.py b/ciclo/loops/common.py index a849270..73513bd 100644 --- a/ciclo/loops/common.py +++ b/ciclo/loops/common.py @@ -73,6 +73,7 @@ def train_loop( catch_keyboard_interrupt: bool = True, metadata: Optional[Any] = None, batch_size_fn: Optional[Callable[[List[Tuple[int, ...]]], int]] = None, + inner_loop_kwargs: Optional[Dict[str, Any]] = None, ) -> LoopOutput[S]: if tasks is None: tasks = {} @@ -83,6 +84,9 @@ def train_loop( if isinstance(test_duration, int): test_duration = Period.create(steps=test_duration) + if inner_loop_kwargs is None: + inner_loop_kwargs = {} + additionl_tasks: Dict[ScheduleLike, CallbackOrList] = {} named_tasks: Dict[str, CallbackOrList] = {} for schedule in list(tasks.keys()): @@ -145,6 +149,7 @@ def train_loop( stop=test_duration, batch_size_fn=batch_size_fn, ), + **inner_loop_kwargs, ) ) test_tasks += named_tasks.pop(ON_EPOCH_END, []) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py new file mode 100644 index 0000000..e98bfe9 --- /dev/null +++ b/tests/test_callbacks.py @@ -0,0 +1,135 @@ +import jax.numpy as jnp + +import ciclo + + +def dummy_inner_loop_fn(_): + log_history = [ + { + "stateful_metrics": { + "A": jnp.array(1.0, dtype=jnp.float32), + "B": jnp.array(1.0, dtype=jnp.float32), + }, + "metrics": { + "C": jnp.array(0.0, dtype=jnp.float32), + "D": jnp.array(0.0, dtype=jnp.float32), + }, + "elapsed": { + "steps": 1, + "samples": 1, + }, + }, + { + "stateful_metrics": { + "A": jnp.array(0.0, dtype=jnp.float32), + }, + "metrics": { + "C": jnp.array(1.0, dtype=jnp.float32), + }, + "elapsed": { + "steps": 2, + "samples": 2, + }, + }, + ] + return None, log_history, None + + +class TestCallbacks: + def test_inner_loop_default_aggregation(self): + inner_loop = ciclo.callbacks.inner_loop( + "test", + dummy_inner_loop_fn, + ) + + log_history, _ = inner_loop(None) + + assert log_history == { + "stateful_metrics": { + "A_test": jnp.array(0.0, dtype=jnp.float32), + "B_test": jnp.array(1.0, dtype=jnp.float32), + }, + "metrics": { + "C_test": jnp.array(1.0, dtype=jnp.float32), + "D_test": jnp.array(0.0, dtype=jnp.float32), + }, + } + + def test_inner_loop_callable_aggregation(self): + inner_loop = ciclo.callbacks.inner_loop( + "test", + dummy_inner_loop_fn, + aggregation=sum, + ) + + log_history, _ = inner_loop(None) + + assert log_history == { + "stateful_metrics": { + "A_test": jnp.array(1.0, dtype=jnp.float32), + "B_test": jnp.array(1.0, dtype=jnp.float32), + }, + "metrics": { + "C_test": jnp.array(1.0, dtype=jnp.float32), + "D_test": jnp.array(0.0, dtype=jnp.float32), + }, + } + + def test_inner_loop_mean_aggregation(self): + inner_loop = ciclo.callbacks.inner_loop( + "test", + dummy_inner_loop_fn, + aggregation="mean", + ) + + log_history, _ = inner_loop(None) + + assert log_history == { + "stateful_metrics": { + "A_test": jnp.array(0.5, dtype=jnp.float32), + "B_test": jnp.array(1.0, dtype=jnp.float32), + }, + "metrics": { + "C_test": jnp.array(0.5, dtype=jnp.float32), + "D_test": jnp.array(0.0, dtype=jnp.float32), + }, + } + + def test_inner_loop_aggregation_dict(self): + inner_loop = ciclo.callbacks.inner_loop( + "test", + dummy_inner_loop_fn, + aggregation={"stateful_metrics": "sum", "metrics": "min"}, + ) + + log_history, _ = inner_loop(None) + + assert log_history == { + "stateful_metrics": { + "A_test": jnp.array(1.0, dtype=jnp.float32), + "B_test": jnp.array(1.0, dtype=jnp.float32), + }, + "metrics": { + "C_test": jnp.array(0.0, dtype=jnp.float32), + "D_test": jnp.array(0.0, dtype=jnp.float32), + }, + } + + inner_loop = ciclo.callbacks.inner_loop( + "test", + dummy_inner_loop_fn, + aggregation={"stateful_metrics": "first"}, + ) + + log_history, _ = inner_loop(None) + + assert log_history == { + "stateful_metrics": { + "A_test": jnp.array(1.0, dtype=jnp.float32), + "B_test": jnp.array(1.0, dtype=jnp.float32), + }, + "metrics": { + "C_test": jnp.array(1.0, dtype=jnp.float32), + "D_test": jnp.array(0.0, dtype=jnp.float32), + }, + } diff --git a/tests/test_loops.py b/tests/test_loops.py index bea5e20..5bf2c7d 100644 --- a/tests/test_loops.py +++ b/tests/test_loops.py @@ -72,3 +72,41 @@ def data(): assert a_list == list(range(1, 4)) assert b_list == list(range(-1, -4, -1)) + + def test_inner_loop_kwargs(self): + def increment(state, key): + state[key] += 1 + logs = ciclo.logs() + logs.add_metric(key, state[key]) + return logs, state + + state = {"a": 0} + + state, history, _ = ciclo.train_loop( + state, + ciclo.elapse(range(1)), + { + ciclo.on_test_step: lambda state: increment(state, "a"), + }, + test_dataset=lambda: ciclo.elapse(range(4)), + epoch_duration=1, + stop=1, + ) + + assert history[0]["metrics"]["a_test"] == 4 + + state = {"a": 0} + + state, history, _ = ciclo.train_loop( + state, + ciclo.elapse(range(1)), + { + ciclo.on_test_step: lambda state: increment(state, "a"), + }, + test_dataset=lambda: ciclo.elapse(range(4)), + epoch_duration=1, + stop=1, + inner_loop_kwargs={"aggregation": "sum"}, + ) + + assert history[0]["metrics"]["a_test"] == 10