Skip to content

Commit

Permalink
process -> thread
Browse files Browse the repository at this point in the history
  • Loading branch information
ShawnXuan committed Jun 4, 2021
1 parent e659921 commit 30f9480
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions ClickThroughRate/WideDeepLearning/wdl_train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def str_list(x):
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_format', type=str, default='ofrecord', help='ofrecord or onerec')
parser.add_argument(
"--use_single_dataloader_process",
"--use_single_dataloader_thread",
action="store_true",
help="use single dataloader processes per node or not."
help="use single dataloader threads per node or not."
)
parser.add_argument('--num_dataloader_process_per_gpu', type=int, default=1)
parser.add_argument('--num_dataloader_thread_per_gpu', type=int, default=2)
parser.add_argument('--train_data_dir', type=str, default='')
parser.add_argument('--train_data_part_num', type=int, default=1)
parser.add_argument('--train_part_name_suffix_length', type=int, default=-1)
Expand Down Expand Up @@ -66,12 +66,12 @@ def str_list(x):
DEEP_HIDDEN_UNITS = [FLAGS.hidden_size for i in range(FLAGS.hidden_units_num)]

def _data_loader(data_dir, data_part_num, batch_size, part_name_suffix_length=-1, shuffle=True):
assert FLAGS.num_dataloader_process_per_gpu >= 1
if FLAGS.use_single_dataloader_process:
assert FLAGS.num_dataloader_thread_per_gpu >= 1
if FLAGS.use_single_dataloader_thread:
devices = ['{}:0'.format(i) for i in range(FLAGS.num_nodes)]
else:
num_dataloader_process = FLAGS.num_dataloader_process_per_gpu * FLAGS.gpu_num_per_node
devices = ['{}:0-{}'.format(i, num_dataloader_process - 1) for i in range(FLAGS.num_nodes)]
num_dataloader_thread = FLAGS.num_dataloader_thread_per_gpu * FLAGS.gpu_num_per_node
devices = ['{}:0-{}'.format(i, num_dataloader_thread - 1) for i in range(FLAGS.num_nodes)]
with flow.scope.placement("cpu", devices):
if FLAGS.dataset_format == 'ofrecord':
data = _data_loader_ofrecord(data_dir, data_part_num, batch_size,
Expand Down

0 comments on commit 30f9480

Please sign in to comment.