Skip to content

Commit

Permalink
Remove CroppedMetric from cropped_metric.py and implement it in trans…
Browse files Browse the repository at this point in the history
…formed_metrics.py; add ResizeMetric class for enhanced resizing functionality with aspect ratio support.
  • Loading branch information
GabrielBG0 committed Nov 11, 2024
1 parent e347ec6 commit 8ee0ed7
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 74 deletions.
74 changes: 0 additions & 74 deletions minerva/analysis/metrics/cropped_metric.py

This file was deleted.

181 changes: 181 additions & 0 deletions minerva/analysis/metrics/transformed_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import warnings
from typing import Optional

import torch
from torchmetrics import Metric


class CroppedMetric(Metric):
def __init__(
self,
target_h_size: int,
target_w_size: int,
metric: Metric,
dist_sync_on_step: bool = False,
):
"""
Initializes a new instance of CroppedMetric.
Parameters
----------
target_h_size: int
The target height size.
target_w_size: int
The target width size.
dist_sync_on_step: bool, optional
Whether to synchronize metric state across processes at each step.
Defaults to False.
"""
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.metric = metric
self.target_h_size = target_h_size
self.target_w_size = target_w_size

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Updates the metric state with the predictions and targets.
Parameters
----------
preds: torch.Tensor
The predicted tensor.
target:
torch.Tensor The target tensor.
"""

preds = self.crop(preds)
target = self.crop(target)
self.metric.update(preds, target)

def compute(self) -> float:
"""
Computes the cropped metric.
Returns:
float: The cropped metric.
"""
return self.metric.compute()

def crop(self, x: torch.Tensor) -> torch.Tensor:
"""crops the input tensor to the target size.
Parameters
----------
x : torch.Tensor
The input tensor.
Returns
-------
torch.Tensor
The cropped tensor.
"""
h, w = x.shape[-2:]
start_h = (h - self.target_h_size) // 2
start_w = (w - self.target_w_size) // 2
end_h = start_h + self.target_h_size
end_w = start_w + self.target_w_size

return x[..., start_h:end_h, start_w:end_w]


class ResizeMetric(Metric):
def __init__(
self,
target_h_size: Optional[int],
target_w_size: Optional[int],
metric: Metric,
keep_aspect_ratio: bool = False,
dist_sync_on_step: bool = False,
):
"""
Initializes a new instance of ResizeMetric.
Parameters
----------
target_h_size: int
The target height size.
target_w_size: int
The target width size.
dist_sync_on_step: bool, optional
Whether to synchronize metric state across processes at each step.
Defaults to False.
"""
super().__init__(dist_sync_on_step=dist_sync_on_step)

if target_h_size is None and target_w_size is None:
raise ValueError(
"At least one of target_h_size or target_w_size must be provided."
)

if (
target_h_size is not None and target_w_size is None
) and keep_aspect_ratio is False:
warnings.warn(
"A target_w_size is not provided, but keep_aspect_ratio is set to False. keep_aspect_ratio will be set to True. If you want to resize the image to a specific width, please provide a target_w_size."
)
keep_aspect_ratio = True

if (
target_w_size is not None and target_h_size is None
) and keep_aspect_ratio is False:
warnings.warn(
"A target_h_size is not provided, but keep_aspect_ratio is set to False. keep_aspect_ratio will be set to True. If you want to resize the image to a specific height, please provide a target_h_size."
)
keep_aspect_ratio = True

self.metric = metric
self.target_h_size = target_h_size
self.target_w_size = target_w_size
self.keep_aspect_ratio = keep_aspect_ratio

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Updates the metric state with the predictions and targets.
Parameters
----------
preds: torch.Tensor
The predicted tensor.
target:
torch.Tensor The target tensor.
"""

preds = self.resize(preds)
target = self.resize(target)
self.metric.update(preds, target)

def compute(self) -> float:
"""
Computes the resized metric.
Returns:
float: The resized metric.
"""
return self.metric.compute()

def resize(self, x: torch.Tensor) -> torch.Tensor:
"""Resizes the input tensor to the target size.
Parameters
----------
x : torch.Tensor
The input tensor.
Returns
-------
torch.Tensor
The resized tensor.
"""
h, w = x.shape[-2:]

target_h_size = self.target_h_size
target_w_size = self.target_w_size
if self.keep_aspect_ratio:
if self.target_h_size is None:
scale = target_w_size / w
target_h_size = int(h * scale)
elif self.target_w_size is None:
scale = target_h_size / h
target_w_size = int(w * scale)

return torch.nn.functional.interpolate(x, size=(target_h_size, target_w_size))

0 comments on commit 8ee0ed7

Please sign in to comment.