forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mlp.py
230 lines (200 loc) · 7.98 KB
/
mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..functional import ACT2FN, concat
from ..module import Module
from ..quantization import QuantMode
from .linear import ColumnLinear, RowLinear
from .lora import LoraRuntimeParams
class MLP(Module):
def __init__(
self,
hidden_size,
ffn_hidden_size,
hidden_act,
bias=True,
dtype=None,
tp_group=None,
tp_size=1,
quant_mode=QuantMode(0),
):
super().__init__()
if hidden_act not in ACT2FN:
raise ValueError(
'unsupported activation function: {}'.format(hidden_act))
fc_output_size = 2 * ffn_hidden_size if hidden_act == 'swiglu' else ffn_hidden_size
self.fc = ColumnLinear(hidden_size,
fc_output_size,
bias=bias,
dtype=dtype,
tp_group=tp_group,
tp_size=tp_size,
gather_output=False)
self.proj = RowLinear(ffn_hidden_size,
hidden_size,
bias=bias,
dtype=dtype,
tp_group=tp_group,
tp_size=tp_size)
self.hidden_act = hidden_act
self.dtype = dtype
self.bias = bias
self.quant_mode = quant_mode
def forward(self, hidden_states, lora_layer_params=None):
mlp_fc_lora_params = None
if lora_layer_params is not None:
mlp_fc_lora_params = lora_layer_params.get_runtime_params(
0, "mlp_h_to_4h")
mlp_proj_lora_params = None
if lora_layer_params is not None:
mlp_proj_lora_params = lora_layer_params.get_runtime_params(
0, "mlp_4h_to_h")
inter = self.fc(hidden_states, mlp_fc_lora_params)
inter = ACT2FN[self.hidden_act](inter)
output = self.proj(inter, lora_runtime_params=mlp_proj_lora_params)
return output
class GatedMLP(MLP):
def __init__(
self,
hidden_size,
ffn_hidden_size,
hidden_act,
bias=True,
dtype=None,
tp_group=None,
tp_size=1,
quant_mode=QuantMode(0),
):
super().__init__(hidden_size,
ffn_hidden_size,
hidden_act,
bias=bias,
dtype=dtype,
tp_group=tp_group,
tp_size=tp_size,
quant_mode=quant_mode)
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.tp_group = tp_group
self.tp_size = tp_size
self.gate = ColumnLinear(hidden_size,
ffn_hidden_size,
bias=bias,
dtype=dtype,
tp_group=tp_group,
tp_size=tp_size,
gather_output=False)
def forward(self, hidden_states, lora_layer_params=None):
mlp_fc_lora_params = None
if lora_layer_params is not None:
mlp_fc_lora_params = lora_layer_params.get_runtime_params(
0, "mlp_h_to_4h")
mlp_gate_lora_params = None
if lora_layer_params is not None:
mlp_gate_lora_params = lora_layer_params.get_runtime_params(
0, "mlp_gate")
mlp_proj_lora_params = None
if lora_layer_params is not None:
mlp_proj_lora_params = lora_layer_params.get_runtime_params(
0, "mlp_4h_to_h")
inter = self.fc(hidden_states, mlp_fc_lora_params)
inter = ACT2FN[self.hidden_act](inter)
gate = self.gate(hidden_states, mlp_gate_lora_params)
intermediate = inter * gate
output = self.proj(intermediate,
lora_runtime_params=mlp_proj_lora_params)
return output
class FusedGatedMLP(Module):
def __init__(
self,
hidden_size,
ffn_hidden_size,
hidden_act,
bias=True,
dtype=None,
tp_group=None,
tp_size=1,
quant_mode=QuantMode(0),
):
super().__init__()
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.hidden_act = hidden_act
self.bias = bias
self.dtype = dtype
self.tp_group = tp_group
self.tp_size = tp_size
self.quant_mode = quant_mode
self.fused_fc = ColumnLinear(
self.hidden_size,
self.ffn_hidden_size * 2,
bias=self.bias,
dtype=self.dtype,
tp_group=self.tp_group,
tp_size=self.tp_size,
gather_output=False,
)
self.proj = RowLinear(ffn_hidden_size,
hidden_size,
bias=bias,
dtype=dtype,
tp_group=tp_group,
tp_size=tp_size)
def forward(self, hidden_states, lora_layer_params=None):
# Combine the following pattern
#
# SiLU(FC(x)) + Gate(x)
#
# into:
#
# SwiGLU(FusedFC(x))
#
# Upside is we don't need to modify 4 different weight loading paths just to concat weights
inter = self.fused_fc(hidden_states)
if lora_layer_params is not None:
mlp_fc_lora_params = lora_layer_params.get_runtime_params(
0, "mlp_h_to_4h")
mlp_gate_lora_params = lora_layer_params.get_runtime_params(
0, "mlp_gate")
if mlp_fc_lora_params is not None and mlp_gate_lora_params is not None:
mlp_in_lora_params = LoraRuntimeParams(
lora_ranks=[
mlp_fc_lora_params.lora_ranks[0],
mlp_gate_lora_params.lora_ranks[0]
],
lora_weights_pointers=[
mlp_fc_lora_params.lora_weights_pointers[0],
mlp_gate_lora_params.lora_weights_pointers[0]
],
host_request_types=mlp_fc_lora_params.host_request_types,
host_context_lengths=mlp_fc_lora_params.
host_context_lengths,
max_context_length=mlp_fc_lora_params.max_context_length)
mlp_fc_lora, mlp_gate_lora = self.mlp_in_lora(
hidden_states, mlp_in_lora_params)
mlp_in_result = concat([mlp_gate_lora, mlp_fc_lora],
dim=mlp_fc_lora.rank() - 1)
inter = inter + mlp_in_result
if self.hidden_act == 'silu':
inter = ACT2FN['swiglu'](inter)
else:
raise NotImplementedError(
f"Activation {self.hidden_act} not yet implemented for FusedGatedMLP"
)
mlp_proj_lora_params = None
if lora_layer_params is not None:
mlp_proj_lora_params = lora_layer_params.get_runtime_params(
0, "mlp_4h_to_h")
output = self.proj(inter, lora_runtime_params=mlp_proj_lora_params)
return output