Skip to content

Commit

Permalink
fix (api): 适配新的rms norm和matmul算子,避免load时transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
PanZezhong1725 authored and YdrMaster committed Jan 31, 2024
1 parent 2bd0635 commit 1f0fd6e
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 21 deletions.
21 changes: 16 additions & 5 deletions src/09python_ffi/src/refactor_graph/frontend/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def make_op(
inputs: Tuple[Union[str, np.ndarray], ...],
outputs: Tuple[str, ...] | int = 1,
name: str | None = None,
use_onnx_standard: bool = True
use_onnx_standard: bool = True,
):
if use_onnx_standard and not op_type.startswith("onnx::"):
op_type = "onnx::" + op_type
Expand Down Expand Up @@ -357,7 +357,7 @@ def run(
self._nodes, operators, edges, input_names, output_names
)
self._executor = self._compiler.compile_on(
find_device(self._device, self._device_id), "default", []
find_device(self._device, self._device_id), "default", ["ce"]
)
elif recompile or self._executor is None:
# Set input info
Expand All @@ -382,7 +382,7 @@ def run(
self._compiler.set_input(input_names.index(name), dynamic_tensor)

self._executor = self._compiler.compile_on(
find_device(self._device, self._device_id), "default", []
find_device(self._device, self._device_id), "default", ["ce"]
)

## Executor should have been created at this point
Expand Down Expand Up @@ -489,6 +489,8 @@ def make_onnx(
return onnx.helper.make_model(graph)

def load_params(self, data: Dict[str, np.ndarray]):
if len(self._parameters) != len(data):
print(f"Warning: the number of loaded params does not match current model.")
for name in self._parameters:
new_data = data.get(name)
if new_data is not None:
Expand Down Expand Up @@ -534,8 +536,8 @@ def div(self, A, B, C="") -> str:
def pow(self, A, B, C="") -> str:
return self.make_op("Pow", {}, (A, B), (C,))[0]

def matmul(self, A, B, Y="") -> str:
return self.make_op("MatMul", {}, (A, B), (Y,))[0]
def matmul(self, A, B, Y="", transA=0, transB=0) -> str:
return self.make_op("llm::MatMul", {"transA": transA, "transB": transB}, (A, B), (Y,), use_onnx_standard=False)[0]

def gemm(self, A, B, C=None, Y="", alpha=1.0, beta=1.0, transA=0, transB=0) -> str:
inputs = (A, B, C) if C is not None else (A, B)
Expand Down Expand Up @@ -615,3 +617,12 @@ def cast(self, input, to: DTYPE, output="") -> str:

def softmax(self, input, axis=-1, output="") -> str:
return self.make_op("Softmax", {"axis": axis}, (input,), (output,))[0]

def rms_norm(self, input, weight, eps=1e-5, output=""):
return self.make_op(
"llm::RmsNormalization",
{"epsilon": eps},
(input, weight),
(output,),
use_onnx_standard=False,
)[0]
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def load_param_safetensors(self, model_path):
tensor_name = ""
if naming[0] == "lm_head":
tensor_name = "LlamaModel/lm_head/weight"
data = np.transpose(data, (1, 0)).copy()
elif naming[0] == "model" and naming[1] == "embed_tokens":
tensor_name = "LlamaModel/embed_tokens"
elif naming[0] == "model" and naming[1] == "norm":
Expand All @@ -181,10 +180,8 @@ def load_param_safetensors(self, model_path):
tensor_name = f"LlamaModel/Decoder_{naming[2]}/"
if naming[3] == "self_attn":
tensor_name += f"Attention/{naming[4]}/weight"
data = np.transpose(data, (1, 0)).copy()
elif naming[3] == "mlp":
tensor_name += f"FeedForward/{naming[4]}/weight"
data = np.transpose(data, (1, 0)).copy()
else:
tensor_name += f"{naming[3]}/weight"
else:
Expand Down
4 changes: 2 additions & 2 deletions src/09python_ffi/src/refactor_graph/frontend/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
shape = (in_features, out_features)
shape = (out_features, in_features)
self.weight = self.parameter(
(np.random.random(shape)).astype(dtype.np_type()), "weight"
)
Expand All @@ -29,7 +29,7 @@ def __init__(

def __call__(self, input):
super().__call__([input])
output = self.matmul(input, self.weight)
output = self.matmul(input, self.weight, transB=1)
if self.use_bias:
output = self.add(output, self.bias)
self.outputs.append(output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class RMSNorm(InfiniTensorModel):
def __init__(self, hidden_size, eps: float = 1e-6, dtype=DTYPE.F32, **kwargs):
super().__init__(**kwargs)
self.eps = np.array(eps, dtype=dtype.np_type())
self.eps = eps
self.hidden_size = hidden_size
self.dtype = dtype
self.weight = self.parameter(
Expand All @@ -14,15 +14,6 @@ def __init__(self, hidden_size, eps: float = 1e-6, dtype=DTYPE.F32, **kwargs):

def __call__(self, hidden_states):
super().__call__([hidden_states])
variance = self.reduce_mean(
self.pow(hidden_states, np.array(2, dtype=self.dtype.np_type())), -1
)
hidden_states = self.mul(
hidden_states,
self.div(
np.array(1, dtype=self.dtype.np_type()), self.sqrt(self.add(variance, self.eps))
),
)
hidden_states = self.mul(hidden_states, self.weight)
hidden_states = self.rms_norm(hidden_states, self.weight, self.eps)
self.outputs = [hidden_states]
return hidden_states

0 comments on commit 1f0fd6e

Please sign in to comment.