@@ -100,6 +100,8 @@ def qkv_weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str]
100100
101101 def out_proj_weight_loader (self , param , loaded_weight , loaded_shard_id : Optional [str ] = None ):
102102 loaded_weight = loaded_weight .transpose ([1 , 0 ])
103+ if not param ._is_initialized ():
104+ param .initialize ()
103105 assert param .shape == loaded_weight .shape , (
104106 f" Attempted to load weight ({ loaded_weight .shape } ) " f"into parameter ({ param .shape } )"
105107 )
@@ -109,7 +111,7 @@ def out_proj_weight_loader(self, param, loaded_weight, loaded_shard_id: Optional
109111 loaded_weight = loaded_weight .view (param .dtype )
110112 else :
111113 loaded_weight = loaded_weight .cast (param .dtype )
112- param . copy_ ( loaded_weight , False )
114+ h2d_copy ( param , loaded_weight )
113115
114116 def forward (
115117 self ,
@@ -287,6 +289,8 @@ def __init__(self, config):
287289
288290 def weight_loader (self , param , loaded_weight , loaded_shard_id : Optional [str ] = None ):
289291 loaded_weight = loaded_weight .transpose ([1 , 0 ])
292+ if not param ._is_initialized ():
293+ param .initialize ()
290294 assert param .shape == loaded_weight .shape , (
291295 f" Attempted to load weight ({ loaded_weight .shape } ) " f"into parameter ({ param .shape } )"
292296 )
@@ -296,7 +300,7 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
296300 loaded_weight = loaded_weight .view (param .dtype )
297301 else :
298302 loaded_weight = loaded_weight .cast (param .dtype )
299- param . copy_ ( loaded_weight , False )
303+ h2d_copy ( param , loaded_weight )
300304
301305 def forward (self , hidden_states : paddle .Tensor ) -> paddle .Tensor :
302306 hidden_states = self .fc1 (hidden_states )
0 commit comments