diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 384ef2229..95dd5e8c7 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.6.0 # Use the ref you want to point at
+ rev: v5.0.0 # Use the ref you want to point at
hooks:
- id: trailing-whitespace
- id: check-ast
@@ -16,7 +16,7 @@ repos:
- id: check-toml
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: 'v0.6.5'
+ rev: 'v0.9.2'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
@@ -25,7 +25,7 @@ repos:
types_or: [python, jupyter]
- repo: https://github.com/pre-commit/mirrors-mypy
- rev: v1.11.2
+ rev: v1.14.1
hooks:
- id: mypy
entry: python3 -m mypy --config-file pyproject.toml
@@ -41,7 +41,7 @@ repos:
entry: python3 -m nbstripout
- repo: https://github.com/nbQA-dev/nbQA
- rev: 1.8.7
+ rev: 1.9.1
hooks:
- id: nbqa-black
- id: nbqa-ruff
diff --git a/cyclops/data/df/feature.py b/cyclops/data/df/feature.py
index db8c6e5a5..6ac30200e 100644
--- a/cyclops/data/df/feature.py
+++ b/cyclops/data/df/feature.py
@@ -60,7 +60,7 @@ def __init__(self, **kwargs: Any) -> None:
if kwargs[FEATURE_TYPE_ATTR] not in FEATURE_TYPES:
raise ValueError(
f"""Feature type '{kwargs[FEATURE_TYPE_ATTR]}'
- not in {', '.join(FEATURE_TYPES)}.""",
+ not in {", ".join(FEATURE_TYPES)}.""",
)
# Set attributes
diff --git a/cyclops/data/features/medical_image.py b/cyclops/data/features/medical_image.py
index 7b7d9cae9..3dcfd9583 100644
--- a/cyclops/data/features/medical_image.py
+++ b/cyclops/data/features/medical_image.py
@@ -209,11 +209,14 @@ def decode_example(
use_auth_token = token_per_repo_id.get(repo_id)
except ValueError:
use_auth_token = None
- with xopen(
- path,
- "rb",
- use_auth_token=use_auth_token,
- ) as file_obj, BytesIO(file_obj.read()) as buffer:
+ with (
+ xopen(
+ path,
+ "rb",
+ use_auth_token=use_auth_token,
+ ) as file_obj,
+ BytesIO(file_obj.read()) as buffer,
+ ):
image, metadata = self._read_file_from_bytes(buffer)
metadata["filename_or_obj"] = path
diff --git a/cyclops/data/impute.py b/cyclops/data/impute.py
index 4595e92be..6ab284faa 100644
--- a/cyclops/data/impute.py
+++ b/cyclops/data/impute.py
@@ -304,7 +304,7 @@ def _process_imputefunc(
if imputefunc not in IMPUTEFUNCS:
raise ValueError(
f"""Imputefunc string {imputefunc} not supported.
- Supporting: {','.join(IMPUTEFUNCS)}""",
+ Supporting: {",".join(IMPUTEFUNCS)}""",
)
func = IMPUTEFUNCS[imputefunc]
elif callable(imputefunc):
diff --git a/cyclops/data/slicer.py b/cyclops/data/slicer.py
index ad06a0728..92e74563a 100644
--- a/cyclops/data/slicer.py
+++ b/cyclops/data/slicer.py
@@ -825,8 +825,7 @@ def filter_string_contains(
example_values = pa.array(examples[column_name])
if not pa.types.is_string(example_values.type):
raise ValueError(
- "Expected string feature, but got feature of type "
- f"{example_values.type}.",
+ f"Expected string feature, but got feature of type {example_values.type}.",
)
# get all the values that contain the given substring
diff --git a/cyclops/evaluate/evaluator.py b/cyclops/evaluate/evaluator.py
index cf4d01611..f7b455b3b 100644
--- a/cyclops/evaluate/evaluator.py
+++ b/cyclops/evaluate/evaluator.py
@@ -170,8 +170,7 @@ def _load_data(
if split is None:
split = choose_split(dataset, **load_dataset_kwargs)
LOGGER.warning(
- "Got `split=None` but `dataset` is a string. "
- "Using `split=%s` instead.",
+ "Got `split=None` but `dataset` is a string. Using `split=%s` instead.",
split,
)
diff --git a/cyclops/evaluate/fairness/evaluator.py b/cyclops/evaluate/fairness/evaluator.py
index 22b39d175..7f78bbea6 100644
--- a/cyclops/evaluate/fairness/evaluator.py
+++ b/cyclops/evaluate/fairness/evaluator.py
@@ -150,7 +150,7 @@ def evaluate_fairness( # noqa: PLR0912
# input validation and formatting
if not isinstance(dataset, Dataset):
raise TypeError(
- "Expected `dataset` to be of type `Dataset`, but got " f"{type(dataset)}.",
+ f"Expected `dataset` to be of type `Dataset`, but got {type(dataset)}.",
)
if array_lib not in _SUPPORTED_ARRAY_LIBS:
raise NotImplementedError(f"The array library `{array_lib}` is not supported.")
@@ -520,8 +520,7 @@ def _validate_group_bins(
for group, bins in group_bins.items():
if not isinstance(bins, (list, int)):
raise TypeError(
- f"The bins for {group} must be a list or an integer. "
- f"Got {type(bins)}.",
+ f"The bins for {group} must be a list or an integer. Got {type(bins)}.",
)
if isinstance(bins, int) and not 2 <= bins < len(unique_values[group]):
diff --git a/cyclops/evaluate/metrics/accuracy.py b/cyclops/evaluate/metrics/accuracy.py
index 07fdc0b43..228b0ffe3 100644
--- a/cyclops/evaluate/metrics/accuracy.py
+++ b/cyclops/evaluate/metrics/accuracy.py
@@ -339,9 +339,9 @@ def __new__( # type: ignore # mypy expects a subclass of Accuracy
zero_division=zero_division,
)
if task == "multiclass":
- assert (
- isinstance(num_classes, int) and num_classes > 0
- ), "Number of classes must be specified for multiclass classification."
+ assert isinstance(num_classes, int) and num_classes > 0, (
+ "Number of classes must be specified for multiclass classification."
+ )
return MulticlassAccuracy(
num_classes=num_classes,
top_k=top_k,
@@ -349,9 +349,9 @@ def __new__( # type: ignore # mypy expects a subclass of Accuracy
zero_division=zero_division,
)
if task == "multilabel":
- assert (
- isinstance(num_labels, int) and num_labels > 0
- ), "Number of labels must be specified for multilabel classification."
+ assert isinstance(num_labels, int) and num_labels > 0, (
+ "Number of labels must be specified for multilabel classification."
+ )
return MultilabelAccuracy(
num_labels=num_labels,
threshold=threshold,
diff --git a/cyclops/evaluate/metrics/auroc.py b/cyclops/evaluate/metrics/auroc.py
index 903886cee..950d4f421 100644
--- a/cyclops/evaluate/metrics/auroc.py
+++ b/cyclops/evaluate/metrics/auroc.py
@@ -336,18 +336,18 @@ def __new__( # type: ignore # mypy expects a subclass of AUROC
if task == "binary":
return BinaryAUROC(max_fpr=max_fpr, thresholds=thresholds)
if task == "multiclass":
- assert (
- isinstance(num_classes, int) and num_classes > 0
- ), "Number of classes must be a positive integer."
+ assert isinstance(num_classes, int) and num_classes > 0, (
+ "Number of classes must be a positive integer."
+ )
return MulticlassAUROC(
num_classes=num_classes,
thresholds=thresholds,
average=average, # type: ignore
)
if task == "multilabel":
- assert (
- isinstance(num_labels, int) and num_labels > 0
- ), "Number of labels must be a positive integer."
+ assert isinstance(num_labels, int) and num_labels > 0, (
+ "Number of labels must be a positive integer."
+ )
return MultilabelAUROC(
num_labels=num_labels,
thresholds=thresholds,
diff --git a/cyclops/evaluate/metrics/experimental/utils/types.py b/cyclops/evaluate/metrics/experimental/utils/types.py
index e3543ddfc..96184c661 100644
--- a/cyclops/evaluate/metrics/experimental/utils/types.py
+++ b/cyclops/evaluate/metrics/experimental/utils/types.py
@@ -1,3 +1,5 @@
+# noqa: A005
+
"""Utilities for array-API compatibility."""
import builtins
diff --git a/cyclops/evaluate/metrics/f_beta.py b/cyclops/evaluate/metrics/f_beta.py
index 9c0ffda5d..053dd56e8 100644
--- a/cyclops/evaluate/metrics/f_beta.py
+++ b/cyclops/evaluate/metrics/f_beta.py
@@ -363,9 +363,9 @@ def __new__( # type: ignore # mypy expects a subclass of FbetaScore
zero_division=zero_division,
)
if task == "multiclass":
- assert (
- isinstance(num_classes, int) and num_classes > 0
- ), "Number of classes must be specified for multiclass classification."
+ assert isinstance(num_classes, int) and num_classes > 0, (
+ "Number of classes must be specified for multiclass classification."
+ )
return MulticlassFbetaScore(
beta=beta,
num_classes=num_classes,
@@ -374,9 +374,9 @@ def __new__( # type: ignore # mypy expects a subclass of FbetaScore
zero_division=zero_division,
)
if task == "multilabel":
- assert (
- isinstance(num_labels, int) and num_labels > 0
- ), "Number of labels must be specified for multilabel classification."
+ assert isinstance(num_labels, int) and num_labels > 0, (
+ "Number of labels must be specified for multilabel classification."
+ )
return MultilabelFbetaScore(
beta=beta,
num_labels=num_labels,
@@ -682,9 +682,9 @@ def __new__( # type: ignore # mypy expects a subclass of F1Score
zero_division=zero_division,
)
if task == "multiclass":
- assert (
- isinstance(num_classes, int) and num_classes > 0
- ), "Number of classes must be specified for multiclass classification."
+ assert isinstance(num_classes, int) and num_classes > 0, (
+ "Number of classes must be specified for multiclass classification."
+ )
return MulticlassF1Score(
num_classes=num_classes,
top_k=top_k,
@@ -692,9 +692,9 @@ def __new__( # type: ignore # mypy expects a subclass of F1Score
zero_division=zero_division,
)
if task == "multilabel":
- assert (
- isinstance(num_labels, int) and num_labels > 0
- ), "Number of labels must be specified for multilabel classification."
+ assert isinstance(num_labels, int) and num_labels > 0, (
+ "Number of labels must be specified for multilabel classification."
+ )
return MultilabelF1Score(
num_labels=num_labels,
threshold=threshold,
diff --git a/cyclops/evaluate/metrics/functional/accuracy.py b/cyclops/evaluate/metrics/functional/accuracy.py
index 3fbaaa9cf..66eeb87c9 100644
--- a/cyclops/evaluate/metrics/functional/accuracy.py
+++ b/cyclops/evaluate/metrics/functional/accuracy.py
@@ -446,9 +446,9 @@ def accuracy(
zero_division=zero_division,
)
elif task == "multiclass":
- assert (
- isinstance(num_classes, int) and num_classes > 0
- ), "Number of classes must be specified for multiclass classification."
+ assert isinstance(num_classes, int) and num_classes > 0, (
+ "Number of classes must be specified for multiclass classification."
+ )
accuracy_score = multiclass_accuracy(
target,
preds,
@@ -458,9 +458,9 @@ def accuracy(
zero_division=zero_division,
)
elif task == "multilabel":
- assert (
- isinstance(num_labels, int) and num_labels > 0
- ), "Number of labels must be specified for multilabel classification."
+ assert isinstance(num_labels, int) and num_labels > 0, (
+ "Number of labels must be specified for multilabel classification."
+ )
accuracy_score = multilabel_accuracy(
target,
preds,
diff --git a/cyclops/evaluate/metrics/functional/auroc.py b/cyclops/evaluate/metrics/functional/auroc.py
index 316d8d638..9f194de3d 100644
--- a/cyclops/evaluate/metrics/functional/auroc.py
+++ b/cyclops/evaluate/metrics/functional/auroc.py
@@ -574,9 +574,9 @@ def auroc(
if task == "binary":
return binary_auroc(target, preds, max_fpr=max_fpr, thresholds=thresholds)
if task == "multiclass":
- assert (
- isinstance(num_classes, int) and num_classes > 0
- ), "Number of classes must be a positive integer."
+ assert isinstance(num_classes, int) and num_classes > 0, (
+ "Number of classes must be a positive integer."
+ )
return multiclass_auroc(
target,
preds,
@@ -585,9 +585,9 @@ def auroc(
average=average, # type: ignore[arg-type]
)
if task == "multilabel":
- assert (
- isinstance(num_labels, int) and num_labels > 0
- ), "Number of labels must be a positive integer."
+ assert isinstance(num_labels, int) and num_labels > 0, (
+ "Number of labels must be a positive integer."
+ )
return multilabel_auroc(
target,
preds,
diff --git a/cyclops/evaluate/metrics/functional/f_beta.py b/cyclops/evaluate/metrics/functional/f_beta.py
index d3ce7def0..c531eaefe 100644
--- a/cyclops/evaluate/metrics/functional/f_beta.py
+++ b/cyclops/evaluate/metrics/functional/f_beta.py
@@ -468,9 +468,9 @@ def fbeta_score(
zero_division=zero_division,
)
if task == "multiclass":
- assert (
- isinstance(num_classes, int) and num_classes > 0
- ), "Number of classes must be specified for multiclass classification."
+ assert isinstance(num_classes, int) and num_classes > 0, (
+ "Number of classes must be specified for multiclass classification."
+ )
return multiclass_fbeta_score(
target,
preds,
@@ -481,9 +481,9 @@ def fbeta_score(
zero_division=zero_division,
)
if task == "multilabel":
- assert (
- isinstance(num_labels, int) and num_labels > 0
- ), "Number of labels must be specified for multilabel classification."
+ assert isinstance(num_labels, int) and num_labels > 0, (
+ "Number of labels must be specified for multilabel classification."
+ )
return multilabel_fbeta_score(
target,
preds,
diff --git a/cyclops/evaluate/metrics/functional/precision_recall.py b/cyclops/evaluate/metrics/functional/precision_recall.py
index 318e416c8..a355a12b9 100644
--- a/cyclops/evaluate/metrics/functional/precision_recall.py
+++ b/cyclops/evaluate/metrics/functional/precision_recall.py
@@ -427,9 +427,9 @@ def precision(
zero_division=zero_division,
)
if task == "multiclass":
- assert (
- isinstance(num_classes, int) and num_classes > 0
- ), "Number of classes must be specified for multiclass classification."
+ assert isinstance(num_classes, int) and num_classes > 0, (
+ "Number of classes must be specified for multiclass classification."
+ )
return multiclass_precision(
target,
preds,
@@ -439,9 +439,9 @@ def precision(
zero_division=zero_division,
)
if task == "multilabel":
- assert (
- isinstance(num_labels, int) and num_labels > 0
- ), "Number of labels must be specified for multilabel classification."
+ assert isinstance(num_labels, int) and num_labels > 0, (
+ "Number of labels must be specified for multilabel classification."
+ )
return multilabel_precision(
target,
preds,
@@ -786,9 +786,9 @@ def recall(
zero_division=zero_division,
)
if task == "multiclass":
- assert (
- isinstance(num_classes, int) and num_classes > 0
- ), "Number of classes must be specified for multiclass classification."
+ assert isinstance(num_classes, int) and num_classes > 0, (
+ "Number of classes must be specified for multiclass classification."
+ )
return multiclass_recall(
target,
preds,
@@ -798,9 +798,9 @@ def recall(
zero_division=zero_division,
)
if task == "multilabel":
- assert (
- isinstance(num_labels, int) and num_labels > 0
- ), "Number of labels must be specified for multilabel classification."
+ assert isinstance(num_labels, int) and num_labels > 0, (
+ "Number of labels must be specified for multilabel classification."
+ )
return multilabel_recall(
target,
preds,
diff --git a/cyclops/evaluate/metrics/functional/precision_recall_curve.py b/cyclops/evaluate/metrics/functional/precision_recall_curve.py
index a0f9b69e3..c982a5423 100644
--- a/cyclops/evaluate/metrics/functional/precision_recall_curve.py
+++ b/cyclops/evaluate/metrics/functional/precision_recall_curve.py
@@ -1056,9 +1056,9 @@ def precision_recall_curve(
pos_label=pos_label,
)
if task == "multiclass":
- assert (
- isinstance(num_classes, int) and num_classes > 0
- ), "Number of classes must be a positive integer."
+ assert isinstance(num_classes, int) and num_classes > 0, (
+ "Number of classes must be a positive integer."
+ )
return multiclass_precision_recall_curve(
target,
@@ -1067,9 +1067,9 @@ def precision_recall_curve(
thresholds=thresholds,
)
if task == "multilabel":
- assert (
- isinstance(num_labels, int) and num_labels > 0
- ), "Number of labels must be a positive integer."
+ assert isinstance(num_labels, int) and num_labels > 0, (
+ "Number of labels must be a positive integer."
+ )
return multilabel_precision_recall_curve(
target,
diff --git a/cyclops/evaluate/metrics/functional/specificity.py b/cyclops/evaluate/metrics/functional/specificity.py
index b1f1822c7..4f77ba350 100644
--- a/cyclops/evaluate/metrics/functional/specificity.py
+++ b/cyclops/evaluate/metrics/functional/specificity.py
@@ -434,9 +434,9 @@ def specificity(
zero_division=zero_division,
)
if task == "multiclass":
- assert (
- isinstance(num_classes, int) and num_classes > 0
- ), "Number of classes must be specified for multiclass classification."
+ assert isinstance(num_classes, int) and num_classes > 0, (
+ "Number of classes must be specified for multiclass classification."
+ )
return multiclass_specificity(
target,
preds,
@@ -446,9 +446,9 @@ def specificity(
zero_division=zero_division,
)
if task == "multilabel":
- assert (
- isinstance(num_labels, int) and num_labels > 0
- ), "Number of labels must be specified for multilabel classification."
+ assert isinstance(num_labels, int) and num_labels > 0, (
+ "Number of labels must be specified for multilabel classification."
+ )
return multilabel_specificity(
target,
preds,
diff --git a/cyclops/evaluate/metrics/functional/stat_scores.py b/cyclops/evaluate/metrics/functional/stat_scores.py
index 34a881689..fc7db1ffb 100644
--- a/cyclops/evaluate/metrics/functional/stat_scores.py
+++ b/cyclops/evaluate/metrics/functional/stat_scores.py
@@ -215,7 +215,7 @@ def _binary_stat_scores_format(
if check:
raise ValueError(
f"Detected the following values in `target`: {unique_values} but"
- f" expected only the following values {[0,1]}.",
+ f" expected only the following values {[0, 1]}.",
)
# If preds is label array, also check that it only contains [0,1] values
@@ -823,9 +823,9 @@ def stat_scores(
threshold=threshold,
)
elif task == "multiclass":
- assert (
- isinstance(num_classes, int) and num_classes > 0
- ), "Number of classes must be a positive integer."
+ assert isinstance(num_classes, int) and num_classes > 0, (
+ "Number of classes must be a positive integer."
+ )
scores = multiclass_stat_scores(
target,
preds,
@@ -834,9 +834,9 @@ def stat_scores(
top_k=top_k,
)
elif task == "multilabel":
- assert (
- isinstance(num_labels, int) and num_labels > 0
- ), "Number of labels must be a positive integer."
+ assert isinstance(num_labels, int) and num_labels > 0, (
+ "Number of labels must be a positive integer."
+ )
scores = multilabel_stat_scores(
target,
preds,
diff --git a/cyclops/evaluate/metrics/precision_recall.py b/cyclops/evaluate/metrics/precision_recall.py
index 6de0f33c4..2aabf4be9 100644
--- a/cyclops/evaluate/metrics/precision_recall.py
+++ b/cyclops/evaluate/metrics/precision_recall.py
@@ -345,9 +345,9 @@ def __new__( # type: ignore # mypy expects a subclass of Precision
zero_division=zero_division,
)
if task == "multiclass":
- assert (
- isinstance(num_classes, int) and num_classes > 0
- ), "Number of classes must be specified for multiclass classification."
+ assert isinstance(num_classes, int) and num_classes > 0, (
+ "Number of classes must be specified for multiclass classification."
+ )
return MulticlassPrecision(
num_classes=num_classes,
top_k=top_k,
@@ -355,9 +355,9 @@ def __new__( # type: ignore # mypy expects a subclass of Precision
zero_division=zero_division,
)
if task == "multilabel":
- assert (
- isinstance(num_labels, int) and num_labels > 0
- ), "Number of labels must be specified for multilabel classification."
+ assert isinstance(num_labels, int) and num_labels > 0, (
+ "Number of labels must be specified for multilabel classification."
+ )
return MultilabelPrecision(
num_labels=num_labels,
threshold=threshold,
@@ -694,9 +694,9 @@ def __new__( # type: ignore # mypy expects a subclass of Recall
zero_division=zero_division,
)
if task == "multiclass":
- assert (
- isinstance(num_classes, int) and num_classes > 0
- ), "Number of classes must be specified for multiclass classification."
+ assert isinstance(num_classes, int) and num_classes > 0, (
+ "Number of classes must be specified for multiclass classification."
+ )
return MulticlassRecall(
num_classes=num_classes,
top_k=top_k,
@@ -704,9 +704,9 @@ def __new__( # type: ignore # mypy expects a subclass of Recall
zero_division=zero_division,
)
if task == "multilabel":
- assert (
- isinstance(num_labels, int) and num_labels > 0
- ), "Number of labels must be specified for multilabel classification."
+ assert isinstance(num_labels, int) and num_labels > 0, (
+ "Number of labels must be specified for multilabel classification."
+ )
return MultilabelRecall(
num_labels=num_labels,
threshold=threshold,
diff --git a/cyclops/evaluate/metrics/precision_recall_curve.py b/cyclops/evaluate/metrics/precision_recall_curve.py
index 9a5ce76b5..e220c7d05 100644
--- a/cyclops/evaluate/metrics/precision_recall_curve.py
+++ b/cyclops/evaluate/metrics/precision_recall_curve.py
@@ -566,17 +566,17 @@ def __new__( # type: ignore # mypy expects a subclass of PrecisionRecallCurve
pos_label=pos_label,
)
if task == "multiclass":
- assert (
- isinstance(num_classes, int) and num_classes > 0
- ), "Number of classes must be a positive integer."
+ assert isinstance(num_classes, int) and num_classes > 0, (
+ "Number of classes must be a positive integer."
+ )
return MulticlassPrecisionRecallCurve(
num_classes=num_classes,
thresholds=thresholds,
)
if task == "multilabel":
- assert (
- isinstance(num_labels, int) and num_labels > 0
- ), "Number of labels must be a positive integer."
+ assert isinstance(num_labels, int) and num_labels > 0, (
+ "Number of labels must be a positive integer."
+ )
return MultilabelPrecisionRecallCurve(
num_labels=num_labels,
thresholds=thresholds,
diff --git a/cyclops/evaluate/metrics/sensitivity.py b/cyclops/evaluate/metrics/sensitivity.py
index 5ea9ab5df..29284c630 100644
--- a/cyclops/evaluate/metrics/sensitivity.py
+++ b/cyclops/evaluate/metrics/sensitivity.py
@@ -302,9 +302,9 @@ def __new__( # type: ignore # mypy expects a subclass of Sensitivity
zero_division=zero_division,
)
if task == "multiclass":
- assert (
- isinstance(num_classes, int) and num_classes > 0
- ), "Number of classes must be specified for multiclass classification."
+ assert isinstance(num_classes, int) and num_classes > 0, (
+ "Number of classes must be specified for multiclass classification."
+ )
return MulticlassSensitivity(
num_classes=num_classes,
top_k=top_k,
@@ -312,9 +312,9 @@ def __new__( # type: ignore # mypy expects a subclass of Sensitivity
zero_division=zero_division,
)
if task == "multilabel":
- assert (
- isinstance(num_labels, int) and num_labels > 0
- ), "Number of labels must be specified for multilabel classification."
+ assert isinstance(num_labels, int) and num_labels > 0, (
+ "Number of labels must be specified for multilabel classification."
+ )
return MultilabelSensitivity(
num_labels=num_labels,
threshold=threshold,
diff --git a/cyclops/evaluate/metrics/specificity.py b/cyclops/evaluate/metrics/specificity.py
index e8efabca5..efe4574b0 100644
--- a/cyclops/evaluate/metrics/specificity.py
+++ b/cyclops/evaluate/metrics/specificity.py
@@ -375,9 +375,9 @@ def __new__( # type: ignore # mypy expects a subclass of Specificity
zero_division=zero_division,
)
if task == "multiclass":
- assert (
- isinstance(num_classes, int) and num_classes > 0
- ), "Number of classes must be specified for multiclass classification."
+ assert isinstance(num_classes, int) and num_classes > 0, (
+ "Number of classes must be specified for multiclass classification."
+ )
return MulticlassSpecificity(
num_classes=num_classes,
top_k=top_k,
@@ -385,9 +385,9 @@ def __new__( # type: ignore # mypy expects a subclass of Specificity
zero_division=zero_division,
)
if task == "multilabel":
- assert (
- isinstance(num_labels, int) and num_labels > 0
- ), "Number of labels must be specified for multilabel classification."
+ assert isinstance(num_labels, int) and num_labels > 0, (
+ "Number of labels must be specified for multilabel classification."
+ )
return MultilabelSpecificity(
num_labels=num_labels,
threshold=threshold,
diff --git a/cyclops/evaluate/metrics/stat_scores.py b/cyclops/evaluate/metrics/stat_scores.py
index 1b3fee8e9..13ac23fce 100644
--- a/cyclops/evaluate/metrics/stat_scores.py
+++ b/cyclops/evaluate/metrics/stat_scores.py
@@ -487,18 +487,18 @@ def __new__( # type: ignore # mypy expects a subclass of StatScores
if task == "binary":
return BinaryStatScores(threshold=threshold, pos_label=pos_label)
if task == "multiclass":
- assert (
- isinstance(num_classes, int) and num_classes > 0
- ), "Number of classes must be a positive integer."
+ assert isinstance(num_classes, int) and num_classes > 0, (
+ "Number of classes must be a positive integer."
+ )
return MulticlassStatScores(
num_classes=num_classes,
top_k=top_k,
classwise=classwise,
)
if task == "multilabel":
- assert (
- isinstance(num_labels, int) and num_labels > 0
- ), "Number of labels must be a positive integer."
+ assert isinstance(num_labels, int) and num_labels > 0, (
+ "Number of labels must be a positive integer."
+ )
return MultilabelStatScores(
num_labels=num_labels,
threshold=threshold,
diff --git a/cyclops/evaluate/utils.py b/cyclops/evaluate/utils.py
index 835de6263..1bbe99d1e 100644
--- a/cyclops/evaluate/utils.py
+++ b/cyclops/evaluate/utils.py
@@ -159,9 +159,11 @@ def get_columns_as_array(
if isinstance(columns, str):
columns = [columns]
- with dataset.formatted_as("arrow", columns=columns, output_all_columns=True) if (
- isinstance(dataset, Dataset) and dataset.format != "arrow"
- ) else nullcontext():
+ with (
+ dataset.formatted_as("arrow", columns=columns, output_all_columns=True)
+ if (isinstance(dataset, Dataset) and dataset.format != "arrow")
+ else nullcontext()
+ ):
out_arr = squeeze_all(
xp.stack(
[xp.asarray(dataset[col].to_pylist()) for col in columns], axis=-1
diff --git a/cyclops/models/catalog.py b/cyclops/models/catalog.py
index a2571f919..082ea9cfa 100644
--- a/cyclops/models/catalog.py
+++ b/cyclops/models/catalog.py
@@ -39,13 +39,11 @@
LOGGER = logging.getLogger(__name__)
setup_logging(print_level="WARN", logger=LOGGER)
_xgboost_unavailable_message = (
- "The XGBoost library is required to use the `XGBClassifier` model. "
- "Please install it as an extra using `python3 -m pip install 'pycyclops[xgboost]'`\
+ "The XGBoost library is required to use the `XGBClassifier` model. Please install it as an extra using `python3 -m pip install 'pycyclops[xgboost]'`\
or using `python3 -m pip install xgboost`."
)
_torchxrayvision_unavailable_message = (
- "The torchxrayvision library is required to use the `densenet` or `resnet` model. "
- "Please install it as an extra using `python3 -m pip install 'pycyclops[torchxrayvision]'`\
+ "The torchxrayvision library is required to use the `densenet` or `resnet` model. Please install it as an extra using `python3 -m pip install 'pycyclops[torchxrayvision]'`\
or using `python3 -m pip install torchxrayvision`."
)
_torch_unavailable_message = (
diff --git a/cyclops/models/wrappers/pt_model.py b/cyclops/models/wrappers/pt_model.py
index 50b3a4302..37499dec2 100644
--- a/cyclops/models/wrappers/pt_model.py
+++ b/cyclops/models/wrappers/pt_model.py
@@ -968,14 +968,17 @@ def fit(
splits_mapping["validation"] = val_split
format_kwargs = {} if transforms is None else {"transform": transforms}
- with X[train_split].formatted_as(
- "custom" if transforms is not None else "torch",
- columns=feature_columns + target_columns,
- **format_kwargs,
- ), X[val_split].formatted_as(
- "custom" if transforms is not None else "torch",
- columns=feature_columns + target_columns,
- **format_kwargs,
+ with (
+ X[train_split].formatted_as(
+ "custom" if transforms is not None else "torch",
+ columns=feature_columns + target_columns,
+ **format_kwargs,
+ ),
+ X[val_split].formatted_as(
+ "custom" if transforms is not None else "torch",
+ columns=feature_columns + target_columns,
+ **format_kwargs,
+ ),
):
self.partial_fit(
X,
@@ -1309,7 +1312,7 @@ def save_model(self, filepath: str, overwrite: bool = True, **kwargs):
if include_lr_scheduler:
state_dict["lr_scheduler"] = self.lr_scheduler_.state_dict() # type: ignore[attr-defined]
- epoch = kwargs.get("epoch", None)
+ epoch = kwargs.get("epoch")
if epoch is not None:
filename, extension = os.path.basename(filepath).split(".")
filepath = join(
diff --git a/cyclops/models/wrappers/sk_model.py b/cyclops/models/wrappers/sk_model.py
index b0d534f1e..07d92304e 100644
--- a/cyclops/models/wrappers/sk_model.py
+++ b/cyclops/models/wrappers/sk_model.py
@@ -400,7 +400,7 @@ def partial_fit(
splits_mapping = {"train": "train"}
if not hasattr(self.model_, "partial_fit"):
raise AttributeError(
- f"Model {self.model_name}" "does not have a `partial_fit` method.",
+ f"Model {self.model_name}does not have a `partial_fit` method.",
)
# Train data is a Hugging Face Dataset Dictionary.
if isinstance(X, DatasetDict):
@@ -687,7 +687,7 @@ def predict_proba(
splits_mapping = {"test": "test"}
if not hasattr(self.model_, "predict_proba"):
raise AttributeError(
- f"Model {self.model_name}" "does not have a `predict_proba` method.",
+ f"Model {self.model_name}does not have a `predict_proba` method.",
)
# Data is a Hugging Face Dataset Dictionary.
if isinstance(X, DatasetDict):
diff --git a/cyclops/monitor/tester.py b/cyclops/monitor/tester.py
index 9a37d9d0f..ac4b1d7c5 100644
--- a/cyclops/monitor/tester.py
+++ b/cyclops/monitor/tester.py
@@ -1057,9 +1057,9 @@ def get_record(self, record_type):
def counts(self, record_type, max_ensemble_size=None) -> np.ndarray:
"""Get counts."""
- assert (
- max_ensemble_size is None or max_ensemble_size > 0
- ), "max_ensemble_size must be positive or None"
+ assert max_ensemble_size is None or max_ensemble_size > 0, (
+ "max_ensemble_size must be positive or None"
+ )
rec = self.get_record(record_type)
counts = []
for i in rec.seed.unique():
diff --git a/cyclops/monitor/utils.py b/cyclops/monitor/utils.py
index 7584462ca..d990c5096 100644
--- a/cyclops/monitor/utils.py
+++ b/cyclops/monitor/utils.py
@@ -258,12 +258,12 @@ def get_args(obj: Any, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""
args = {}
- for key in kwargs:
+ for key, value in kwargs.items():
if (inspect.isclass(obj) and key in inspect.signature(obj).parameters) or (
(inspect.ismethod(obj) or inspect.isfunction(obj))
and key in inspect.getfullargspec(obj).args
):
- args[key] = kwargs[key]
+ args[key] = value
return args
diff --git a/cyclops/report/plot/classification.py b/cyclops/report/plot/classification.py
index 60da0b884..7a5e062e1 100644
--- a/cyclops/report/plot/classification.py
+++ b/cyclops/report/plot/classification.py
@@ -86,13 +86,13 @@ def _set_class_names(self, class_names: List[str]) -> None:
"""
if class_names is not None:
- assert (
- len(class_names) == self.class_num
- ), "class_names must be equal to class_num"
+ assert len(class_names) == self.class_num, (
+ "class_names must be equal to class_num"
+ )
elif self.task_type == "multilabel":
- class_names = [f"Label_{i+1}" for i in range(self.class_num)]
+ class_names = [f"Label_{i + 1}" for i in range(self.class_num)]
else:
- class_names = [f"Class_{i+1}" for i in range(self.class_num)]
+ class_names = [f"Class_{i + 1}" for i in range(self.class_num)]
self.class_names = class_names
def calibration(
@@ -257,7 +257,9 @@ def threshperf(
== len(roc_curve.thresholds)
== len(ppv)
== len(npv)
- ), "Length mismatch between ROC curve, PPV, NPV. All curves need to be computed using the same thresholds"
+ ), (
+ "Length mismatch between ROC curve, PPV, NPV. All curves need to be computed using the same thresholds"
+ )
# Define hover template to show three decimal places
hover_template = "Threshold: %{x:.3f}
Metric Value: %{y:.3f}"
# Create a subplot for each metric
@@ -397,16 +399,18 @@ def roc_curve(
),
)
else:
- assert (
- len(fprs) == len(tprs) == self.class_num
- ), "fprs and tprs must be of length class_num for \
+ assert len(fprs) == len(tprs) == self.class_num, (
+ "fprs and tprs must be of length class_num for \
multiclass/multilabel tasks"
+ )
for i in range(self.class_num):
if auroc is not None:
assert (
len(auroc) == self.class_num # type: ignore[arg-type]
- ), "AUROCs must be of length class_num for \
+ ), (
+ "AUROCs must be of length class_num for \
multiclass/multilabel tasks"
+ )
name = f"{self.class_names[i]} (AUC = {auroc[i]:.2f})" # type: ignore[index] # noqa: E501
else:
name = self.class_names[i]
@@ -496,16 +500,18 @@ def roc_curve_comparison(
)
else:
for slice_name, slice_curve in roc_curves.items():
- assert (
- len(slice_curve[0]) == len(slice_curve[1]) == self.class_num
- ), f"FPRs and TPRs must be of length class_num for \
+ assert len(slice_curve[0]) == len(slice_curve[1]) == self.class_num, (
+ f"FPRs and TPRs must be of length class_num for \
multiclass/multilabel tasks in slice {slice_name}"
+ )
for i in range(self.class_num):
if aurocs and slice_name in aurocs:
assert (
len(aurocs[slice_name]) == self.class_num # type: ignore[arg-type] # noqa: E501
- ), "AUROCs must be of length class_num for \
+ ), (
+ "AUROCs must be of length class_num for \
multiclass/multilabel tasks"
+ )
name = f"{slice_name}, {self.class_names[i]} \
(AUC = {aurocs[i]:.2f})" # type: ignore[index]
else:
@@ -583,10 +589,10 @@ def precision_recall_curve(
)
else:
trace = []
- assert (
- len(recalls) == len(precisions) == self.class_num
- ), "Recalls and precisions must be of length class_num for \
+ assert len(recalls) == len(precisions) == self.class_num, (
+ "Recalls and precisions must be of length class_num for \
multiclass/multilabel tasks"
+ )
for i in range(self.class_num):
trace.append(
line_plot(
@@ -668,14 +674,18 @@ def precision_recall_curve_comparison(
len(slice_curve.precision)
== len(slice_curve.recall)
== self.class_num
- ), f"Recalls and precisions must be of length class_num for \
+ ), (
+ f"Recalls and precisions must be of length class_num for \
multiclass/multilabel tasks in slice {slice_name}"
+ )
for i in range(self.class_num):
if auprcs and slice_name in auprcs:
assert (
len(auprcs[slice_name]) == self.class_num # type: ignore[arg-type] # noqa: E501
- ), "AUPRCs must be of length class_num for \
+ ), (
+ "AUPRCs must be of length class_num for \
multiclass/multilabel tasks"
+ )
name = f"{slice_name}, {self.class_names[i]} \
(AUC = {auprcs[i]:.2f})"
else:
@@ -747,8 +757,10 @@ def metrics_value(
assert all(
len(value) == self.class_num # type: ignore[arg-type]
for value in metrics.values()
- ), "Every metric must be of length class_num for \
+ ), (
+ "Every metric must be of length class_num for \
multiclass/multilabel tasks"
+ )
for i in range(self.class_num):
trace.append(
bar_plot(
@@ -981,10 +993,10 @@ def metrics_comparison_radar(
if isinstance(metric_values, list) or (
isinstance(metric_values, np.ndarray) and metric_values.ndim > 0
):
- assert (
- len(metric_values) == self.class_num
- ), "Metric values must be of length class_num for \
+ assert len(metric_values) == self.class_num, (
+ "Metric values must be of length class_num for \
multiclass/multilabel tasks"
+ )
radial_data.extend(metric_values)
theta = [
f"{metric_name}: {self.class_names[i]}"
diff --git a/cyclops/report/report.py b/cyclops/report/report.py
index 786dae53c..19f916581 100644
--- a/cyclops/report/report.py
+++ b/cyclops/report/report.py
@@ -668,9 +668,9 @@ def log_dataset(
"""
# sensitive features must be in features
if features is not None and sensitive_features is not None:
- assert all(
- feature in features for feature in sensitive_features
- ), "All sensitive features must be in the features list."
+ assert all(feature in features for feature in sensitive_features), (
+ "All sensitive features must be in the features list."
+ )
# TODO: plot dataset distribution
data = {
diff --git a/cyclops/report/utils.py b/cyclops/report/utils.py
index 94c8a556b..a4f9a3a04 100644
--- a/cyclops/report/utils.py
+++ b/cyclops/report/utils.py
@@ -304,9 +304,9 @@ def get_metrics_trends(
slice_names=slice_names,
metric_names=metric_names,
)
- assert (
- len(performance_history) > 0
- ), "No performance history found. Check slice and metric names."
+ assert len(performance_history) > 0, (
+ "No performance history found. Check slice and metric names."
+ )
performance_recent = []
for metric_name, metric_value in flat_results.items():
name_split = metric_name.split("/")
@@ -323,9 +323,9 @@ def get_metrics_trends(
slice_names=slice_names,
metric_names=metric_names,
)
- assert (
- len(performance_recent) > 0
- ), "No performance metrics found. Check slice and metric names."
+ assert len(performance_recent) > 0, (
+ "No performance metrics found. Check slice and metric names."
+ )
today = dt_date.today().strftime("%Y-%m-%d")
now = dt_datetime.now().strftime("%H-%M-%S")
if keep_timestamps:
diff --git a/cyclops/tasks/utils.py b/cyclops/tasks/utils.py
index cb442e0bb..734227b97 100644
--- a/cyclops/tasks/utils.py
+++ b/cyclops/tasks/utils.py
@@ -108,8 +108,10 @@ def prepare_models(
models_dict = {model_name: models}
# models contains one model name
elif isinstance(models, str):
- assert models in list_models(), f"Model name is not registered! \
+ assert models in list_models(), (
+ f"Model name is not registered! \
Available models are: {list_models()}"
+ )
models_dict = {models: create_model(models)}
# models contains a list or tuple of model names or wrapped models
elif isinstance(models, (list, tuple)):
@@ -118,8 +120,10 @@ def prepare_models(
model_name = _model_names_mapping.get(model.model.__name__)
models_dict[model_name] = model
elif isinstance(model, str):
- assert model in list_models(), f"Model name is not registered! \
+ assert model in list_models(), (
+ f"Model name is not registered! \
Available models are: {list_models()}"
+ )
models_dict[model] = create_model(model)
else:
raise TypeError(
diff --git a/cyclops/utils/optional.py b/cyclops/utils/optional.py
index e7795f344..07d06c240 100644
--- a/cyclops/utils/optional.py
+++ b/cyclops/utils/optional.py
@@ -64,8 +64,7 @@ def import_optional_module(
return module
except ModuleNotFoundError as exc:
msg = (
- f"Missing optional dependency '{name}'. "
- f"Use pip or conda to install {name}."
+ f"Missing optional dependency '{name}'. Use pip or conda to install {name}."
)
if error == "raise":
raise type(exc)(msg) from None
diff --git a/cyclops/utils/profile.py b/cyclops/utils/profile.py
index fd77ae7c2..1a5a9115e 100644
--- a/cyclops/utils/profile.py
+++ b/cyclops/utils/profile.py
@@ -1,3 +1,5 @@
+# noqa: A005
+
"""Useful functions for timing, profiling."""
import logging
diff --git a/docs/source/examples/metrics.ipynb b/docs/source/examples/metrics.ipynb
index cf2b10ba4..5e867b1f8 100644
--- a/docs/source/examples/metrics.ipynb
+++ b/docs/source/examples/metrics.ipynb
@@ -15,6 +15,8 @@
"metadata": {},
"outputs": [],
"source": [
+ "\"\"\"Imports.\"\"\"\n",
+ "\n",
"import numpy as np\n",
"import pandas as pd\n",
"from datasets.arrow_dataset import Dataset\n",
@@ -57,7 +59,7 @@
"outputs": [],
"source": [
"df = breast_cancer_data.frame\n",
- "df.describe().T"
+ "print(df.describe().T)"
]
},
{
@@ -222,7 +224,7 @@
"outputs": [],
"source": [
"slice_spec = SliceSpec(spec_list, intersections=2)\n",
- "slice_spec"
+ "print(slice_spec)"
]
},
{
@@ -316,15 +318,8 @@
" target_columns=\"target\",\n",
" prediction_columns=\"preds_prob\",\n",
")\n",
- "fairness_result"
+ "print(fairness_result)"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/docs/source/examples/report.ipynb b/docs/source/examples/report.ipynb
index 6312ba6f5..457f7572d 100644
--- a/docs/source/examples/report.ipynb
+++ b/docs/source/examples/report.ipynb
@@ -1,11 +1,21 @@
{
"cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Report Generation for Heart Failure Prediction\n",
+ "Here's an example to demonstrate how we can generate a report as we proceed through all the steps to train and evaluate a model. For this purpose, we are going to use Kaggle's heart prediction failure dataset and gradually populate the report with information about dataset, model and results."
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
+ "\"\"\"Imports.\"\"\"\n",
+ "\n",
"import copy\n",
"import inspect\n",
"import os\n",
@@ -23,6 +33,7 @@
"from tqdm import tqdm\n",
"\n",
"from cyclops.data.slicer import SliceSpec\n",
+ "from cyclops.evaluate import evaluator\n",
"from cyclops.evaluate.metrics import create_metric\n",
"from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict\n",
"from cyclops.report import ModelCardReport\n",
@@ -30,14 +41,6 @@
"from cyclops.report.utils import flatten_results_dict"
]
},
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Report Generation for Heart Failure Prediction\n",
- "Here's an example to demonstrate how we can generate a report as we proceed through all the steps to train and evaluate a model. For this purpose, we are going to use Kaggle's heart prediction failure dataset and gradually populate the report with information about dataset, model and results."
- ]
- },
{
"cell_type": "markdown",
"metadata": {},
@@ -105,7 +108,7 @@
"metadata": {},
"outputs": [],
"source": [
- "df.describe().T"
+ "print(df.describe().T)"
]
},
{
@@ -531,9 +534,6 @@
"metadata": {},
"outputs": [],
"source": [
- "from cyclops.evaluate import evaluator\n",
- "\n",
- "\n",
"# Create Dataset object\n",
"heart_failure_data = Dataset.from_pandas(df_test)\n",
"\n",
@@ -557,24 +557,6 @@
")"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "result"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "results_flat"
- ]
- },
{
"cell_type": "markdown",
"metadata": {},
@@ -878,7 +860,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "cyclops",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -892,9 +874,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.11"
+ "version": "3.10.12"
}
},
"nbformat": 4,
- "nbformat_minor": 2
+ "nbformat_minor": 4
}
diff --git a/docs/source/tutorials/diabetes_130/readmission_prediction.ipynb b/docs/source/tutorials/diabetes_130/readmission_prediction.ipynb
index bb7bdb3dd..61766e525 100644
--- a/docs/source/tutorials/diabetes_130/readmission_prediction.ipynb
+++ b/docs/source/tutorials/diabetes_130/readmission_prediction.ipynb
@@ -152,17 +152,6 @@
"variables = diabetes_130_data[\"variables\"]"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "metadata"
- ]
- },
{
"cell_type": "code",
"execution_count": null,
diff --git a/docs/source/tutorials/diabetes_130/readmission_prediction_detectron.ipynb b/docs/source/tutorials/diabetes_130/readmission_prediction_detectron.ipynb
index f34093832..e62cd2649 100644
--- a/docs/source/tutorials/diabetes_130/readmission_prediction_detectron.ipynb
+++ b/docs/source/tutorials/diabetes_130/readmission_prediction_detectron.ipynb
@@ -115,17 +115,6 @@
"variables = diabetes_130_data[\"variables\"]"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "metadata"
- ]
- },
{
"cell_type": "code",
"execution_count": null,
@@ -749,7 +738,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.9.7"
+ "version": "3.10.12"
}
},
"nbformat": 4,
diff --git a/tests/cyclops/evaluate/metrics/experimental/test_metric.py b/tests/cyclops/evaluate/metrics/experimental/test_metric.py
index ceea71694..a925112d2 100644
--- a/tests/cyclops/evaluate/metrics/experimental/test_metric.py
+++ b/tests/cyclops/evaluate/metrics/experimental/test_metric.py
@@ -46,9 +46,9 @@ def test_add_state_factory():
metric.add_state_default_factory("a", lambda xp: xp.asarray(0), None) # type: ignore
reduce_fn = metric._reductions["a"]
assert reduce_fn is None, "Saved reduction function is not None."
- assert (
- metric._default_factories.get("a") is not None
- ), "Default factory was not correctly created."
+ assert metric._default_factories.get("a") is not None, (
+ "Default factory was not correctly created."
+ )
# default_factory is 'list'
metric.add_state_default_factory("b", list) # type: ignore
@@ -61,8 +61,7 @@ def test_add_state_factory():
reduce_fn = metric._reductions["c"]
assert callable(reduce_fn), "Saved reduction function is not callable."
assert reduce_fn is dim_zero_sum, (
- "Saved reduction function is not the same as the one used to "
- "create the state."
+ "Saved reduction function is not the same as the one used to create the state."
)
assert reduce_fn(anp.asarray([1, 1])) == anp.asarray(
2,
@@ -73,8 +72,7 @@ def test_add_state_factory():
reduce_fn = metric._reductions["d"]
assert callable(reduce_fn), "Saved reduction function is not callable."
assert reduce_fn is dim_zero_mean, (
- "Saved reduction function is not the same as the one used to "
- "create the state."
+ "Saved reduction function is not the same as the one used to create the state."
)
assert np.allclose(
reduce_fn(anp.asarray([1.0, 2.0])),
@@ -86,8 +84,7 @@ def test_add_state_factory():
reduce_fn = metric._reductions["e"]
assert callable(reduce_fn), "Saved reduction function is not callable."
assert reduce_fn is dim_zero_cat, (
- "Saved reduction function is not the same as the one used to "
- "create the state."
+ "Saved reduction function is not the same as the one used to create the state."
)
np.testing.assert_array_equal(
reduce_fn([anp.asarray([1]), anp.asarray([1])]),
@@ -100,8 +97,7 @@ def test_add_state_factory():
reduce_fn = metric._reductions["f"]
assert callable(reduce_fn), "Saved reduction function is not callable."
assert reduce_fn is dim_zero_max, (
- "Saved reduction function is not the same as the one used to "
- "create the state."
+ "Saved reduction function is not the same as the one used to create the state."
)
np.testing.assert_array_equal(
reduce_fn(anp.asarray([1, 2])),
@@ -115,8 +111,7 @@ def test_add_state_factory():
reduce_fn = metric._reductions["g"]
assert callable(reduce_fn), "Saved reduction function is not callable."
assert reduce_fn is dim_zero_min, (
- "Saved reduction function is not the same as the one used to "
- "create the state."
+ "Saved reduction function is not the same as the one used to create the state."
)
np.testing.assert_array_equal(
reduce_fn(anp.asarray([1, 2])),
diff --git a/tests/cyclops/evaluate/metrics/experimental/test_metric_dict.py b/tests/cyclops/evaluate/metrics/experimental/test_metric_dict.py
index 401726060..508ce2d6a 100644
--- a/tests/cyclops/evaluate/metrics/experimental/test_metric_dict.py
+++ b/tests/cyclops/evaluate/metrics/experimental/test_metric_dict.py
@@ -94,16 +94,16 @@ def test_metric_dict_adfix(prefix, postfix):
# test __call__
output = metrics(anp.asarray(1, dtype=anp.float32))
for name in names:
- assert (
- name in output
- ), f"`MetricDict` output does not contain metric {name} when called."
+ assert name in output, (
+ f"`MetricDict` output does not contain metric {name} when called."
+ )
# test `compute`
output = metrics.compute()
for name in names:
- assert (
- name in output
- ), f"`MetricDict` output does not contain metric {name} using the `compute` method."
+ assert name in output, (
+ f"`MetricDict` output does not contain metric {name} using the `compute` method."
+ )
# test `clone`
new_metrics = metrics.clone(prefix="new_")
@@ -112,9 +112,9 @@ def test_metric_dict_adfix(prefix, postfix):
n[len(prefix) :] if prefix is not None else n for n in names
]
for name in names:
- assert (
- f"new_{name}" in output
- ), f"`MetricDict` output does not contain metric new_{name} when cloned."
+ assert f"new_{name}" in output, (
+ f"`MetricDict` output does not contain metric new_{name} when cloned."
+ )
for k in new_metrics:
assert "new_" in k
diff --git a/tests/cyclops/evaluate/metrics/experimental/testers.py b/tests/cyclops/evaluate/metrics/experimental/testers.py
index 4d58305d3..a0d887d93 100644
--- a/tests/cyclops/evaluate/metrics/experimental/testers.py
+++ b/tests/cyclops/evaluate/metrics/experimental/testers.py
@@ -85,9 +85,9 @@ def _class_impl_test( # noqa: PLR0912
t_size = target.shape[0]
p_size = preds.shape[0]
- assert (
- p_size == t_size
- ), f"`preds` and `target` have different number of samples: {p_size} and {t_size}."
+ assert p_size == t_size, (
+ f"`preds` and `target` have different number of samples: {p_size} and {t_size}."
+ )
num_batches = p_size
# instantiate metric
@@ -176,9 +176,9 @@ def _function_impl_test(
t_size = target.shape[0]
p_size = preds.shape[0]
- assert (
- p_size == t_size
- ), f"`preds` and `target` have different number of samples: {p_size} and {t_size}."
+ assert p_size == t_size, (
+ f"`preds` and `target` have different number of samples: {p_size} and {t_size}."
+ )
metric_args = metric_args or {}
metric = partial(metric_function, **metric_args)
diff --git a/tests/cyclops/evaluate/metrics/experimental/utils/test_ops.py b/tests/cyclops/evaluate/metrics/experimental/utils/test_ops.py
index 65e61c18b..032b53733 100644
--- a/tests/cyclops/evaluate/metrics/experimental/utils/test_ops.py
+++ b/tests/cyclops/evaluate/metrics/experimental/utils/test_ops.py
@@ -189,19 +189,17 @@ def test_apply_to_nested_collections(self):
},
}
- for k in expected_result:
+ for k, v in expected_result.items():
assert k in result
- if isinstance(expected_result[k], dict):
- for kk in expected_result[k]:
+ if isinstance(v, dict):
+ for kk in v:
assert kk in result[k]
- assert anp.all(expected_result[k][kk] == result[k][kk])
- elif isinstance(expected_result[k], (tuple, list)):
- assert all(
- anp.all(a == b) for a, b in zip(result[k], expected_result[k])
- )
+ assert anp.all(v[kk] == result[k][kk])
+ elif isinstance(v, (tuple, list)):
+ assert all(anp.all(a == b) for a, b in zip(result[k], v))
else:
- assert anp.all(expected_result[k] == result[k])
+ assert anp.all(v == result[k])
class TestBincount:
diff --git a/tests/cyclops/evaluate/metrics/helpers.py b/tests/cyclops/evaluate/metrics/helpers.py
index 724ab89a5..bc030e43b 100644
--- a/tests/cyclops/evaluate/metrics/helpers.py
+++ b/tests/cyclops/evaluate/metrics/helpers.py
@@ -146,5 +146,5 @@ def _assert_allclose(data_a: Any, data_b: Any, atol: float = 1e-8):
_assert_allclose(data_a[key], data_b[key], atol=atol)
else:
raise ValueError(
- f"Unknown format for comparison: {type(data_a)} and" f" {type(data_b)}",
+ f"Unknown format for comparison: {type(data_a)} and {type(data_b)}",
)