From 90951038118a128126fe4304f32e787e7aae0b56 Mon Sep 17 00:00:00 2001 From: giulero Date: Thu, 25 Jan 2024 18:23:24 +0100 Subject: [PATCH] Do not wrap array in a torch tensor if already a tensor --- src/adam/pytorch/torch_like.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/adam/pytorch/torch_like.py b/src/adam/pytorch/torch_like.py index 0bc98ea1..71962f65 100644 --- a/src/adam/pytorch/torch_like.py +++ b/src/adam/pytorch/torch_like.py @@ -59,6 +59,8 @@ def __matmul__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike": if type(self) is type(other): return TorchLike(self.array @ other.array) + if isinstance(other, torch.Tensor): + return TorchLike(self.array @ other) else: return TorchLike(self.array @ torch.tensor(other))