- 
                Notifications
    You must be signed in to change notification settings 
- Fork 19.6k
Fix Nested Metrics Handling in CompileMetrics #21761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|  | @@ -538,6 +538,101 @@ def test_struct_loss_namedtuple(self): | |||||||||||||||||||||||||
| value = compile_loss(y_true, y_pred) | ||||||||||||||||||||||||||
| self.assertAllClose(value, 1.07666, atol=1e-5) | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
| def test_nested_dict_metrics(self): | ||||||||||||||||||||||||||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. | ||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||
| from keras.src import Input | ||||||||||||||||||||||||||
| from keras.src import Model | ||||||||||||||||||||||||||
| from keras.src import layers | ||||||||||||||||||||||||||
| from keras.src import metrics as metrics_module | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
| # Create test data matching the nested structure | ||||||||||||||||||||||||||
| y_true = { | ||||||||||||||||||||||||||
| 'a': np.random.rand(10, 10), | ||||||||||||||||||||||||||
| 'b': { | ||||||||||||||||||||||||||
| 'c': np.random.rand(10, 10), | ||||||||||||||||||||||||||
| 'd': np.random.rand(10, 10) | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| y_pred = { | ||||||||||||||||||||||||||
| 'a': np.random.rand(10, 10), | ||||||||||||||||||||||||||
| 'b': { | ||||||||||||||||||||||||||
| 'c': np.random.rand(10, 10), | ||||||||||||||||||||||||||
| 'd': np.random.rand(10, 10) | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
| # Test compiling with nested dict metrics | ||||||||||||||||||||||||||
| compile_metrics = CompileMetrics( | ||||||||||||||||||||||||||
| metrics={ | ||||||||||||||||||||||||||
| 'a': [metrics_module.MeanSquaredError()], | ||||||||||||||||||||||||||
| 'b': { | ||||||||||||||||||||||||||
| 'c': [metrics_module.MeanSquaredError(), metrics_module.MeanAbsoluteError()], | ||||||||||||||||||||||||||
| 'd': [metrics_module.MeanSquaredError()] | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
| weighted_metrics=None, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
| # Build the metrics | ||||||||||||||||||||||||||
| compile_metrics.build(y_true, y_pred) | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
| # Update state and get results | ||||||||||||||||||||||||||
| compile_metrics.update_state(y_true, y_pred, sample_weight=None) | ||||||||||||||||||||||||||
| result = compile_metrics.result() | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
| # Check that expected metrics are present | ||||||||||||||||||||||||||
| # The actual names might be different based on how MetricsList handles output names | ||||||||||||||||||||||||||
| expected_metric_names = [] | ||||||||||||||||||||||||||
| for key in result.keys(): | ||||||||||||||||||||||||||
| if 'mean_squared_error' in key or 'mean_absolute_error' in key: | ||||||||||||||||||||||||||
| expected_metric_names.append(key) | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
| # At least some metrics should be computed | ||||||||||||||||||||||||||
| self.assertGreater(len(expected_metric_names), 0) | ||||||||||||||||||||||||||
| 
      Comment on lines
    
      +585
     to 
      +591
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The assertions in this test are too weak. It only checks that  
        Suggested change
       
 | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
| # Test with symbolic tensors as well | ||||||||||||||||||||||||||
| y_true_symb = tree.map_structure(lambda _: backend.KerasTensor((10, 10)), y_true) | ||||||||||||||||||||||||||
| y_pred_symb = tree.map_structure(lambda _: backend.KerasTensor((10, 10)), y_pred) | ||||||||||||||||||||||||||
| compile_metrics_symbolic = CompileMetrics( | ||||||||||||||||||||||||||
| metrics={ | ||||||||||||||||||||||||||
| 'a': [metrics_module.MeanSquaredError()], | ||||||||||||||||||||||||||
| 'b': { | ||||||||||||||||||||||||||
| 'c': [metrics_module.MeanSquaredError(), metrics_module.MeanAbsoluteError()], | ||||||||||||||||||||||||||
| 'd': [metrics_module.MeanSquaredError()] | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
| weighted_metrics=None, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| compile_metrics_symbolic.build(y_true_symb, y_pred_symb) | ||||||||||||||||||||||||||
| self.assertTrue(compile_metrics_symbolic.built) | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
| def test_nested_dict_metrics(): | ||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||
| from keras.src import layers | ||||||||||||||||||||||||||
| from keras.src import Input | ||||||||||||||||||||||||||
| from keras.src import Model | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
| X = np.random.rand(100, 32) | ||||||||||||||||||||||||||
| Y1 = np.random.rand(100, 10) | ||||||||||||||||||||||||||
| Y2 = np.random.rand(100, 10) | ||||||||||||||||||||||||||
| Y3 = np.random.rand(100, 10) | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
| def create_model(): | ||||||||||||||||||||||||||
| x = Input(shape=(32,)) | ||||||||||||||||||||||||||
| y1 = layers.Dense(10)(x) | ||||||||||||||||||||||||||
| y2 = layers.Dense(10)(x) | ||||||||||||||||||||||||||
| y3 = layers.Dense(10)(x) | ||||||||||||||||||||||||||
| return Model(inputs=x, outputs={'a': y1, 'b': {'c': y2, 'd': y3}}) | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
| model = create_model() | ||||||||||||||||||||||||||
| model.compile( | ||||||||||||||||||||||||||
| optimizer='adam', | ||||||||||||||||||||||||||
| loss={'a': 'mse', 'b': {'c': 'mse', 'd': 'mse'}}, | ||||||||||||||||||||||||||
| metrics={'a': ['mae'], 'b': {'c': 'mse', 'd': 'mae'}}, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| model.train_on_batch(X, {'a': Y1, 'b': {'c': Y2, 'd': Y3}}) | ||||||||||||||||||||||||||
| 
      Comment on lines
    
      +610
     to 
      +634
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test function  Additionally, having a standalone function with the same name as a class method in the same file is confusing. It would be better to integrate this into a test class and give it a more descriptive name, such as  | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
| def test_struct_loss_invalid_path(self): | ||||||||||||||||||||||||||
| y_true = { | ||||||||||||||||||||||||||
| "a": { | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When handling lists/tuples in
build_recursive_metrics, usingzipcan silently truncate sequences ifmetrics_cfg,yt, andyphave different lengths. This is inconsistent with_build_metrics_setand_build_metrics_set_for_nested, which explicitly check for length mismatches and raise aValueError. For a consistent and predictable user experience, you should add a length check here to ensure the structures are compatible and raise an informative error if they are not.