From 79359bf24f0ba60471e1f7e348336e4b32930b62 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 24 Jan 2023 17:38:44 +0100 Subject: [PATCH] fix MHA API fixes #3 when pykeops is not installed atleast --- src/models/ssm_seq.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/models/ssm_seq.py b/src/models/ssm_seq.py index ce629f3..decec0e 100644 --- a/src/models/ssm_seq.py +++ b/src/models/ssm_seq.py @@ -30,7 +30,8 @@ def create_mixer_cls(ssm_cls=H3, ssm_cfg=None, attn_layer_idx=None, attn_cfg=None, layer_idx=None): if attn_layer_idx is not None and layer_idx in attn_layer_idx: causal = True if attn_cfg is None else attn_cfg.pop('causal', True) - mixer_cls = partial(MHA, causal=causal, **(attn_cfg if attn_cfg is not None else {})) + mixer_cls = partial(MHA, layer_idx=layer_idx, causal=causal, + **(attn_cfg if attn_cfg is not None else {})) else: mixer_cls = partial(ssm_cls, layer_idx=layer_idx, **(ssm_cfg if ssm_cfg is not None else {}))