Skip to content

Commit a96b663

Browse files
fix: use column names for database mappers (#335)
* fix: use column names for database mappers When writing results where the metrics include 'confusion_matrix', only the first column name is written. In the case of the confusion_matrix it is "true_positive". The desired behaviour is to write all column values. * Apply same practice for CBPE mapper * Add tests dealing with result components --------- Co-authored-by: Niels Nuyttens <niels@nannyml.com>
1 parent bb28916 commit a96b663

File tree

2 files changed

+38
-6
lines changed

2 files changed

+38
-6
lines changed

nannyml/io/db/mappers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,9 @@ def _parse(
234234

235235
res: List[DbMetric] = []
236236

237-
for metric in [metric.column_name for metric in result.metrics]:
237+
column_names = [column_name for metric in result.metrics for column_name in metric.column_names]
238+
239+
for metric in column_names:
238240
res += (
239241
result.filter(partition='analysis')
240242
.to_df()[
@@ -288,7 +290,7 @@ def _parse(
288290

289291
res: List[Metric] = []
290292

291-
for metric in [component[1] for metric in result.metrics for component in metric.components]:
293+
for metric in [column_name for metric in result.metrics for column_name in metric.column_names]:
292294
res += (
293295
result.filter(period='analysis')
294296
.to_df()[

tests/io/test_writers.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def realized_performance_for_binary_classification_result():
112112
y_true='work_home_actual',
113113
problem_type='classification_binary',
114114
timestamp_column_name='timestamp',
115-
metrics=['roc_auc', 'f1'],
115+
metrics=['roc_auc', 'f1', 'confusion_matrix'],
116116
).fit(reference_df)
117117
result = calc.calculate(analysis_df.merge(analysis_targets_df, on='identifier'))
118118
return result
@@ -160,7 +160,7 @@ def cbpe_estimated_performance_for_binary_classification_result():
160160
y_true='work_home_actual',
161161
problem_type='classification_binary',
162162
timestamp_column_name='timestamp',
163-
metrics=['roc_auc', 'f1'],
163+
metrics=['roc_auc', 'f1', 'confusion_matrix'],
164164
).fit(reference_df)
165165
result = calc.estimate(analysis_df.merge(analysis_targets_df, on='identifier'))
166166
return result
@@ -355,14 +355,14 @@ def test_pickle_file_writer_raises_no_exceptions_when_writing(result):
355355
'data_reconstruction_feature_drift_metrics',
356356
10,
357357
),
358-
(lazy_fixture('realized_performance_for_binary_classification_result'), 'realized_performance_metrics', 40),
358+
(lazy_fixture('realized_performance_for_binary_classification_result'), 'realized_performance_metrics', 120),
359359
(
360360
lazy_fixture('realized_performance_for_multiclass_classification_result'),
361361
'realized_performance_metrics',
362362
40,
363363
),
364364
(lazy_fixture('realized_performance_for_regression_result'), 'realized_performance_metrics', 40),
365-
(lazy_fixture('cbpe_estimated_performance_for_binary_classification_result'), 'cbpe_performance_metrics', 20),
365+
(lazy_fixture('cbpe_estimated_performance_for_binary_classification_result'), 'cbpe_performance_metrics', 60),
366366
(
367367
lazy_fixture('cbpe_estimated_performance_for_multiclass_classification_result'),
368368
'cbpe_performance_metrics',
@@ -389,3 +389,33 @@ def test_database_writer_exports_correctly(result, table_name, expected_row_coun
389389

390390
finally:
391391
os.remove('test.db')
392+
393+
394+
@pytest.mark.parametrize(
395+
'result, table_name',
396+
[
397+
(lazy_fixture('realized_performance_for_binary_classification_result'), 'realized_performance_metrics'),
398+
(lazy_fixture('cbpe_estimated_performance_for_binary_classification_result'), 'cbpe_performance_metrics'),
399+
],
400+
)
401+
def test_database_writer_deals_with_metric_components(result, table_name):
402+
try:
403+
writer = DatabaseWriter(connection_string='sqlite:///test.db', model_name='test')
404+
writer.write(result.filter(metrics=['confusion_matrix']))
405+
406+
import sqlite3
407+
408+
with sqlite3.connect("test.db", uri=True) as db:
409+
res = db.cursor().execute(f"SELECT DISTINCT metric_name FROM {table_name}").fetchall()
410+
sut = [row[0] for row in res]
411+
412+
assert 'true_positive' in sut
413+
assert 'false_positive' in sut
414+
assert 'true_negative' in sut
415+
assert 'false_negative' in sut
416+
417+
except Exception as exc:
418+
pytest.fail(f"an unexpected exception occurred: {exc}")
419+
420+
finally:
421+
os.remove('test.db')

0 commit comments

Comments
 (0)