From 1f0fd6e372ecf3022a57dd6d2ef164e1e3d82f9a Mon Sep 17 00:00:00 2001 From: panzezhong Date: Mon, 29 Jan 2024 09:54:41 +0800 Subject: [PATCH] =?UTF-8?q?fix=20(api):=20=E9=80=82=E9=85=8D=E6=96=B0?= =?UTF-8?q?=E7=9A=84rms=20norm=E5=92=8Cmatmul=E7=AE=97=E5=AD=90=EF=BC=8C?= =?UTF-8?q?=E9=81=BF=E5=85=8Dload=E6=97=B6transpose?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/refactor_graph/frontend/modeling.py | 21 ++++++++++++++----- .../frontend/models/llama/llama.py | 3 --- .../src/refactor_graph/frontend/nn/nn.py | 4 ++-- .../frontend/transformer/rms_norm.py | 13 ++---------- 4 files changed, 20 insertions(+), 21 deletions(-) diff --git a/src/09python_ffi/src/refactor_graph/frontend/modeling.py b/src/09python_ffi/src/refactor_graph/frontend/modeling.py index a9f60ba1..785a20ec 100644 --- a/src/09python_ffi/src/refactor_graph/frontend/modeling.py +++ b/src/09python_ffi/src/refactor_graph/frontend/modeling.py @@ -186,7 +186,7 @@ def make_op( inputs: Tuple[Union[str, np.ndarray], ...], outputs: Tuple[str, ...] | int = 1, name: str | None = None, - use_onnx_standard: bool = True + use_onnx_standard: bool = True, ): if use_onnx_standard and not op_type.startswith("onnx::"): op_type = "onnx::" + op_type @@ -357,7 +357,7 @@ def run( self._nodes, operators, edges, input_names, output_names ) self._executor = self._compiler.compile_on( - find_device(self._device, self._device_id), "default", [] + find_device(self._device, self._device_id), "default", ["ce"] ) elif recompile or self._executor is None: # Set input info @@ -382,7 +382,7 @@ def run( self._compiler.set_input(input_names.index(name), dynamic_tensor) self._executor = self._compiler.compile_on( - find_device(self._device, self._device_id), "default", [] + find_device(self._device, self._device_id), "default", ["ce"] ) ## Executor should have been created at this point @@ -489,6 +489,8 @@ def make_onnx( return onnx.helper.make_model(graph) def load_params(self, data: Dict[str, np.ndarray]): + if len(self._parameters) != len(data): + print(f"Warning: the number of loaded params does not match current model.") for name in self._parameters: new_data = data.get(name) if new_data is not None: @@ -534,8 +536,8 @@ def div(self, A, B, C="") -> str: def pow(self, A, B, C="") -> str: return self.make_op("Pow", {}, (A, B), (C,))[0] - def matmul(self, A, B, Y="") -> str: - return self.make_op("MatMul", {}, (A, B), (Y,))[0] + def matmul(self, A, B, Y="", transA=0, transB=0) -> str: + return self.make_op("llm::MatMul", {"transA": transA, "transB": transB}, (A, B), (Y,), use_onnx_standard=False)[0] def gemm(self, A, B, C=None, Y="", alpha=1.0, beta=1.0, transA=0, transB=0) -> str: inputs = (A, B, C) if C is not None else (A, B) @@ -615,3 +617,12 @@ def cast(self, input, to: DTYPE, output="") -> str: def softmax(self, input, axis=-1, output="") -> str: return self.make_op("Softmax", {"axis": axis}, (input,), (output,))[0] + + def rms_norm(self, input, weight, eps=1e-5, output=""): + return self.make_op( + "llm::RmsNormalization", + {"epsilon": eps}, + (input, weight), + (output,), + use_onnx_standard=False, + )[0] diff --git a/src/09python_ffi/src/refactor_graph/frontend/models/llama/llama.py b/src/09python_ffi/src/refactor_graph/frontend/models/llama/llama.py index c88b8de1..e6ec9587 100644 --- a/src/09python_ffi/src/refactor_graph/frontend/models/llama/llama.py +++ b/src/09python_ffi/src/refactor_graph/frontend/models/llama/llama.py @@ -172,7 +172,6 @@ def load_param_safetensors(self, model_path): tensor_name = "" if naming[0] == "lm_head": tensor_name = "LlamaModel/lm_head/weight" - data = np.transpose(data, (1, 0)).copy() elif naming[0] == "model" and naming[1] == "embed_tokens": tensor_name = "LlamaModel/embed_tokens" elif naming[0] == "model" and naming[1] == "norm": @@ -181,10 +180,8 @@ def load_param_safetensors(self, model_path): tensor_name = f"LlamaModel/Decoder_{naming[2]}/" if naming[3] == "self_attn": tensor_name += f"Attention/{naming[4]}/weight" - data = np.transpose(data, (1, 0)).copy() elif naming[3] == "mlp": tensor_name += f"FeedForward/{naming[4]}/weight" - data = np.transpose(data, (1, 0)).copy() else: tensor_name += f"{naming[3]}/weight" else: diff --git a/src/09python_ffi/src/refactor_graph/frontend/nn/nn.py b/src/09python_ffi/src/refactor_graph/frontend/nn/nn.py index 47285d08..4c73a0df 100644 --- a/src/09python_ffi/src/refactor_graph/frontend/nn/nn.py +++ b/src/09python_ffi/src/refactor_graph/frontend/nn/nn.py @@ -17,7 +17,7 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - shape = (in_features, out_features) + shape = (out_features, in_features) self.weight = self.parameter( (np.random.random(shape)).astype(dtype.np_type()), "weight" ) @@ -29,7 +29,7 @@ def __init__( def __call__(self, input): super().__call__([input]) - output = self.matmul(input, self.weight) + output = self.matmul(input, self.weight, transB=1) if self.use_bias: output = self.add(output, self.bias) self.outputs.append(output) diff --git a/src/09python_ffi/src/refactor_graph/frontend/transformer/rms_norm.py b/src/09python_ffi/src/refactor_graph/frontend/transformer/rms_norm.py index b14a9fff..48c52a27 100644 --- a/src/09python_ffi/src/refactor_graph/frontend/transformer/rms_norm.py +++ b/src/09python_ffi/src/refactor_graph/frontend/transformer/rms_norm.py @@ -5,7 +5,7 @@ class RMSNorm(InfiniTensorModel): def __init__(self, hidden_size, eps: float = 1e-6, dtype=DTYPE.F32, **kwargs): super().__init__(**kwargs) - self.eps = np.array(eps, dtype=dtype.np_type()) + self.eps = eps self.hidden_size = hidden_size self.dtype = dtype self.weight = self.parameter( @@ -14,15 +14,6 @@ def __init__(self, hidden_size, eps: float = 1e-6, dtype=DTYPE.F32, **kwargs): def __call__(self, hidden_states): super().__call__([hidden_states]) - variance = self.reduce_mean( - self.pow(hidden_states, np.array(2, dtype=self.dtype.np_type())), -1 - ) - hidden_states = self.mul( - hidden_states, - self.div( - np.array(1, dtype=self.dtype.np_type()), self.sqrt(self.add(variance, self.eps)) - ), - ) - hidden_states = self.mul(hidden_states, self.weight) + hidden_states = self.rms_norm(hidden_states, self.weight, self.eps) self.outputs = [hidden_states] return hidden_states