Skip to content

Commit 3ef60ff

Browse files
authored
Merge pull request #746 from scap3yvt/scap3yvt-patch-black
Consistent black across the codebase
2 parents 10821bd + ec115db commit 3ef60ff

File tree

4 files changed

+46
-45
lines changed

4 files changed

+46
-45
lines changed

GANDLF/data/patch_miner/opm/utils.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from skimage.morphology import remove_small_holes
1111
from skimage.color.colorconv import rgb2hsv
1212
import cv2
13-
#from skimage.exposure import rescale_intensity
14-
#from skimage.color import rgb2hed
13+
14+
# from skimage.exposure import rescale_intensity
15+
# from skimage.color import rgb2hed
1516

1617
# import matplotlib.pyplot as plt
1718
import yaml
@@ -238,7 +239,8 @@ def alpha_rgb_2d_channel_check(img):
238239
else:
239240
return False
240241

241-
#def pen_marking_check(img, intensity_thresh=225, intensity_thresh_saturation =50, intensity_thresh_b = 128):
242+
243+
# def pen_marking_check(img, intensity_thresh=225, intensity_thresh_saturation =50, intensity_thresh_b = 128):
242244
# """
243245
# This function is used to curate patches from the input image. It is used to remove patches that have pen markings.
244246
# Args:
@@ -259,7 +261,14 @@ def alpha_rgb_2d_channel_check(img):
259261
# #Assume patch is valid
260262
# return True
261263

262-
def patch_artifact_check(img, intensity_thresh = 250, intensity_thresh_saturation = 5, intensity_thresh_b = 128, patch_size = (256,256)):
264+
265+
def patch_artifact_check(
266+
img,
267+
intensity_thresh=250,
268+
intensity_thresh_saturation=5,
269+
intensity_thresh_b=128,
270+
patch_size=(256, 256),
271+
):
263272
"""
264273
This function is used to curate patches from the input image. It is used to remove patches that are mostly background.
265274
Args:
@@ -271,23 +280,36 @@ def patch_artifact_check(img, intensity_thresh = 250, intensity_thresh_saturatio
271280
Returns:
272281
bool: Whether the patch is valid (True) or not (False)
273282
"""
274-
#patch_size = config["patch_size"]
283+
# patch_size = config["patch_size"]
275284
patch_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
276285
count_white_pixels = np.sum(np.logical_and.reduce(img > intensity_thresh, axis=2))
277286
percent_pixels = count_white_pixels / (patch_size[0] * patch_size[1])
278287
count_black_pixels = np.sum(np.logical_and.reduce(img < intensity_thresh_b, axis=2))
279288
percent_pixel_b = count_black_pixels / (patch_size[0] * patch_size[1])
280-
percent_pixel_2 = np.sum(patch_hsv[...,1] < intensity_thresh_saturation) / (patch_size[0] * patch_size[1])
281-
percent_pixel_3 = np.sum(patch_hsv[...,2] > intensity_thresh) / (patch_size[0] * patch_size[1])
289+
percent_pixel_2 = np.sum(patch_hsv[..., 1] < intensity_thresh_saturation) / (
290+
patch_size[0] * patch_size[1]
291+
)
292+
percent_pixel_3 = np.sum(patch_hsv[..., 2] > intensity_thresh) / (
293+
patch_size[0] * patch_size[1]
294+
)
282295

283-
if percent_pixel_2 > 0.99 or np.mean(patch_hsv[...,1]) < 5 or percent_pixel_3 > 0.99:
296+
if (
297+
percent_pixel_2 > 0.99
298+
or np.mean(patch_hsv[..., 1]) < 5
299+
or percent_pixel_3 > 0.99
300+
):
284301
if percent_pixel_2 < 0.1:
285302
return False
286-
elif (percent_pixel_2 > 0.99 and percent_pixel_3 > 0.99) or percent_pixel_b > 0.99 or percent_pixels > 0.9:
303+
elif (
304+
(percent_pixel_2 > 0.99 and percent_pixel_3 > 0.99)
305+
or percent_pixel_b > 0.99
306+
or percent_pixels > 0.9
307+
):
287308
return False
288309
# assume that the patch is valid
289310
return True
290311

312+
291313
def parse_config(config_file):
292314
"""
293315
Parse config file and return a dictionary of config values.
@@ -304,7 +326,7 @@ def parse_config(config_file):
304326
config["value_map"] = config.get("value_map", None)
305327
config["read_type"] = config.get("read_type", "random")
306328
config["overlap_factor"] = config.get("overlap_factor", 0.0)
307-
config["patch_size"] = config.get("patch_size", [256,256])
329+
config["patch_size"] = config.get("patch_size", [256, 256])
308330

309331
return config
310332

GANDLF/metrics/classification.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ def overall_stats(predictions, ground_truth, params):
6464
"aucroc": tm.AUROC(
6565
task=task,
6666
num_classes=params["model"]["num_classes"],
67-
average=average_type_key
68-
if average_type_key != "micro"
69-
else "macro",
67+
average=average_type_key if average_type_key != "micro" else "macro",
7068
),
7169
}
7270
for metric_name, calculator in calculators.items():

GANDLF/metrics/generic.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
)
1616

1717

18-
def generic_function_output_with_check(
19-
predicted_classes, label, metric_function
20-
):
18+
def generic_function_output_with_check(predicted_classes, label, metric_function):
2119
if torch.min(predicted_classes) < 0:
2220
print(
2321
"WARNING: Negative values detected in prediction, cannot compute torchmetrics calculations."
@@ -32,16 +30,12 @@ def generic_function_output_with_check(
3230
max_clamp_val = metric_function.num_classes - 1
3331
except AttributeError:
3432
max_clamp_val = 1
35-
predicted_new = torch.clamp(
36-
predicted_classes.cpu().int(), max=max_clamp_val
37-
)
33+
predicted_new = torch.clamp(predicted_classes.cpu().int(), max=max_clamp_val)
3834
predicted_new = predicted_new.reshape(label.shape)
3935
return metric_function(predicted_new, label.cpu().int())
4036

4137

42-
def generic_torchmetrics_score(
43-
output, label, metric_class, metric_key, params
44-
):
38+
def generic_torchmetrics_score(output, label, metric_class, metric_key, params):
4539
task = determine_classification_task_type(params)
4640
num_classes = params["model"]["num_classes"]
4741
predicted_classes = output
@@ -67,25 +61,19 @@ def recall_score(output, label, params):
6761

6862

6963
def precision_score(output, label, params):
70-
return generic_torchmetrics_score(
71-
output, label, Precision, "precision", params
72-
)
64+
return generic_torchmetrics_score(output, label, Precision, "precision", params)
7365

7466

7567
def f1_score(output, label, params):
7668
return generic_torchmetrics_score(output, label, F1Score, "f1", params)
7769

7870

7971
def accuracy(output, label, params):
80-
return generic_torchmetrics_score(
81-
output, label, Accuracy, "accuracy", params
82-
)
72+
return generic_torchmetrics_score(output, label, Accuracy, "accuracy", params)
8373

8474

8575
def specificity_score(output, label, params):
86-
return generic_torchmetrics_score(
87-
output, label, Specificity, "specificity", params
88-
)
76+
return generic_torchmetrics_score(output, label, Specificity, "specificity", params)
8977

9078

9179
def iou_score(output, label, params):

GANDLF/utils/generic.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def checkPatchDivisibility(patch_size, number=16):
4949
return True
5050

5151

52-
def determine_classification_task_type(params: Dict[str, Union[Dict[str, Any], Any]]) -> str:
52+
def determine_classification_task_type(
53+
params: Dict[str, Union[Dict[str, Any], Any]]
54+
) -> str:
5355
"""Determine the task (binary or multiclass) from the model config.
5456
Args:
5557
params (dict): The parameter dictionary containing training and data information.
@@ -159,10 +161,7 @@ def checkPatchDimensions(patch_size, numlay):
159161
patch_size_to_check = patch_size_to_check[:-1]
160162

161163
if all(
162-
[
163-
x >= 2 ** (numlay + 1) and x % 2**numlay == 0
164-
for x in patch_size_to_check
165-
]
164+
[x >= 2 ** (numlay + 1) and x % 2**numlay == 0 for x in patch_size_to_check]
166165
):
167166
return numlay
168167
else:
@@ -198,9 +197,7 @@ def get_array_from_image_or_tensor(input_tensor_or_image):
198197
elif isinstance(input_tensor_or_image, np.ndarray):
199198
return input_tensor_or_image
200199
else:
201-
raise ValueError(
202-
"Input must be a torch.Tensor or sitk.Image or np.ndarray"
203-
)
200+
raise ValueError("Input must be a torch.Tensor or sitk.Image or np.ndarray")
204201

205202

206203
def set_determinism(seed=42):
@@ -270,9 +267,7 @@ def __update_metric_from_list_to_single_string(input_metrics_dict) -> dict:
270267
output_metrics_dict = deepcopy(cohort_level_metrics)
271268
for metric in metrics_dict_from_parameters:
272269
if isinstance(sample_level_metrics[metric], np.ndarray):
273-
to_print = (
274-
sample_level_metrics[metric] / length_of_dataloader
275-
).tolist()
270+
to_print = (sample_level_metrics[metric] / length_of_dataloader).tolist()
276271
else:
277272
to_print = sample_level_metrics[metric] / length_of_dataloader
278273
output_metrics_dict[metric] = to_print
@@ -315,7 +310,5 @@ def define_multidim_average_type_key(params, metric_name) -> str:
315310
Returns:
316311
str: The average type key.
317312
"""
318-
average_type_key = params["metrics"][metric_name].get(
319-
"multidim_average", "global"
320-
)
313+
average_type_key = params["metrics"][metric_name].get("multidim_average", "global")
321314
return average_type_key

0 commit comments

Comments
 (0)