diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 384ef2229..95dd5e8c7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 # Use the ref you want to point at + rev: v5.0.0 # Use the ref you want to point at hooks: - id: trailing-whitespace - id: check-ast @@ -16,7 +16,7 @@ repos: - id: check-toml - repo: https://github.com/astral-sh/ruff-pre-commit - rev: 'v0.6.5' + rev: 'v0.9.2' hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -25,7 +25,7 @@ repos: types_or: [python, jupyter] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.2 + rev: v1.14.1 hooks: - id: mypy entry: python3 -m mypy --config-file pyproject.toml @@ -41,7 +41,7 @@ repos: entry: python3 -m nbstripout - repo: https://github.com/nbQA-dev/nbQA - rev: 1.8.7 + rev: 1.9.1 hooks: - id: nbqa-black - id: nbqa-ruff diff --git a/cyclops/data/df/feature.py b/cyclops/data/df/feature.py index db8c6e5a5..6ac30200e 100644 --- a/cyclops/data/df/feature.py +++ b/cyclops/data/df/feature.py @@ -60,7 +60,7 @@ def __init__(self, **kwargs: Any) -> None: if kwargs[FEATURE_TYPE_ATTR] not in FEATURE_TYPES: raise ValueError( f"""Feature type '{kwargs[FEATURE_TYPE_ATTR]}' - not in {', '.join(FEATURE_TYPES)}.""", + not in {", ".join(FEATURE_TYPES)}.""", ) # Set attributes diff --git a/cyclops/data/features/medical_image.py b/cyclops/data/features/medical_image.py index 7b7d9cae9..3dcfd9583 100644 --- a/cyclops/data/features/medical_image.py +++ b/cyclops/data/features/medical_image.py @@ -209,11 +209,14 @@ def decode_example( use_auth_token = token_per_repo_id.get(repo_id) except ValueError: use_auth_token = None - with xopen( - path, - "rb", - use_auth_token=use_auth_token, - ) as file_obj, BytesIO(file_obj.read()) as buffer: + with ( + xopen( + path, + "rb", + use_auth_token=use_auth_token, + ) as file_obj, + BytesIO(file_obj.read()) as buffer, + ): image, metadata = self._read_file_from_bytes(buffer) metadata["filename_or_obj"] = path diff --git a/cyclops/data/impute.py b/cyclops/data/impute.py index 4595e92be..6ab284faa 100644 --- a/cyclops/data/impute.py +++ b/cyclops/data/impute.py @@ -304,7 +304,7 @@ def _process_imputefunc( if imputefunc not in IMPUTEFUNCS: raise ValueError( f"""Imputefunc string {imputefunc} not supported. - Supporting: {','.join(IMPUTEFUNCS)}""", + Supporting: {",".join(IMPUTEFUNCS)}""", ) func = IMPUTEFUNCS[imputefunc] elif callable(imputefunc): diff --git a/cyclops/data/slicer.py b/cyclops/data/slicer.py index ad06a0728..92e74563a 100644 --- a/cyclops/data/slicer.py +++ b/cyclops/data/slicer.py @@ -825,8 +825,7 @@ def filter_string_contains( example_values = pa.array(examples[column_name]) if not pa.types.is_string(example_values.type): raise ValueError( - "Expected string feature, but got feature of type " - f"{example_values.type}.", + f"Expected string feature, but got feature of type {example_values.type}.", ) # get all the values that contain the given substring diff --git a/cyclops/evaluate/evaluator.py b/cyclops/evaluate/evaluator.py index cf4d01611..f7b455b3b 100644 --- a/cyclops/evaluate/evaluator.py +++ b/cyclops/evaluate/evaluator.py @@ -170,8 +170,7 @@ def _load_data( if split is None: split = choose_split(dataset, **load_dataset_kwargs) LOGGER.warning( - "Got `split=None` but `dataset` is a string. " - "Using `split=%s` instead.", + "Got `split=None` but `dataset` is a string. Using `split=%s` instead.", split, ) diff --git a/cyclops/evaluate/fairness/evaluator.py b/cyclops/evaluate/fairness/evaluator.py index 22b39d175..7f78bbea6 100644 --- a/cyclops/evaluate/fairness/evaluator.py +++ b/cyclops/evaluate/fairness/evaluator.py @@ -150,7 +150,7 @@ def evaluate_fairness( # noqa: PLR0912 # input validation and formatting if not isinstance(dataset, Dataset): raise TypeError( - "Expected `dataset` to be of type `Dataset`, but got " f"{type(dataset)}.", + f"Expected `dataset` to be of type `Dataset`, but got {type(dataset)}.", ) if array_lib not in _SUPPORTED_ARRAY_LIBS: raise NotImplementedError(f"The array library `{array_lib}` is not supported.") @@ -520,8 +520,7 @@ def _validate_group_bins( for group, bins in group_bins.items(): if not isinstance(bins, (list, int)): raise TypeError( - f"The bins for {group} must be a list or an integer. " - f"Got {type(bins)}.", + f"The bins for {group} must be a list or an integer. Got {type(bins)}.", ) if isinstance(bins, int) and not 2 <= bins < len(unique_values[group]): diff --git a/cyclops/evaluate/metrics/accuracy.py b/cyclops/evaluate/metrics/accuracy.py index 07fdc0b43..228b0ffe3 100644 --- a/cyclops/evaluate/metrics/accuracy.py +++ b/cyclops/evaluate/metrics/accuracy.py @@ -339,9 +339,9 @@ def __new__( # type: ignore # mypy expects a subclass of Accuracy zero_division=zero_division, ) if task == "multiclass": - assert ( - isinstance(num_classes, int) and num_classes > 0 - ), "Number of classes must be specified for multiclass classification." + assert isinstance(num_classes, int) and num_classes > 0, ( + "Number of classes must be specified for multiclass classification." + ) return MulticlassAccuracy( num_classes=num_classes, top_k=top_k, @@ -349,9 +349,9 @@ def __new__( # type: ignore # mypy expects a subclass of Accuracy zero_division=zero_division, ) if task == "multilabel": - assert ( - isinstance(num_labels, int) and num_labels > 0 - ), "Number of labels must be specified for multilabel classification." + assert isinstance(num_labels, int) and num_labels > 0, ( + "Number of labels must be specified for multilabel classification." + ) return MultilabelAccuracy( num_labels=num_labels, threshold=threshold, diff --git a/cyclops/evaluate/metrics/auroc.py b/cyclops/evaluate/metrics/auroc.py index 903886cee..950d4f421 100644 --- a/cyclops/evaluate/metrics/auroc.py +++ b/cyclops/evaluate/metrics/auroc.py @@ -336,18 +336,18 @@ def __new__( # type: ignore # mypy expects a subclass of AUROC if task == "binary": return BinaryAUROC(max_fpr=max_fpr, thresholds=thresholds) if task == "multiclass": - assert ( - isinstance(num_classes, int) and num_classes > 0 - ), "Number of classes must be a positive integer." + assert isinstance(num_classes, int) and num_classes > 0, ( + "Number of classes must be a positive integer." + ) return MulticlassAUROC( num_classes=num_classes, thresholds=thresholds, average=average, # type: ignore ) if task == "multilabel": - assert ( - isinstance(num_labels, int) and num_labels > 0 - ), "Number of labels must be a positive integer." + assert isinstance(num_labels, int) and num_labels > 0, ( + "Number of labels must be a positive integer." + ) return MultilabelAUROC( num_labels=num_labels, thresholds=thresholds, diff --git a/cyclops/evaluate/metrics/experimental/utils/types.py b/cyclops/evaluate/metrics/experimental/utils/types.py index e3543ddfc..96184c661 100644 --- a/cyclops/evaluate/metrics/experimental/utils/types.py +++ b/cyclops/evaluate/metrics/experimental/utils/types.py @@ -1,3 +1,5 @@ +# noqa: A005 + """Utilities for array-API compatibility.""" import builtins diff --git a/cyclops/evaluate/metrics/f_beta.py b/cyclops/evaluate/metrics/f_beta.py index 9c0ffda5d..053dd56e8 100644 --- a/cyclops/evaluate/metrics/f_beta.py +++ b/cyclops/evaluate/metrics/f_beta.py @@ -363,9 +363,9 @@ def __new__( # type: ignore # mypy expects a subclass of FbetaScore zero_division=zero_division, ) if task == "multiclass": - assert ( - isinstance(num_classes, int) and num_classes > 0 - ), "Number of classes must be specified for multiclass classification." + assert isinstance(num_classes, int) and num_classes > 0, ( + "Number of classes must be specified for multiclass classification." + ) return MulticlassFbetaScore( beta=beta, num_classes=num_classes, @@ -374,9 +374,9 @@ def __new__( # type: ignore # mypy expects a subclass of FbetaScore zero_division=zero_division, ) if task == "multilabel": - assert ( - isinstance(num_labels, int) and num_labels > 0 - ), "Number of labels must be specified for multilabel classification." + assert isinstance(num_labels, int) and num_labels > 0, ( + "Number of labels must be specified for multilabel classification." + ) return MultilabelFbetaScore( beta=beta, num_labels=num_labels, @@ -682,9 +682,9 @@ def __new__( # type: ignore # mypy expects a subclass of F1Score zero_division=zero_division, ) if task == "multiclass": - assert ( - isinstance(num_classes, int) and num_classes > 0 - ), "Number of classes must be specified for multiclass classification." + assert isinstance(num_classes, int) and num_classes > 0, ( + "Number of classes must be specified for multiclass classification." + ) return MulticlassF1Score( num_classes=num_classes, top_k=top_k, @@ -692,9 +692,9 @@ def __new__( # type: ignore # mypy expects a subclass of F1Score zero_division=zero_division, ) if task == "multilabel": - assert ( - isinstance(num_labels, int) and num_labels > 0 - ), "Number of labels must be specified for multilabel classification." + assert isinstance(num_labels, int) and num_labels > 0, ( + "Number of labels must be specified for multilabel classification." + ) return MultilabelF1Score( num_labels=num_labels, threshold=threshold, diff --git a/cyclops/evaluate/metrics/functional/accuracy.py b/cyclops/evaluate/metrics/functional/accuracy.py index 3fbaaa9cf..66eeb87c9 100644 --- a/cyclops/evaluate/metrics/functional/accuracy.py +++ b/cyclops/evaluate/metrics/functional/accuracy.py @@ -446,9 +446,9 @@ def accuracy( zero_division=zero_division, ) elif task == "multiclass": - assert ( - isinstance(num_classes, int) and num_classes > 0 - ), "Number of classes must be specified for multiclass classification." + assert isinstance(num_classes, int) and num_classes > 0, ( + "Number of classes must be specified for multiclass classification." + ) accuracy_score = multiclass_accuracy( target, preds, @@ -458,9 +458,9 @@ def accuracy( zero_division=zero_division, ) elif task == "multilabel": - assert ( - isinstance(num_labels, int) and num_labels > 0 - ), "Number of labels must be specified for multilabel classification." + assert isinstance(num_labels, int) and num_labels > 0, ( + "Number of labels must be specified for multilabel classification." + ) accuracy_score = multilabel_accuracy( target, preds, diff --git a/cyclops/evaluate/metrics/functional/auroc.py b/cyclops/evaluate/metrics/functional/auroc.py index 316d8d638..9f194de3d 100644 --- a/cyclops/evaluate/metrics/functional/auroc.py +++ b/cyclops/evaluate/metrics/functional/auroc.py @@ -574,9 +574,9 @@ def auroc( if task == "binary": return binary_auroc(target, preds, max_fpr=max_fpr, thresholds=thresholds) if task == "multiclass": - assert ( - isinstance(num_classes, int) and num_classes > 0 - ), "Number of classes must be a positive integer." + assert isinstance(num_classes, int) and num_classes > 0, ( + "Number of classes must be a positive integer." + ) return multiclass_auroc( target, preds, @@ -585,9 +585,9 @@ def auroc( average=average, # type: ignore[arg-type] ) if task == "multilabel": - assert ( - isinstance(num_labels, int) and num_labels > 0 - ), "Number of labels must be a positive integer." + assert isinstance(num_labels, int) and num_labels > 0, ( + "Number of labels must be a positive integer." + ) return multilabel_auroc( target, preds, diff --git a/cyclops/evaluate/metrics/functional/f_beta.py b/cyclops/evaluate/metrics/functional/f_beta.py index d3ce7def0..c531eaefe 100644 --- a/cyclops/evaluate/metrics/functional/f_beta.py +++ b/cyclops/evaluate/metrics/functional/f_beta.py @@ -468,9 +468,9 @@ def fbeta_score( zero_division=zero_division, ) if task == "multiclass": - assert ( - isinstance(num_classes, int) and num_classes > 0 - ), "Number of classes must be specified for multiclass classification." + assert isinstance(num_classes, int) and num_classes > 0, ( + "Number of classes must be specified for multiclass classification." + ) return multiclass_fbeta_score( target, preds, @@ -481,9 +481,9 @@ def fbeta_score( zero_division=zero_division, ) if task == "multilabel": - assert ( - isinstance(num_labels, int) and num_labels > 0 - ), "Number of labels must be specified for multilabel classification." + assert isinstance(num_labels, int) and num_labels > 0, ( + "Number of labels must be specified for multilabel classification." + ) return multilabel_fbeta_score( target, preds, diff --git a/cyclops/evaluate/metrics/functional/precision_recall.py b/cyclops/evaluate/metrics/functional/precision_recall.py index 318e416c8..a355a12b9 100644 --- a/cyclops/evaluate/metrics/functional/precision_recall.py +++ b/cyclops/evaluate/metrics/functional/precision_recall.py @@ -427,9 +427,9 @@ def precision( zero_division=zero_division, ) if task == "multiclass": - assert ( - isinstance(num_classes, int) and num_classes > 0 - ), "Number of classes must be specified for multiclass classification." + assert isinstance(num_classes, int) and num_classes > 0, ( + "Number of classes must be specified for multiclass classification." + ) return multiclass_precision( target, preds, @@ -439,9 +439,9 @@ def precision( zero_division=zero_division, ) if task == "multilabel": - assert ( - isinstance(num_labels, int) and num_labels > 0 - ), "Number of labels must be specified for multilabel classification." + assert isinstance(num_labels, int) and num_labels > 0, ( + "Number of labels must be specified for multilabel classification." + ) return multilabel_precision( target, preds, @@ -786,9 +786,9 @@ def recall( zero_division=zero_division, ) if task == "multiclass": - assert ( - isinstance(num_classes, int) and num_classes > 0 - ), "Number of classes must be specified for multiclass classification." + assert isinstance(num_classes, int) and num_classes > 0, ( + "Number of classes must be specified for multiclass classification." + ) return multiclass_recall( target, preds, @@ -798,9 +798,9 @@ def recall( zero_division=zero_division, ) if task == "multilabel": - assert ( - isinstance(num_labels, int) and num_labels > 0 - ), "Number of labels must be specified for multilabel classification." + assert isinstance(num_labels, int) and num_labels > 0, ( + "Number of labels must be specified for multilabel classification." + ) return multilabel_recall( target, preds, diff --git a/cyclops/evaluate/metrics/functional/precision_recall_curve.py b/cyclops/evaluate/metrics/functional/precision_recall_curve.py index a0f9b69e3..c982a5423 100644 --- a/cyclops/evaluate/metrics/functional/precision_recall_curve.py +++ b/cyclops/evaluate/metrics/functional/precision_recall_curve.py @@ -1056,9 +1056,9 @@ def precision_recall_curve( pos_label=pos_label, ) if task == "multiclass": - assert ( - isinstance(num_classes, int) and num_classes > 0 - ), "Number of classes must be a positive integer." + assert isinstance(num_classes, int) and num_classes > 0, ( + "Number of classes must be a positive integer." + ) return multiclass_precision_recall_curve( target, @@ -1067,9 +1067,9 @@ def precision_recall_curve( thresholds=thresholds, ) if task == "multilabel": - assert ( - isinstance(num_labels, int) and num_labels > 0 - ), "Number of labels must be a positive integer." + assert isinstance(num_labels, int) and num_labels > 0, ( + "Number of labels must be a positive integer." + ) return multilabel_precision_recall_curve( target, diff --git a/cyclops/evaluate/metrics/functional/specificity.py b/cyclops/evaluate/metrics/functional/specificity.py index b1f1822c7..4f77ba350 100644 --- a/cyclops/evaluate/metrics/functional/specificity.py +++ b/cyclops/evaluate/metrics/functional/specificity.py @@ -434,9 +434,9 @@ def specificity( zero_division=zero_division, ) if task == "multiclass": - assert ( - isinstance(num_classes, int) and num_classes > 0 - ), "Number of classes must be specified for multiclass classification." + assert isinstance(num_classes, int) and num_classes > 0, ( + "Number of classes must be specified for multiclass classification." + ) return multiclass_specificity( target, preds, @@ -446,9 +446,9 @@ def specificity( zero_division=zero_division, ) if task == "multilabel": - assert ( - isinstance(num_labels, int) and num_labels > 0 - ), "Number of labels must be specified for multilabel classification." + assert isinstance(num_labels, int) and num_labels > 0, ( + "Number of labels must be specified for multilabel classification." + ) return multilabel_specificity( target, preds, diff --git a/cyclops/evaluate/metrics/functional/stat_scores.py b/cyclops/evaluate/metrics/functional/stat_scores.py index 34a881689..fc7db1ffb 100644 --- a/cyclops/evaluate/metrics/functional/stat_scores.py +++ b/cyclops/evaluate/metrics/functional/stat_scores.py @@ -215,7 +215,7 @@ def _binary_stat_scores_format( if check: raise ValueError( f"Detected the following values in `target`: {unique_values} but" - f" expected only the following values {[0,1]}.", + f" expected only the following values {[0, 1]}.", ) # If preds is label array, also check that it only contains [0,1] values @@ -823,9 +823,9 @@ def stat_scores( threshold=threshold, ) elif task == "multiclass": - assert ( - isinstance(num_classes, int) and num_classes > 0 - ), "Number of classes must be a positive integer." + assert isinstance(num_classes, int) and num_classes > 0, ( + "Number of classes must be a positive integer." + ) scores = multiclass_stat_scores( target, preds, @@ -834,9 +834,9 @@ def stat_scores( top_k=top_k, ) elif task == "multilabel": - assert ( - isinstance(num_labels, int) and num_labels > 0 - ), "Number of labels must be a positive integer." + assert isinstance(num_labels, int) and num_labels > 0, ( + "Number of labels must be a positive integer." + ) scores = multilabel_stat_scores( target, preds, diff --git a/cyclops/evaluate/metrics/precision_recall.py b/cyclops/evaluate/metrics/precision_recall.py index 6de0f33c4..2aabf4be9 100644 --- a/cyclops/evaluate/metrics/precision_recall.py +++ b/cyclops/evaluate/metrics/precision_recall.py @@ -345,9 +345,9 @@ def __new__( # type: ignore # mypy expects a subclass of Precision zero_division=zero_division, ) if task == "multiclass": - assert ( - isinstance(num_classes, int) and num_classes > 0 - ), "Number of classes must be specified for multiclass classification." + assert isinstance(num_classes, int) and num_classes > 0, ( + "Number of classes must be specified for multiclass classification." + ) return MulticlassPrecision( num_classes=num_classes, top_k=top_k, @@ -355,9 +355,9 @@ def __new__( # type: ignore # mypy expects a subclass of Precision zero_division=zero_division, ) if task == "multilabel": - assert ( - isinstance(num_labels, int) and num_labels > 0 - ), "Number of labels must be specified for multilabel classification." + assert isinstance(num_labels, int) and num_labels > 0, ( + "Number of labels must be specified for multilabel classification." + ) return MultilabelPrecision( num_labels=num_labels, threshold=threshold, @@ -694,9 +694,9 @@ def __new__( # type: ignore # mypy expects a subclass of Recall zero_division=zero_division, ) if task == "multiclass": - assert ( - isinstance(num_classes, int) and num_classes > 0 - ), "Number of classes must be specified for multiclass classification." + assert isinstance(num_classes, int) and num_classes > 0, ( + "Number of classes must be specified for multiclass classification." + ) return MulticlassRecall( num_classes=num_classes, top_k=top_k, @@ -704,9 +704,9 @@ def __new__( # type: ignore # mypy expects a subclass of Recall zero_division=zero_division, ) if task == "multilabel": - assert ( - isinstance(num_labels, int) and num_labels > 0 - ), "Number of labels must be specified for multilabel classification." + assert isinstance(num_labels, int) and num_labels > 0, ( + "Number of labels must be specified for multilabel classification." + ) return MultilabelRecall( num_labels=num_labels, threshold=threshold, diff --git a/cyclops/evaluate/metrics/precision_recall_curve.py b/cyclops/evaluate/metrics/precision_recall_curve.py index 9a5ce76b5..e220c7d05 100644 --- a/cyclops/evaluate/metrics/precision_recall_curve.py +++ b/cyclops/evaluate/metrics/precision_recall_curve.py @@ -566,17 +566,17 @@ def __new__( # type: ignore # mypy expects a subclass of PrecisionRecallCurve pos_label=pos_label, ) if task == "multiclass": - assert ( - isinstance(num_classes, int) and num_classes > 0 - ), "Number of classes must be a positive integer." + assert isinstance(num_classes, int) and num_classes > 0, ( + "Number of classes must be a positive integer." + ) return MulticlassPrecisionRecallCurve( num_classes=num_classes, thresholds=thresholds, ) if task == "multilabel": - assert ( - isinstance(num_labels, int) and num_labels > 0 - ), "Number of labels must be a positive integer." + assert isinstance(num_labels, int) and num_labels > 0, ( + "Number of labels must be a positive integer." + ) return MultilabelPrecisionRecallCurve( num_labels=num_labels, thresholds=thresholds, diff --git a/cyclops/evaluate/metrics/sensitivity.py b/cyclops/evaluate/metrics/sensitivity.py index 5ea9ab5df..29284c630 100644 --- a/cyclops/evaluate/metrics/sensitivity.py +++ b/cyclops/evaluate/metrics/sensitivity.py @@ -302,9 +302,9 @@ def __new__( # type: ignore # mypy expects a subclass of Sensitivity zero_division=zero_division, ) if task == "multiclass": - assert ( - isinstance(num_classes, int) and num_classes > 0 - ), "Number of classes must be specified for multiclass classification." + assert isinstance(num_classes, int) and num_classes > 0, ( + "Number of classes must be specified for multiclass classification." + ) return MulticlassSensitivity( num_classes=num_classes, top_k=top_k, @@ -312,9 +312,9 @@ def __new__( # type: ignore # mypy expects a subclass of Sensitivity zero_division=zero_division, ) if task == "multilabel": - assert ( - isinstance(num_labels, int) and num_labels > 0 - ), "Number of labels must be specified for multilabel classification." + assert isinstance(num_labels, int) and num_labels > 0, ( + "Number of labels must be specified for multilabel classification." + ) return MultilabelSensitivity( num_labels=num_labels, threshold=threshold, diff --git a/cyclops/evaluate/metrics/specificity.py b/cyclops/evaluate/metrics/specificity.py index e8efabca5..efe4574b0 100644 --- a/cyclops/evaluate/metrics/specificity.py +++ b/cyclops/evaluate/metrics/specificity.py @@ -375,9 +375,9 @@ def __new__( # type: ignore # mypy expects a subclass of Specificity zero_division=zero_division, ) if task == "multiclass": - assert ( - isinstance(num_classes, int) and num_classes > 0 - ), "Number of classes must be specified for multiclass classification." + assert isinstance(num_classes, int) and num_classes > 0, ( + "Number of classes must be specified for multiclass classification." + ) return MulticlassSpecificity( num_classes=num_classes, top_k=top_k, @@ -385,9 +385,9 @@ def __new__( # type: ignore # mypy expects a subclass of Specificity zero_division=zero_division, ) if task == "multilabel": - assert ( - isinstance(num_labels, int) and num_labels > 0 - ), "Number of labels must be specified for multilabel classification." + assert isinstance(num_labels, int) and num_labels > 0, ( + "Number of labels must be specified for multilabel classification." + ) return MultilabelSpecificity( num_labels=num_labels, threshold=threshold, diff --git a/cyclops/evaluate/metrics/stat_scores.py b/cyclops/evaluate/metrics/stat_scores.py index 1b3fee8e9..13ac23fce 100644 --- a/cyclops/evaluate/metrics/stat_scores.py +++ b/cyclops/evaluate/metrics/stat_scores.py @@ -487,18 +487,18 @@ def __new__( # type: ignore # mypy expects a subclass of StatScores if task == "binary": return BinaryStatScores(threshold=threshold, pos_label=pos_label) if task == "multiclass": - assert ( - isinstance(num_classes, int) and num_classes > 0 - ), "Number of classes must be a positive integer." + assert isinstance(num_classes, int) and num_classes > 0, ( + "Number of classes must be a positive integer." + ) return MulticlassStatScores( num_classes=num_classes, top_k=top_k, classwise=classwise, ) if task == "multilabel": - assert ( - isinstance(num_labels, int) and num_labels > 0 - ), "Number of labels must be a positive integer." + assert isinstance(num_labels, int) and num_labels > 0, ( + "Number of labels must be a positive integer." + ) return MultilabelStatScores( num_labels=num_labels, threshold=threshold, diff --git a/cyclops/evaluate/utils.py b/cyclops/evaluate/utils.py index 835de6263..1bbe99d1e 100644 --- a/cyclops/evaluate/utils.py +++ b/cyclops/evaluate/utils.py @@ -159,9 +159,11 @@ def get_columns_as_array( if isinstance(columns, str): columns = [columns] - with dataset.formatted_as("arrow", columns=columns, output_all_columns=True) if ( - isinstance(dataset, Dataset) and dataset.format != "arrow" - ) else nullcontext(): + with ( + dataset.formatted_as("arrow", columns=columns, output_all_columns=True) + if (isinstance(dataset, Dataset) and dataset.format != "arrow") + else nullcontext() + ): out_arr = squeeze_all( xp.stack( [xp.asarray(dataset[col].to_pylist()) for col in columns], axis=-1 diff --git a/cyclops/models/catalog.py b/cyclops/models/catalog.py index a2571f919..082ea9cfa 100644 --- a/cyclops/models/catalog.py +++ b/cyclops/models/catalog.py @@ -39,13 +39,11 @@ LOGGER = logging.getLogger(__name__) setup_logging(print_level="WARN", logger=LOGGER) _xgboost_unavailable_message = ( - "The XGBoost library is required to use the `XGBClassifier` model. " - "Please install it as an extra using `python3 -m pip install 'pycyclops[xgboost]'`\ + "The XGBoost library is required to use the `XGBClassifier` model. Please install it as an extra using `python3 -m pip install 'pycyclops[xgboost]'`\ or using `python3 -m pip install xgboost`." ) _torchxrayvision_unavailable_message = ( - "The torchxrayvision library is required to use the `densenet` or `resnet` model. " - "Please install it as an extra using `python3 -m pip install 'pycyclops[torchxrayvision]'`\ + "The torchxrayvision library is required to use the `densenet` or `resnet` model. Please install it as an extra using `python3 -m pip install 'pycyclops[torchxrayvision]'`\ or using `python3 -m pip install torchxrayvision`." ) _torch_unavailable_message = ( diff --git a/cyclops/models/wrappers/pt_model.py b/cyclops/models/wrappers/pt_model.py index 50b3a4302..37499dec2 100644 --- a/cyclops/models/wrappers/pt_model.py +++ b/cyclops/models/wrappers/pt_model.py @@ -968,14 +968,17 @@ def fit( splits_mapping["validation"] = val_split format_kwargs = {} if transforms is None else {"transform": transforms} - with X[train_split].formatted_as( - "custom" if transforms is not None else "torch", - columns=feature_columns + target_columns, - **format_kwargs, - ), X[val_split].formatted_as( - "custom" if transforms is not None else "torch", - columns=feature_columns + target_columns, - **format_kwargs, + with ( + X[train_split].formatted_as( + "custom" if transforms is not None else "torch", + columns=feature_columns + target_columns, + **format_kwargs, + ), + X[val_split].formatted_as( + "custom" if transforms is not None else "torch", + columns=feature_columns + target_columns, + **format_kwargs, + ), ): self.partial_fit( X, @@ -1309,7 +1312,7 @@ def save_model(self, filepath: str, overwrite: bool = True, **kwargs): if include_lr_scheduler: state_dict["lr_scheduler"] = self.lr_scheduler_.state_dict() # type: ignore[attr-defined] - epoch = kwargs.get("epoch", None) + epoch = kwargs.get("epoch") if epoch is not None: filename, extension = os.path.basename(filepath).split(".") filepath = join( diff --git a/cyclops/models/wrappers/sk_model.py b/cyclops/models/wrappers/sk_model.py index b0d534f1e..07d92304e 100644 --- a/cyclops/models/wrappers/sk_model.py +++ b/cyclops/models/wrappers/sk_model.py @@ -400,7 +400,7 @@ def partial_fit( splits_mapping = {"train": "train"} if not hasattr(self.model_, "partial_fit"): raise AttributeError( - f"Model {self.model_name}" "does not have a `partial_fit` method.", + f"Model {self.model_name}does not have a `partial_fit` method.", ) # Train data is a Hugging Face Dataset Dictionary. if isinstance(X, DatasetDict): @@ -687,7 +687,7 @@ def predict_proba( splits_mapping = {"test": "test"} if not hasattr(self.model_, "predict_proba"): raise AttributeError( - f"Model {self.model_name}" "does not have a `predict_proba` method.", + f"Model {self.model_name}does not have a `predict_proba` method.", ) # Data is a Hugging Face Dataset Dictionary. if isinstance(X, DatasetDict): diff --git a/cyclops/monitor/tester.py b/cyclops/monitor/tester.py index 9a37d9d0f..ac4b1d7c5 100644 --- a/cyclops/monitor/tester.py +++ b/cyclops/monitor/tester.py @@ -1057,9 +1057,9 @@ def get_record(self, record_type): def counts(self, record_type, max_ensemble_size=None) -> np.ndarray: """Get counts.""" - assert ( - max_ensemble_size is None or max_ensemble_size > 0 - ), "max_ensemble_size must be positive or None" + assert max_ensemble_size is None or max_ensemble_size > 0, ( + "max_ensemble_size must be positive or None" + ) rec = self.get_record(record_type) counts = [] for i in rec.seed.unique(): diff --git a/cyclops/monitor/utils.py b/cyclops/monitor/utils.py index 7584462ca..d990c5096 100644 --- a/cyclops/monitor/utils.py +++ b/cyclops/monitor/utils.py @@ -258,12 +258,12 @@ def get_args(obj: Any, kwargs: Dict[str, Any]) -> Dict[str, Any]: """ args = {} - for key in kwargs: + for key, value in kwargs.items(): if (inspect.isclass(obj) and key in inspect.signature(obj).parameters) or ( (inspect.ismethod(obj) or inspect.isfunction(obj)) and key in inspect.getfullargspec(obj).args ): - args[key] = kwargs[key] + args[key] = value return args diff --git a/cyclops/report/plot/classification.py b/cyclops/report/plot/classification.py index 60da0b884..7a5e062e1 100644 --- a/cyclops/report/plot/classification.py +++ b/cyclops/report/plot/classification.py @@ -86,13 +86,13 @@ def _set_class_names(self, class_names: List[str]) -> None: """ if class_names is not None: - assert ( - len(class_names) == self.class_num - ), "class_names must be equal to class_num" + assert len(class_names) == self.class_num, ( + "class_names must be equal to class_num" + ) elif self.task_type == "multilabel": - class_names = [f"Label_{i+1}" for i in range(self.class_num)] + class_names = [f"Label_{i + 1}" for i in range(self.class_num)] else: - class_names = [f"Class_{i+1}" for i in range(self.class_num)] + class_names = [f"Class_{i + 1}" for i in range(self.class_num)] self.class_names = class_names def calibration( @@ -257,7 +257,9 @@ def threshperf( == len(roc_curve.thresholds) == len(ppv) == len(npv) - ), "Length mismatch between ROC curve, PPV, NPV. All curves need to be computed using the same thresholds" + ), ( + "Length mismatch between ROC curve, PPV, NPV. All curves need to be computed using the same thresholds" + ) # Define hover template to show three decimal places hover_template = "Threshold: %{x:.3f}
Metric Value: %{y:.3f}" # Create a subplot for each metric @@ -397,16 +399,18 @@ def roc_curve( ), ) else: - assert ( - len(fprs) == len(tprs) == self.class_num - ), "fprs and tprs must be of length class_num for \ + assert len(fprs) == len(tprs) == self.class_num, ( + "fprs and tprs must be of length class_num for \ multiclass/multilabel tasks" + ) for i in range(self.class_num): if auroc is not None: assert ( len(auroc) == self.class_num # type: ignore[arg-type] - ), "AUROCs must be of length class_num for \ + ), ( + "AUROCs must be of length class_num for \ multiclass/multilabel tasks" + ) name = f"{self.class_names[i]} (AUC = {auroc[i]:.2f})" # type: ignore[index] # noqa: E501 else: name = self.class_names[i] @@ -496,16 +500,18 @@ def roc_curve_comparison( ) else: for slice_name, slice_curve in roc_curves.items(): - assert ( - len(slice_curve[0]) == len(slice_curve[1]) == self.class_num - ), f"FPRs and TPRs must be of length class_num for \ + assert len(slice_curve[0]) == len(slice_curve[1]) == self.class_num, ( + f"FPRs and TPRs must be of length class_num for \ multiclass/multilabel tasks in slice {slice_name}" + ) for i in range(self.class_num): if aurocs and slice_name in aurocs: assert ( len(aurocs[slice_name]) == self.class_num # type: ignore[arg-type] # noqa: E501 - ), "AUROCs must be of length class_num for \ + ), ( + "AUROCs must be of length class_num for \ multiclass/multilabel tasks" + ) name = f"{slice_name}, {self.class_names[i]} \ (AUC = {aurocs[i]:.2f})" # type: ignore[index] else: @@ -583,10 +589,10 @@ def precision_recall_curve( ) else: trace = [] - assert ( - len(recalls) == len(precisions) == self.class_num - ), "Recalls and precisions must be of length class_num for \ + assert len(recalls) == len(precisions) == self.class_num, ( + "Recalls and precisions must be of length class_num for \ multiclass/multilabel tasks" + ) for i in range(self.class_num): trace.append( line_plot( @@ -668,14 +674,18 @@ def precision_recall_curve_comparison( len(slice_curve.precision) == len(slice_curve.recall) == self.class_num - ), f"Recalls and precisions must be of length class_num for \ + ), ( + f"Recalls and precisions must be of length class_num for \ multiclass/multilabel tasks in slice {slice_name}" + ) for i in range(self.class_num): if auprcs and slice_name in auprcs: assert ( len(auprcs[slice_name]) == self.class_num # type: ignore[arg-type] # noqa: E501 - ), "AUPRCs must be of length class_num for \ + ), ( + "AUPRCs must be of length class_num for \ multiclass/multilabel tasks" + ) name = f"{slice_name}, {self.class_names[i]} \ (AUC = {auprcs[i]:.2f})" else: @@ -747,8 +757,10 @@ def metrics_value( assert all( len(value) == self.class_num # type: ignore[arg-type] for value in metrics.values() - ), "Every metric must be of length class_num for \ + ), ( + "Every metric must be of length class_num for \ multiclass/multilabel tasks" + ) for i in range(self.class_num): trace.append( bar_plot( @@ -981,10 +993,10 @@ def metrics_comparison_radar( if isinstance(metric_values, list) or ( isinstance(metric_values, np.ndarray) and metric_values.ndim > 0 ): - assert ( - len(metric_values) == self.class_num - ), "Metric values must be of length class_num for \ + assert len(metric_values) == self.class_num, ( + "Metric values must be of length class_num for \ multiclass/multilabel tasks" + ) radial_data.extend(metric_values) theta = [ f"{metric_name}: {self.class_names[i]}" diff --git a/cyclops/report/report.py b/cyclops/report/report.py index 786dae53c..19f916581 100644 --- a/cyclops/report/report.py +++ b/cyclops/report/report.py @@ -668,9 +668,9 @@ def log_dataset( """ # sensitive features must be in features if features is not None and sensitive_features is not None: - assert all( - feature in features for feature in sensitive_features - ), "All sensitive features must be in the features list." + assert all(feature in features for feature in sensitive_features), ( + "All sensitive features must be in the features list." + ) # TODO: plot dataset distribution data = { diff --git a/cyclops/report/utils.py b/cyclops/report/utils.py index 94c8a556b..a4f9a3a04 100644 --- a/cyclops/report/utils.py +++ b/cyclops/report/utils.py @@ -304,9 +304,9 @@ def get_metrics_trends( slice_names=slice_names, metric_names=metric_names, ) - assert ( - len(performance_history) > 0 - ), "No performance history found. Check slice and metric names." + assert len(performance_history) > 0, ( + "No performance history found. Check slice and metric names." + ) performance_recent = [] for metric_name, metric_value in flat_results.items(): name_split = metric_name.split("/") @@ -323,9 +323,9 @@ def get_metrics_trends( slice_names=slice_names, metric_names=metric_names, ) - assert ( - len(performance_recent) > 0 - ), "No performance metrics found. Check slice and metric names." + assert len(performance_recent) > 0, ( + "No performance metrics found. Check slice and metric names." + ) today = dt_date.today().strftime("%Y-%m-%d") now = dt_datetime.now().strftime("%H-%M-%S") if keep_timestamps: diff --git a/cyclops/tasks/utils.py b/cyclops/tasks/utils.py index cb442e0bb..734227b97 100644 --- a/cyclops/tasks/utils.py +++ b/cyclops/tasks/utils.py @@ -108,8 +108,10 @@ def prepare_models( models_dict = {model_name: models} # models contains one model name elif isinstance(models, str): - assert models in list_models(), f"Model name is not registered! \ + assert models in list_models(), ( + f"Model name is not registered! \ Available models are: {list_models()}" + ) models_dict = {models: create_model(models)} # models contains a list or tuple of model names or wrapped models elif isinstance(models, (list, tuple)): @@ -118,8 +120,10 @@ def prepare_models( model_name = _model_names_mapping.get(model.model.__name__) models_dict[model_name] = model elif isinstance(model, str): - assert model in list_models(), f"Model name is not registered! \ + assert model in list_models(), ( + f"Model name is not registered! \ Available models are: {list_models()}" + ) models_dict[model] = create_model(model) else: raise TypeError( diff --git a/cyclops/utils/optional.py b/cyclops/utils/optional.py index e7795f344..07d06c240 100644 --- a/cyclops/utils/optional.py +++ b/cyclops/utils/optional.py @@ -64,8 +64,7 @@ def import_optional_module( return module except ModuleNotFoundError as exc: msg = ( - f"Missing optional dependency '{name}'. " - f"Use pip or conda to install {name}." + f"Missing optional dependency '{name}'. Use pip or conda to install {name}." ) if error == "raise": raise type(exc)(msg) from None diff --git a/cyclops/utils/profile.py b/cyclops/utils/profile.py index fd77ae7c2..1a5a9115e 100644 --- a/cyclops/utils/profile.py +++ b/cyclops/utils/profile.py @@ -1,3 +1,5 @@ +# noqa: A005 + """Useful functions for timing, profiling.""" import logging diff --git a/docs/source/examples/metrics.ipynb b/docs/source/examples/metrics.ipynb index cf2b10ba4..5e867b1f8 100644 --- a/docs/source/examples/metrics.ipynb +++ b/docs/source/examples/metrics.ipynb @@ -15,6 +15,8 @@ "metadata": {}, "outputs": [], "source": [ + "\"\"\"Imports.\"\"\"\n", + "\n", "import numpy as np\n", "import pandas as pd\n", "from datasets.arrow_dataset import Dataset\n", @@ -57,7 +59,7 @@ "outputs": [], "source": [ "df = breast_cancer_data.frame\n", - "df.describe().T" + "print(df.describe().T)" ] }, { @@ -222,7 +224,7 @@ "outputs": [], "source": [ "slice_spec = SliceSpec(spec_list, intersections=2)\n", - "slice_spec" + "print(slice_spec)" ] }, { @@ -316,15 +318,8 @@ " target_columns=\"target\",\n", " prediction_columns=\"preds_prob\",\n", ")\n", - "fairness_result" + "print(fairness_result)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/docs/source/examples/report.ipynb b/docs/source/examples/report.ipynb index 6312ba6f5..457f7572d 100644 --- a/docs/source/examples/report.ipynb +++ b/docs/source/examples/report.ipynb @@ -1,11 +1,21 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Report Generation for Heart Failure Prediction\n", + "Here's an example to demonstrate how we can generate a report as we proceed through all the steps to train and evaluate a model. For this purpose, we are going to use Kaggle's heart prediction failure dataset and gradually populate the report with information about dataset, model and results." + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ + "\"\"\"Imports.\"\"\"\n", + "\n", "import copy\n", "import inspect\n", "import os\n", @@ -23,6 +33,7 @@ "from tqdm import tqdm\n", "\n", "from cyclops.data.slicer import SliceSpec\n", + "from cyclops.evaluate import evaluator\n", "from cyclops.evaluate.metrics import create_metric\n", "from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict\n", "from cyclops.report import ModelCardReport\n", @@ -30,14 +41,6 @@ "from cyclops.report.utils import flatten_results_dict" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Report Generation for Heart Failure Prediction\n", - "Here's an example to demonstrate how we can generate a report as we proceed through all the steps to train and evaluate a model. For this purpose, we are going to use Kaggle's heart prediction failure dataset and gradually populate the report with information about dataset, model and results." - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -105,7 +108,7 @@ "metadata": {}, "outputs": [], "source": [ - "df.describe().T" + "print(df.describe().T)" ] }, { @@ -531,9 +534,6 @@ "metadata": {}, "outputs": [], "source": [ - "from cyclops.evaluate import evaluator\n", - "\n", - "\n", "# Create Dataset object\n", "heart_failure_data = Dataset.from_pandas(df_test)\n", "\n", @@ -557,24 +557,6 @@ ")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "result" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "results_flat" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -878,7 +860,7 @@ ], "metadata": { "kernelspec": { - "display_name": "cyclops", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -892,9 +874,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/docs/source/tutorials/diabetes_130/readmission_prediction.ipynb b/docs/source/tutorials/diabetes_130/readmission_prediction.ipynb index bb7bdb3dd..61766e525 100644 --- a/docs/source/tutorials/diabetes_130/readmission_prediction.ipynb +++ b/docs/source/tutorials/diabetes_130/readmission_prediction.ipynb @@ -152,17 +152,6 @@ "variables = diabetes_130_data[\"variables\"]" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "metadata" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/docs/source/tutorials/diabetes_130/readmission_prediction_detectron.ipynb b/docs/source/tutorials/diabetes_130/readmission_prediction_detectron.ipynb index f34093832..e62cd2649 100644 --- a/docs/source/tutorials/diabetes_130/readmission_prediction_detectron.ipynb +++ b/docs/source/tutorials/diabetes_130/readmission_prediction_detectron.ipynb @@ -115,17 +115,6 @@ "variables = diabetes_130_data[\"variables\"]" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "metadata" - ] - }, { "cell_type": "code", "execution_count": null, @@ -749,7 +738,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/tests/cyclops/evaluate/metrics/experimental/test_metric.py b/tests/cyclops/evaluate/metrics/experimental/test_metric.py index ceea71694..a925112d2 100644 --- a/tests/cyclops/evaluate/metrics/experimental/test_metric.py +++ b/tests/cyclops/evaluate/metrics/experimental/test_metric.py @@ -46,9 +46,9 @@ def test_add_state_factory(): metric.add_state_default_factory("a", lambda xp: xp.asarray(0), None) # type: ignore reduce_fn = metric._reductions["a"] assert reduce_fn is None, "Saved reduction function is not None." - assert ( - metric._default_factories.get("a") is not None - ), "Default factory was not correctly created." + assert metric._default_factories.get("a") is not None, ( + "Default factory was not correctly created." + ) # default_factory is 'list' metric.add_state_default_factory("b", list) # type: ignore @@ -61,8 +61,7 @@ def test_add_state_factory(): reduce_fn = metric._reductions["c"] assert callable(reduce_fn), "Saved reduction function is not callable." assert reduce_fn is dim_zero_sum, ( - "Saved reduction function is not the same as the one used to " - "create the state." + "Saved reduction function is not the same as the one used to create the state." ) assert reduce_fn(anp.asarray([1, 1])) == anp.asarray( 2, @@ -73,8 +72,7 @@ def test_add_state_factory(): reduce_fn = metric._reductions["d"] assert callable(reduce_fn), "Saved reduction function is not callable." assert reduce_fn is dim_zero_mean, ( - "Saved reduction function is not the same as the one used to " - "create the state." + "Saved reduction function is not the same as the one used to create the state." ) assert np.allclose( reduce_fn(anp.asarray([1.0, 2.0])), @@ -86,8 +84,7 @@ def test_add_state_factory(): reduce_fn = metric._reductions["e"] assert callable(reduce_fn), "Saved reduction function is not callable." assert reduce_fn is dim_zero_cat, ( - "Saved reduction function is not the same as the one used to " - "create the state." + "Saved reduction function is not the same as the one used to create the state." ) np.testing.assert_array_equal( reduce_fn([anp.asarray([1]), anp.asarray([1])]), @@ -100,8 +97,7 @@ def test_add_state_factory(): reduce_fn = metric._reductions["f"] assert callable(reduce_fn), "Saved reduction function is not callable." assert reduce_fn is dim_zero_max, ( - "Saved reduction function is not the same as the one used to " - "create the state." + "Saved reduction function is not the same as the one used to create the state." ) np.testing.assert_array_equal( reduce_fn(anp.asarray([1, 2])), @@ -115,8 +111,7 @@ def test_add_state_factory(): reduce_fn = metric._reductions["g"] assert callable(reduce_fn), "Saved reduction function is not callable." assert reduce_fn is dim_zero_min, ( - "Saved reduction function is not the same as the one used to " - "create the state." + "Saved reduction function is not the same as the one used to create the state." ) np.testing.assert_array_equal( reduce_fn(anp.asarray([1, 2])), diff --git a/tests/cyclops/evaluate/metrics/experimental/test_metric_dict.py b/tests/cyclops/evaluate/metrics/experimental/test_metric_dict.py index 401726060..508ce2d6a 100644 --- a/tests/cyclops/evaluate/metrics/experimental/test_metric_dict.py +++ b/tests/cyclops/evaluate/metrics/experimental/test_metric_dict.py @@ -94,16 +94,16 @@ def test_metric_dict_adfix(prefix, postfix): # test __call__ output = metrics(anp.asarray(1, dtype=anp.float32)) for name in names: - assert ( - name in output - ), f"`MetricDict` output does not contain metric {name} when called." + assert name in output, ( + f"`MetricDict` output does not contain metric {name} when called." + ) # test `compute` output = metrics.compute() for name in names: - assert ( - name in output - ), f"`MetricDict` output does not contain metric {name} using the `compute` method." + assert name in output, ( + f"`MetricDict` output does not contain metric {name} using the `compute` method." + ) # test `clone` new_metrics = metrics.clone(prefix="new_") @@ -112,9 +112,9 @@ def test_metric_dict_adfix(prefix, postfix): n[len(prefix) :] if prefix is not None else n for n in names ] for name in names: - assert ( - f"new_{name}" in output - ), f"`MetricDict` output does not contain metric new_{name} when cloned." + assert f"new_{name}" in output, ( + f"`MetricDict` output does not contain metric new_{name} when cloned." + ) for k in new_metrics: assert "new_" in k diff --git a/tests/cyclops/evaluate/metrics/experimental/testers.py b/tests/cyclops/evaluate/metrics/experimental/testers.py index 4d58305d3..a0d887d93 100644 --- a/tests/cyclops/evaluate/metrics/experimental/testers.py +++ b/tests/cyclops/evaluate/metrics/experimental/testers.py @@ -85,9 +85,9 @@ def _class_impl_test( # noqa: PLR0912 t_size = target.shape[0] p_size = preds.shape[0] - assert ( - p_size == t_size - ), f"`preds` and `target` have different number of samples: {p_size} and {t_size}." + assert p_size == t_size, ( + f"`preds` and `target` have different number of samples: {p_size} and {t_size}." + ) num_batches = p_size # instantiate metric @@ -176,9 +176,9 @@ def _function_impl_test( t_size = target.shape[0] p_size = preds.shape[0] - assert ( - p_size == t_size - ), f"`preds` and `target` have different number of samples: {p_size} and {t_size}." + assert p_size == t_size, ( + f"`preds` and `target` have different number of samples: {p_size} and {t_size}." + ) metric_args = metric_args or {} metric = partial(metric_function, **metric_args) diff --git a/tests/cyclops/evaluate/metrics/experimental/utils/test_ops.py b/tests/cyclops/evaluate/metrics/experimental/utils/test_ops.py index 65e61c18b..032b53733 100644 --- a/tests/cyclops/evaluate/metrics/experimental/utils/test_ops.py +++ b/tests/cyclops/evaluate/metrics/experimental/utils/test_ops.py @@ -189,19 +189,17 @@ def test_apply_to_nested_collections(self): }, } - for k in expected_result: + for k, v in expected_result.items(): assert k in result - if isinstance(expected_result[k], dict): - for kk in expected_result[k]: + if isinstance(v, dict): + for kk in v: assert kk in result[k] - assert anp.all(expected_result[k][kk] == result[k][kk]) - elif isinstance(expected_result[k], (tuple, list)): - assert all( - anp.all(a == b) for a, b in zip(result[k], expected_result[k]) - ) + assert anp.all(v[kk] == result[k][kk]) + elif isinstance(v, (tuple, list)): + assert all(anp.all(a == b) for a, b in zip(result[k], v)) else: - assert anp.all(expected_result[k] == result[k]) + assert anp.all(v == result[k]) class TestBincount: diff --git a/tests/cyclops/evaluate/metrics/helpers.py b/tests/cyclops/evaluate/metrics/helpers.py index 724ab89a5..bc030e43b 100644 --- a/tests/cyclops/evaluate/metrics/helpers.py +++ b/tests/cyclops/evaluate/metrics/helpers.py @@ -146,5 +146,5 @@ def _assert_allclose(data_a: Any, data_b: Any, atol: float = 1e-8): _assert_allclose(data_a[key], data_b[key], atol=atol) else: raise ValueError( - f"Unknown format for comparison: {type(data_a)} and" f" {type(data_b)}", + f"Unknown format for comparison: {type(data_a)} and {type(data_b)}", )