From 1f9990e3acfaf054b41ad111d65750a9d927b74f Mon Sep 17 00:00:00 2001 From: Benjamin Fattori Date: Sat, 9 Nov 2024 21:23:52 -0500 Subject: [PATCH] add #TODO --- mamba/mamba_inner_fn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mamba/mamba_inner_fn.py b/mamba/mamba_inner_fn.py index 78d86f3..3d3de64 100644 --- a/mamba/mamba_inner_fn.py +++ b/mamba/mamba_inner_fn.py @@ -162,7 +162,7 @@ def backward(ctx, grad_output): # type: ignore x.mT.contiguous(), conv1d_w, conv1d_b, x_grad.contiguous(), act=1 ) - dx = torch.cat([dx_pre_conv.mT.contiguous(), dres], dim=-1) + dx = torch.cat([dx_pre_conv.mT.contiguous(), dres], dim=-1) #TODO: This is pretty slow, probably better to fuse with bwd on cc return ( dx,