Skip to content

Commit

Permalink
Update the code for generating the training dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
fengxinmin committed Feb 12, 2025
1 parent 0c95153 commit 2f5c56e
Show file tree
Hide file tree
Showing 253 changed files with 459 additions and 2 deletions.
26 changes: 24 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,29 @@

</div>

## Training Dataset

## :running_woman: Previous work
The training dataset is available at [Baidu Cloud](https://pan.baidu.com/s/1ZMPZqOcQS_gri_pzSq2vGA?pwd=tmxn). We used 668 4K sequences with 32 frames from the BVI-DVC dataset, Tencent Video Dataset, and UVG dataset. These sequences were cropped or downsampled to create datasets with four different resolutions: 3840x2160, 1920x1080, 960x544, and 480x272. We organized the training dataset using HDF5 format, which includes the following files:

[Partition Map Prediction for Fast Block Partitioning in VVC Intra-frame Coding](https://github.com/AolinFeng/PMP-VVC-TIP2023)
- `train_seqs.h5`: Luma components of the original sequences.
- `train_qp22.h5`: Training dataset label for basic QP22.
- `train_qp27.h5`: Training dataset label for basic QP27.
- `train_qp32.h5`: Training dataset label for basic QP32.
- `train_qp37.h5`: Training dataset label for basic QP37.

To further support subsequent research, we also provide the code for generating the training dataset, which includes:

1. Modified VTM source code `codec/print_encoder` and the executable file `codec/exe/print_encoder.exe` for extracting block partitioning statistics from YUV sequences. Code `dataset_preparation.py` for extracting the statistics into `DepthSaving/` with multiple threads.
3. Code `depth2dataset.py` for converting the statistics into partition maps.





<!-- ## :running_woman: TODO -->

## References

1. [Partition Map Prediction for Fast Block Partitioning in VVC Intra-frame Coding](https://github.com/AolinFeng/PMP-VVC-TIP2023)

2.
File renamed without changes.
118 changes: 118 additions & 0 deletions dataset_preparation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import os
import argparse
import time
from concurrent.futures import ThreadPoolExecutor

def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')

def write_seq_name_from_dir(Train_sequence_dir):
seq_list = os.listdir(Train_sequence_dir)
with open('sequences_name.txt','w') as f:
for seq_name in seq_list:
f.write(seq_name.strip('.yuv') + '\n')

def load_ifo_from_cfg(cfg_path):
fp = open(cfg_path)
input_path = None
bit_depth = None
width = None
height = None
frame_num = None
for line in fp:
if "InputFile" in line:
line = line.rstrip('\n').replace(' ', '').split('#')[0]
loc = line.find(':')
input_path = line[loc+1:]
elif "InputBitDepth" in line:
bit_depth = int(line.rstrip('\n').replace(' ', '').split('#')[0].split(':')[-1])
elif "SourceWidth" in line:
width = int(line.rstrip('\n').replace(' ', '').split('#')[0].split(':')[-1])
elif "SourceHeight" in line:
height = int(line.rstrip('\n').replace(' ', '').split('#')[0].split(':')[-1])
elif "FramesToBeEncoded" in line:
frame_num = int(line.rstrip('\n').replace(' ', '').split('#')[0].split(':')[-1])
if (input_path is None) or (bit_depth is None) or (width is None) or (height is None):
print("Format of CFG error !!!!!!!!")
return
return input_path, bit_depth, width, height,frame_num



def make_dataset(qp, seq_name, seq_dir, exe_name, bit_depth, enc_mode):
"""only for depth label
enc_mode: LDP LDB RA"""
cur_path = os.getcwd()
exe_dir = os.path.join(cur_path, "codec")
log_path = os.path.join(cur_path, "log", "QP" + str(qp))
out_path = os.path.join(cur_path, 'output', enc_mode, 'QP' + str(qp), "train_dataset")
depth_path = os.path.join(cur_path, "DepthSaving")
seq_path = os.path.join(seq_dir, seq_name + '.yuv')
frame_num = args.test_frm_num
per_cfg_path = os.path.join(args.cfg_dir, seq_name + '.cfg')
is10bit = False
if bit_depth == 10:
is10bit = True
bin_name = seq_name + ".bin"
enc_log_name = 'enc_' + seq_name + '.log'
enc_cfg_name = 'encoder_randomaccess_vtm.cfg'
if os.path.exists(os.path.join(depth_path, seq_name + '_Depth.txt')):
os.remove(os.path.join(depth_path, seq_name + '_Depth.txt'))
if os.path.exists(os.path.join(log_path, enc_log_name)):
os.remove(os.path.join(log_path, enc_log_name))
if not os.path.exists(out_path):
os.makedirs(out_path)
if args.train:
encoder_order = os.path.join(exe_dir, exe_name) + ' -c ' + os.path.join(args.cfg_dir, enc_cfg_name)\
+ ' -c ' + per_cfg_path + " -f "+ str(frame_num) + ' -i ' + seq_path + " -q " +str(qp) \
+ ' -b ' + os.path.join(out_path, seq_name + '.bin') \
+ " >> " + os.path.join(log_path, enc_log_name)
else:
encoder_order = os.path.join(exe_dir, exe_name) + ' -c ' + os.path.join(args.cfg_dir, enc_cfg_name)\
+ ' -c ' + per_cfg_path + " -f "+ str(frame_num) + ' -i ' + seq_path + " -q " +str(qp) \
+ ' -b ' + os.path.join(out_path, seq_name + '.bin') \
+ " >> " + os.path.join(log_path, enc_log_name)
print(encoder_order)
os.system(encoder_order)


if __name__ == '__main__':
# Encode each sequence and generate partition depth files, e.g., DepthSaving/BasketballDrillText_832x480_50_Depth.txt
parser = argparse.ArgumentParser()
parser.add_argument('--qp', type=int, default=None, help='option: 22 27 32 37')
parser.add_argument('--sequences_dir', type=str, help='Path to the training sequences')
parser.add_argument('--cfg_dir', type=str, help='Path to the configuration file of the training sequences')
parser.add_argument('--test_frm_num', type=int, default=32)
parser.add_argument('--bit_depth', type=int, default=8)

args = parser.parse_args()
qp_list = [32,27,22,37] if args.qp is None else [args.qp]
exe_name = 'print_encoder.exe'

write_seq_name_from_dir(args.sequences_dir)
seq_list = list()
with open("sequences_name.txt") as f:
for line in f.readlines():
seq_list.append(line.strip('\n'))

po = ThreadPoolExecutor(max_workers=35)
future_list = []
for qp in qp_list:
for seq_name in seq_list:
# make_dataset(qp, seq_name, args.sequences_dir, exe_name, args.bit_depth, args.enc_mode)
future_list.append(po.submit(make_dataset, seq_name, args.sequences_dir, exe_name, args.bit_depth, args.enc_mode))
while True:
is_done = True
for future in future_list:
is_done &= future.done()
if is_done:
break
time.sleep(5)
po.shutdown()
Loading

0 comments on commit 2f5c56e

Please sign in to comment.