Skip to content

Commit

Permalink
fix MHA API
Browse files Browse the repository at this point in the history
fixes HazyResearch#3 when pykeops is not installed atleast
  • Loading branch information
kashif authored Jan 24, 2023
1 parent 68be180 commit 79359bf
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/models/ssm_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
def create_mixer_cls(ssm_cls=H3, ssm_cfg=None, attn_layer_idx=None, attn_cfg=None, layer_idx=None):
if attn_layer_idx is not None and layer_idx in attn_layer_idx:
causal = True if attn_cfg is None else attn_cfg.pop('causal', True)
mixer_cls = partial(MHA, causal=causal, **(attn_cfg if attn_cfg is not None else {}))
mixer_cls = partial(MHA, layer_idx=layer_idx, causal=causal,
**(attn_cfg if attn_cfg is not None else {}))
else:
mixer_cls = partial(ssm_cls, layer_idx=layer_idx,
**(ssm_cfg if ssm_cfg is not None else {}))
Expand Down

0 comments on commit 79359bf

Please sign in to comment.