Skip to content

Commit

Permalink
Add an inferface to Python API for a set of files and jobspecs (#158)
Browse files Browse the repository at this point in the history
Co-authored-by: Devin Robison <drobison00@users.noreply.github.com>
  • Loading branch information
edknv and drobison00 authored Oct 23, 2024
1 parent 3deb2dc commit cdf1b64
Show file tree
Hide file tree
Showing 18 changed files with 973 additions and 348 deletions.
38 changes: 5 additions & 33 deletions client/src/nv_ingest_client/cli/util/click.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# SPDX-License-Identifier: Apache-2.0


import glob
import json
import logging
import os
Expand All @@ -29,6 +28,7 @@
from nv_ingest_client.primitives.tasks.split import SplitTaskSchema
from nv_ingest_client.primitives.tasks.store import StoreTaskSchema
from nv_ingest_client.primitives.tasks.vdb_upload import VdbUploadTaskSchema
from nv_ingest_client.util.util import generate_matching_files

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -137,6 +137,9 @@ def click_validate_task(ctx, param, value):
else:
raise ValueError(f"Unsupported task type: {task_id}")

if new_task_id in validated_tasks:
raise ValueError(f"Duplicate task detected: {new_task_id}")

logger.debug("Adding task: %s", new_task_id)
validated_tasks[new_task_id] = new_task
except ValueError as e:
Expand Down Expand Up @@ -190,37 +193,6 @@ def pre_process_dataset(dataset_json: str, shuffle_dataset: bool):
return file_source


def _generate_matching_files(file_sources):
"""
Generates a list of file paths that match the given patterns specified in file_sources.
Parameters
----------
file_sources : list of str
A list containing the file source patterns to match against.
Returns
-------
generator
A generator yielding paths to files that match the specified patterns.
Notes
-----
This function utilizes glob pattern matching to find files that match the specified patterns.
It yields each matching file path, allowing for efficient processing of potentially large
sets of files.
"""

files = [
file_path
for pattern in file_sources
for file_path in glob.glob(pattern, recursive=True)
if os.path.isfile(file_path)
]
for file_path in files:
yield file_path


def click_match_and_validate_files(ctx, param, value):
"""
Matches and validates files based on the provided file source patterns.
Expand All @@ -239,7 +211,7 @@ def click_match_and_validate_files(ctx, param, value):
if not value:
return []

matching_files = list(_generate_matching_files(value))
matching_files = list(generate_matching_files(value))
if not matching_files:
logger.warning("No files found matching the specified patterns.")
return []
Expand Down
255 changes: 29 additions & 226 deletions client/src/nv_ingest_client/cli/util/processing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import concurrent

import json
import logging
import os
Expand All @@ -12,7 +12,8 @@
from concurrent.futures import as_completed
from statistics import mean
from statistics import median
from typing import Dict, Any
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
from typing import Type
Expand All @@ -23,9 +24,7 @@
from tqdm import tqdm

from nv_ingest_client.client import NvIngestClient
from nv_ingest_client.primitives import JobSpec
from nv_ingest_client.util.file_processing.extract import extract_file_content
from nv_ingest_client.util.util import check_ingest_result
from nv_ingest_client.util.processing import handle_future_result
from nv_ingest_client.util.util import estimate_page_count

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -131,7 +130,7 @@ def check_schema(schema: Type[BaseModel], options: dict, task_id: str, original_


def report_stage_statistics(
stage_elapsed_times: defaultdict(list), total_trace_elapsed: float, abs_elapsed: float
stage_elapsed_times: defaultdict(list), total_trace_elapsed: float, abs_elapsed: float
) -> None:
"""
Reports the statistics for each processing stage, including average, median, total time spent,
Expand Down Expand Up @@ -206,10 +205,10 @@ def report_overall_speed(total_pages_processed: int, start_time_ns: int, total_f


def report_statistics(
start_time_ns: int,
stage_elapsed_times: defaultdict,
total_pages_processed: int,
total_files: int,
start_time_ns: int,
stage_elapsed_times: defaultdict,
total_pages_processed: int,
total_files: int,
) -> None:
"""
Aggregates and reports statistics for the entire processing session.
Expand Down Expand Up @@ -425,136 +424,20 @@ def save_response_data(response, output_directory):
f.write(json.dumps(documents, indent=2))


def create_job_specs_for_batch(files_batch: List[str], tasks: Dict[str, Any], client: NvIngestClient) -> List[str]:
"""
Create and submit job specifications (JobSpecs) for a batch of files, returning the job IDs.
This function takes a batch of files, processes each file to extract its content and type,
creates a job specification (JobSpec) for each file, and adds tasks from the provided task
list. It then submits the jobs to the client and collects their job IDs.
Parameters
----------
files_batch : List[str]
A list of file paths to be processed. Each file is assumed to be in a format compatible
with the `extract_file_content` function, which extracts the file's content and type.
tasks : Dict[str, Any]
A dictionary of tasks to be added to each job. The keys represent task names, and the
values represent task specifications or configurations. Standard tasks include "split",
"extract", "store", "caption", "dedup", "filter", "embed", and "vdb_upload".
client : NvIngestClient
An instance of NvIngestClient, which handles the job submission. The `add_job` method of
the client is used to submit each job specification.
Returns
-------
Tuple[List[JobSpec], List[str]]
A Tuple containing the list of JobSpecs and list of job IDs corresponding to the submitted jobs.
Each job ID is returned by the client's `add_job` method.
Raises
------
ValueError
If there is an error extracting the file content or type from any of the files, a
ValueError will be logged, and the corresponding file will be skipped.
Notes
-----
- The function assumes that a utility function `extract_file_content` is defined elsewhere,
which extracts the content and type from the provided file paths.
- For each file, a `JobSpec` is created with relevant metadata, including document type and
file content. Various tasks are conditionally added based on the provided `tasks` dictionary.
- The job specification includes tracing options with a timestamp (in nanoseconds) for
diagnostic purposes.
Examples
--------
Suppose you have a batch of files and tasks to process:
>>> files_batch = ["file1.txt", "file2.pdf"]
>>> tasks = {"split": ..., "extract_txt": ..., "store": ...}
>>> client = NvIngestClient()
>>> job_ids = create_job_specs_for_batch(files_batch, tasks, client)
>>> print(job_ids)
['job_12345', 'job_67890']
In this example, jobs are created and submitted for the files in `files_batch`, with the
tasks in `tasks` being added to each job specification. The returned job IDs are then
printed.
See Also
--------
extract_file_content : Function that extracts the content and type of a file.
JobSpec : The class representing a job specification.
NvIngestClient : Client class used to submit jobs to a job processing system.
"""

job_ids = []
for file_name in files_batch:
try:
file_content, file_type = extract_file_content(file_name) # Assume these are defined
file_type = file_type.value
except ValueError as ve:
logger.error(f"Error extracting content from {file_name}: {ve}")
continue

job_spec = JobSpec(
document_type=file_type,
payload=file_content,
source_id=file_name,
source_name=file_name,
extended_options={"tracing_options": {"trace": True, "ts_send": time.time_ns()}},
)

logger.debug(f"Tasks: {tasks.keys()}")
for task in tasks:
logger.debug(f"Task: {task}")

# TODO(Devin): Formalize this later, don't have time right now.
if "split" in tasks:
job_spec.add_task(tasks["split"])

if f"extract_{file_type}" in tasks:
job_spec.add_task(tasks[f"extract_{file_type}"])

if "store" in tasks:
job_spec.add_task(tasks["store"])

if "caption" in tasks:
job_spec.add_task(tasks["caption"])

if "dedup" in tasks:
job_spec.add_task(tasks["dedup"])

if "filter" in tasks:
job_spec.add_task(tasks["filter"])

if "embed" in tasks:
job_spec.add_task(tasks["embed"])

if "vdb_upload" in tasks:
job_spec.add_task(tasks["vdb_upload"])

job_id = client.add_job(job_spec)
job_ids.append(job_id)

return job_ids


def generate_job_batch_for_iteration(
client: Any,
pbar: Any,
files: List[str],
tasks: Dict,
processed: int,
batch_size: int,
retry_job_ids: List[str],
fail_on_error: bool = False
client: Any,
pbar: Any,
files: List[str],
tasks: Dict,
processed: int,
batch_size: int,
retry_job_ids: List[str],
fail_on_error: bool = False,
) -> Tuple[List[str], Dict[str, str], int]:
"""
Generates a batch of job specifications for the current iteration of file processing. This function handles retrying failed jobs and creating new jobs for unprocessed files. The job specifications are then submitted for processing.
Generates a batch of job specifications for the current iteration of file processing.
This function handles retrying failed jobs and creating new jobs for unprocessed files.
The job specifications are then submitted for processing.
Parameters
----------
Expand Down Expand Up @@ -599,9 +482,9 @@ def generate_job_batch_for_iteration(

if (cur_job_count < batch_size) and (processed < len(files)):
new_job_count = min(batch_size - cur_job_count, len(files) - processed)
batch_files = files[processed: processed + new_job_count] # noqa: E203
batch_files = files[processed : processed + new_job_count] # noqa: E203

new_job_indices = create_job_specs_for_batch(batch_files, tasks, client)
new_job_indices = client.create_jobs_for_batch(batch_files, tasks)
if len(new_job_indices) != new_job_count:
missing_jobs = new_job_count - len(new_job_indices)
error_msg = f"Missing {missing_jobs} job specs -- this is likely due to bad reads or file corruption"
Expand All @@ -620,93 +503,14 @@ def generate_job_batch_for_iteration(
return job_indices, job_index_map_updates, processed


def handle_future_result(
future: concurrent.futures.Future,
futures_dict: Dict[concurrent.futures.Future, str],
) -> Dict[str, Any]:
"""
Handle the result of a completed future job, process annotations, and save the result.
This function processes the result of a future, extracts annotations (if any), logs them,
checks the validity of the ingest result, and optionally saves the result to the provided
output directory. If the result indicates a failure, a retry list of job IDs is prepared.
Parameters
----------
future : concurrent.futures.Future
A future object representing an asynchronous job. The result of this job will be
processed once it completes.
futures_dict : Dict[concurrent.futures.Future, str]
A dictionary mapping future objects to job IDs. The job ID associated with the
provided future is retrieved from this dictionary.
Returns
-------
Dict[str, Any]
Raises
------
RuntimeError
If the job result is invalid, this exception is raised with a description of the failure.
Notes
-----
- The `future.result()` is assumed to return a tuple where the first element is the actual
result (as a dictionary), and the second element (if present) can be ignored.
- Annotations in the result (if any) are logged for debugging purposes.
- The `check_ingest_result` function (assumed to be defined elsewhere) is used to validate
the result. If the result is invalid, a `RuntimeError` is raised.
- The function handles saving the result data to the specified output directory using the
`save_response_data` function.
Examples
--------
Suppose we have a future object representing a job, a dictionary of futures to job IDs,
and a directory for saving results:
>>> future = concurrent.futures.Future()
>>> futures_dict = {future: "job_12345"}
>>> job_id_map = {"job_12345": {...}}
>>> output_directory = "/path/to/save"
>>> result, retry_job_ids = handle_future_result(future, futures_dict, job_id_map, output_directory)
In this example, the function processes the completed job and saves the result to the
specified directory. If the job fails, it raises a `RuntimeError` and returns a list of
retry job IDs.
See Also
--------
check_ingest_result : Function to validate the result of the job.
save_response_data : Function to save the result to a directory.
"""

try:
result, _ = future.result()[0]
if ("annotations" in result) and result["annotations"]:
annotations = result["annotations"]
for key, value in annotations.items():
logger.debug(f"Annotation: {key} -> {json.dumps(value, indent=2)}")

failed, description = check_ingest_result(result)

if failed:
raise RuntimeError(f"{description}")
except Exception as e:
logger.debug(f"Error processing future result: {e}")
raise e

return result


def create_and_process_jobs(
files: List[str],
client: NvIngestClient,
tasks: Dict[str, Any],
output_directory: str,
batch_size: int,
timeout: int = 10,
fail_on_error: bool = False,
files: List[str],
client: NvIngestClient,
tasks: Dict[str, Any],
output_directory: str,
batch_size: int,
timeout: int = 10,
fail_on_error: bool = False,
) -> Tuple[int, Dict[str, List[float]], int]:
"""
Process a list of files, creating and submitting jobs for each file, then fetch and handle the results.
Expand Down Expand Up @@ -807,7 +611,6 @@ def create_and_process_jobs(

futures_dict = client.fetch_job_result_async(job_ids, timeout=timeout, data_only=False)
for future in as_completed(futures_dict.keys()):

retry = False
job_id = futures_dict[future]
source_name = job_id_map[job_id]
Expand Down
Loading

0 comments on commit cdf1b64

Please sign in to comment.