@@ -155,19 +155,17 @@ def _hybrid_embedding(name, ids, embedding_size, vocab_size, hf_vocab_size):
155
155
dtype = flow .float ,
156
156
initializer = flow .random_uniform_initializer (minval = - 0.05 , maxval = 0.05 ),
157
157
)
158
- hf_embedding = flow .gather (params = hf_embedding_table , indices = hf_ids )#, no_duplicates_in_indices=True)
158
+ hf_embedding = flow .gather (params = hf_embedding_table , indices = hf_ids )
159
159
lf_ids = lf_ids - hf_vocab_size_constant
160
160
with flow .scope .placement ('cpu' , '0:0' ):
161
161
lf_embedding_table = flow .get_variable (
162
162
name = f'lf_{ name } ' ,
163
163
shape = (vocab_size - hf_vocab_size , embedding_size ),
164
- #shape=(vocab_size, embedding_size),
165
164
dtype = flow .float ,
166
165
initializer = flow .random_uniform_initializer (minval = - 0.05 , maxval = 0.05 ),
167
166
)
168
- lf_embedding = flow .gather (params = lf_embedding_table , indices = lf_ids )#, no_duplicates_in_indices=True)
167
+ lf_embedding = flow .gather (params = lf_embedding_table , indices = lf_ids )
169
168
unique_embedding = flow .reshape (flow .zeros_like (unique_ids , dtype = flow .float ), (- 1 , 1 )) * flow .constant (0.0 , dtype = flow .float , shape = (1 ,embedding_size ))
170
- # unique_embedding = flow.constant(0.0, dtype=flow.float, shape=(b*s, embedding_size))
171
169
unique_embedding = flow .tensor_scatter_nd_update (params = unique_embedding , updates = hf_embedding , indices = hf_indices )
172
170
unique_embedding = flow .tensor_scatter_nd_update (params = unique_embedding , updates = lf_embedding , indices = lf_indices )
173
171
unique_embedding = flow .gather (params = unique_embedding , indices = unique_ids_idx )
@@ -309,8 +307,6 @@ def print_args(args):
309
307
for arg in vars (args ):
310
308
print ("{} = {}" .format (arg , getattr (args , arg )))
311
309
print ("-" .ljust (66 , "-" ))
312
- #print("Time stamp: {}".format(
313
- # str(datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))))
314
310
315
311
316
312
def main ():
@@ -320,8 +316,6 @@ def main():
320
316
flow .config .enable_model_io_v2 (True )
321
317
flow .config .enable_debug_mode (True )
322
318
flow .config .collective_boxing .nccl_enable_all_to_all (True )
323
- #flow.config.enable_numa_aware_cuda_malloc_host(True)
324
- #flow.config.collective_boxing.enable_fusion(False)
325
319
check_point = flow .train .CheckPoint ()
326
320
check_point .init ()
327
321
for i in range (FLAGS .max_iter ):
0 commit comments