Skip to content

Commit

Permalink
Adapt checking logic
Browse files Browse the repository at this point in the history
  • Loading branch information
TimoImhof committed Aug 4, 2024
1 parent 88bd867 commit 3de2581
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/adapters/methods/reft.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ def _gather_adapted_states(self, hidden_states: torch.Tensor):
# if cached indexing matrices are computed for different hidden_states size -> recompute
cache_invalidated = False
if hasattr(context, "pref_idx") and hasattr(context, "suff_idx"):
cache_invalidated = context.suff_idx.size(1) != seq_len
cache_invalidated = (
torch.max(context.suff_idx) >= seq_len # indices out of bounds
or bsz != context.suff_idx.size(0) # batch size mismatch
or ddim != context.suff_idx.size(2) # hidden size mismatch
)

# no cached indexing matrices available -> compute now
if not hasattr(context, "pref_idx") and not hasattr(context, "suff_idx") or cache_invalidated:
Expand Down

0 comments on commit 3de2581

Please sign in to comment.