Skip to content

Commit

Permalink
Refactor tensor type conversion in ResizedMetric to handle LongTensor…
Browse files Browse the repository at this point in the history
… and uint8 formats correctly
  • Loading branch information
GabrielBG0 committed Nov 15, 2024
1 parent 7e63fd5 commit e4131f6
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions minerva/analysis/metrics/transformed_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,5 +177,15 @@ def resize(self, x: torch.Tensor) -> torch.Tensor:
elif self.target_w_size is None:
scale = target_h_size / h
target_w_size = int(w * scale)
x = x.to(torch.uint8) if "LongTensor" in x.type() else x
return torch.nn.functional.interpolate(x, size=(target_h_size, target_w_size))
type_convert = False
if "LongTensor" in x.type():
x = x.to(torch.uint8)
type_convert = True

return (
torch.nn.functional.interpolate(x, size=(target_h_size, target_w_size))
if not type_convert
else torch.nn.functional.interpolate(
x, size=(target_h_size, target_w_size)
).to(torch.long)
)

0 comments on commit e4131f6

Please sign in to comment.