diff --git a/mamba/mamba_inner_fn.py b/mamba/mamba_inner_fn.py index 78d86f3..3d3de64 100644 --- a/mamba/mamba_inner_fn.py +++ b/mamba/mamba_inner_fn.py @@ -162,7 +162,7 @@ def backward(ctx, grad_output): # type: ignore x.mT.contiguous(), conv1d_w, conv1d_b, x_grad.contiguous(), act=1 ) - dx = torch.cat([dx_pre_conv.mT.contiguous(), dres], dim=-1) + dx = torch.cat([dx_pre_conv.mT.contiguous(), dres], dim=-1) #TODO: This is pretty slow, probably better to fuse with bwd on cc return ( dx,