Skip to content

Commit

Permalink
add a test for the multi input stream variant
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 29, 2024
1 parent 091370e commit 4656f36
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 3 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Tests the examples in README
on: push

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v4
- name: Install the latest version of rye
uses: eifinger/setup-rye@v2
- name: Use UV instead of pip
run: rye config --set-bool behavior.use-uv=true
- name: Install dependencies
run: |
rye sync
- name: Run pytest
run: rye run pytest tests/
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __init__(
init_alpha0 = torch.zeros((num_residual_streams, 1))
init_alpha0[init_residual_index, 0] = 1.

self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1)
self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1, activation = act)
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))

self.dynamic_beta = ProjActScale(dim, 1, activation = act, squeeze_output = True)
Expand All @@ -200,7 +200,7 @@ def __init__(

self.additional_norms = ModuleList([RMSNorm(dim) for _, dim in additional_input_paths])
self.additional_to_dynamic_input = ModuleList([ProjActScale(dim, 1, activation = act, squeeze_output = True) for _ , dim in additional_input_paths])
self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0])])
self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0]) for _ in additional_input_paths])

self.additional_input_paths = additional_input_paths

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "hyper-connections"
version = "0.0.24"
version = "0.1.0"
description = "Hyper-Connections"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down
138 changes: 138 additions & 0 deletions tests/test_hyper_connections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import pytest

import torch
from torch import nn

@pytest.mark.parametrize('disable', (False, True))
def test_readme(disable):

# a single branch layer

branch = nn.Linear(512, 512)

# before

residual = torch.randn(2, 1024, 512)

residual = branch(residual) + residual

# after, say 4 streams in paper

from hyper_connections import get_init_and_expand_reduce_stream_functions

init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4, disable = disable)

# 1. wrap your branch function

hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch)

# 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions

residual = expand_stream(residual)

# 3. forward your residual as usual into the wrapped branch function(s)

residual = hyper_conn_branch(residual)

# 4. reduce 4 streams with a summation, this has to be done after your for-loop trunk. for transformer, unsure whether to do before or after final norm

residual = reduce_stream(residual)

assert residual.shape == (2, 1024, 512)

def test_manual():
# a single branch layer

branch = nn.Linear(512, 512)

# before

residual = torch.randn(2, 1024, 512)

residual = branch(residual) + residual

# after, say 4 streams in paper

from hyper_connections import get_init_and_expand_reduce_stream_functions

init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)

# 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above

hyper_conn = init_hyper_conn(dim = 512)

# 2. expand to 4 streams

residual = expand_stream(residual)

# 3. forward your residual into hyper connection for the branch input + add residual function (learned betas)

branch_input, add_residual = hyper_conn(residual)

branch_output = branch(branch_input)

residual = add_residual(branch_output)

# or you can do it in one line as so -> residual = hyper_conn.decorate_branch(branch)(residual)

# 4. reduce 4 streams with a summation, this has to be done after your for loop trunk

residual = reduce_stream(residual)
assert residual.shape == (2, 1024, 512)

@pytest.mark.parametrize('disable', (False, True))
def test_multi_input_hyper_connections(disable):

# two branch layers

class CustomModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(512, 512)
self.second_linear = nn.Linear(256, 512)
self.third_linear = nn.Linear(128, 512)

def forward(self, x, second, *, third):
return self.linear(x) + self.second_linear(second) + self.third_linear(third), 3.

branch = CustomModule()

# before

residual = torch.randn(3, 1024, 512)
second_residual = torch.randn(3, 1024, 256)
third_residual = torch.randn(3, 1024, 128)

# residual = branch1(residual) + branch2(residual) + residual

# after, say 4 streams in paper

from hyper_connections.hyper_connections_with_multi_input_streams import HyperConnections

init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = disable)

# 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above

hyper_conn = init_hyper_conn(
dim = 512,
branch = branch,
additional_input_paths = [
(1, 256), # points at second residual stream, first arg
('third', 128) # points at third residual stream, keyword argument 'third'
],
layer_index = 1,
)

# 2. expand to 4 streams

residual = expand_stream(residual)
second_residual = expand_stream(second_residual)
third_residual = expand_stream(third_residual)

# 3. forward your residual into hyper connection for the branch input + add residual function (learned betas)

residual, rest_output = hyper_conn(residual, second_residual, third = third_residual)

residual = reduce_stream(residual)

assert residual.shape == (3, 1024, 512)

0 comments on commit 4656f36

Please sign in to comment.