diff --git a/ClickThroughRate/WideDeepLearning/wdl_train_eval.py b/ClickThroughRate/WideDeepLearning/wdl_train_eval.py index ea3c87a..726fb69 100644 --- a/ClickThroughRate/WideDeepLearning/wdl_train_eval.py +++ b/ClickThroughRate/WideDeepLearning/wdl_train_eval.py @@ -27,11 +27,11 @@ def str_list(x): parser = argparse.ArgumentParser() parser.add_argument('--dataset_format', type=str, default='ofrecord', help='ofrecord or onerec') parser.add_argument( - "--use_single_dataloader_process", + "--use_single_dataloader_thread", action="store_true", - help="use single dataloader processes per node or not." + help="use single dataloader threads per node or not." ) -parser.add_argument('--num_dataloader_process_per_gpu', type=int, default=1) +parser.add_argument('--num_dataloader_thread_per_gpu', type=int, default=2) parser.add_argument('--train_data_dir', type=str, default='') parser.add_argument('--train_data_part_num', type=int, default=1) parser.add_argument('--train_part_name_suffix_length', type=int, default=-1) @@ -66,12 +66,12 @@ def str_list(x): DEEP_HIDDEN_UNITS = [FLAGS.hidden_size for i in range(FLAGS.hidden_units_num)] def _data_loader(data_dir, data_part_num, batch_size, part_name_suffix_length=-1, shuffle=True): - assert FLAGS.num_dataloader_process_per_gpu >= 1 - if FLAGS.use_single_dataloader_process: + assert FLAGS.num_dataloader_thread_per_gpu >= 1 + if FLAGS.use_single_dataloader_thread: devices = ['{}:0'.format(i) for i in range(FLAGS.num_nodes)] else: - num_dataloader_process = FLAGS.num_dataloader_process_per_gpu * FLAGS.gpu_num_per_node - devices = ['{}:0-{}'.format(i, num_dataloader_process - 1) for i in range(FLAGS.num_nodes)] + num_dataloader_thread = FLAGS.num_dataloader_thread_per_gpu * FLAGS.gpu_num_per_node + devices = ['{}:0-{}'.format(i, num_dataloader_thread - 1) for i in range(FLAGS.num_nodes)] with flow.scope.placement("cpu", devices): if FLAGS.dataset_format == 'ofrecord': data = _data_loader_ofrecord(data_dir, data_part_num, batch_size,