Skip to content

Commit d1275e6

Browse files
committed
rm usless comments
1 parent 3aec0ec commit d1275e6

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,19 +155,17 @@ def _hybrid_embedding(name, ids, embedding_size, vocab_size, hf_vocab_size):
155155
dtype=flow.float,
156156
initializer=flow.random_uniform_initializer(minval=-0.05, maxval=0.05),
157157
)
158-
hf_embedding = flow.gather(params=hf_embedding_table, indices=hf_ids)#, no_duplicates_in_indices=True)
158+
hf_embedding = flow.gather(params=hf_embedding_table, indices=hf_ids)
159159
lf_ids = lf_ids - hf_vocab_size_constant
160160
with flow.scope.placement('cpu', '0:0'):
161161
lf_embedding_table = flow.get_variable(
162162
name=f'lf_{name}',
163163
shape=(vocab_size - hf_vocab_size, embedding_size),
164-
#shape=(vocab_size, embedding_size),
165164
dtype=flow.float,
166165
initializer=flow.random_uniform_initializer(minval=-0.05, maxval=0.05),
167166
)
168-
lf_embedding = flow.gather(params=lf_embedding_table, indices=lf_ids)#, no_duplicates_in_indices=True)
167+
lf_embedding = flow.gather(params=lf_embedding_table, indices=lf_ids)
169168
unique_embedding = flow.reshape(flow.zeros_like(unique_ids, dtype=flow.float), (-1, 1)) * flow.constant(0.0, dtype=flow.float, shape=(1,embedding_size))
170-
# unique_embedding = flow.constant(0.0, dtype=flow.float, shape=(b*s, embedding_size))
171169
unique_embedding = flow.tensor_scatter_nd_update(params=unique_embedding, updates=hf_embedding, indices=hf_indices)
172170
unique_embedding = flow.tensor_scatter_nd_update(params=unique_embedding, updates=lf_embedding, indices=lf_indices)
173171
unique_embedding = flow.gather(params=unique_embedding, indices=unique_ids_idx)
@@ -309,8 +307,6 @@ def print_args(args):
309307
for arg in vars(args):
310308
print("{} = {}".format(arg, getattr(args, arg)))
311309
print("-".ljust(66, "-"))
312-
#print("Time stamp: {}".format(
313-
# str(datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))))
314310

315311

316312
def main():
@@ -320,8 +316,6 @@ def main():
320316
flow.config.enable_model_io_v2(True)
321317
flow.config.enable_debug_mode(True)
322318
flow.config.collective_boxing.nccl_enable_all_to_all(True)
323-
#flow.config.enable_numa_aware_cuda_malloc_host(True)
324-
#flow.config.collective_boxing.enable_fusion(False)
325319
check_point = flow.train.CheckPoint()
326320
check_point.init()
327321
for i in range(FLAGS.max_iter):

0 commit comments

Comments
 (0)