Skip to content

Commit

Permalink
fixing tests once more
Browse files Browse the repository at this point in the history
  • Loading branch information
adefossez committed Oct 9, 2023
1 parent ba62e00 commit 19ee715
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 32 deletions.
24 changes: 14 additions & 10 deletions audiocraft/modules/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,16 @@ def get_rotation(self, start: int, end: int):
self.rotation = torch.polar(torch.ones_like(angles), angles)
return self.rotation[start:end]

def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):
def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False):
"""Apply rope rotation to query or key tensor."""
T = x.shape[1]
rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2)
T = x.shape[time_dim]
target_shape = [1] * x.dim()
target_shape[time_dim] = T
target_shape[-1] = -1
rotation = self.get_rotation(start, start + T).view(target_shape)

if self.xpos:
decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2)
decay = self.xpos.get_decay(start, start + T).view(target_shape)
else:
decay = 1.0

Expand All @@ -96,11 +99,11 @@ def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):

x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2)
x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x)

return x_out.type_as(x)

def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1):
""" Apply rope rotation to both query and key tensors.
Supports streaming mode, in which query and key are not expected to have the same shape.
In streaming mode, key will be of length [P + C] with P the cached past timesteps, but
Expand All @@ -110,12 +113,13 @@ def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
query (torch.Tensor): Query to rotate.
key (torch.Tensor): Key to rotate.
start (int): Start index of the sequence for time offset.
time_dim (int): which dimension represent the time steps.
"""
query_timesteps = query.shape[1]
key_timesteps = key.shape[1]
query_timesteps = query.shape[time_dim]
key_timesteps = key.shape[time_dim]
streaming_offset = key_timesteps - query_timesteps

query_out = self.rotate(query, start + streaming_offset)
key_out = self.rotate(key, start, invert_decay=True)
query_out = self.rotate(query, start + streaming_offset, time_dim)
key_out = self.rotate(key, start, time_dim, invert_decay=True)

return query_out, key_out
23 changes: 11 additions & 12 deletions audiocraft/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def set_efficient_attention_backend(backend: str = 'torch'):
_efficient_attention_backend = backend


def _get_attention_time_dimension() -> int:
if _efficient_attention_backend == 'torch':
def _get_attention_time_dimension(memory_efficient: bool) -> int:
if _efficient_attention_backend == 'torch' and memory_efficient:
return 2
else:
return 1
Expand Down Expand Up @@ -89,11 +89,11 @@ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float =
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)


def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
if n_rep == 1:
return x
if _efficient_attention_backend == 'torch':
if _efficient_attention_backend == 'torch' and memory_efficient:
bs, n_kv_heads, slen, head_dim = x.shape
return (
x[:, :, None, :, :]
Expand Down Expand Up @@ -234,7 +234,7 @@ def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype
# Return a causal mask, accounting for potentially stored past keys/values
# We actually return a bias for the attention score, as this has the same
# convention both in the builtin MHA in Pytorch, and Xformers functions.
time_dim = _get_attention_time_dimension()
time_dim = _get_attention_time_dimension(self.memory_efficient)
if self.memory_efficient:
from xformers.ops import LowerTriangularMask
if current_steps == 1:
Expand Down Expand Up @@ -264,7 +264,7 @@ def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype
torch.full([], float('-inf'), device=device, dtype=dtype))

def _complete_kv(self, k, v):
time_dim = _get_attention_time_dimension()
time_dim = _get_attention_time_dimension(self.memory_efficient)
if self.cross_attention:
# With cross attention we assume all keys and values
# are already available, and streaming is with respect
Expand Down Expand Up @@ -298,8 +298,7 @@ def _complete_kv(self, k, v):
return nk, nv

def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
# TODO: fix and verify layout.
assert _efficient_attention_backend == 'xformers', "Rope not supported with torch attn."
time_dim = _get_attention_time_dimension(self.memory_efficient)
# Apply rope embeddings to query and key tensors.
assert self.rope is not None
if 'past_keys' in self._streaming_state:
Expand All @@ -311,7 +310,7 @@ def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
else:
past_context_offset = 0
streaming_offset = past_context_offset + past_keys_offset
return self.rope.rotate_qk(query, key, start=streaming_offset)
return self.rope.rotate_qk(query, key, start=streaming_offset, time_dim=time_dim)

def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
key_padding_mask=None, need_weights=False, attn_mask=None,
Expand All @@ -320,7 +319,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
assert not is_causal, ("New param added in torch 2.0.1 not supported, "
"use the causal args in the constructor.")

time_dim = _get_attention_time_dimension()
time_dim = _get_attention_time_dimension(self.memory_efficient)
if time_dim == 2:
layout = "b h t d"
else:
Expand Down Expand Up @@ -394,8 +393,8 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
q, k = self._apply_rope(q, k)
k, v = self._complete_kv(k, v)
if self.kv_repeat > 1:
k = expand_repeated_kv(k, self.kv_repeat)
v = expand_repeated_kv(v, self.kv_repeat)
k = expand_repeated_kv(k, self.kv_repeat, self.memory_efficient)
v = expand_repeated_kv(v, self.kv_repeat, self.memory_efficient)
if self.attention_as_float32:
q, k, v = [x.float() for x in [q, k, v]]
if self.memory_efficient:
Expand Down
16 changes: 8 additions & 8 deletions tests/modules/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def test_rope():
set_efficient_attention_backend('xformers')
set_efficient_attention_backend('torch')
B, T, H, C = 8, 75, 16, 128

rope = RotaryEmbedding(dim=C)
Expand All @@ -24,7 +24,7 @@ def test_rope():


def test_rope_io_dtypes():
set_efficient_attention_backend('xformers')
set_efficient_attention_backend('torch')
B, T, H, C = 8, 75, 16, 128

rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32)
Expand All @@ -48,7 +48,7 @@ def test_rope_io_dtypes():


def test_transformer_with_rope():
set_efficient_attention_backend('xformers')
set_efficient_attention_backend('torch')
torch.manual_seed(1234)
for pos in ['rope', 'sin_rope']:
tr = StreamingTransformer(
Expand All @@ -64,7 +64,7 @@ def test_transformer_with_rope():

@torch.no_grad()
def test_rope_streaming():
set_efficient_attention_backend('xformers')
set_efficient_attention_backend('torch')
torch.manual_seed(1234)
tr = StreamingTransformer(
16, 4, 2, causal=True, dropout=0.,
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_rope_streaming():

@torch.no_grad()
def test_rope_streaming_past_context():
set_efficient_attention_backend('xformers')
set_efficient_attention_backend('torch')
torch.manual_seed(1234)

for context in [None, 10]:
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_rope_streaming_past_context():


def test_rope_memory_efficient():
set_efficient_attention_backend('xformers')
set_efficient_attention_backend('torch')
torch.manual_seed(1234)
tr = StreamingTransformer(
16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
Expand All @@ -143,7 +143,7 @@ def test_rope_memory_efficient():


def test_rope_with_xpos():
set_efficient_attention_backend('xformers')
set_efficient_attention_backend('torch')
B, T, H, C = 8, 75, 16, 128

rope = RotaryEmbedding(dim=C, xpos=True)
Expand All @@ -156,7 +156,7 @@ def test_rope_with_xpos():


def test_positional_scale():
set_efficient_attention_backend('xformers')
set_efficient_attention_backend('torch')
B, T, H, C = 8, 75, 16, 128

rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0)
Expand Down
4 changes: 2 additions & 2 deletions tests/modules/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_streaming_api():


def test_memory_efficient():
for backend in ['torch', 'xformers']:
for backend in ['torch']:
torch.manual_seed(1234)
set_efficient_attention_backend(backend)

Expand Down Expand Up @@ -132,7 +132,7 @@ def test_attention_as_float32():

@torch.no_grad()
def test_streaming_memory_efficient():
for backend in ['torch', 'xformers']:
for backend in ['torch']:
torch.manual_seed(1234)
set_efficient_attention_backend(backend)
tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True)
Expand Down

0 comments on commit 19ee715

Please sign in to comment.