diff --git a/src/adam/pytorch/torch_like.py b/src/adam/pytorch/torch_like.py index 0bc98ea1..71962f65 100644 --- a/src/adam/pytorch/torch_like.py +++ b/src/adam/pytorch/torch_like.py @@ -59,6 +59,8 @@ def __matmul__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike": if type(self) is type(other): return TorchLike(self.array @ other.array) + if isinstance(other, torch.Tensor): + return TorchLike(self.array @ other) else: return TorchLike(self.array @ torch.tensor(other))