From e4131f66640eacea62ef4138dc8e2a77724410cf Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Fri, 15 Nov 2024 00:46:37 -0300 Subject: [PATCH] Refactor tensor type conversion in ResizedMetric to handle LongTensor and uint8 formats correctly --- minerva/analysis/metrics/transformed_metrics.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/minerva/analysis/metrics/transformed_metrics.py b/minerva/analysis/metrics/transformed_metrics.py index 23d65c0..e3d3e56 100644 --- a/minerva/analysis/metrics/transformed_metrics.py +++ b/minerva/analysis/metrics/transformed_metrics.py @@ -177,5 +177,15 @@ def resize(self, x: torch.Tensor) -> torch.Tensor: elif self.target_w_size is None: scale = target_h_size / h target_w_size = int(w * scale) - x = x.to(torch.uint8) if "LongTensor" in x.type() else x - return torch.nn.functional.interpolate(x, size=(target_h_size, target_w_size)) + type_convert = False + if "LongTensor" in x.type(): + x = x.to(torch.uint8) + type_convert = True + + return ( + torch.nn.functional.interpolate(x, size=(target_h_size, target_w_size)) + if not type_convert + else torch.nn.functional.interpolate( + x, size=(target_h_size, target_w_size) + ).to(torch.long) + )