Skip to content

Commit d24d24b

Browse files
committed
add a hyper connections specifically for 2d channel first
1 parent e36ce11 commit d24d24b

File tree

3 files changed

+251
-1
lines changed

3 files changed

+251
-1
lines changed
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
from __future__ import annotations
2+
from typing import Callable
3+
4+
from functools import partial
5+
from random import randrange
6+
7+
import torch
8+
from torch import nn
9+
from torch.nn import Module
10+
import torch.nn.functional as F
11+
from torch.utils._pytree import tree_flatten, tree_unflatten
12+
13+
from einops import rearrange, repeat, reduce, einsum
14+
from einops.layers.torch import Reduce, Rearrange
15+
16+
from hyper_connections.hyper_connections import (
17+
Residual,
18+
RMSNorm
19+
)
20+
21+
"""
22+
ein notation:
23+
b - batch
24+
d - feature dimension
25+
s - residual streams
26+
t - residual streams + num branch inputs
27+
"""
28+
29+
# helper functions
30+
31+
def exists(v):
32+
return v is not None
33+
34+
def default(v, d):
35+
return v if exists(v) else d
36+
37+
def identity(t):
38+
return t
39+
40+
# main functions
41+
42+
def get_expand_reduce_stream_functions(num_streams, disable = False):
43+
44+
if num_streams == 1 or disable:
45+
return (nn.Identity(), nn.Identity())
46+
47+
expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
48+
reduce_fn = Reduce(pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
49+
50+
return expand_fn, reduce_fn
51+
52+
def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
53+
54+
hyper_conn_klass = HyperConnections if not disable else Residual
55+
56+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
57+
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
58+
59+
return (init_hyper_conn_fn, *expand_reduce_fns)
60+
61+
# norms
62+
63+
class RMSNorm(Module):
64+
def __init__(self, dim):
65+
super().__init__()
66+
self.scale = dim ** 0.5
67+
self.gamma = nn.Parameter(torch.zeros(dim, 1, 1))
68+
69+
def forward(self, x):
70+
return F.normalize(x, dim = 1) * self.scale * (self.gamma + 1)
71+
72+
# hyper connection residual streams
73+
74+
class HyperConnections(Module):
75+
def __init__(
76+
self,
77+
num_residual_streams,
78+
*,
79+
dim,
80+
branch: Module | None = None,
81+
layer_index = None,
82+
tanh = True,
83+
channel_first = True,
84+
dropout = 0.,
85+
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
86+
):
87+
"""
88+
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
89+
"""
90+
super().__init__()
91+
92+
self.branch = branch
93+
94+
# activation, seemingly results were wishy washy depending on using tanh or not
95+
96+
self.act = nn.Tanh() if tanh else nn.Identity()
97+
98+
self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
99+
100+
assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
101+
102+
self.num_residual_streams = num_residual_streams
103+
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
104+
105+
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
106+
107+
init_alpha0 = torch.zeros((num_residual_streams, 1))
108+
init_alpha0[init_residual_index, 0] = 1.
109+
110+
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
111+
112+
self.dynamic_alpha_fn = nn.Conv2d(dim, num_residual_streams + 1, 1, bias = False)
113+
nn.init.zeros_(self.dynamic_alpha_fn.weight)
114+
115+
self.dynamic_beta_fn = nn.Sequential(
116+
nn.Conv2d(dim, 1, 1, bias = False),
117+
Rearrange('b 1 ... -> b ...')
118+
)
119+
120+
nn.init.zeros_(self.dynamic_beta_fn[0].weight)
121+
122+
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
123+
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
124+
125+
126+
# dropouts
127+
128+
self.dropout = nn.Dropout(dropout)
129+
130+
# maybe residual transform
131+
132+
self.residual_transform = default(residual_transform, nn.Identity())
133+
134+
def width_connection(self, residuals):
135+
136+
maybe_transformed_residuals = self.residual_transform(residuals)
137+
138+
# width connection
139+
140+
normed = self.norm(residuals)
141+
142+
# alpha for weighted sum of residuals going into branch
143+
144+
wc_weight = self.act(self.dynamic_alpha_fn(normed))
145+
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
146+
147+
dynamic_alpha = rearrange(dynamic_alpha, '(b s) ... -> b s ...', s = self.num_residual_streams)
148+
alpha = dynamic_alpha + rearrange(self.static_alpha, 's t -> s t 1 1')
149+
150+
# beta for weights from branch output back to residual streams
151+
152+
dc_weight = self.act(self.dynamic_beta_fn(normed))
153+
dynamic_beta = dc_weight * self.dynamic_beta_scale
154+
dynamic_beta = rearrange(dynamic_beta, '(b s) ... -> b s ...', s = self.num_residual_streams)
155+
beta = dynamic_beta + rearrange(self.static_beta, 's -> s 1 1')
156+
157+
residuals = rearrange(residuals, '(b s) ... -> b s ...', s = self.num_residual_streams)
158+
mix_h = einsum(alpha, residuals, 'b s t ..., b s d ... -> b t d ...')
159+
160+
branch_input, residuals = mix_h[:, 0, ...], mix_h[:, 1:, ...]
161+
162+
return branch_input, maybe_transformed_residuals, dict(beta = beta)
163+
164+
def depth_connection(self, branch_output, residuals, *, beta):
165+
# 'depth' connection
166+
167+
output = einsum(branch_output, beta, 'b d ..., b s ... -> b s d ...')
168+
output = rearrange(output, 'b s d ... -> (b s) d ...')
169+
170+
residuals = residuals + output
171+
172+
return self.dropout(residuals)
173+
174+
def decorate_branch(self, branch: Callable):
175+
assert not exists(self.branch), 'branch was already wrapped on init'
176+
177+
def forward_and_add_residual(residual, *args, **kwargs):
178+
branch_input, add_residual = self.forward(residual)
179+
180+
branch_output = branch(branch_input, *args, **kwargs)
181+
182+
residual = add_residual(branch_output)
183+
184+
return residual
185+
186+
return forward_and_add_residual
187+
188+
def forward(self, residuals, *branch_args, **branch_kwargs):
189+
190+
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
191+
192+
def add_residual_fn(branch_out):
193+
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
194+
195+
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
196+
197+
return tree_unflatten((branch_out, *rest), tree_spec)
198+
199+
if not exists(self.branch):
200+
return branch_input, add_residual_fn
201+
202+
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
203+
204+
return add_residual_fn(branch_output)
205+
206+
HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
207+
HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "hyper-connections"
3-
version = "0.1.7"
3+
version = "0.1.8"
44
description = "Hyper-Connections"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_hyper_connections.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,46 @@ def test_residual_transform(disable):
179179
after_residual = reduce_stream(residual)
180180

181181
assert before_residual.shape == after_residual.shape
182+
183+
@pytest.mark.parametrize('disable', (False, True))
184+
def test_channel_first_hyper_connection(disable):
185+
186+
# a single branch layer
187+
188+
branch = nn.Sequential(
189+
nn.Conv2d(512, 512, 3, padding = 1),
190+
nn.SiLU(),
191+
nn.Conv2d(512, 256, 3, padding = 1)
192+
)
193+
194+
residual_fn = nn.Conv2d(512, 256, 1)
195+
196+
# before
197+
198+
residual = torch.randn(2, 512, 16, 16)
199+
200+
before_residual = branch(residual) + residual_fn(residual)
201+
202+
# after, say 4 streams in paper
203+
204+
from hyper_connections.hyper_connections_channel_first import get_init_and_expand_reduce_stream_functions
205+
206+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4, disable = disable)
207+
208+
# 1. wrap your branch function
209+
210+
hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch, residual_transform = residual_fn)
211+
212+
# 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
213+
214+
residual = expand_stream(residual)
215+
216+
# 3. forward your residual as usual into the wrapped branch function(s)
217+
218+
residual = hyper_conn_branch(residual)
219+
220+
# 4. reduce 4 streams with a summation, this has to be done after your for-loop trunk. for transformer, unsure whether to do before or after final norm
221+
222+
after_residual = reduce_stream(residual)
223+
224+
assert before_residual.shape == after_residual.shape

0 commit comments

Comments
 (0)