From 75a0f87f087220441b02a869b874a025e84667f5 Mon Sep 17 00:00:00 2001 From: panzezhong Date: Tue, 30 Jan 2024 12:54:17 +0800 Subject: [PATCH] =?UTF-8?q?fix=20(api):=20=E4=B8=BArms=20norm=E5=8A=A0?= =?UTF-8?q?=E4=B8=8Afp32=E5=BC=BA=E8=BD=AC=EF=BC=8C=E4=B8=BAload=20param?= =?UTF-8?q?=E5=8A=A0=E7=B1=BB=E5=9E=8B=E8=BD=AC=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/09python_ffi/src/refactor_graph/frontend/modeling.py | 9 +++++---- .../src/refactor_graph/frontend/transformer/rms_norm.py | 4 +++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/09python_ffi/src/refactor_graph/frontend/modeling.py b/src/09python_ffi/src/refactor_graph/frontend/modeling.py index 785a20ec..3031d5db 100644 --- a/src/09python_ffi/src/refactor_graph/frontend/modeling.py +++ b/src/09python_ffi/src/refactor_graph/frontend/modeling.py @@ -489,19 +489,20 @@ def make_onnx( return onnx.helper.make_model(graph) def load_params(self, data: Dict[str, np.ndarray]): - if len(self._parameters) != len(data): - print(f"Warning: the number of loaded params does not match current model.") + if len(self._parameters) != len(data) or set(data.keys()) != set(self._parameters.keys()): + print(f"Warning: the number or name of loaded params does not match current model.") for name in self._parameters: new_data = data.get(name) if new_data is not None: if self._parameters[name].shape != new_data.shape: print( - f"Warning: Shape mismatch for {name}, expecting {self._parameters[name].shape} but get {new_data.shape}" + f"Warning: Shape mismatch for {name}, expecting {self._parameters[name].shape} but get {new_data.shape}." ) if self._parameters[name].dtype != new_data.dtype: print( - f"Warning: Type mismatch for {name}, expecting {self._parameters[name].dtype} but get {new_data.dtype}" + f"Warning: Type mismatch for {name}. Casting to {self._parameters[name].dtype} from {new_data.dtype}." ) + new_data = new_data.astype(self._parameters[name].dtype) self._parameters[name] = new_data else: print(f"Warning: Value for {name} is not provided for loading.") diff --git a/src/09python_ffi/src/refactor_graph/frontend/transformer/rms_norm.py b/src/09python_ffi/src/refactor_graph/frontend/transformer/rms_norm.py index 48c52a27..2bfe55e8 100644 --- a/src/09python_ffi/src/refactor_graph/frontend/transformer/rms_norm.py +++ b/src/09python_ffi/src/refactor_graph/frontend/transformer/rms_norm.py @@ -9,11 +9,13 @@ def __init__(self, hidden_size, eps: float = 1e-6, dtype=DTYPE.F32, **kwargs): self.hidden_size = hidden_size self.dtype = dtype self.weight = self.parameter( - np.ones(self.hidden_size, dtype=self.dtype.np_type()), "weight" + np.ones(self.hidden_size, dtype=DTYPE.F32.np_type()), "weight" ) def __call__(self, hidden_states): super().__call__([hidden_states]) + hidden_states = self.cast(hidden_states, DTYPE.F32) hidden_states = self.rms_norm(hidden_states, self.weight, self.eps) + hidden_states = self.cast(hidden_states, self.dtype) self.outputs = [hidden_states] return hidden_states