Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 221 additions & 1 deletion keras/src/trainers/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,40 @@ def variables(self):
return vars

def build(self, y_true, y_pred):
# Handle nested structures using tree utilities similar to CompileLoss
if tree.is_nested(y_true) and tree.is_nested(y_pred):
try:
# Align the metrics configuration with the structure of y_pred
if self._has_nested_structure(self._user_metrics) or self._has_nested_structure(self._user_weighted_metrics):
self._flat_metrics = self._build_nested_metrics(
self._user_metrics, y_true, y_pred, "metrics"
)
self._flat_weighted_metrics = self._build_nested_metrics(
self._user_weighted_metrics, y_true, y_pred, "weighted_metrics"
)
else:
self._build_with_flat_structure(y_true, y_pred)
except (ValueError, TypeError):
self._build_with_flat_structure(y_true, y_pred)
else:
self._build_with_flat_structure(y_true, y_pred)
self.built = True

def _has_nested_structure(self, obj):
"""Helper method to check if object has nested dict/list structure."""
if obj is None:
return False
if isinstance(obj, dict):
for value in obj.values():
if isinstance(value, (dict, list)):
return True
elif isinstance(obj, list):
for item in obj:
if isinstance(item, (dict, list)):
return True
return False

def _build_with_flat_structure(self, y_true, y_pred):
num_outputs = 1 # default
# Resolve output names. If y_pred is a dict, prefer its keys.
if isinstance(y_pred, dict):
Expand Down Expand Up @@ -219,7 +253,193 @@ def build(self, y_true, y_pred):
y_pred,
argument_name="weighted_metrics",
)
self.built = True

def _build_nested_metrics(self, metrics_config, y_true, y_pred, argument_name):
"""Build metrics for nested structures following y_pred structure."""
if metrics_config is None:
# If metrics_config is None, create None placeholders for each output
return self._build_flat_placeholders(y_true, y_pred)

if (isinstance(metrics_config, dict) and
isinstance(y_pred, dict) and
set(metrics_config.keys()).issubset(set(y_pred.keys())) and
not any(tree.is_nested(v) for v in y_pred.values())):

return self._build_metrics_set_for_nested(metrics_config, y_true, y_pred, argument_name)

# Handle metrics configuration with tree structure similar to y_pred
def build_recursive_metrics(metrics_cfg, yt, yp, path=(), is_nested_path=False):
"""Recursively build metrics for nested structures."""
if isinstance(metrics_cfg, dict) and isinstance(yp, dict):
# Both metrics and predictions are dicts, process recursively
flat_metrics = []
for key in yp.keys():
current_path = path + (key,)
if key in metrics_cfg:
if isinstance(yp[key], dict) and isinstance(metrics_cfg[key], dict):
flat_metrics.extend(build_recursive_metrics(metrics_cfg[key], yt[key], yp[key], current_path, True))
elif isinstance(yp[key], (list, tuple)) and isinstance(metrics_cfg[key], (list, tuple)):
flat_metrics.extend(build_recursive_metrics(metrics_cfg[key], yt[key], yp[key], current_path, True))
else:
output_name = "_".join(map(str, current_path)) if is_nested_path else None
flat_metrics.append(self._build_single_output_metrics(metrics_cfg[key], yt[key], yp[key], argument_name, output_name=output_name))
else:
flat_metrics.append(None)
return flat_metrics
elif isinstance(metrics_cfg, (list, tuple)) and isinstance(yp, (list, tuple)):

flat_metrics = []
for i, (m_cfg, y_t_elem, y_p_elem) in enumerate(zip(metrics_cfg, yt, yp)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When handling lists/tuples in build_recursive_metrics, using zip can silently truncate sequences if metrics_cfg, yt, and yp have different lengths. This is inconsistent with _build_metrics_set and _build_metrics_set_for_nested, which explicitly check for length mismatches and raise a ValueError. For a consistent and predictable user experience, you should add a length check here to ensure the structures are compatible and raise an informative error if they are not.

current_path = path + (i,)
if isinstance(y_p_elem, (dict, list, tuple)) and isinstance(m_cfg, (dict, list, tuple)):
flat_metrics.extend(build_recursive_metrics(m_cfg, y_t_elem, y_p_elem, current_path, True))
else:
output_name = "_".join(map(str, current_path)) if is_nested_path else None
flat_metrics.append(self._build_single_output_metrics(m_cfg, y_t_elem, y_p_elem, argument_name, output_name=output_name))
return flat_metrics
else:
output_name = "_".join(map(str, path)) if path and is_nested_path else None
return [self._build_single_output_metrics(metrics_cfg, yt, yp, argument_name, output_name=output_name)]

# For truly complex nested structures, use recursive approach
return build_recursive_metrics(metrics_config, y_true, y_pred)

def _build_single_output_metrics(self, metric_config, y_true, y_pred, argument_name, output_name=None):
"""Build metrics for a single output."""
if metric_config is None:
return None
elif not isinstance(metric_config, list):
metric_config = [metric_config]
if not all(is_function_like(m) for m in metric_config):
raise ValueError(
f"All entries in the sublists of the "
f"`{argument_name}` structure should be metric objects. "
f"Found the following with unknown types: {metric_config}"
)
return MetricsList(
[
get_metric(m, y_true, y_pred)
for m in metric_config
if m is not None
],
output_name=output_name
)

def _build_flat_placeholders(self, y_true, y_pred):
"""Create None placeholders for each output when config is None."""
flat_y_pred = tree.flatten(y_pred)
return [None] * len(flat_y_pred)

def _build_metrics_set_for_nested(self, metrics, y_true, y_pred, argument_name):
"""Alternative method to build metrics when we detect nested structures."""
flat_y_pred = tree.flatten(y_pred)
flat_y_true = tree.flatten(y_true)

if isinstance(y_pred, dict):
flat_output_names = tree.flatten(y_pred)
output_names = self._flatten_dict_keys(y_pred)
else:
output_names = [None] * len(flat_y_pred) if self.output_names is None else self.output_names

# If metrics is a flat dict that should map to the outputs
if isinstance(metrics, dict):
flat_metrics = []
if isinstance(y_pred, dict):
# Map metrics dict to y_pred dict keys
for idx, (name, yt, yp) in enumerate(zip(y_pred.keys(), flat_y_true, flat_y_pred)):
if name in metrics:
metric_list = metrics[name]
if not isinstance(metric_list, list):
metric_list = [metric_list]
if not all(is_function_like(e) for e in metric_list):
raise ValueError(
f"All entries in the sublists of the "
f"`{argument_name}` dict should be metric objects. "
f"At key '{name}', found the following with unknown types: {metric_list}"
)
flat_metrics.append(
MetricsList(
[
get_metric(m, yt, yp)
for m in metric_list
if m is not None
],
output_name=name,
)
)
else:
flat_metrics.append(None)
else:
return self._build_metrics_set(metrics, len(flat_y_pred), output_names, flat_y_true, flat_y_pred, argument_name)
elif isinstance(metrics, (list, tuple)):
# Handle list/tuple case for nested outputs
if len(metrics) != len(flat_y_pred):
raise ValueError(
f"For a model with multiple outputs, "
f"when providing the `{argument_name}` argument as a "
f"list, it should have as many entries as the model has "
f"outputs. Received:\n{argument_name}={metrics}\nof "
f"length {len(metrics)} whereas the model has "
f"{len(flat_y_pred)} outputs."
)
flat_metrics = []
for idx, (mls, yt, yp) in enumerate(zip(metrics, flat_y_true, flat_y_pred)):
if not isinstance(mls, list):
mls = [mls]
name = output_names[idx] if output_names and idx < len(output_names) else None
if not all(is_function_like(e) for e in mls):
raise ValueError(
f"All entries in the sublists of the "
f"`{argument_name}` list should be metric objects. "
f"Found the following sublist with unknown types: {mls}"
)
flat_metrics.append(
MetricsList(
[
get_metric(m, yt, yp)
for m in mls
if m is not None
],
output_name=name,
)
)
else:
# Handle single metric applied to all outputs
flat_metrics = []
for idx, (yt, yp) in enumerate(zip(flat_y_true, flat_y_pred)):
name = output_names[idx] if output_names and idx < len(output_names) else None
if metrics is None:
flat_metrics.append(None)
else:
if not is_function_like(metrics):
raise ValueError(
f"Expected all entries in the `{argument_name}` list "
f"to be metric objects. Received instead:\n"
f"{argument_name}={metrics}"
)
flat_metrics.append(
MetricsList(
[get_metric(metrics, yt, yp)],
output_name=name,
)
)

return flat_metrics

def _flatten_dict_keys(self, d):
"""Flatten dict to get key names in order."""
if isinstance(d, dict):
return list(d.keys())
elif isinstance(d, (list, tuple)):
result = []
for item in d:
if isinstance(item, dict):
result.extend(list(item.keys()))
else:
result.append(None)
return result
else:
return [None]

def _build_metrics_set(
self, metrics, num_outputs, output_names, y_true, y_pred, argument_name
Expand Down
95 changes: 95 additions & 0 deletions keras/src/trainers/compile_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,101 @@ def test_struct_loss_namedtuple(self):
value = compile_loss(y_true, y_pred)
self.assertAllClose(value, 1.07666, atol=1e-5)

def test_nested_dict_metrics(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test method test_nested_dict_metrics is defined within the TestCompileLoss class, but it is testing the functionality of CompileMetrics. For better code organization and clarity, this test should be moved to the TestCompileMetrics class.

import numpy as np
from keras.src import Input
from keras.src import Model
from keras.src import layers
from keras.src import metrics as metrics_module

# Create test data matching the nested structure
y_true = {
'a': np.random.rand(10, 10),
'b': {
'c': np.random.rand(10, 10),
'd': np.random.rand(10, 10)
}
}
y_pred = {
'a': np.random.rand(10, 10),
'b': {
'c': np.random.rand(10, 10),
'd': np.random.rand(10, 10)
}
}

# Test compiling with nested dict metrics
compile_metrics = CompileMetrics(
metrics={
'a': [metrics_module.MeanSquaredError()],
'b': {
'c': [metrics_module.MeanSquaredError(), metrics_module.MeanAbsoluteError()],
'd': [metrics_module.MeanSquaredError()]
}
},
weighted_metrics=None,
)

# Build the metrics
compile_metrics.build(y_true, y_pred)

# Update state and get results
compile_metrics.update_state(y_true, y_pred, sample_weight=None)
result = compile_metrics.result()

# Check that expected metrics are present
# The actual names might be different based on how MetricsList handles output names
expected_metric_names = []
for key in result.keys():
if 'mean_squared_error' in key or 'mean_absolute_error' in key:
expected_metric_names.append(key)

# At least some metrics should be computed
self.assertGreater(len(expected_metric_names), 0)
Comment on lines +585 to +591
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertions in this test are too weak. It only checks that len(expected_metric_names) is greater than zero, which doesn't confirm that the correct metrics are created with the correct names for the nested structure. The test should be more specific by asserting the presence of expected metric names (e.g., a_mean_squared_error, b_c_mean_squared_error). This would provide a much stronger validation of the implementation.

Suggested change
expected_metric_names = []
for key in result.keys():
if 'mean_squared_error' in key or 'mean_absolute_error' in key:
expected_metric_names.append(key)
# At least some metrics should be computed
self.assertGreater(len(expected_metric_names), 0)
self.assertIn("a_mean_squared_error", result)
self.assertIn("b_c_mean_squared_error", result)
self.assertIn("b_c_mean_absolute_error", result)
self.assertIn("b_d_mean_squared_error", result)
self.assertEqual(len(result), 4)


# Test with symbolic tensors as well
y_true_symb = tree.map_structure(lambda _: backend.KerasTensor((10, 10)), y_true)
y_pred_symb = tree.map_structure(lambda _: backend.KerasTensor((10, 10)), y_pred)
compile_metrics_symbolic = CompileMetrics(
metrics={
'a': [metrics_module.MeanSquaredError()],
'b': {
'c': [metrics_module.MeanSquaredError(), metrics_module.MeanAbsoluteError()],
'd': [metrics_module.MeanSquaredError()]
}
},
weighted_metrics=None,
)
compile_metrics_symbolic.build(y_true_symb, y_pred_symb)
self.assertTrue(compile_metrics_symbolic.built)


def test_nested_dict_metrics():
import numpy as np
from keras.src import layers
from keras.src import Input
from keras.src import Model

X = np.random.rand(100, 32)
Y1 = np.random.rand(100, 10)
Y2 = np.random.rand(100, 10)
Y3 = np.random.rand(100, 10)

def create_model():
x = Input(shape=(32,))
y1 = layers.Dense(10)(x)
y2 = layers.Dense(10)(x)
y3 = layers.Dense(10)(x)
return Model(inputs=x, outputs={'a': y1, 'b': {'c': y2, 'd': y3}})

model = create_model()
model.compile(
optimizer='adam',
loss={'a': 'mse', 'b': {'c': 'mse', 'd': 'mse'}},
metrics={'a': ['mae'], 'b': {'c': 'mse', 'd': 'mae'}},
)
model.train_on_batch(X, {'a': Y1, 'b': {'c': Y2, 'd': Y3}})
Comment on lines +610 to +634
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This test function test_nested_dict_metrics lacks assertions, which means it only verifies that the code runs without raising an exception. A test should validate the output or behavior. Please add assertions to check the metrics returned by train_on_batch or the model.metrics_names attribute after compilation.

Additionally, having a standalone function with the same name as a class method in the same file is confusing. It would be better to integrate this into a test class and give it a more descriptive name, such as test_nested_metrics_with_model_compile.


def test_struct_loss_invalid_path(self):
y_true = {
"a": {
Expand Down
Loading