@@ -112,7 +112,7 @@ def realized_performance_for_binary_classification_result():
112
112
y_true = 'work_home_actual' ,
113
113
problem_type = 'classification_binary' ,
114
114
timestamp_column_name = 'timestamp' ,
115
- metrics = ['roc_auc' , 'f1' ],
115
+ metrics = ['roc_auc' , 'f1' , 'confusion_matrix' ],
116
116
).fit (reference_df )
117
117
result = calc .calculate (analysis_df .merge (analysis_targets_df , on = 'identifier' ))
118
118
return result
@@ -160,7 +160,7 @@ def cbpe_estimated_performance_for_binary_classification_result():
160
160
y_true = 'work_home_actual' ,
161
161
problem_type = 'classification_binary' ,
162
162
timestamp_column_name = 'timestamp' ,
163
- metrics = ['roc_auc' , 'f1' ],
163
+ metrics = ['roc_auc' , 'f1' , 'confusion_matrix' ],
164
164
).fit (reference_df )
165
165
result = calc .estimate (analysis_df .merge (analysis_targets_df , on = 'identifier' ))
166
166
return result
@@ -355,14 +355,14 @@ def test_pickle_file_writer_raises_no_exceptions_when_writing(result):
355
355
'data_reconstruction_feature_drift_metrics' ,
356
356
10 ,
357
357
),
358
- (lazy_fixture ('realized_performance_for_binary_classification_result' ), 'realized_performance_metrics' , 40 ),
358
+ (lazy_fixture ('realized_performance_for_binary_classification_result' ), 'realized_performance_metrics' , 120 ),
359
359
(
360
360
lazy_fixture ('realized_performance_for_multiclass_classification_result' ),
361
361
'realized_performance_metrics' ,
362
362
40 ,
363
363
),
364
364
(lazy_fixture ('realized_performance_for_regression_result' ), 'realized_performance_metrics' , 40 ),
365
- (lazy_fixture ('cbpe_estimated_performance_for_binary_classification_result' ), 'cbpe_performance_metrics' , 20 ),
365
+ (lazy_fixture ('cbpe_estimated_performance_for_binary_classification_result' ), 'cbpe_performance_metrics' , 60 ),
366
366
(
367
367
lazy_fixture ('cbpe_estimated_performance_for_multiclass_classification_result' ),
368
368
'cbpe_performance_metrics' ,
@@ -389,3 +389,33 @@ def test_database_writer_exports_correctly(result, table_name, expected_row_coun
389
389
390
390
finally :
391
391
os .remove ('test.db' )
392
+
393
+
394
+ @pytest .mark .parametrize (
395
+ 'result, table_name' ,
396
+ [
397
+ (lazy_fixture ('realized_performance_for_binary_classification_result' ), 'realized_performance_metrics' ),
398
+ (lazy_fixture ('cbpe_estimated_performance_for_binary_classification_result' ), 'cbpe_performance_metrics' ),
399
+ ],
400
+ )
401
+ def test_database_writer_deals_with_metric_components (result , table_name ):
402
+ try :
403
+ writer = DatabaseWriter (connection_string = 'sqlite:///test.db' , model_name = 'test' )
404
+ writer .write (result .filter (metrics = ['confusion_matrix' ]))
405
+
406
+ import sqlite3
407
+
408
+ with sqlite3 .connect ("test.db" , uri = True ) as db :
409
+ res = db .cursor ().execute (f"SELECT DISTINCT metric_name FROM { table_name } " ).fetchall ()
410
+ sut = [row [0 ] for row in res ]
411
+
412
+ assert 'true_positive' in sut
413
+ assert 'false_positive' in sut
414
+ assert 'true_negative' in sut
415
+ assert 'false_negative' in sut
416
+
417
+ except Exception as exc :
418
+ pytest .fail (f"an unexpected exception occurred: { exc } " )
419
+
420
+ finally :
421
+ os .remove ('test.db' )
0 commit comments