Skip to content

Commit 8ee0ed7

Browse files
committed
Remove CroppedMetric from cropped_metric.py and implement it in transformed_metrics.py; add ResizeMetric class for enhanced resizing functionality with aspect ratio support.
1 parent e347ec6 commit 8ee0ed7

File tree

2 files changed

+181
-74
lines changed

2 files changed

+181
-74
lines changed

minerva/analysis/metrics/cropped_metric.py

-74
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import warnings
2+
from typing import Optional
3+
4+
import torch
5+
from torchmetrics import Metric
6+
7+
8+
class CroppedMetric(Metric):
9+
def __init__(
10+
self,
11+
target_h_size: int,
12+
target_w_size: int,
13+
metric: Metric,
14+
dist_sync_on_step: bool = False,
15+
):
16+
"""
17+
Initializes a new instance of CroppedMetric.
18+
19+
Parameters
20+
----------
21+
target_h_size: int
22+
The target height size.
23+
target_w_size: int
24+
The target width size.
25+
dist_sync_on_step: bool, optional
26+
Whether to synchronize metric state across processes at each step.
27+
Defaults to False.
28+
"""
29+
super().__init__(dist_sync_on_step=dist_sync_on_step)
30+
self.metric = metric
31+
self.target_h_size = target_h_size
32+
self.target_w_size = target_w_size
33+
34+
def update(self, preds: torch.Tensor, target: torch.Tensor):
35+
"""
36+
Updates the metric state with the predictions and targets.
37+
38+
Parameters
39+
----------
40+
preds: torch.Tensor
41+
The predicted tensor.
42+
target:
43+
torch.Tensor The target tensor.
44+
"""
45+
46+
preds = self.crop(preds)
47+
target = self.crop(target)
48+
self.metric.update(preds, target)
49+
50+
def compute(self) -> float:
51+
"""
52+
Computes the cropped metric.
53+
54+
Returns:
55+
float: The cropped metric.
56+
"""
57+
return self.metric.compute()
58+
59+
def crop(self, x: torch.Tensor) -> torch.Tensor:
60+
"""crops the input tensor to the target size.
61+
62+
Parameters
63+
----------
64+
x : torch.Tensor
65+
The input tensor.
66+
67+
Returns
68+
-------
69+
torch.Tensor
70+
The cropped tensor.
71+
"""
72+
h, w = x.shape[-2:]
73+
start_h = (h - self.target_h_size) // 2
74+
start_w = (w - self.target_w_size) // 2
75+
end_h = start_h + self.target_h_size
76+
end_w = start_w + self.target_w_size
77+
78+
return x[..., start_h:end_h, start_w:end_w]
79+
80+
81+
class ResizeMetric(Metric):
82+
def __init__(
83+
self,
84+
target_h_size: Optional[int],
85+
target_w_size: Optional[int],
86+
metric: Metric,
87+
keep_aspect_ratio: bool = False,
88+
dist_sync_on_step: bool = False,
89+
):
90+
"""
91+
Initializes a new instance of ResizeMetric.
92+
93+
Parameters
94+
----------
95+
target_h_size: int
96+
The target height size.
97+
target_w_size: int
98+
The target width size.
99+
dist_sync_on_step: bool, optional
100+
Whether to synchronize metric state across processes at each step.
101+
Defaults to False.
102+
"""
103+
super().__init__(dist_sync_on_step=dist_sync_on_step)
104+
105+
if target_h_size is None and target_w_size is None:
106+
raise ValueError(
107+
"At least one of target_h_size or target_w_size must be provided."
108+
)
109+
110+
if (
111+
target_h_size is not None and target_w_size is None
112+
) and keep_aspect_ratio is False:
113+
warnings.warn(
114+
"A target_w_size is not provided, but keep_aspect_ratio is set to False. keep_aspect_ratio will be set to True. If you want to resize the image to a specific width, please provide a target_w_size."
115+
)
116+
keep_aspect_ratio = True
117+
118+
if (
119+
target_w_size is not None and target_h_size is None
120+
) and keep_aspect_ratio is False:
121+
warnings.warn(
122+
"A target_h_size is not provided, but keep_aspect_ratio is set to False. keep_aspect_ratio will be set to True. If you want to resize the image to a specific height, please provide a target_h_size."
123+
)
124+
keep_aspect_ratio = True
125+
126+
self.metric = metric
127+
self.target_h_size = target_h_size
128+
self.target_w_size = target_w_size
129+
self.keep_aspect_ratio = keep_aspect_ratio
130+
131+
def update(self, preds: torch.Tensor, target: torch.Tensor):
132+
"""
133+
Updates the metric state with the predictions and targets.
134+
135+
Parameters
136+
----------
137+
preds: torch.Tensor
138+
The predicted tensor.
139+
target:
140+
torch.Tensor The target tensor.
141+
"""
142+
143+
preds = self.resize(preds)
144+
target = self.resize(target)
145+
self.metric.update(preds, target)
146+
147+
def compute(self) -> float:
148+
"""
149+
Computes the resized metric.
150+
151+
Returns:
152+
float: The resized metric.
153+
"""
154+
return self.metric.compute()
155+
156+
def resize(self, x: torch.Tensor) -> torch.Tensor:
157+
"""Resizes the input tensor to the target size.
158+
159+
Parameters
160+
----------
161+
x : torch.Tensor
162+
The input tensor.
163+
164+
Returns
165+
-------
166+
torch.Tensor
167+
The resized tensor.
168+
"""
169+
h, w = x.shape[-2:]
170+
171+
target_h_size = self.target_h_size
172+
target_w_size = self.target_w_size
173+
if self.keep_aspect_ratio:
174+
if self.target_h_size is None:
175+
scale = target_w_size / w
176+
target_h_size = int(h * scale)
177+
elif self.target_w_size is None:
178+
scale = target_h_size / h
179+
target_w_size = int(w * scale)
180+
181+
return torch.nn.functional.interpolate(x, size=(target_h_size, target_w_size))

0 commit comments

Comments
 (0)