From bcd41e5965bac7980924d9a11f37e1911a3ede8d Mon Sep 17 00:00:00 2001 From: panxuchen Date: Wed, 4 Sep 2024 20:46:31 +0800 Subject: [PATCH 1/5] auto split files for ray mode --- data_juicer/core/ray_data.py | 215 +++++++++++++++++++++++++++---- data_juicer/core/ray_executor.py | 6 +- data_juicer/utils/constant.py | 1 + 3 files changed, 193 insertions(+), 29 deletions(-) diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index 17235a2b8..f45c33b2d 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import os +from typing import Any, Generator, List, Union import pyarrow as pa from loguru import logger @@ -7,10 +10,12 @@ from data_juicer.core.data import DJDataset from data_juicer.ops import Filter, Mapper from data_juicer.utils.availability_utils import AvailabilityChecking -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import DEFAULT_MAX_FILE_SIZE, Fields from data_juicer.utils.process_utils import calculate_np with AvailabilityChecking(['ray'], requires_type='dist'): + import ray + import ray.data as rd from ray.data import Dataset @@ -78,54 +83,216 @@ def get_num_gpus(op, op_proc): return 1.0 / proc_per_gpu +def split_jsonl(file_path: str, max_size: int, output_dir: str) -> List[str]: + """Split a jsonl file into multiple sub files + + Args: + file_path (`str`): path of the original jsonl file + max_size (`int`): max size of each sub file (in MB) + output_dir (`str`): directory to save the sub files + + Returns: + List[str]: list of sub file paths + """ + os.makedirs(output_dir, exist_ok=True) + sub_file_paths = [] + file_index = 0 + max_byte_size = max_size * 1024**2 + base_file_name = os.path.basename(file_path) + file_name = os.path.splitext(base_file_name)[0] + current_size = 0 + with open(file_path, 'r', encoding='utf-8') as infile: + while True: + # Create a new file if we're starting or have reached the max size + if current_size >= max_byte_size: + file_index += 1 + current_size = 0 + # Determine the output file name + output_file_name = f'{file_name}_{file_index}.jsonl' + output_file_path = os.path.join(output_dir, output_file_name) + # Open the output file in append mode + with open(output_file_path, 'a', encoding='utf-8') as outfile: + line = infile.readline() + if not line: + break + # Write the line to the current output file + outfile.write(line) + # Update the current file size + current_size += len(line.encode('utf-8')) + # Add the new file path to the list if it's a new file + if output_file_path not in sub_file_paths: + sub_file_paths.append(output_file_path) + + return sub_file_paths + + +def split_jsonl_dataset( + dataset_paths: Union[str, List[str]], + max_size: int, + output_dir: str, +) -> List[str]: + """Re-split the jsonl dataset files. + + Args: + file_path (`str`): path of the original jsonl file + max_size (`int`): max size of each sub file (in MB) + output_dir (`str`): directory to save the sub files + + Returns: + List[str]: list of sub file paths + """ + if isinstance(dataset_paths, str): + dataset_paths = [dataset_paths] + + all_sub_file_paths = [] + for path in dataset_paths: + sub_file_paths = split_jsonl(path, max_size, output_dir) + all_sub_file_paths.extend(sub_file_paths) + + return all_sub_file_paths + + +def get_jsonl_file_names(dataset_dir_path: str) -> List[str]: + """Load all jsonl files in a directory. + + Args: + dataset_dir_path (`str`): path of the directory containing jsonl files + or the path of a single jsonl file + + Returns: + List[str]: list of jsonl file paths + """ + if os.path.isdir(dataset_dir_path): + jsonl_files = [ + os.path.join(dataset_dir_path, f) + for f in os.listdir(dataset_dir_path) + ] + elif os.path.isfile(dataset_dir_path) and dataset_dir_path.endswith( + '.jsonl') or dataset_dir_path.endswith('.json'): + jsonl_files = [dataset_dir_path] + else: + raise ValueError( + 'Invalid path: it should be a directory containing jsonl files' + ' or a single jsonl file.') + return jsonl_files + + +def best_file_num(cpu: int, memory: int, file_size: int) -> int: + """Calculate the best number of files in a single batch. + Each cpu should process the same number of files (at least one), + while the total memory should be at least 2 times larger than the + total file size. + + Args: + cpu (`int`): number of CPUs available + memory (`int`): memory available in MB + file_size (`int`): size of a single file in MB + + Returns: + int: best number of files in a single batch + """ + max_files_by_memory = memory // (2 * file_size) + + best_num_files = max(1, (max_files_by_memory // cpu) * cpu) + + return best_num_files + + +def load_splited_json_dataset( + file_paths: List[str], + file_num_in_batch: int, +) -> Generator[Dataset, None, None]: + """Load dataset from the splited jsonl files. + + Args: + file_paths (`List[str]`): + A list of paths to the JSONL files. + file_num_in_batch (`int`): + The number of files to be included in each batch. + + Yields: + `Dataset`: A dataset containing data from the specified batch of files. + """ + num_batches = (len(file_paths) + file_num_in_batch - 1) + for i in range(num_batches): + start_idx = i * file_num_in_batch + end_idx = min((i + 1) * file_num_in_batch, len(file_paths)) + batch_file_paths = file_paths[start_idx:end_idx] + dataset = rd.read_json(batch_file_paths) + yield dataset + + class RayDataset(DJDataset): - def __init__(self, - dataset: Dataset, - dataset_path: str = None, - cfg=None) -> None: - self.data = preprocess_dataset(dataset, dataset_path, cfg) + def __init__(self, datasets: Union[Dataset, Generator], cfg=None) -> None: + self.cfg = cfg self.num_proc = None + if isinstance(datasets, Dataset): + self.datasets = [datasets] + else: + self.datasets = datasets if cfg: self.num_proc = cfg.np + @classmethod + def read_jsonl(cls, + path: Union[str, List[str]], + cfg: Any = None) -> RayDataset: + files = split_jsonl_dataset(get_jsonl_file_names(path), + DEFAULT_MAX_FILE_SIZE, cfg.work_dir) + cpu = ray.cluster_resources().get('CPU', 0) + memory = ray.cluster_resources().get('memory', 0) / 1024 / 1024 + batch_file_num = best_file_num(cpu, memory, DEFAULT_MAX_FILE_SIZE) + return RayDataset(dataset=load_splited_json_dataset( + files, batch_file_num), + cfg=cfg) + + @classmethod + def read_item(cls, data: dict, cfg: Any = None) -> RayDataset: + return RayDataset(dataset=rd.from_items(data), cfg=cfg) + def process(self, operators, *, exporter=None, checkpointer=None, tracer=None) -> DJDataset: - if operators is None: - return self - if not isinstance(operators, list): - operators = [operators] - for op in operators: - self._run_single_op(op) + outputs = [] + for dataset in self.datasets: + # todo: pass dataset path into the function + data = preprocess_dataset(dataset, self.cfg) + if operators is None: + return self + if not isinstance(operators, list): + operators = [operators] + for op in operators: + data = self._run_single_op(op, data) + outputs.append(data) return self - def _run_single_op(self, op): + def _run_single_op(self, op, dataset: Dataset) -> Dataset: op_proc = calculate_np(op._name, op.mem_required, op.cpu_required, self.num_proc, op.use_cuda()) num_gpus = get_num_gpus(op, op_proc) try: if isinstance(op, Mapper): - self.data = self.data.map_batches(op.process, - batch_size=1, - batch_format='pyarrow', - num_gpus=num_gpus) + dataset = dataset.map_batches(op.process, + batch_size=1, + batch_format='pyarrow', + num_gpus=num_gpus) elif isinstance(op, Filter): - self.data = self.data.map_batches(op.compute_stats, - batch_size=1, - batch_format='pyarrow', - num_gpus=num_gpus) + dataset = dataset.map_batches(op.compute_stats, + batch_size=1, + batch_format='pyarrow', + num_gpus=num_gpus) if op.stats_export_path is not None: - self.data.write_json(op.stats_export_path, - force_ascii=False) - self.data = self.data.filter(op.process) + dataset.write_json(op.stats_export_path, force_ascii=False) + dataset = dataset.filter(op.process) else: logger.error( 'Ray executor only support Filter and Mapper OPs for now') raise NotImplementedError + return dataset except: # noqa: E722 logger.error(f'An error occurred during Op [{op._name}].') import traceback diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index a071c2dea..a3c7eabd0 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -9,7 +9,6 @@ with AvailabilityChecking(['ray'], requires_type='dist'): import ray - import ray.data as rd class RayExecutor: @@ -57,10 +56,7 @@ def run(self, load_data_np=None): from data_juicer.format.formatter import FORMATTERS dataset = FORMATTERS.modules[obj_name](**args).load_dataset() else: - dataset = rd.read_json(self.cfg.dataset_path) - - # convert all the path in dataset to absolute path - dataset = RayDataset(dataset, self.cfg.dataset_path, self.cfg) + dataset = RayDataset(self.cfg.dataset_path, self.cfg) # 2. extract processes logger.info('Preparing process operators...') ops = load_ops(self.cfg.process, self.cfg.op_fusion) diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 13bddb687..e44b0cb58 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -8,6 +8,7 @@ from loguru import logger DEFAULT_PREFIX = '__dj__' +DEFAULT_MAX_FILE_SIZE = 128 * 1024 * 1024 # 128 MB class Fields(object): From 4342d7efffdf3bebc315da4471aa3ec0469cb59a Mon Sep 17 00:00:00 2001 From: panxuchen Date: Thu, 5 Sep 2024 10:14:27 +0800 Subject: [PATCH 2/5] fix ray dataset --- data_juicer/core/ray_data.py | 15 ++++++++++++++- data_juicer/core/ray_executor.py | 2 +- data_juicer/utils/unittest_utils.py | 2 +- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index f45c33b2d..f4fd86c62 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -3,6 +3,7 @@ import os from typing import Any, Generator, List, Union +import pandas as pd import pyarrow as pa from loguru import logger @@ -233,6 +234,7 @@ def __init__(self, datasets: Union[Dataset, Generator], cfg=None) -> None: self.datasets = datasets if cfg: self.num_proc = cfg.np + self.output_dataset = [] @classmethod def read_jsonl(cls, @@ -260,7 +262,7 @@ def process(self, outputs = [] for dataset in self.datasets: # todo: pass dataset path into the function - data = preprocess_dataset(dataset, self.cfg) + data = preprocess_dataset(dataset, dataset_path=None, cfg=self.cfg) if operators is None: return self if not isinstance(operators, list): @@ -268,6 +270,7 @@ def process(self, for op in operators: data = self._run_single_op(op, data) outputs.append(data) + self.datasets = outputs return self def _run_single_op(self, op, dataset: Dataset) -> Dataset: @@ -298,3 +301,13 @@ def _run_single_op(self, op, dataset: Dataset) -> Dataset: import traceback traceback.print_exc() exit(1) + + def to_pandas(self) -> pd.DataFrame: + dfs = [] + for data in self.datasets: + dfs.append(data.to_pandas()) + return pd.concat(dfs, ignore_index=True) + + def write_json(self, path: str, force_ascii: bool = False) -> None: + for dataset in self.datasets: + dataset.write_json(path, force_ascii) diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index a3c7eabd0..042bcf60f 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -70,5 +70,5 @@ def run(self, load_data_np=None): # 4. data export logger.info('Exporting dataset to disk...') - dataset.data.write_json(self.cfg.export_path, force_ascii=False) + dataset.write_json(self.cfg.export_path, force_ascii=False) return dataset diff --git a/data_juicer/utils/unittest_utils.py b/data_juicer/utils/unittest_utils.py index 604bee72d..37944563d 100644 --- a/data_juicer/utils/unittest_utils.py +++ b/data_juicer/utils/unittest_utils.py @@ -83,7 +83,7 @@ def run_single_op(self, dataset: DJDataset, op, column_names): dataset = dataset.select_columns(column_names=column_names) return dataset.to_list() elif current_tag.startswith('ray'): - dataset = dataset.data.to_pandas().get(column_names) + dataset = dataset.to_pandas().get(column_names) if dataset is None: return [] return dataset.to_dict(orient='records') From 104f4c9c324c5066377f29d6ce91ef8dac209a13 Mon Sep 17 00:00:00 2001 From: panxuchen Date: Thu, 5 Sep 2024 12:30:05 +0800 Subject: [PATCH 3/5] fix split --- data_juicer/core/ray_data.py | 86 +++++++++++++++++--------------- data_juicer/core/ray_executor.py | 8 +-- data_juicer/utils/constant.py | 2 +- 3 files changed, 52 insertions(+), 44 deletions(-) diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index f4fd86c62..84ade8193 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -84,54 +84,60 @@ def get_num_gpus(op, op_proc): return 1.0 / proc_per_gpu -def split_jsonl(file_path: str, max_size: int, output_dir: str) -> List[str]: - """Split a jsonl file into multiple sub files +def split_jsonl(file_path: str, max_size: int, + output_dir: str) -> Generator[str]: + """Split a jsonl file into multiple sub files more efficiently. Args: file_path (`str`): path of the original jsonl file max_size (`int`): max size of each sub file (in MB) output_dir (`str`): directory to save the sub files - Returns: - List[str]: list of sub file paths + Yields: + str: path of each newly created sub file """ os.makedirs(output_dir, exist_ok=True) - sub_file_paths = [] file_index = 0 max_byte_size = max_size * 1024**2 base_file_name = os.path.basename(file_path) file_name = os.path.splitext(base_file_name)[0] current_size = 0 + buffer = [] + buffer_size = 0 + logger.info(f'Spliting {file_path}.') + with open(file_path, 'r', encoding='utf-8') as infile: while True: - # Create a new file if we're starting or have reached the max size - if current_size >= max_byte_size: - file_index += 1 - current_size = 0 # Determine the output file name output_file_name = f'{file_name}_{file_index}.jsonl' output_file_path = os.path.join(output_dir, output_file_name) - # Open the output file in append mode - with open(output_file_path, 'a', encoding='utf-8') as outfile: + + # Read lines until we reach the max buffer size + while current_size + buffer_size < max_byte_size: line = infile.readline() if not line: break - # Write the line to the current output file - outfile.write(line) - # Update the current file size - current_size += len(line.encode('utf-8')) - # Add the new file path to the list if it's a new file - if output_file_path not in sub_file_paths: - sub_file_paths.append(output_file_path) + buffer.append(line) + buffer_size += len(line) + + # Write the buffered lines to the current output file + if buffer: + with open(output_file_path, 'a', encoding='utf-8') as outfile: + outfile.writelines(buffer) + buffer = [] + buffer_size = 0 + file_index += 1 + yield output_file_path - return sub_file_paths + if not line: + break def split_jsonl_dataset( dataset_paths: Union[str, List[str]], max_size: int, output_dir: str, -) -> List[str]: +) -> Generator[str]: """Re-split the jsonl dataset files. Args: @@ -139,18 +145,17 @@ def split_jsonl_dataset( max_size (`int`): max size of each sub file (in MB) output_dir (`str`): directory to save the sub files - Returns: - List[str]: list of sub file paths + Yields: + str: path of each newly created sub file """ if isinstance(dataset_paths, str): dataset_paths = [dataset_paths] - all_sub_file_paths = [] + logger.info('Re-splitting dataset files...') for path in dataset_paths: - sub_file_paths = split_jsonl(path, max_size, output_dir) - all_sub_file_paths.extend(sub_file_paths) - - return all_sub_file_paths + for sub_file_path in split_jsonl(path, max_size, output_dir): + logger.info(f'Splited into {sub_file_path}') + yield sub_file_path def get_jsonl_file_names(dataset_dir_path: str) -> List[str]: @@ -194,19 +199,19 @@ def best_file_num(cpu: int, memory: int, file_size: int) -> int: """ max_files_by_memory = memory // (2 * file_size) - best_num_files = max(1, (max_files_by_memory // cpu) * cpu) - + best_num_files = max(1, (max_files_by_memory // cpu)) * cpu + logger.info(f'Best number of files in a single batch: {best_num_files}') return best_num_files def load_splited_json_dataset( - file_paths: List[str], + file_paths: Generator[str], file_num_in_batch: int, ) -> Generator[Dataset, None, None]: """Load dataset from the splited jsonl files. Args: - file_paths (`List[str]`): + file_paths (`Generator[str]`): A list of paths to the JSONL files. file_num_in_batch (`int`): The number of files to be included in each batch. @@ -214,13 +219,14 @@ def load_splited_json_dataset( Yields: `Dataset`: A dataset containing data from the specified batch of files. """ - num_batches = (len(file_paths) + file_num_in_batch - 1) - for i in range(num_batches): - start_idx = i * file_num_in_batch - end_idx = min((i + 1) * file_num_in_batch, len(file_paths)) - batch_file_paths = file_paths[start_idx:end_idx] - dataset = rd.read_json(batch_file_paths) - yield dataset + files = [] + for file_path in file_paths: + files.append(file_path) + if len(files) >= file_num_in_batch: + yield rd.read_json(files) + files.clear() + if len(files) > 0: + yield rd.read_json(files) class RayDataset(DJDataset): @@ -245,7 +251,7 @@ def read_jsonl(cls, cpu = ray.cluster_resources().get('CPU', 0) memory = ray.cluster_resources().get('memory', 0) / 1024 / 1024 batch_file_num = best_file_num(cpu, memory, DEFAULT_MAX_FILE_SIZE) - return RayDataset(dataset=load_splited_json_dataset( + return RayDataset(datasets=load_splited_json_dataset( files, batch_file_num), cfg=cfg) @@ -310,4 +316,4 @@ def to_pandas(self) -> pd.DataFrame: def write_json(self, path: str, force_ascii: bool = False) -> None: for dataset in self.datasets: - dataset.write_json(path, force_ascii) + dataset.write_json(path, force_ascii=force_ascii) diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index 042bcf60f..d914112e0 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -56,7 +56,7 @@ def run(self, load_data_np=None): from data_juicer.format.formatter import FORMATTERS dataset = FORMATTERS.modules[obj_name](**args).load_dataset() else: - dataset = RayDataset(self.cfg.dataset_path, self.cfg) + dataset = RayDataset.read_jsonl(self.cfg.dataset_path, self.cfg) # 2. extract processes logger.info('Preparing process operators...') ops = load_ops(self.cfg.process, self.cfg.op_fusion) @@ -65,10 +65,12 @@ def run(self, load_data_np=None): logger.info('Processing data...') tstart = time.time() dataset.process(ops) - tend = time.time() - logger.info(f'All Ops are done in {tend - tstart:.3f}s.') # 4. data export logger.info('Exporting dataset to disk...') dataset.write_json(self.cfg.export_path, force_ascii=False) + + tend = time.time() + logger.info(f'All Ops are done in {tend - tstart:.3f}s.') + return dataset diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index e44b0cb58..6910d7191 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -8,7 +8,7 @@ from loguru import logger DEFAULT_PREFIX = '__dj__' -DEFAULT_MAX_FILE_SIZE = 128 * 1024 * 1024 # 128 MB +DEFAULT_MAX_FILE_SIZE = 128 # 128 MB class Fields(object): From e85412715eecdaecf552e84c909145db90189284 Mon Sep 17 00:00:00 2001 From: panxuchen Date: Thu, 5 Sep 2024 15:34:00 +0800 Subject: [PATCH 4/5] fix time cal --- data_juicer/core/ray_data.py | 5 +++-- data_juicer/core/ray_executor.py | 3 +-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index 84ade8193..763f566fe 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -122,7 +122,7 @@ def split_jsonl(file_path: str, max_size: int, # Write the buffered lines to the current output file if buffer: - with open(output_file_path, 'a', encoding='utf-8') as outfile: + with open(output_file_path, 'w', encoding='utf-8') as outfile: outfile.writelines(buffer) buffer = [] buffer_size = 0 @@ -197,7 +197,7 @@ def best_file_num(cpu: int, memory: int, file_size: int) -> int: Returns: int: best number of files in a single batch """ - max_files_by_memory = memory // (2 * file_size) + max_files_by_memory = memory // (16 * file_size) best_num_files = max(1, (max_files_by_memory // cpu)) * cpu logger.info(f'Best number of files in a single batch: {best_num_files}') @@ -250,6 +250,7 @@ def read_jsonl(cls, DEFAULT_MAX_FILE_SIZE, cfg.work_dir) cpu = ray.cluster_resources().get('CPU', 0) memory = ray.cluster_resources().get('memory', 0) / 1024 / 1024 + logger.info(f'CPU: {cpu}, Memory: {memory}') batch_file_num = best_file_num(cpu, memory, DEFAULT_MAX_FILE_SIZE) return RayDataset(datasets=load_splited_json_dataset( files, batch_file_num), diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index d914112e0..2dcce7660 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -46,7 +46,7 @@ def run(self, load_data_np=None): """ # 1. load data logger.info('Loading dataset with Ray...') - + tstart = time.time() if self.cfg.get('generated_dataset_config', None): generated_dataset_config = self.cfg.generated_dataset_config assert isinstance(generated_dataset_config, @@ -63,7 +63,6 @@ def run(self, load_data_np=None): # 3. data process logger.info('Processing data...') - tstart = time.time() dataset.process(ops) # 4. data export From b0caeaf811fe5cd888e6a9017cf3224308e90391 Mon Sep 17 00:00:00 2001 From: panxuchen Date: Fri, 6 Sep 2024 16:19:24 +0800 Subject: [PATCH 5/5] fix spliting --- data_juicer/core/ray_data.py | 152 +++++++++++++++++-------------- data_juicer/core/ray_executor.py | 4 +- 2 files changed, 88 insertions(+), 68 deletions(-) diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index 763f566fe..330771f87 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -1,5 +1,6 @@ from __future__ import annotations +import concurrent.futures import os from typing import Any, Generator, List, Union @@ -10,6 +11,7 @@ from data_juicer import cuda_device_count from data_juicer.core.data import DJDataset from data_juicer.ops import Filter, Mapper +from data_juicer.ops.base_op import OP from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import DEFAULT_MAX_FILE_SIZE, Fields from data_juicer.utils.process_utils import calculate_np @@ -44,37 +46,41 @@ def convert_to_absolute_paths(dict_with_paths, dataset_dir, path_keys): return dict_with_paths -# TODO: check path for nestdataset -def set_dataset_to_absolute_path(dataset, dataset_path, cfg): - """ - Set all the path in input data to absolute path. - Checks dataset_dir and project_dir for valid paths. - """ - if not (cfg.video_key in dataset.columns() or cfg.image_key - in dataset.columns() or cfg.audio_key in dataset.columns()): - return dataset - dataset_dir = os.path.dirname(dataset_path) - dataset = dataset.map(lambda item: convert_to_absolute_paths( - item, dataset_dir, [cfg.video_key, cfg.image_key, cfg.audio_key])) - logger.info(f"transfer {dataset.count()} sample's paths") - return dataset - +class RayPreprocessOperator(OP): -def preprocess_dataset(dataset: Dataset, dataset_path, cfg) -> Dataset: - if dataset_path: - dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg) - columns = dataset.columns() - if Fields.stats not in columns: - logger.info(f'columns {columns}') - - def process_batch_arrow(table: pa.Table) -> pa.Table: - new_column_data = [{} for _ in range(len(table))] - new_talbe = table.append_column(Fields.stats, [new_column_data]) - return new_talbe + def __init__(self, dataset_path=None, cfg=None) -> None: + super().__init__() + self.dataset_path = dataset_path + self.cfg = cfg + self._name = 'RayPreporcess' + + def run(self, dataset: Dataset) -> Dataset: + columns = dataset.columns() + if Fields.stats not in columns: + logger.info(f'columns {columns}') + + def process_batch_arrow(table: pa.Table) -> pa.Table: + new_column_data = [{} for _ in range(len(table))] + new_talbe = table.append_column(Fields.stats, + [new_column_data]) + return new_talbe + + dataset = dataset.map_batches(process_batch_arrow, + batch_format='pyarrow') + if self.dataset_path: + # TODO: check path for nestdataset + if not (self.cfg.video_key in dataset.columns() + or self.cfg.image_key in dataset.columns() + or self.cfg.audio_key in dataset.columns()): + return dataset + dataset_dir = os.path.dirname(self.dataset_path) + dataset = dataset.map(lambda item: convert_to_absolute_paths( + item, dataset_dir, + [self.cfg.video_key, self.cfg.image_key, self.cfg.audio_key])) + return dataset - dataset = dataset.map_batches(process_batch_arrow, - batch_format='pyarrow') - return dataset + def use_cuda(self): + return False def get_num_gpus(op, op_proc): @@ -133,6 +139,11 @@ def split_jsonl(file_path: str, max_size: int, break +def parallel_split_jsonl(file_path, max_size, output_dir) -> List[str]: + """Wrapper function for using with ThreadPoolExecutor.""" + return list(split_jsonl(file_path, max_size, output_dir)) + + def split_jsonl_dataset( dataset_paths: Union[str, List[str]], max_size: int, @@ -152,10 +163,21 @@ def split_jsonl_dataset( dataset_paths = [dataset_paths] logger.info('Re-splitting dataset files...') - for path in dataset_paths: - for sub_file_path in split_jsonl(path, max_size, output_dir): - logger.info(f'Splited into {sub_file_path}') - yield sub_file_path + + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: + futures = { + executor.submit(parallel_split_jsonl, path, max_size, output_dir): + path + for path in dataset_paths + } + for future in concurrent.futures.as_completed(futures): + try: + results = future.result() + for result in results: + logger.info(f'Splited into {result}') + yield result + except Exception as e: + logger.error(f'Failed to split file: {e}') def get_jsonl_file_names(dataset_dir_path: str) -> List[str]: @@ -185,8 +207,7 @@ def get_jsonl_file_names(dataset_dir_path: str) -> List[str]: def best_file_num(cpu: int, memory: int, file_size: int) -> int: """Calculate the best number of files in a single batch. - Each cpu should process the same number of files (at least one), - while the total memory should be at least 2 times larger than the + The total memory should be at least 4 times larger than the total file size. Args: @@ -197,9 +218,8 @@ def best_file_num(cpu: int, memory: int, file_size: int) -> int: Returns: int: best number of files in a single batch """ - max_files_by_memory = memory // (16 * file_size) - - best_num_files = max(1, (max_files_by_memory // cpu)) * cpu + max_files_by_memory = memory // (4 * file_size) + best_num_files = min(cpu, max_files_by_memory) logger.info(f'Best number of files in a single batch: {best_num_files}') return best_num_files @@ -241,43 +261,36 @@ def __init__(self, datasets: Union[Dataset, Generator], cfg=None) -> None: if cfg: self.num_proc = cfg.np self.output_dataset = [] + self._ops = [RayPreprocessOperator(dataset_path=None, cfg=cfg)] @classmethod def read_jsonl(cls, path: Union[str, List[str]], - cfg: Any = None) -> RayDataset: - files = split_jsonl_dataset(get_jsonl_file_names(path), - DEFAULT_MAX_FILE_SIZE, cfg.work_dir) - cpu = ray.cluster_resources().get('CPU', 0) - memory = ray.cluster_resources().get('memory', 0) / 1024 / 1024 - logger.info(f'CPU: {cpu}, Memory: {memory}') - batch_file_num = best_file_num(cpu, memory, DEFAULT_MAX_FILE_SIZE) - return RayDataset(datasets=load_splited_json_dataset( - files, batch_file_num), - cfg=cfg) + cfg: Any = None, + resplit: bool = True) -> RayDataset: + if resplit: + resplit_dir = os.path.join(cfg.work_dir, 'resplit') + os.makedirs(resplit_dir, exist_ok=True) + files = split_jsonl_dataset(get_jsonl_file_names(path), + DEFAULT_MAX_FILE_SIZE, resplit_dir) + cpu = ray.cluster_resources().get('CPU', 0) + memory = ray.cluster_resources().get('memory', 0) / 1024 / 1024 + logger.info(f'CPU: {cpu}, Memory: {memory}') + batch_file_num = best_file_num(cpu, memory, DEFAULT_MAX_FILE_SIZE) + return RayDataset(datasets=load_splited_json_dataset( + files, batch_file_num), + cfg=cfg) + else: + return RayDataset(datasets=rd.read_json(path), cfg=cfg) @classmethod def read_item(cls, data: dict, cfg: Any = None) -> RayDataset: return RayDataset(dataset=rd.from_items(data), cfg=cfg) - def process(self, - operators, - *, - exporter=None, - checkpointer=None, - tracer=None) -> DJDataset: - outputs = [] - for dataset in self.datasets: - # todo: pass dataset path into the function - data = preprocess_dataset(dataset, dataset_path=None, cfg=self.cfg) - if operators is None: - return self - if not isinstance(operators, list): - operators = [operators] - for op in operators: - data = self._run_single_op(op, data) - outputs.append(data) - self.datasets = outputs + def process(self, operators) -> DJDataset: + if not isinstance(operators, list): + operators = [operators] + self._ops.extend(operators) return self def _run_single_op(self, op, dataset: Dataset) -> Dataset: @@ -298,6 +311,8 @@ def _run_single_op(self, op, dataset: Dataset) -> Dataset: if op.stats_export_path is not None: dataset.write_json(op.stats_export_path, force_ascii=False) dataset = dataset.filter(op.process) + elif isinstance(op, RayPreprocessOperator): + dataset = op.run(dataset) else: logger.error( 'Ray executor only support Filter and Mapper OPs for now') @@ -317,4 +332,7 @@ def to_pandas(self) -> pd.DataFrame: def write_json(self, path: str, force_ascii: bool = False) -> None: for dataset in self.datasets: - dataset.write_json(path, force_ascii=force_ascii) + if len(self._ops) > 0: + for op in self._ops: + dataset = self._run_single_op(op, dataset) + dataset.write_json(path, force_ascii=force_ascii) diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index 2dcce7660..14e99160a 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -56,7 +56,9 @@ def run(self, load_data_np=None): from data_juicer.format.formatter import FORMATTERS dataset = FORMATTERS.modules[obj_name](**args).load_dataset() else: - dataset = RayDataset.read_jsonl(self.cfg.dataset_path, self.cfg) + dataset = RayDataset.read_jsonl(self.cfg.dataset_path, + self.cfg, + resplit=True) # 2. extract processes logger.info('Preparing process operators...') ops = load_ops(self.cfg.process, self.cfg.op_fusion)