Skip to content

Commit

Permalink
Merge pull request #199 from Oneflow-Inc/wdl_multi_dataloader_process
Browse files Browse the repository at this point in the history
Wdl multi dataloader thread
  • Loading branch information
ShawnXuan authored Jun 9, 2021
2 parents c9a9342 + 30f9480 commit 707e226
Showing 1 changed file with 38 additions and 24 deletions.
62 changes: 38 additions & 24 deletions ClickThroughRate/WideDeepLearning/wdl_train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def str_list(x):
return x.split(',')
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_format', type=str, default='ofrecord', help='ofrecord or onerec')
parser.add_argument(
"--use_single_dataloader_thread",
action="store_true",
help="use single dataloader threads per node or not."
)
parser.add_argument('--num_dataloader_thread_per_gpu', type=int, default=2)
parser.add_argument('--train_data_dir', type=str, default='')
parser.add_argument('--train_data_part_num', type=int, default=1)
parser.add_argument('--train_part_name_suffix_length', type=int, default=-1)
Expand Down Expand Up @@ -60,16 +66,23 @@ def str_list(x):
DEEP_HIDDEN_UNITS = [FLAGS.hidden_size for i in range(FLAGS.hidden_units_num)]

def _data_loader(data_dir, data_part_num, batch_size, part_name_suffix_length=-1, shuffle=True):
if FLAGS.dataset_format == 'ofrecord':
return _data_loader_ofrecord(data_dir, data_part_num, batch_size, part_name_suffix_length,
shuffle)
elif FLAGS.dataset_format == 'onerec':
return _data_loader_onerec(data_dir, batch_size, shuffle)
elif FLAGS.dataset_format == 'synthetic':
return _data_loader_synthetic(batch_size)
assert FLAGS.num_dataloader_thread_per_gpu >= 1
if FLAGS.use_single_dataloader_thread:
devices = ['{}:0'.format(i) for i in range(FLAGS.num_nodes)]
else:
assert 0, "Please specify dataset_type as `ofrecord`, `onerec` or `synthetic`."

num_dataloader_thread = FLAGS.num_dataloader_thread_per_gpu * FLAGS.gpu_num_per_node
devices = ['{}:0-{}'.format(i, num_dataloader_thread - 1) for i in range(FLAGS.num_nodes)]
with flow.scope.placement("cpu", devices):
if FLAGS.dataset_format == 'ofrecord':
data = _data_loader_ofrecord(data_dir, data_part_num, batch_size,
part_name_suffix_length, shuffle)
elif FLAGS.dataset_format == 'onerec':
data = _data_loader_onerec(data_dir, batch_size, shuffle)
elif FLAGS.dataset_format == 'synthetic':
data = _data_loader_synthetic(batch_size)
else:
assert 0, "Please specify dataset_type as `ofrecord`, `onerec` or `synthetic`."
return flow.identity_n(data)


def _data_loader_ofrecord(data_dir, data_part_num, batch_size, part_name_suffix_length=-1,
Expand All @@ -88,22 +101,20 @@ def _blob_decoder(bn, shape, dtype=flow.int32):
dense_fields = _blob_decoder("dense_fields", (FLAGS.num_dense_fields,), flow.float)
wide_sparse_fields = _blob_decoder("wide_sparse_fields", (FLAGS.num_wide_sparse_fields,))
deep_sparse_fields = _blob_decoder("deep_sparse_fields", (FLAGS.num_deep_sparse_fields,))
return flow.identity_n([labels, dense_fields, wide_sparse_fields, deep_sparse_fields])
return [labels, dense_fields, wide_sparse_fields, deep_sparse_fields]


def _data_loader_synthetic(batch_size):
devices = ['{}:0-{}'.format(i, FLAGS.gpu_num_per_node - 1) for i in range(FLAGS.num_nodes)]
with flow.scope.placement("cpu", devices):
def _blob_random(shape, dtype=flow.int32, initializer=flow.zeros_initializer(flow.int32)):
return flow.data.decode_random(shape=shape, dtype=dtype, batch_size=batch_size,
initializer=initializer)
labels = _blob_random((1,), initializer=flow.random_uniform_initializer(dtype=flow.int32))
dense_fields = _blob_random((FLAGS.num_dense_fields,), dtype=flow.float,
initializer=flow.random_uniform_initializer())
wide_sparse_fields = _blob_random((FLAGS.num_wide_sparse_fields,))
deep_sparse_fields = _blob_random((FLAGS.num_deep_sparse_fields,))
print('use synthetic data')
return flow.identity_n([labels, dense_fields, wide_sparse_fields, deep_sparse_fields])
def _blob_random(shape, dtype=flow.int32, initializer=flow.zeros_initializer(flow.int32)):
return flow.data.decode_random(shape=shape, dtype=dtype, batch_size=batch_size,
initializer=initializer)
labels = _blob_random((1,), initializer=flow.random_uniform_initializer(dtype=flow.int32))
dense_fields = _blob_random((FLAGS.num_dense_fields,), dtype=flow.float,
initializer=flow.random_uniform_initializer())
wide_sparse_fields = _blob_random((FLAGS.num_wide_sparse_fields,))
deep_sparse_fields = _blob_random((FLAGS.num_deep_sparse_fields,))
print('use synthetic data')
return [labels, dense_fields, wide_sparse_fields, deep_sparse_fields]


def _data_loader_onerec(data_dir, batch_size, shuffle):
Expand All @@ -112,6 +123,7 @@ def _data_loader_onerec(data_dir, batch_size, shuffle):
files = glob.glob(os.path.join(data_dir, '*.onerec'))
readdata = flow.data.onerec_reader(files=files, batch_size=batch_size, random_shuffle=shuffle,
verify_example=False,
shuffle_mode="batch",
shuffle_buffer_size=64,
shuffle_after_epoch=shuffle)

Expand All @@ -122,7 +134,7 @@ def _blob_decoder(bn, shape, dtype=flow.int32):
dense_fields = _blob_decoder("dense_fields", (FLAGS.num_dense_fields,), flow.float)
wide_sparse_fields = _blob_decoder("wide_sparse_fields", (FLAGS.num_wide_sparse_fields,))
deep_sparse_fields = _blob_decoder("deep_sparse_fields", (FLAGS.num_deep_sparse_fields,))
return flow.identity_n([labels, dense_fields, wide_sparse_fields, deep_sparse_fields])
return [labels, dense_fields, wide_sparse_fields, deep_sparse_fields]


def _model(dense_fields, wide_sparse_fields, deep_sparse_fields):
Expand Down Expand Up @@ -260,7 +272,9 @@ def main():
flow.config.gpu_device_num(FLAGS.gpu_num_per_node)
flow.config.enable_model_io_v2(True)
flow.config.enable_debug_mode(True)
flow.config.collective_boxing.nccl_enable_all_to_all(True)
flow.config.enable_legacy_model_io(True)
flow.config.nccl_use_compute_stream(True)
# flow.config.collective_boxing.nccl_enable_all_to_all(True)
#flow.config.enable_numa_aware_cuda_malloc_host(True)
#flow.config.collective_boxing.enable_fusion(False)
check_point = flow.train.CheckPoint()
Expand Down

0 comments on commit 707e226

Please sign in to comment.