Skip to content

Commit

Permalink
Flatten dataclass hyperparameters for logging (#18906)
Browse files Browse the repository at this point in the history
Co-authored-by: jaswon <jason@jwon.xyz>
  • Loading branch information
jaswon and jaswon authored Nov 3, 2023
1 parent ed7cc27 commit 8d68607
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/lightning/fabric/utilities/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from argparse import Namespace
from dataclasses import asdict, is_dataclass
from typing import Any, Dict, Mapping, MutableMapping, Optional, Union

import numpy as np
Expand Down Expand Up @@ -88,8 +89,11 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent
result: Dict[str, Any] = {}
for k, v in params.items():
new_key = parent_key + delimiter + str(k) if parent_key else str(k)
if isinstance(v, Namespace):
if is_dataclass(v):
v = asdict(v)
elif isinstance(v, Namespace):
v = vars(v)

if isinstance(v, MutableMapping):
result = {**result, **_flatten_dict(v, parent_key=new_key, delimiter=delimiter)}
else:
Expand Down
16 changes: 16 additions & 0 deletions tests/tests_fabric/utilities/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from argparse import Namespace
from dataclasses import dataclass

import numpy as np
import torch
Expand Down Expand Up @@ -73,6 +74,21 @@ def test_flatten_dict():
assert "a" not in params
assert "b" not in params

# Test flattening of dataclass objects
@dataclass
class A:
c: int
d: int

@dataclass
class B:
a: A
b: int

params = {"params": B(a=A(c=1, d=2), b=3), "param": 4}
params = _flatten_dict(params)
assert params == {"param": 4, "params/b": 3, "params/a/c": 1, "params/a/d": 2}


def test_sanitize_callable_params():
"""Callback function are not serializiable.
Expand Down

0 comments on commit 8d68607

Please sign in to comment.