diff --git a/client/src/nv_ingest_client/cli/util/click.py b/client/src/nv_ingest_client/cli/util/click.py index 2da3f03f..f7fd9c72 100644 --- a/client/src/nv_ingest_client/cli/util/click.py +++ b/client/src/nv_ingest_client/cli/util/click.py @@ -12,6 +12,13 @@ import click from nv_ingest_client.cli.util.processing import check_schema +from nv_ingest_client.primitives.tasks.caption import CaptionTaskSchema +from nv_ingest_client.primitives.tasks.chart_extraction import ChartExtractionSchema +from nv_ingest_client.primitives.tasks.chart_extraction import ChartExtractionTask +from nv_ingest_client.primitives.tasks.dedup import DedupTaskSchema +from nv_ingest_client.primitives.tasks.embed import EmbedTaskSchema +from nv_ingest_client.primitives.tasks.extract import ExtractTaskSchema +from nv_ingest_client.primitives.tasks.filter import FilterTaskSchema from nv_ingest_client.primitives.tasks import CaptionTask from nv_ingest_client.primitives.tasks import DedupTask from nv_ingest_client.primitives.tasks import EmbedTask @@ -20,13 +27,10 @@ from nv_ingest_client.primitives.tasks import SplitTask from nv_ingest_client.primitives.tasks import StoreTask from nv_ingest_client.primitives.tasks import VdbUploadTask -from nv_ingest_client.primitives.tasks.caption import CaptionTaskSchema -from nv_ingest_client.primitives.tasks.dedup import DedupTaskSchema -from nv_ingest_client.primitives.tasks.embed import EmbedTaskSchema -from nv_ingest_client.primitives.tasks.extract import ExtractTaskSchema -from nv_ingest_client.primitives.tasks.filter import FilterTaskSchema from nv_ingest_client.primitives.tasks.split import SplitTaskSchema from nv_ingest_client.primitives.tasks.store import StoreTaskSchema +from nv_ingest_client.primitives.tasks.table_extraction import TableExtractionSchema +from nv_ingest_client.primitives.tasks.table_extraction import TableExtractionTask from nv_ingest_client.primitives.tasks.vdb_upload import VdbUploadTaskSchema from nv_ingest_client.util.util import generate_matching_files @@ -104,36 +108,44 @@ def click_validate_task(ctx, param, value): if task_id == "split": task_options = check_schema(SplitTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" - new_task = SplitTask(**task_options.dict()) + new_task = [(new_task_id, SplitTask(**task_options.dict()))] elif task_id == "extract": task_options = check_schema(ExtractTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}_{task_options.document_type}" - new_task = ExtractTask(**task_options.dict()) + new_task = [(new_task_id, ExtractTask(**task_options.dict()))] + + if (task_options.extract_tables == True): + subtask_options = check_schema(TableExtractionSchema, {}, "table_data_extract", "{}") + new_task.append(("table_data_extract", TableExtractionTask(**subtask_options.dict()))) + + if (task_options.extract_charts == True): + subtask_options = check_schema(ChartExtractionSchema, {}, "chart_data_extract", "{}") + new_task.append(("chart_data_extract", ChartExtractionTask(**subtask_options.dict()))) + elif task_id == "store": task_options = check_schema(StoreTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" - new_task = StoreTask(**task_options.dict()) + new_task = [(new_task_id, StoreTask(**task_options.dict()))] elif task_id == "caption": task_options = check_schema(CaptionTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" - new_task = CaptionTask(**task_options.dict()) + new_task = [(new_task_id, CaptionTask(**task_options.dict()))] elif task_id == "dedup": task_options = check_schema(DedupTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" - new_task = DedupTask(**task_options.dict()) + new_task = [(new_task_id, DedupTask(**task_options.dict()))] elif task_id == "filter": task_options = check_schema(FilterTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" - new_task = FilterTask(**task_options.dict()) + new_task = [(new_task_id, FilterTask(**task_options.dict()))] elif task_id == "embed": task_options = check_schema(EmbedTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" - new_task = EmbedTask(**task_options.dict()) + new_task = [(new_task_id, EmbedTask(**task_options.dict()))] elif task_id == "vdb_upload": task_options = check_schema(VdbUploadTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" - new_task = VdbUploadTask(**task_options.dict()) - + new_task = [(new_task_id, VdbUploadTask(**task_options.dict()))] else: raise ValueError(f"Unsupported task type: {task_id}") @@ -141,14 +153,14 @@ def click_validate_task(ctx, param, value): raise ValueError(f"Duplicate task detected: {new_task_id}") logger.debug("Adding task: %s", new_task_id) - validated_tasks[new_task_id] = new_task + for task_tuple in new_task: + validated_tasks[task_tuple[0]] = task_tuple[1] except ValueError as e: validation_errors.append(str(e)) if validation_errors: # Aggregate error messages with original values highlighted error_message = "\n".join(validation_errors) - # logger.error(error_message) raise click.BadParameter(error_message) return validated_tasks diff --git a/client/src/nv_ingest_client/cli/util/processing.py b/client/src/nv_ingest_client/cli/util/processing.py index 132a505c..f5f40992 100644 --- a/client/src/nv_ingest_client/cli/util/processing.py +++ b/client/src/nv_ingest_client/cli/util/processing.py @@ -130,7 +130,7 @@ def check_schema(schema: Type[BaseModel], options: dict, task_id: str, original_ def report_stage_statistics( - stage_elapsed_times: defaultdict(list), total_trace_elapsed: float, abs_elapsed: float + stage_elapsed_times: defaultdict, total_trace_elapsed: float, abs_elapsed: float ) -> None: """ Reports the statistics for each processing stage, including average, median, total time spent, diff --git a/client/src/nv_ingest_client/client/client.py b/client/src/nv_ingest_client/client/client.py index 45c9d0e9..b23cb473 100644 --- a/client/src/nv_ingest_client/client/client.py +++ b/client/src/nv_ingest_client/client/client.py @@ -60,13 +60,13 @@ class NvIngestClient: """ def __init__( - self, - message_client_allocator: Callable[..., RestClient] = RestClient, - message_client_hostname: Optional[str] = "localhost", - message_client_port: Optional[int] = 7670, - message_client_kwargs: Optional[Dict] = None, - msg_counter_id: Optional[str] = "nv-ingest-message-id", - worker_pool_size: int = 1, + self, + message_client_allocator: Callable[..., RestClient] = RestClient, + message_client_hostname: Optional[str] = "localhost", + message_client_port: Optional[int] = 7670, + message_client_kwargs: Optional[Dict] = None, + msg_counter_id: Optional[str] = "nv-ingest-message-id", + worker_pool_size: int = 1, ) -> None: """ Initializes the NvIngestClient with a client allocator, REST configuration, a message counter ID, @@ -149,9 +149,9 @@ def _pop_job_state(self, job_index: str) -> JobState: return job_state def _get_and_check_job_state( - self, - job_index: str, - required_state: Union[JobStateEnum, List[JobStateEnum]] = None, + self, + job_index: str, + required_state: Union[JobStateEnum, List[JobStateEnum]] = None, ) -> JobState: if required_state and not isinstance(required_state, list): required_state = [required_state] @@ -192,13 +192,13 @@ def add_job(self, job_spec: Union[BatchJobSpec, JobSpec]) -> str: raise ValueError(f"Unexpected type: {type(job_spec)}") def create_job( - self, - payload: str, - source_id: str, - source_name: str, - document_type: str = None, - tasks: Optional[list] = None, - extended_options: Optional[dict] = None, + self, + payload: str, + source_id: str, + source_name: str, + document_type: str = None, + tasks: Optional[list] = None, + extended_options: Optional[dict] = None, ) -> str: """ Creates a new job with the specified parameters and adds it to the job tracking dictionary. @@ -249,10 +249,10 @@ def add_task(self, job_index: str, task: Task) -> None: job_state.job_spec.add_task(task) def create_task( - self, - job_index: Union[str, int], - task_type: TaskType, - task_params: dict = None, + self, + job_index: Union[str, int], + task_type: TaskType, + task_params: dict = None, ) -> None: """ Creates a task of the specified type with given parameters and associates it with the existing job. @@ -345,12 +345,12 @@ def _fetch_job_result_wait(self, job_id: str, timeout: float = 60, data_only: bo # This is the direct Python approach function for retrieving jobs which handles the timeouts directly # in the function itself instead of expecting the user to handle it themselves def fetch_job_result( - self, - job_ids: List[str], - timeout: float = 100, - max_retries: Optional[int] = None, - retry_delay: float = 1, - verbose: bool = False, + self, + job_ids: List[str], + timeout: float = 100, + max_retries: Optional[int] = None, + retry_delay: float = 1, + verbose: bool = False, ) -> List[Tuple[Optional[Dict], str]]: """ Fetches job results for multiple job IDs concurrently with individual timeouts and retry logic. @@ -437,7 +437,7 @@ def _ensure_submitted(self, job_ids: List[str]): job_state.future = None def fetch_job_result_async( - self, job_ids: Union[str, List[str]], timeout: float = 10, data_only: bool = True + self, job_ids: Union[str, List[str]], timeout: float = 10, data_only: bool = True ) -> Dict[Future, str]: """ Fetches job results for a list or a single job ID asynchronously and returns a mapping of futures to job IDs. @@ -467,9 +467,9 @@ def fetch_job_result_async( return future_to_job_id def _submit_job( - self, - job_index: str, - job_queue_id: str, + self, + job_index: str, + job_queue_id: str, ) -> Optional[Dict]: """ Submits a job to a specified job queue and optionally waits for a response if blocking is True. @@ -514,7 +514,7 @@ def _submit_job( raise def submit_job( - self, job_indices: Union[str, List[str]], job_queue_id: str, batch_size: int = 10 + self, job_indices: Union[str, List[str]], job_queue_id: str, batch_size: int = 10 ) -> List[Union[Dict, None]]: if isinstance(job_indices, str): job_indices = [job_indices] diff --git a/client/src/nv_ingest_client/nv_ingest_cli.py b/client/src/nv_ingest_client/nv_ingest_cli.py index 1c382227..11986260 100644 --- a/client/src/nv_ingest_client/nv_ingest_cli.py +++ b/client/src/nv_ingest_client/nv_ingest_cli.py @@ -11,7 +11,6 @@ import click import pkg_resources -from nv_ingest_client.cli.util.click import ClientType from nv_ingest_client.cli.util.click import LogLevel from nv_ingest_client.cli.util.click import click_match_and_validate_files from nv_ingest_client.cli.util.click import click_validate_batch_size diff --git a/client/src/nv_ingest_client/primitives/tasks/chart_extraction.py b/client/src/nv_ingest_client/primitives/tasks/chart_extraction.py new file mode 100644 index 00000000..37c7612a --- /dev/null +++ b/client/src/nv_ingest_client/primitives/tasks/chart_extraction.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-arguments + +import logging +from typing import Dict + +from pydantic import BaseModel + +from .task_base import Task + +logger = logging.getLogger(__name__) + + +class ChartExtractionSchema(BaseModel): + class Config: + extra = "forbid" + + +class ChartExtractionTask(Task): + """ + Object for chart extraction task + """ + + def __init__( + self) -> None: + """ + Setup Dedup Task Config + """ + super().__init__() + + def __str__(self) -> str: + """ + Returns a string with the object's config and run time state + """ + info = "" + info += "chart extraction task\n" + return info + + def to_dict(self) -> Dict: + """ + Convert to a dict for submission to redis + """ + + task_properties = { + "params": {}, + } + + return {"type": "chart_data_extract", "task_properties": task_properties} diff --git a/client/src/nv_ingest_client/primitives/tasks/table_extraction.py b/client/src/nv_ingest_client/primitives/tasks/table_extraction.py new file mode 100644 index 00000000..5d1d7299 --- /dev/null +++ b/client/src/nv_ingest_client/primitives/tasks/table_extraction.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-arguments + +import logging +from typing import Dict + +from pydantic import BaseModel + +from .task_base import Task + +logger = logging.getLogger(__name__) + + +class TableExtractionSchema(BaseModel): + class Config: + extra = "forbid" + + +class TableExtractionTask(Task): + """ + Object for table extraction tasks + """ + + def __init__( + self) -> None: + """ + Setup Dedup Task Config + """ + super().__init__() + + def __str__(self) -> str: + """ + Returns a string with the object's config and run time state + """ + info = "" + info += "table extraction task\n" + return info + + def to_dict(self) -> Dict: + """ + Convert to a dict for submission to redis + """ + + task_properties = { + "params": {}, + } + + return {"type": "table_data_extract", "task_properties": task_properties} diff --git a/client/src/nv_ingest_client/primitives/tasks/task_base.py b/client/src/nv_ingest_client/primitives/tasks/task_base.py index aa33d452..5d4a65cd 100644 --- a/client/src/nv_ingest_client/primitives/tasks/task_base.py +++ b/client/src/nv_ingest_client/primitives/tasks/task_base.py @@ -24,6 +24,8 @@ class TaskType(Enum): TRANSFORM = auto() STORE = auto() VDB_UPLOAD = auto() + TABLE_DATA_EXTRACT = auto() + CHART_DATA_EXTRACT = auto() def is_valid_task_type(task_type_str: str) -> bool: diff --git a/docker-compose.yaml b/docker-compose.yaml index 772c2c78..0edf06ab 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -119,11 +119,11 @@ services: nv-ingest-ms-runtime: image: nvcr.io/ohlfw0olaadg/ea-participants/nv-ingest:24.08 build: - context: ${NV_INGEST_ROOT} + context: ${NV_INGEST_ROOT:-.} dockerfile: "./Dockerfile" target: runtime volumes: - - ${DATASET_ROOT}:/workspace/data + - ${DATASET_ROOT:-./data}:/workspace/data ports: - "7670:7670" cap_add: @@ -138,8 +138,8 @@ services: - DEPLOT_HEALTH_ENDPOINT=deplot:8000 - DEPLOT_HTTP_ENDPOINT=http://deplot:8000/v1/chat/completions # build.nvidia.com hosted deplot - #- DEPLOT_HTTP_ENDPOINT=https://ai.api.nvidia.com/v1/vlm/google/deplot - DEPLOT_INFER_PROTOCOL=http + #- DEPLOT_HTTP_ENDPOINT=https://ai.api.nvidia.com/v1/vlm/google/deplot - DOUGHNUT_GRPC_TRITON=triton-doughnut:8001 - INGEST_LOG_LEVEL=DEFAULT - MESSAGE_CLIENT_HOST=redis diff --git a/requirements.txt b/requirements.txt index 2c5bf10f..e0cd4296 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,33 +3,34 @@ charset-normalizer click dataclasses farm-haystack[ocr,inference,pdf,preprocessing,file-conversion] +fastapi==0.109.1 fastparquet==2024.2.0 fsspec +gunicorn==22.0.0 minio~=7.2.5 more_itertools nltk==3.9.1 -openai==1.40.6 +numpy olefile==0.47 +openai==1.40.6 onnx==1.17.0 opencv-python==4.10.0.84 opentelemetry-api opentelemetry-exporter-otlp opentelemetry-instrumentation -opentelemetry-instrumentation-fastapi opentelemetry-instrumentation-asgi +opentelemetry-instrumentation-fastapi opentelemetry-sdk pandas~=1.5.3 pydantic==1.10.14 pyinstrument pypdfium2 python-docx +python-multipart python-pptx==0.6.23 redis~=5.0.1 setuptools==70.0.0 tabulate torchvision==0.18.0 unstructured-client==0.23.3 -fastapi==0.109.1 uvicorn==0.24.0-post.1 -gunicorn==22.0.0 -python-multipart diff --git a/src/nv_ingest/api/v1/ingest.py b/src/nv_ingest/api/v1/ingest.py index 12b078ca..feec8b90 100644 --- a/src/nv_ingest/api/v1/ingest.py +++ b/src/nv_ingest/api/v1/ingest.py @@ -10,27 +10,26 @@ # pylint: skip-file +from io import BytesIO +from typing import Annotated import base64 -import copy import json -from io import BytesIO import logging import time import traceback -from typing import Annotated -from opentelemetry import trace -from nv_ingest_client.primitives.jobs.job_spec import JobSpec -from fastapi import File, UploadFile from fastapi import APIRouter from fastapi import Depends +from fastapi import File, UploadFile from fastapi import HTTPException -from nv_ingest_client.primitives.tasks.extract import ExtractTask +from nv_ingest_client.primitives.jobs.job_spec import JobSpec +from opentelemetry import trace +from redis import RedisError +from nv_ingest_client.primitives.tasks.extract import ExtractTask from nv_ingest.schemas.message_wrapper_schema import MessageWrapper from nv_ingest.service.impl.ingest.redis_ingest_service import RedisIngestService from nv_ingest.service.meta.ingest.ingest_service_meta import IngestServiceMeta -from nv_ingest.schemas.ingest_job_schema import DocumentTypeEnum logger = logging.getLogger("uvicorn") tracer = trace.get_tracer(__name__) @@ -128,13 +127,12 @@ async def submit_job(job_spec: MessageWrapper, ingest_service: INGEST_SERVICE_T) # will be able to trace across uvicorn -> morpheus current_trace_id = trace.get_current_span().get_span_context().trace_id - # Recreate the JobSpec to test what is going on .... job_spec_dict = json.loads(job_spec.payload) job_spec_dict['tracing_options']['trace_id'] = current_trace_id updated_job_spec = MessageWrapper( payload=json.dumps(job_spec_dict) ) - + submitted_job_id = await ingest_service.submit_job(updated_job_spec) return submitted_job_id except Exception as ex: diff --git a/src/nv_ingest/extraction_workflows/docx/docxreader.py b/src/nv_ingest/extraction_workflows/docx/docxreader.py index 449aeb45..87f569cf 100644 --- a/src/nv_ingest/extraction_workflows/docx/docxreader.py +++ b/src/nv_ingest/extraction_workflows/docx/docxreader.py @@ -333,7 +333,7 @@ def _construct_image_metadata(self, image, para_idx, caption, base_unified_metad # For docx there is no bounding box. The paragraph that follows the image is typically # the caption. Add that para to the page nearby for now. fixme - bbox = (-1, -1, -1, -1) + bbox = (0, 0, 0, 0) page_nearby_blocks = { "text": {"content": [], "bbox": []}, "images": {"content": [], "bbox": []}, diff --git a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py index 7ed52cd2..7a1de0f1 100644 --- a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py @@ -17,6 +17,8 @@ # limitations under the License. import logging +import traceback + from math import log from typing import List from typing import Optional @@ -30,17 +32,12 @@ from nv_ingest.schemas.metadata_schema import AccessLevelEnum from nv_ingest.schemas.metadata_schema import TextTypeEnum from nv_ingest.schemas.pdf_extractor_schema import PDFiumConfigSchema -from nv_ingest.util.image_processing.table_and_chart import join_cached_and_deplot_output from nv_ingest.util.image_processing.transforms import crop_image from nv_ingest.util.image_processing.transforms import numpy_to_base64 -from nv_ingest.util.nim.helpers import call_image_inference_model from nv_ingest.util.nim.helpers import create_inference_client -from nv_ingest.util.nim.helpers import get_version from nv_ingest.util.nim.helpers import perform_model_inference -from nv_ingest.util.nim.helpers import preprocess_image_for_paddle from nv_ingest.util.pdf.metadata_aggregators import Base64Image -from nv_ingest.util.pdf.metadata_aggregators import ImageChart -from nv_ingest.util.pdf.metadata_aggregators import ImageTable +from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata from nv_ingest.util.pdf.metadata_aggregators import construct_table_and_chart_metadata from nv_ingest.util.pdf.metadata_aggregators import construct_text_metadata @@ -49,8 +46,6 @@ from nv_ingest.util.pdf.pdfium import pdfium_pages_to_numpy from nv_ingest.util.pdf.pdfium import pdfium_try_get_bitmap_as_numpy -PADDLE_MIN_WIDTH = 32 -PADDLE_MIN_HEIGHT = 32 YOLOX_MAX_BATCH_SIZE = 8 YOLOX_MAX_WIDTH = 1536 YOLOX_MAX_HEIGHT = 1536 @@ -64,18 +59,16 @@ def extract_tables_and_charts_using_image_ensemble( - pages: List[libpdfium.PdfPage], - config: PDFiumConfigSchema, - max_batch_size: int = YOLOX_MAX_BATCH_SIZE, - num_classes: int = YOLOX_NUM_CLASSES, - conf_thresh: float = YOLOX_CONF_THRESHOLD, - iou_thresh: float = YOLOX_IOU_THRESHOLD, - min_score: float = YOLOX_MIN_SCORE, - final_thresh: float = YOLOX_FINAL_SCORE, - extract_tables: bool = True, - extract_charts: bool = True, - trace_info: Optional[List] = None, -) -> List[Tuple[int, ImageTable]]: + pages: List[libpdfium.PdfPage], + config: PDFiumConfigSchema, + max_batch_size: int = YOLOX_MAX_BATCH_SIZE, + num_classes: int = YOLOX_NUM_CLASSES, + conf_thresh: float = YOLOX_CONF_THRESHOLD, + iou_thresh: float = YOLOX_IOU_THRESHOLD, + min_score: float = YOLOX_MIN_SCORE, + final_thresh: float = YOLOX_FINAL_SCORE, + trace_info: Optional[List] = None, +) -> List[Tuple[int, CroppedImageWithContent]]: """ Extract tables and charts from a series of document pages using an ensemble of image-based models. @@ -135,74 +128,49 @@ def extract_tables_and_charts_using_image_ensemble( """ tables_and_charts = [] - if not extract_tables and not extract_charts: - logger.debug("Nothing to do since both extract_tables and extract_charts are set to false.") - return tables_and_charts - - yolox_client = paddle_client = deplot_client = cached_client = None - paddle_version = None + yolox_client = None try: - yolox_client = create_inference_client(config.yolox_endpoints, config.auth_token, config.yolox_infer_protocol) - if extract_tables: - paddle_client = create_inference_client( - config.paddle_endpoints, config.auth_token, config.paddle_infer_protocol - ) - paddle_version = get_version(config.paddle_endpoints[1]) - if extract_charts: - cached_client = create_inference_client( - config.cached_endpoints, config.auth_token, config.cached_infer_protocol - ) - deplot_client = create_inference_client( - config.deplot_endpoints, config.auth_token, config.deplot_infer_protocol - ) + yolox_client = create_inference_client(config.yolox_endpoints, config.auth_token) batches = [] i = 0 while i < len(pages): batch_size = min(2 ** int(log(len(pages) - i, 2)), max_batch_size) - batches.append(pages[i : i + batch_size]) # noqa: E203 + batches.append(pages[i: i + batch_size]) # noqa: E203 i += batch_size - page_idx = 0 + page_index = 0 for batch in batches: original_images, _ = pdfium_pages_to_numpy( batch, scale_tuple=(YOLOX_MAX_WIDTH, YOLOX_MAX_HEIGHT), trace_info=trace_info ) + # original images is an implicitly indexed list of pages original_image_shapes = [image.shape for image in original_images] input_array = prepare_images_for_inference(original_images) output_array = perform_model_inference(yolox_client, "yolox", input_array, trace_info=trace_info) - results = process_inference_results( + + # Get back inference results + yolox_annotated_detections = process_inference_results( output_array, original_image_shapes, num_classes, conf_thresh, iou_thresh, min_score, final_thresh ) - for annotation_dict, original_image in zip(results, original_images): - handle_table_chart_extraction( + for annotation_dict, original_image in zip(yolox_annotated_detections, original_images): + extract_table_and_chart_images( annotation_dict, original_image, - page_idx, - paddle_client, - deplot_client, - cached_client, + page_index, tables_and_charts, - extract_tables=extract_tables, - extract_charts=extract_charts, - paddle_version=paddle_version, - trace_info=trace_info, ) - page_idx += 1 + page_index += 1 + except Exception as e: - logger.error(f"Error during table/chart extraction: {str(e)}") - raise + logger.error(f"Unhandled error during table/chart extraction: {str(e)}") + traceback.print_exc() + raise e finally: - if isinstance(paddle_client, grpcclient.InferenceServerClient): - paddle_client.close() - if isinstance(cached_client, grpcclient.InferenceServerClient): - cached_client.close() - if isinstance(deplot_client, grpcclient.InferenceServerClient): - deplot_client.close() if isinstance(yolox_client, grpcclient.InferenceServerClient): yolox_client.close() @@ -243,13 +211,13 @@ def prepare_images_for_inference(images: List[np.ndarray]) -> np.ndarray: def process_inference_results( - output_array: np.ndarray, - original_image_shapes: List[Tuple[int, int]], - num_classes: int, - conf_thresh: float, - iou_thresh: float, - min_score: float, - final_thresh: float, + output_array: np.ndarray, + original_image_shapes: List[Tuple[int, int]], + num_classes: int, + conf_thresh: float, + iou_thresh: float, + min_score: float, + final_thresh: float, ): """ Process the model output to generate detection results and expand bounding boxes. @@ -298,6 +266,7 @@ def process_inference_results( annotation_dicts = [yolox_utils.expand_chart_bboxes(annotation_dict) for annotation_dict in results] inference_results = [] + # Filter out bounding boxes below the final threshold for annotation_dict in annotation_dicts: new_dict = {} if "table" in annotation_dict: @@ -312,18 +281,11 @@ def process_inference_results( # Handle individual table/chart extraction and model inference -def handle_table_chart_extraction( - annotation_dict, - original_image, - page_idx, - paddle_client, - deplot_client, - cached_client, - tables_and_charts, - extract_tables=True, - extract_charts=True, - paddle_version=None, - trace_info=None, +def extract_table_and_chart_images( + annotation_dict, + original_image, + page_idx, + tables_and_charts, ): """ Handle the extraction of tables and charts from the inference results and run additional model inference. @@ -336,12 +298,6 @@ def handle_table_chart_extraction( The original image from which objects were detected. page_idx : int The index of the current page being processed. - paddle_client : grpcclient.InferenceServerClient - The gRPC client for the paddle model used to process tables. - deplot_client : grpcclient.InferenceServerClient - The gRPC client for the deplot model used to process charts. - cached_client : grpcclient.InferenceServerClient - The gRPC client for the cached model used to process charts. tables_and_charts : List[Tuple[int, ImageTable]] A list to which extracted tables and charts will be appended. @@ -356,8 +312,7 @@ def handle_table_chart_extraction( >>> annotation_dict = {"table": [], "chart": []} >>> original_image = np.random.rand(1536, 1536, 3) >>> tables_and_charts = [] - >>> handle_table_chart_extraction(annotation_dict, original_image, 0, paddle_client, deplot_client, cached_client, - tables_and_charts) + >>> extract_table_and_chart_images(annotation_dict, original_image, 0, tables_and_charts) """ width, height, *_ = original_image.shape @@ -370,54 +325,26 @@ def handle_table_chart_extraction( *bbox, _ = bboxes h1, w1, h2, w2 = bbox * np.array([height, width, height, width]) - if extract_tables and label == "table": - # PaddleOCR NIM enforces minimum dimensions for TRT engines. - cropped = crop_image( - original_image, - (h1, w1, h2, w2), - min_width=PADDLE_MIN_WIDTH, - min_height=PADDLE_MIN_HEIGHT, - ) - if cropped is None: - continue - - base64_img = numpy_to_base64(cropped) - - if isinstance(paddle_client, grpcclient.InferenceServerClient): - cropped = preprocess_image_for_paddle(cropped, paddle_version=paddle_version) - - table_content = call_image_inference_model(paddle_client, "paddle", cropped, trace_info=trace_info) - table_data = ImageTable( - content=table_content, image=base64_img, bbox=(w1, h1, w2, h2), max_width=width, max_height=height - ) - tables_and_charts.append((page_idx, table_data)) + cropped = crop_image(original_image, (h1, w1, h2, w2)) + base64_img = numpy_to_base64(cropped) - elif extract_charts and label == "chart": - cropped = crop_image(original_image, (h1, w1, h2, w2)) - if cropped is None: - continue - - base64_img = numpy_to_base64(cropped) - - deplot_result = call_image_inference_model(deplot_client, "deplot", cropped, trace_info=trace_info) - cached_result = call_image_inference_model(cached_client, "cached", cropped, trace_info=trace_info) - chart_content = join_cached_and_deplot_output(cached_result, deplot_result) - chart_data = ImageChart( - content=chart_content, image=base64_img, bbox=(w1, h1, w2, h2), max_width=width, max_height=height - ) - tables_and_charts.append((page_idx, chart_data)) + table_data = CroppedImageWithContent( + content="", image=base64_img, bbox=(w1, h1, w2, h2), max_width=width, + max_height=height, type_string=label + ) + tables_and_charts.append((page_idx, table_data)) # Define a helper function to use unstructured-io to extract text from a base64 -# encoded bytestram PDF +# encoded bytestream PDF def pdfium( - pdf_stream, - extract_text: bool, - extract_images: bool, - extract_tables: bool, - extract_charts: bool, - trace_info=None, - **kwargs, + pdf_stream, + extract_text: bool, + extract_images: bool, + extract_tables: bool, + extract_charts: bool, + trace_info=None, + **kwargs, ): """ Helper function to use pdfium to extract text from a bytestream PDF. @@ -526,12 +453,13 @@ def pdfium( if obj_type == "IMAGE": try: # Attempt to retrieve the image bitmap - image_numpy: np.ndarray = pdfium_try_get_bitmap_as_numpy(obj) + image_numpy: np.ndarray = pdfium_try_get_bitmap_as_numpy(obj) # noqa image_base64: str = numpy_to_base64(image_numpy) image_bbox = obj.get_pos() image_size = obj.get_size() image_data = Base64Image( - image=image_base64, bbox=image_bbox, width=image_size[0], height=image_size[1], max_width=page_width, max_height=page_height + image=image_base64, bbox=image_bbox, width=image_size[0], height=image_size[1], + max_width=page_width, max_height=page_height ) extracted_image_data = construct_image_metadata( @@ -544,7 +472,7 @@ def pdfium( extracted_data.append(extracted_image_data) except Exception as e: - logger.error(f"Error extracting image: {e}") + logger.error(f"Unhandled error extracting image: {e}") pass # Pdfium failed to extract the image associated with this object - corrupt or missing. # Table and chart collection @@ -569,11 +497,9 @@ def pdfium( if extract_tables or extract_charts: for page_idx, table_and_charts in extract_tables_and_charts_using_image_ensemble( - pages, - pdfium_config, - extract_tables=extract_tables, - extract_charts=extract_charts, - trace_info=trace_info, + pages, + pdfium_config, + trace_info=trace_info, ): extracted_data.append( construct_table_and_chart_metadata( diff --git a/src/nv_ingest/extraction_workflows/pdf/yolox_utils.py b/src/nv_ingest/extraction_workflows/pdf/yolox_utils.py index dfc0c632..e8458ea0 100644 --- a/src/nv_ingest/extraction_workflows/pdf/yolox_utils.py +++ b/src/nv_ingest/extraction_workflows/pdf/yolox_utils.py @@ -113,6 +113,7 @@ def postprocess_results(results, original_image_shapes, min_score=0.0): out.append(annotation_dict) + # {label: [[x1, y1, x2, y2, confidence], ...], ...} return out diff --git a/src/nv_ingest/modules/filters/image_dedup.py b/src/nv_ingest/modules/filters/image_dedup.py index de817236..4d82d499 100644 --- a/src/nv_ingest/modules/filters/image_dedup.py +++ b/src/nv_ingest/modules/filters/image_dedup.py @@ -137,7 +137,7 @@ def _apply_dedup_filter(ctrl_msg: ControlMessage, filter_flag): gdf.drop(labels=["info_message_metadata", "metadata"], inplace=True, axis=1) gdf["info_message_metadata"] = duplicate_images_gdf["info_message_metadata"] gdf.loc[duplicate_images_gdf["document_type"].index, "document_type"] = ContentTypeEnum.INFO_MSG.value - gdf["metadata"] = gdf[exploded_metadata_cols + ["info_message_metadata"]].to_struct() + gdf["metadata"] = gdf[exploded_metadata_cols].to_struct() gdf.drop(labels=gdf.columns.difference(base_cols), inplace=True, axis=1) message_meta = MessageMeta(df=gdf) diff --git a/src/nv_ingest/modules/filters/image_filter.py b/src/nv_ingest/modules/filters/image_filter.py index 74579132..cfb98276 100644 --- a/src/nv_ingest/modules/filters/image_filter.py +++ b/src/nv_ingest/modules/filters/image_filter.py @@ -153,7 +153,7 @@ def _apply_filter(ctrl_msg: ControlMessage, task_params: dict): mdf.loc[ filtered_images_gdf["document_type"].index, "document_type" ] = ContentTypeEnum.INFO_MSG.value # noqa - mdf["metadata"] = mdf[exploded_metadata_cols + ["info_message_metadata"]].to_struct() # noqa + mdf["metadata"] = mdf[exploded_metadata_cols].to_struct() # noqa mdf.drop(labels=mdf.columns.difference(base_cols), inplace=True, axis=1) # noqa diff --git a/src/nv_ingest/modules/sources/redis_task_source.py b/src/nv_ingest/modules/sources/redis_task_source.py index fd80b499..7890240a 100644 --- a/src/nv_ingest/modules/sources/redis_task_source.py +++ b/src/nv_ingest/modules/sources/redis_task_source.py @@ -4,10 +4,12 @@ import logging +import time import traceback from datetime import datetime from functools import partial from typing import Dict +import copy, json import cudf import mrc @@ -30,7 +32,6 @@ MODULE_NAMESPACE = "nv_ingest" RedisTaskSourceLoaderFactory = ModuleLoaderFactory(MODULE_NAME, MODULE_NAMESPACE) - def fetch_and_process_messages(redis_client: RedisClient, validated_config: RedisTaskSourceSchema): """Fetch messages from the Redis list and process them.""" @@ -53,10 +54,12 @@ def process_message(job: Dict, ts_fetched: datetime) -> ControlMessage: Fetch messages from the Redis list (task queue) and yield as ControlMessage. """ + if logger.isEnabledFor(logging.DEBUG): + no_payload = copy.deepcopy(job) + no_payload["job_payload"]["content"] = ["[...]"] # Redact the payload for logging + logger.debug("Job: %s", json.dumps(no_payload, indent=2)) + validate_ingest_job(job) - # no_payload = copy.deepcopy(job) - # no_payload["job_payload"]["content"] = ["[...]"] # Redact the payload for logging - # logger.debug("Job: %s", json.dumps(no_payload, indent=2)) control_message = ControlMessage() try: diff --git a/src/nv_ingest/schemas/chart_extractor_schema.py b/src/nv_ingest/schemas/chart_extractor_schema.py new file mode 100644 index 00000000..ce6dd0cc --- /dev/null +++ b/src/nv_ingest/schemas/chart_extractor_schema.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Optional, Tuple + +from pydantic import BaseModel, root_validator, validator + +logger = logging.getLogger(__name__) + + +class ChartExtractorConfigSchema(BaseModel): + """ + Configuration schema for chart extraction service endpoints and options. + + Parameters + ---------- + auth_token : Optional[str], default=None + Authentication token required for secure services. + + cached_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None) + A tuple containing the gRPC and HTTP services for the cached endpoint. + Either the gRPC or HTTP service can be empty, but not both. + + deplot_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None) + A tuple containing the gRPC and HTTP services for the deplot endpoint. + Either the gRPC or HTTP service can be empty, but not both. + + paddle_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None) + A tuple containing the gRPC and HTTP services for the paddle endpoint. + Either the gRPC or HTTP service can be empty, but not both. + + Methods + ------- + validate_endpoints(values) + Validates that at least one of the gRPC or HTTP services is provided for each endpoint. + + Raises + ------ + ValueError + If both gRPC and HTTP services are empty for any endpoint. + + Config + ------ + extra : str + Pydantic config option to forbid extra fields. + """ + + auth_token: Optional[str] = None + + cached_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) + cached_infer_protocol: str = "" + + deplot_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) + deplot_infer_protocol: str = "" + + ## NOTE: Paddle isn't currently called independently of the cached NIM, but will be in the future. + paddle_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) + paddle_infer_protocol: str = "" + + @root_validator(pre=True) + def validate_endpoints(cls, values): + """ + Validates the gRPC and HTTP services for all endpoints. + + Ensures that at least one service (either gRPC or HTTP) is provided + for each endpoint in the configuration. + + Parameters + ---------- + values : dict + Dictionary containing the values of the attributes for the class. + + Returns + ------- + dict + The validated dictionary of values. + + Raises + ------ + ValueError + If both gRPC and HTTP services are empty for any endpoint. + """ + + def clean_service(service): + """Set service to None if it's an empty string or contains only spaces or quotes.""" + if service is None or not service.strip() or service.strip(" \"'") == "": + return None + return service + + for endpoint_name in ["cached_endpoints", "deplot_endpoints", "paddle_endpoints"]: + grpc_service, http_service = values.get(endpoint_name, (None, None)) + grpc_service = clean_service(grpc_service) + http_service = clean_service(http_service) + + if not grpc_service and not http_service: + raise ValueError(f"Both gRPC and HTTP services cannot be empty for {endpoint_name}.") + + values[endpoint_name] = (grpc_service, http_service) + + return values + + class Config: + extra = "forbid" + + +class ChartExtractorSchema(BaseModel): + """ + Configuration schema for chart extraction processing settings. + + Parameters + ---------- + max_queue_size : int, default=1 + The maximum number of items allowed in the processing queue. + + n_workers : int, default=2 + The number of worker threads to use for processing. + + raise_on_failure : bool, default=False + A flag indicating whether to raise an exception if a failure occurs during chart extraction. + + stage_config : Optional[ChartExtractorConfigSchema], default=None + Configuration for the chart extraction stage, including cached, deplot, and paddle service endpoints. + """ + + max_queue_size: int = 1 + n_workers: int = 2 + raise_on_failure: bool = False + + stage_config: Optional[ChartExtractorConfigSchema] = None + + @validator('max_queue_size', 'n_workers', pre=True, always=True) + def check_positive(cls, v, field): + if v <= 0: + raise ValueError(f"{field.name} must be greater than 10.") + return v + + class Config: + extra = "forbid" diff --git a/src/nv_ingest/schemas/ingest_job_schema.py b/src/nv_ingest/schemas/ingest_job_schema.py index c151c1fb..8b2dc7ef 100644 --- a/src/nv_ingest/schemas/ingest_job_schema.py +++ b/src/nv_ingest/schemas/ingest_job_schema.py @@ -44,6 +44,8 @@ class TaskTypeEnum(str, Enum): split = "split" store = "store" vdb_upload = "vdb_upload" + table_data_extract = "table_data_extract" + chart_data_extract = "chart_data_extract" class FilterTypeEnum(str, Enum): @@ -128,6 +130,12 @@ class IngestTaskVdbUploadSchema(BaseModelNoExt): filter_errors: bool = True +class IngestTaskTableExtraction(BaseModelNoExt): + params: Dict = {} + +class IngestChartTableExtraction(BaseModelNoExt): + params: Dict = {} + class IngestTaskSchema(BaseModelNoExt): type: TaskTypeEnum task_properties: Union[ @@ -139,6 +147,8 @@ class IngestTaskSchema(BaseModelNoExt): IngestTaskDedupSchema, IngestTaskFilterSchema, IngestTaskVdbUploadSchema, + IngestTaskTableExtraction, + IngestChartTableExtraction ] raise_on_failure: bool = False @@ -155,6 +165,8 @@ def check_task_properties_type(cls, values): TaskTypeEnum.split: IngestTaskSplitSchema, TaskTypeEnum.store: IngestTaskStoreSchema, TaskTypeEnum.vdb_upload: IngestTaskVdbUploadSchema, + TaskTypeEnum.table_data_extract: IngestTaskTableExtraction, + TaskTypeEnum.chart_data_extract: IngestChartTableExtraction, }.get(task_type.lower()) # logger.debug(f"Checking task_properties type for task type '{task_type}'") diff --git a/src/nv_ingest/schemas/metadata_schema.py b/src/nv_ingest/schemas/metadata_schema.py index 65cb45d6..6a51c83d 100644 --- a/src/nv_ingest/schemas/metadata_schema.py +++ b/src/nv_ingest/schemas/metadata_schema.py @@ -4,6 +4,7 @@ from datetime import datetime +import logging from enum import Enum from typing import Any from typing import Dict @@ -17,6 +18,8 @@ from nv_ingest.schemas.base_model_noext import BaseModelNoExt from nv_ingest.util.converters import datetools +logger = logging.getLogger(__name__) + # Do we want types and similar items to be enums or just strings? class SourceTypeEnum(str, Enum): @@ -246,6 +249,13 @@ class TextMetadataSchema(BaseModelNoExt): text_location: tuple = (0, 0, 0, 0) +import logging +from pydantic import validator + +# Set up logging +logger = logging.getLogger(__name__) + + class ImageMetadataSchema(BaseModelNoExt): image_type: Union[ImageTypeEnum, str] structured_image_type: ImageTypeEnum = ImageTypeEnum.image_type_1 @@ -257,6 +267,19 @@ class ImageMetadataSchema(BaseModelNoExt): width: int = 0 height: int = 0 + @validator("image_type", pre=True, always=True) + def validate_image_type(cls, v): + if not isinstance(v, (ImageTypeEnum, str)): + raise ValueError("image_type must be a string or ImageTypeEnum") + return v + + @validator("width", "height", pre=True, always=True) + def clamp_non_negative(cls, v, field): + if v < 0: + logger.warning(f"{field.name} is negative; clamping to 0. Original value: {v}") + return 0 + return v + class TableMetadataSchema(BaseModelNoExt): caption: str = "" @@ -267,6 +290,15 @@ class TableMetadataSchema(BaseModelNoExt): uploaded_image_uri: str = "" +class ChartMetadataSchema(BaseModelNoExt): + caption: str = "" + table_format: TableFormatEnum + table_content: str = "" + table_location: tuple = (0, 0, 0, 0) + table_location_max_dimensions: tuple = (0, 0) + uploaded_image_uri: str = "" + + # TODO consider deprecating this in favor of info msg... class ErrorMetadataSchema(BaseModelNoExt): task: TaskTypeEnum @@ -291,6 +323,7 @@ class MetadataSchema(BaseModelNoExt): text_metadata: Optional[TextMetadataSchema] = None image_metadata: Optional[ImageMetadataSchema] = None table_metadata: Optional[TableMetadataSchema] = None + chart_metadata: Optional[ChartMetadataSchema] = None error_metadata: Optional[ErrorMetadataSchema] = None info_message_metadata: Optional[InfoMessageMetadataSchema] = None debug_metadata: Optional[Dict[str, Any]] = None diff --git a/src/nv_ingest/schemas/pdf_extractor_schema.py b/src/nv_ingest/schemas/pdf_extractor_schema.py index 85a1716b..9d0f538a 100644 --- a/src/nv_ingest/schemas/pdf_extractor_schema.py +++ b/src/nv_ingest/schemas/pdf_extractor_schema.py @@ -22,25 +22,10 @@ class PDFiumConfigSchema(BaseModel): auth_token : Optional[str], default=None Authentication token required for secure services. - cached_endpoints : Tuple[str, str] - A tuple containing the gRPC and HTTP services for the cached endpoint. - Either the gRPC or HTTP service can be empty, but not both. - - deplot_endpoints : Tuple[str, str] - A tuple containing the gRPC and HTTP services for the deplot endpoint. - Either the gRPC or HTTP service can be empty, but not both. - - paddle_endpoints : Tuple[str, str] - A tuple containing the gRPC and HTTP services for the paddle endpoint. - Either the gRPC or HTTP service can be empty, but not both. - yolox_endpoints : Tuple[str, str] A tuple containing the gRPC and HTTP services for the yolox endpoint. Either the gRPC or HTTP service can be empty, but not both. - identify_nearby_objects : bool, default=False - A flag indicating whether to identify nearby objects during processing. - Methods ------- validate_endpoints(values) @@ -59,18 +44,9 @@ class PDFiumConfigSchema(BaseModel): auth_token: Optional[str] = None - cached_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) - deplot_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) - paddle_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) yolox_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) - - cached_infer_protocol: str = "" - deplot_infer_protocol: str = "" - paddle_infer_protocol: str = "" yolox_infer_protocol: str = "" - identify_nearby_objects: bool = False - @root_validator(pre=True) def validate_endpoints(cls, values): """ @@ -98,7 +74,7 @@ def clean_service(service): return None return service - for model_name in ["cached", "deplot", "paddle", "yolox"]: + for model_name in ["yolox"]: endpoint_name = f"{model_name}_endpoints" grpc_service, http_service = values.get(endpoint_name) grpc_service = clean_service(grpc_service) diff --git a/src/nv_ingest/schemas/table_extractor_schema.py b/src/nv_ingest/schemas/table_extractor_schema.py new file mode 100644 index 00000000..043a055e --- /dev/null +++ b/src/nv_ingest/schemas/table_extractor_schema.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import logging + +from typing import Optional, Tuple +from pydantic import BaseModel, root_validator, validator + +logger = logging.getLogger(__name__) + + +class TableExtractorConfigSchema(BaseModel): + """ + Configuration schema for the table extraction stage settings. + + Parameters + ---------- + auth_token : Optional[str], default=None + Authentication token required for secure services. + + paddle_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None) + A tuple containing the gRPC and HTTP services for the paddle endpoint. + Either the gRPC or HTTP service can be empty, but not both. + + Methods + ------- + validate_endpoints(values) + Validates that at least one of the gRPC or HTTP services is provided for the yolox endpoint. + + Raises + ------ + ValueError + If both gRPC and HTTP services are empty for the yolox endpoint. + + Config + ------ + extra : str + Pydantic config option to forbid extra fields. + """ + + auth_token: Optional[str] = None + + paddle_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) + paddle_infer_protocol: str = "" + + @root_validator(pre=True) + def validate_endpoints(cls, values): + """ + Validates the gRPC and HTTP services for the yolox endpoint. + + Parameters + ---------- + values : dict + Dictionary containing the values of the attributes for the class. + + Returns + ------- + dict + The validated dictionary of values. + + Raises + ------ + ValueError + If both gRPC and HTTP services are empty for the yolox endpoint. + """ + + def clean_service(service): + """Set service to None if it's an empty string or contains only spaces or quotes.""" + if service is None or not service.strip() or service.strip(" \"'") == "": + return None + return service + + grpc_service, http_service = values.get("paddle_endpoints", (None, None)) + grpc_service = clean_service(grpc_service) + http_service = clean_service(http_service) + + if not grpc_service and not http_service: + raise ValueError("Both gRPC and HTTP services cannot be empty for paddle_endpoints.") + + values["paddle_endpoints"] = (grpc_service, http_service) + + return values + + class Config: + extra = "forbid" + + +class TableExtractorSchema(BaseModel): + """ + Configuration schema for the table extraction processing settings. + + Parameters + ---------- + max_queue_size : int, default=1 + The maximum number of items allowed in the processing queue. + + n_workers : int, default=2 + The number of worker threads to use for processing. + + raise_on_failure : bool, default=False + A flag indicating whether to raise an exception if a failure occurs during table extraction. + + stage_config : Optional[TableExtractorConfigSchema], default=None + Configuration for the table extraction stage, including yolox service endpoints. + """ + + max_queue_size: int = 1 + n_workers: int = 2 + raise_on_failure: bool = False + + @validator('max_queue_size', 'n_workers', pre=True, always=True) + def check_positive(cls, v, field): + if v <= 0: + raise ValueError(f"{field.name} must be greater than 10.") + return v + + stage_config: Optional[TableExtractorConfigSchema] = None + + class Config: + extra = "forbid" diff --git a/src/nv_ingest/stages/nim/__init__.py b/src/nv_ingest/stages/nim/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/nv_ingest/stages/nim/chart_extraction.py b/src/nv_ingest/stages/nim/chart_extraction.py new file mode 100644 index 00000000..46339228 --- /dev/null +++ b/src/nv_ingest/stages/nim/chart_extraction.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import functools +import pandas as pd +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple + +import tritonclient.grpc as grpcclient +from morpheus.config import Config + +from nv_ingest.schemas.chart_extractor_schema import ChartExtractorSchema +from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage +from nv_ingest.util.image_processing.table_and_chart import join_cached_and_deplot_output +from nv_ingest.util.image_processing.transforms import base64_to_numpy +from nv_ingest.util.nim.helpers import call_image_inference_model, create_inference_client + +logger = logging.getLogger(f"morpheus.{__name__}") + + +def _update_metadata(row: pd.Series, cached_client: Any, deplot_client: Any, trace_info: Dict) -> Dict: + """ + Modifies the metadata of a row if the conditions for chart extraction are met. + + Parameters + ---------- + row : pd.Series + A row from the DataFrame containing metadata for the chart extraction. + + cached_client : Any + The client used to call the cached inference model. + + deplot_client : Any + The client used to call the deplot inference model. + + trace_info : Dict + Trace information used for logging or debugging. + + Returns + ------- + Dict + The modified metadata if conditions are met, otherwise the original metadata. + + Raises + ------ + ValueError + If critical information (such as metadata) is missing from the row. + """ + metadata = row.get("metadata") + if metadata is None: + logger.error("Row does not contain 'metadata'.") + raise ValueError("Row does not contain 'metadata'.") + + base64_image = metadata.get("content") + content_metadata = metadata.get("content_metadata", {}) + chart_metadata = metadata.get("table_metadata") + + # Only modify if content type is structured and subtype is 'chart' and chart_metadata exists + if ((content_metadata.get("type") != "structured") or + (content_metadata.get("subtype") != "chart") or + (chart_metadata is None)): + return metadata + + # Modify chart metadata with the result from the inference model + try: + image_array = base64_to_numpy(base64_image) + + deplot_result = call_image_inference_model(deplot_client, "deplot", image_array, trace_info=trace_info) + cached_result = call_image_inference_model(cached_client, "cached", image_array, trace_info=trace_info) + chart_content = join_cached_and_deplot_output(cached_result, deplot_result) + + chart_metadata["table_content"] = chart_content + except Exception as e: + logger.error(f"Unhandled error calling image inference model: {e}", exc_info=True) + raise + + return metadata + + +def _extract_chart_data(df: pd.DataFrame, task_props: Dict[str, Any], + validated_config: Any, trace_info: Optional[Dict] = None) -> Tuple[pd.DataFrame, Dict]: + """ + Extracts chart data from a DataFrame. + + Parameters + ---------- + df : pd.DataFrame + DataFrame containing the content from which chart data is to be extracted. + + task_props : Dict[str, Any] + Dictionary containing task properties and configurations. + + validated_config : Any + The validated configuration object for chart extraction. + + trace_info : Optional[Dict], optional + Optional trace information for debugging or logging. Defaults to None. + + Returns + ------- + Tuple[pd.DataFrame, Dict] + A tuple containing the updated DataFrame and the trace information. + + Raises + ------ + Exception + If any error occurs during the chart data extraction process. + """ + + _ = task_props # unused + + deplot_client = create_inference_client( + validated_config.stage_config.deplot_endpoints, + validated_config.stage_config.auth_token, + validated_config.stage_config.deplot_infer_protocol + ) + + cached_client = create_inference_client( + validated_config.stage_config.cached_endpoints, + validated_config.stage_config.auth_token, + validated_config.stage_config.cached_infer_protocol + ) + + if trace_info is None: + trace_info = {} + logger.debug("No trace_info provided. Initialized empty trace_info dictionary.") + + try: + # Apply the _update_metadata function to each row in the DataFrame + df["metadata"] = df.apply(_update_metadata, axis=1, args=(cached_client, deplot_client, trace_info)) + + return df, trace_info + + except Exception as e: + logger.error("Error occurred while extracting chart data.", exc_info=True) + raise + finally: + if (isinstance(cached_client, grpcclient.InferenceServerClient)): + cached_client.close() + if (isinstance(deplot_client, grpcclient.InferenceServerClient)): + deplot_client.close() + + +def generate_chart_extractor_stage( + c: Config, + stage_config: Dict[str, Any], + task: str = "chart_data_extract", + task_desc: str = "chart_data_extraction", + pe_count: int = 1, +): + """ + Generates a multiprocessing stage to perform chart data extraction from PDF content. + + Parameters + ---------- + c : Config + Morpheus global configuration object. + + stage_config : Dict[str, Any] + Configuration parameters for the chart content extractor, passed as a dictionary + validated against the `ChartExtractorSchema`. + + task : str, optional + The task name for the stage worker function, defining the specific chart extraction process. + Default is "chart_data_extract". + + task_desc : str, optional + A descriptor used for latency tracing and logging during chart extraction. + Default is "chart_data_extraction". + + pe_count : int, optional + The number of process engines to use for chart data extraction. This value controls + how many worker processes will run concurrently. Default is 1. + + Returns + ------- + MultiProcessingBaseStage + A configured Morpheus stage with an applied worker function that handles chart data extraction + from PDF content. + """ + + validated_config = ChartExtractorSchema(**stage_config) + _wrapped_process_fn = functools.partial(_extract_chart_data, validated_config=validated_config) + + return MultiProcessingBaseStage( + c=c, pe_count=pe_count, task=task, task_desc=task_desc, process_fn=_wrapped_process_fn + ) diff --git a/src/nv_ingest/stages/nim/table_extraction.py b/src/nv_ingest/stages/nim/table_extraction.py new file mode 100644 index 00000000..9154e11d --- /dev/null +++ b/src/nv_ingest/stages/nim/table_extraction.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import functools +import pandas as pd +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple + +import tritonclient.grpc as grpcclient +from morpheus.config import Config +from nv_ingest.schemas.table_extractor_schema import TableExtractorSchema +from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage +from nv_ingest.util.nim.helpers import call_image_inference_model, create_inference_client +from nv_ingest.util.image_processing.transforms import base64_to_numpy +from nv_ingest.util.image_processing.transforms import check_numpy_image_size + +logger = logging.getLogger(f"morpheus.{__name__}") + +PADDLE_MIN_WIDTH = 32 +PADDLE_MIN_HEIGHT = 32 + + +def _update_metadata(row: pd.Series, paddle_client: Any, trace_info: Dict) -> Dict: + """ + Modifies the metadata of a row if the conditions for table extraction are met. + + Parameters + ---------- + row : pd.Series + A row from the DataFrame containing metadata for the table extraction. + + paddle_client : Any + The client used to call the image inference model. + + trace_info : Dict + Trace information used for logging or debugging. + + Returns + ------- + Dict + The modified metadata if conditions are met, otherwise the original metadata. + + Raises + ------ + ValueError + If critical information (such as metadata) is missing from the row. + """ + + metadata = row.get("metadata") + if metadata is None: + logger.error("Row does not contain 'metadata'.") + raise ValueError("Row does not contain 'metadata'.") + + base64_image = metadata.get("content") + content_metadata = metadata.get("content_metadata", {}) + table_metadata = metadata.get("table_metadata") + + # Only modify if content type is structured and subtype is 'table' and table_metadata exists + if ((content_metadata.get("type") != "structured") or + (content_metadata.get("subtype") != "table") or + (table_metadata is None)): + return metadata + + # Modify table metadata with the result from the inference model + try: + image_array = base64_to_numpy(base64_image) + paddle_result = "" + if check_numpy_image_size(image_array, PADDLE_MIN_WIDTH, PADDLE_MIN_HEIGHT): + paddle_result = call_image_inference_model(paddle_client, "paddle", image_array, trace_info=trace_info) + + table_metadata["table_content"] = paddle_result + except Exception as e: + logger.error(f"Unhandled error calling image inference model: {e}", exc_info=True) + raise + + return metadata + + +def _extract_table_data(df: pd.DataFrame, task_props: Dict[str, Any], + validated_config: Any, trace_info: Optional[Dict] = None) -> Tuple[pd.DataFrame, Dict]: + """ + Extracts table data from a DataFrame. + + Parameters + ---------- + df : pd.DataFrame + DataFrame containing the content from which table data is to be extracted. + + task_props : Dict[str, Any] + Dictionary containing task properties and configurations. + + validated_config : Any + The validated configuration object for table extraction. + + trace_info : Optional[Dict], optional + Optional trace information for debugging or logging. Defaults to None. + + Returns + ------- + Tuple[pd.DataFrame, Dict] + A tuple containing the updated DataFrame and the trace information. + + Raises + ------ + Exception + If any error occurs during the table data extraction process. + """ + + _ = task_props # unused + + paddle_client = create_inference_client( + validated_config.stage_config.paddle_endpoints, + validated_config.stage_config.auth_token, + validated_config.stage_config.paddle_infer_protocol + ) + + if trace_info is None: + trace_info = {} + logger.debug("No trace_info provided. Initialized empty trace_info dictionary.") + + try: + # Apply the _update_metadata function to each row in the DataFrame + df["metadata"] = df.apply(_update_metadata, axis=1, args=(paddle_client, trace_info)) + + return df, trace_info + + except Exception as e: + logger.error("Error occurred while extracting table data.", exc_info=True) + raise + finally: + if (isinstance(paddle_client, grpcclient.InferenceServerClient)): + paddle_client.close() + + +def generate_table_extractor_stage( + c: Config, + stage_config: Dict[str, Any], + task: str = "table_data_extract", + task_desc: str = "table_data_extraction", + pe_count: int = 1, +): + """ + Generates a multiprocessing stage to perform table data extraction from PDF content. + + Parameters + ---------- + c : Config + Morpheus global configuration object. + + stage_config : Dict[str, Any] + Configuration parameters for the table content extractor, passed as a dictionary + validated against the `TableExtractorSchema`. + + task : str, optional + The task name for the stage worker function, defining the specific table extraction process. + Default is "table_data_extract". + + task_desc : str, optional + A descriptor used for latency tracing and logging during table extraction. + Default is "table_data_extraction". + + pe_count : int, optional + The number of process engines to use for table data extraction. This value controls + how many worker processes will run concurrently. Default is 1. + + Returns + ------- + MultiProcessingBaseStage + A configured Morpheus stage with an applied worker function that handles table data extraction + from PDF content. + """ + + validated_config = TableExtractorSchema(**stage_config) + _wrapped_process_fn = functools.partial(_extract_table_data, validated_config=validated_config) + + return MultiProcessingBaseStage( + c=c, pe_count=pe_count, task=task, task_desc=task_desc, process_fn=_wrapped_process_fn + ) diff --git a/src/nv_ingest/stages/pdf_extractor_stage.py b/src/nv_ingest/stages/pdf_extractor_stage.py index 072fce49..c5aaef85 100644 --- a/src/nv_ingest/stages/pdf_extractor_stage.py +++ b/src/nv_ingest/stages/pdf_extractor_stage.py @@ -61,7 +61,7 @@ def decode_and_extract( try: base64_content = base64_row["content"] except KeyError: - log_error_message = f"NO CONTENT FOUND IN ROW:\n{base64_row}" + log_error_message = f"Unhandled error processing row, no content was found:\n{base64_row}" logger.error(log_error_message) raise diff --git a/src/nv_ingest/util/image_processing/transforms.py b/src/nv_ingest/util/image_processing/transforms.py index f5e1dc11..d441db72 100644 --- a/src/nv_ingest/util/image_processing/transforms.py +++ b/src/nv_ingest/util/image_processing/transforms.py @@ -2,6 +2,7 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import base64 from io import BytesIO from math import ceil from math import floor @@ -10,6 +11,7 @@ import numpy as np from PIL import Image +from PIL import UnidentifiedImageError from nv_ingest.util.converters import bytetools @@ -77,6 +79,24 @@ def pad_image( return canvas, (pad_width, pad_height) +def check_numpy_image_size(image: np.ndarray, min_height: int, min_width: int) -> bool: + """ + Checks if the height and width of the image are larger than the specified minimum values. + + Parameters: + image (np.ndarray): The image array (assumed to be in shape (H, W, C) or (H, W)). + min_height (int): The minimum height required. + min_width (int): The minimum width required. + + Returns: + bool: True if the image dimensions are larger than or equal to the minimum size, False otherwise. + """ + # Check if the image has at least 2 dimensions + if image.ndim < 2: + raise ValueError("The input array does not have sufficient dimensions for an image.") + + height, width = image.shape[:2] + return height >= min_height and width >= min_width def crop_image( array: np.array, bbox: Tuple[int, int, int, int], min_width: int = 1, min_height: int = 1 @@ -232,3 +252,51 @@ def numpy_to_base64(array: np.ndarray) -> str: raise RuntimeError(f"Failed to encode image to base64: {e}") return base64_img + + +def base64_to_numpy(base64_string: str) -> np.ndarray: + """ + Convert a base64-encoded image string to a NumPy array. + + Parameters + ---------- + base64_string : str + Base64-encoded string representing an image. + + Returns + ------- + numpy.ndarray + NumPy array representation of the decoded image. + + Raises + ------ + ValueError + If the base64 string is invalid or cannot be decoded into an image. + ImportError + If required libraries are not installed. + + Examples + -------- + >>> base64_str = '/9j/4AAQSkZJRgABAQAAAQABAAD/2wBD...' + >>> img_array = base64_to_numpy(base64_str) + """ + try: + # Decode the base64 string + image_data = base64.b64decode(base64_string) + except (base64.binascii.Error, ValueError) as e: + raise ValueError("Invalid base64 string") from e + + try: + # Convert the bytes into a BytesIO object + image_bytes = BytesIO(image_data) + + # Open the image using PIL + image = Image.open(image_bytes) + image.load() + except UnidentifiedImageError as e: + raise ValueError("Unable to decode image from base64 string") from e + + # Convert the image to a NumPy array + image_array = np.array(image) + + return image_array diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py index 7989bf7b..99c4a506 100644 --- a/src/nv_ingest/util/nim/helpers.py +++ b/src/nv_ingest/util/nim/helpers.py @@ -51,6 +51,7 @@ def create_inference_client( ------- grpcclient.InferenceServerClient or dict A gRPC client if the gRPC endpoint is provided, otherwise a dictionary containing the HTTP client details. + :param infer_protocol: """ grpc_endpoint, http_endpoint = endpoints diff --git a/src/nv_ingest/util/pdf/metadata_aggregators.py b/src/nv_ingest/util/pdf/metadata_aggregators.py index 0c2a3618..851a931b 100644 --- a/src/nv_ingest/util/pdf/metadata_aggregators.py +++ b/src/nv_ingest/util/pdf/metadata_aggregators.py @@ -10,7 +10,6 @@ from typing import Dict from typing import List from typing import Tuple -from typing import Union import pandas as pd import pypdfium2 as pdfium @@ -26,22 +25,15 @@ from nv_ingest.util.exception_handlers.pdf import pdfium_exception_handler +# TODO(Devin): Shift to this, since there is no difference between ImageTable and ImageChart @dataclass -class ImageTable: - content: str - image: str - bbox: Tuple[int, int, int, int] - max_width: int - max_height: int - - -@dataclass -class ImageChart: +class CroppedImageWithContent: content: str image: str bbox: Tuple[int, int, int, int] max_width: int max_height: int + type_string: str @dataclass @@ -139,16 +131,16 @@ def extract_pdf_metadata(doc: pdfium.PdfDocument, source_id: str) -> PDFMetadata def construct_text_metadata( - accumulated_text, - keywords, - page_idx, - block_idx, - line_idx, - span_idx, - page_count, - text_depth, - source_metadata, - base_unified_metadata, + accumulated_text, + keywords, + page_idx, + block_idx, + line_idx, + span_idx, + page_count, + text_depth, + source_metadata, + base_unified_metadata, ): extracted_text = " ".join(accumulated_text) @@ -195,11 +187,11 @@ def construct_text_metadata( def construct_image_metadata( - image_base64: Base64Image, - page_idx: int, - page_count: int, - source_metadata: Dict[str, Any], - base_unified_metadata: Dict[str, Any], + image_base64: Base64Image, + page_idx: int, + page_count: int, + source_metadata: Dict[str, Any], + base_unified_metadata: Dict[str, Any], ) -> List[Any]: """ Extracts image data from a PdfImage object, converts it to a base64-encoded string, @@ -254,7 +246,7 @@ def construct_image_metadata( "caption": "", "text": "", "image_location": image_base64.bbox, - "image_location_max_dimensions": (image_base64.max_width, image_base64.max_height), + "image_location_max_dimensions": (max(image_base64.max_width,0), max(image_base64.max_height,0)), "height": image_base64.height, } @@ -277,11 +269,11 @@ def construct_image_metadata( # TODO(Devin): Disambiguate tables and charts, create two distinct processing methods @pdfium_exception_handler(descriptor="pdfium") def construct_table_and_chart_metadata( - table: Union[ImageTable, ImageChart], - page_idx: int, - page_count: int, - source_metadata: Dict, - base_unified_metadata: Dict, + structured_image: CroppedImageWithContent, + page_idx: int, + page_count: int, + source_metadata: Dict, + base_unified_metadata: Dict, ): """ +--------------------------------+--------------------------+------------+---+ @@ -311,22 +303,25 @@ def construct_table_and_chart_metadata( +--------------------------------+--------------------------+------------+---+ """ - if isinstance(table, ImageTable): - content = table.image - structured_content_text = table.content + if (structured_image.type_string in ("table",)): + content = structured_image.image + structured_content_text = structured_image.content table_format = TableFormatEnum.IMAGE subtype = ContentSubtypeEnum.TABLE description = StdContentDescEnum.PDF_TABLE + meta_name = "table_metadata" - elif isinstance(table, ImageChart): - content = table.image - structured_content_text = table.content + elif (structured_image.type_string in ("chart",)): + content = structured_image.image + structured_content_text = structured_image.content table_format = TableFormatEnum.IMAGE subtype = ContentSubtypeEnum.CHART description = StdContentDescEnum.PDF_CHART + # TODO(Devin) swap this to chart_metadata after we confirm metadata schema changes. + meta_name = "table_metadata" else: - raise ValueError("Unknown table/chart type.") + raise ValueError(f"Unknown table/chart type: {structured_image.type_string}") content_metadata = { "type": ContentTypeEnum.STRUCTURED, @@ -341,12 +336,12 @@ def construct_table_and_chart_metadata( "subtype": subtype, } - table_metadata = { + structured_metadata = { "caption": "", "table_format": table_format, "table_content": structured_content_text, - "table_location": table.bbox, - "table_location_max_dimensions": (table.max_width, table.max_height), + "table_location": structured_image.bbox, + "table_location_max_dimensions": (structured_image.max_width, structured_image.max_height), } ext_unified_metadata = base_unified_metadata.copy() @@ -356,7 +351,7 @@ def construct_table_and_chart_metadata( "content": content, "source_metadata": source_metadata, "content_metadata": content_metadata, - "table_metadata": table_metadata, + meta_name: structured_metadata, } ) diff --git a/src/nv_ingest/util/pipeline/__init__.py b/src/nv_ingest/util/pipeline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/nv_ingest/util/pipeline/stage_builders.py b/src/nv_ingest/util/pipeline/stage_builders.py new file mode 100644 index 00000000..ee8ae3fe --- /dev/null +++ b/src/nv_ingest/util/pipeline/stage_builders.py @@ -0,0 +1,501 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import math +import os +import logging +import typing + +import click +from morpheus.messages import ControlMessage +from morpheus.stages.general.linear_modules_source import LinearModuleSourceStage +from morpheus.stages.general.linear_modules_stage import LinearModulesStage + +from nv_ingest.modules.injectors.metadata_injector import MetadataInjectorLoaderFactory +from nv_ingest.modules.sinks.redis_task_sink import RedisTaskSinkLoaderFactory +from nv_ingest.modules.sinks.vdb_task_sink import VDBTaskSinkLoaderFactory +from nv_ingest.modules.sources.redis_task_source import RedisTaskSourceLoaderFactory +from nv_ingest.modules.telemetry.job_counter import JobCounterLoaderFactory +from nv_ingest.modules.telemetry.otel_meter import OpenTelemetryMeterLoaderFactory +from nv_ingest.modules.telemetry.otel_tracer import OpenTelemetryTracerLoaderFactory +from nv_ingest.modules.transforms.embed_extractions import EmbedExtractionsLoaderFactory +from nv_ingest.modules.transforms.nemo_doc_splitter import NemoDocSplitterLoaderFactory +from nv_ingest.stages.docx_extractor_stage import generate_docx_extractor_stage +from nv_ingest.stages.filters import generate_dedup_stage +from nv_ingest.stages.filters import generate_image_filter_stage +from nv_ingest.stages.nim.chart_extraction import generate_chart_extractor_stage +from nv_ingest.stages.nim.table_extraction import generate_table_extractor_stage +from nv_ingest.stages.pdf_extractor_stage import generate_pdf_extractor_stage +from nv_ingest.stages.pptx_extractor_stage import generate_pptx_extractor_stage +from nv_ingest.stages.storages.image_storage_stage import ImageStorageStage +from nv_ingest.stages.transforms.image_caption_extraction import generate_caption_extraction_stage + +logger = logging.getLogger(__name__) + + +def validate_positive(ctx, param, value): + if value <= 0: + raise click.BadParameter("must be a positive integer") + return value + + +def get_message_provider_config(): + message_provider_host = os.environ.get("MESSAGE_CLIENT_HOST", "localhost") + message_provider_port = os.environ.get("MESSAGE_CLIENT_PORT", "6379") + + logger.info(f"MESSAGE_CLIENT_HOST: {message_provider_host}") + logger.info(f"MESSAGE_CLIENT_PORT: {message_provider_port}") + + return message_provider_host, message_provider_port + + +def get_caption_classifier_service(): + triton_service_caption_classifier = os.environ.get( + "CAPTION_CLASSIFIER_GRPC_TRITON", + "", + ) + triton_service_caption_classifier_name = os.environ.get( + "CAPTION_CLASSIFIER_MODEL_NAME", + "", + ) + + logger.info(f"CAPTION_CLASSIFIER_GRPC_TRITON: {triton_service_caption_classifier}") + + return triton_service_caption_classifier, triton_service_caption_classifier_name + + +def get_table_detection_service(env_var_prefix): + prefix = env_var_prefix.upper() + grpc_endpoint = os.environ.get( + f"{prefix}_GRPC_ENDPOINT", + "", + ) + http_endpoint = os.environ.get( + f"{prefix}_HTTP_ENDPOINT", + "", + ) + auth_token = os.environ.get( + "NVIDIA_BUILD_API_KEY", + "", + ) or os.environ.get( + "NGC_API_KEY", + "", + ) + infer_protocol = os.environ.get( + f"{prefix}_INFER_PROTOCOL", + "http" if http_endpoint else "grpc" if grpc_endpoint else "", + ) + + logger.info(f"{prefix}_GRPC_TRITON: {grpc_endpoint}") + logger.info(f"{prefix}_HTTP_TRITON: {http_endpoint}") + logger.info(f"{prefix}_INFER_PROTOCOL: {infer_protocol}") + + return grpc_endpoint, http_endpoint, auth_token, infer_protocol + + +def get_default_cpu_count(): + default_cpu_count = os.environ.get("NV_INGEST_MAX_UTIL", int(max(1, math.floor(len(os.sched_getaffinity(0)))))) + + return default_cpu_count + + +def add_source_stage(pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port): + source_module_loader = RedisTaskSourceLoaderFactory.get_instance( + module_name="redis_listener", + module_config=ingest_config.get( + "redis_task_source", + { + "redis_client": { + "host": message_provider_host, + "port": message_provider_port, + } + }, + ), + ) + source_stage = pipe.add_stage( + LinearModuleSourceStage( + morpheus_pipeline_config, + source_module_loader, + output_type=ControlMessage, + output_port_name="output", + ) + ) + + return source_stage + + +def add_submitted_job_counter_stage(pipe, morpheus_pipeline_config, ingest_config): + submitted_job_counter_loader = JobCounterLoaderFactory.get_instance( + module_name="submitted_job_counter", + module_config=ingest_config.get( + "submitted_job_counter_module", + { + "name": "submitted_jobs", + }, + ), + ) + submitted_job_counter_stage = pipe.add_stage( + LinearModulesStage( + morpheus_pipeline_config, + submitted_job_counter_loader, + input_type=ControlMessage, + output_type=ControlMessage, + input_port_name="input", + output_port_name="output", + ) + ) + + return submitted_job_counter_stage + + +def add_metadata_injector_stage(pipe, morpheus_pipeline_config): + metadata_injector_loader = MetadataInjectorLoaderFactory.get_instance( + module_name="metadata_injection", module_config={} + ) + metadata_injector_stage = pipe.add_stage( + LinearModulesStage( + morpheus_pipeline_config, + metadata_injector_loader, + input_type=ControlMessage, + output_type=ControlMessage, + input_port_name="input", + output_port_name="output", + ) + ) + + return metadata_injector_stage + + +def add_pdf_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): + yolox_grpc, yolox_http, yolox_auth, yolox_protocol = get_table_detection_service("yolox") + pdf_content_extractor_config = ingest_config.get( + "pdf_content_extraction_module", + { + "pdfium_config": { + "yolox_endpoints": (yolox_grpc, yolox_http), + "yolox_infer_protocol": yolox_protocol, + "auth_token": yolox_auth, # All auth tokens are the same for the moment + } + }, + ) + pdf_extractor_stage = pipe.add_stage( + generate_pdf_extractor_stage( + morpheus_pipeline_config, + pdf_content_extractor_config, + pe_count=8, + task="extract", + task_desc="pdf_content_extractor", + ) + ) + + return pdf_extractor_stage + + +def add_table_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): + _, _, yolox_auth, _ = get_table_detection_service("yolox") + paddle_grpc, paddle_http, paddle_auth, paddle_protocol = get_table_detection_service("paddle") + table_content_extractor_config = ingest_config.get("table_content_extraction_module", + { + "stage_config": { + "paddle_endpoints": (paddle_grpc, paddle_http), + "paddle_infer_protocol": paddle_protocol, + "auth_token": yolox_auth, + } + }) + + table_extractor_stage = pipe.add_stage( + generate_table_extractor_stage( + morpheus_pipeline_config, + table_content_extractor_config, + pe_count=5 + ) + ) + + return table_extractor_stage + + +def add_chart_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): + _, _, yolox_auth, _ = get_table_detection_service("yolox") + + deplot_grpc, deplot_http, deplot_auth, deplot_protocol = get_table_detection_service("deplot") + cached_grpc, cached_http, cached_auth, cached_protocol = get_table_detection_service("cached") + # NOTE: Paddle isn't currently used directly by the chart extraction stage, but will be in the future. + paddle_grpc, paddle_http, paddle_auth, paddle_protocol = get_table_detection_service("paddle") + table_content_extractor_config = ingest_config.get("table_content_extraction_module", + { + "stage_config": { + "cached_endpoints": (cached_grpc, cached_http), + "cached_infer_protocol": cached_protocol, + "deplot_endpoints": (deplot_grpc, deplot_http), + "deplot_infer_protocol": deplot_protocol, + "paddle_endpoints": (paddle_grpc, paddle_http), + "paddle_infer_protocol": paddle_protocol, + "auth_token": yolox_auth, + } + }) + + table_extractor_stage = pipe.add_stage( + generate_chart_extractor_stage( + morpheus_pipeline_config, + table_content_extractor_config, + pe_count=5 + ) + ) + + return table_extractor_stage + + +def add_docx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count): + docx_extractor_stage = pipe.add_stage( + generate_docx_extractor_stage( + morpheus_pipeline_config, + pe_count=1, + task="extract", + task_desc="docx_content_extractor", + ) + ) + return docx_extractor_stage + + +def add_pptx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count): + pptx_extractor_stage = pipe.add_stage( + generate_pptx_extractor_stage( + morpheus_pipeline_config, + pe_count=1, + task="extract", + task_desc="pptx_content_extractor", + ) + ) + return pptx_extractor_stage + + +def add_image_dedup_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): + image_dedup_config = ingest_config.get("dedup_module", {}) + image_dedup_stage = pipe.add_stage( + generate_dedup_stage( + morpheus_pipeline_config, + image_dedup_config, + pe_count=2, + task="dedup", + task_desc="dedup_images", + ) + ) + return image_dedup_stage + + +def add_image_filter_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): + image_filter_config = ingest_config.get("image_filter", {}) + image_filter_stage = pipe.add_stage( + generate_image_filter_stage( + morpheus_pipeline_config, + image_filter_config, + pe_count=2, + task="filter", + task_desc="filter_images", + ) + ) + return image_filter_stage + + +def add_nemo_splitter_stage(pipe, morpheus_pipeline_config, ingest_config): + nemo_splitter_loader = NemoDocSplitterLoaderFactory.get_instance( + module_name="nemo_doc_splitter", + module_config=ingest_config.get("text_splitting_module", {}), + ) + nemo_splitter_stage = pipe.add_stage( + LinearModulesStage( + morpheus_pipeline_config, + nemo_splitter_loader, + input_type=ControlMessage, + output_type=ControlMessage, + input_port_name="input", + output_port_name="output", + ) + ) + + return nemo_splitter_stage + + +def add_image_caption_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): + endpoint_url, model_name = get_caption_classifier_service() + image_caption_config = ingest_config.get( + "image_caption_extraction_module", + { + "caption_classifier_model_name": model_name, + "endpoint_url": endpoint_url, + }, + ) + image_caption_stage = pipe.add_stage( + generate_caption_extraction_stage( + morpheus_pipeline_config, + image_caption_config, + pe_count=2, + task="caption", + task_desc="caption_ext", + ) + ) + + return image_caption_stage + + +def add_embed_extractions_stage(pipe, morpheus_pipeline_config, ingest_config): + api_key = os.getenv("NGC_API_KEY", "ngc_api_key") + embedding_nim_endpoint = os.getenv("EMBEDDING_NIM_ENDPOINT", "http://embedding:8000/v1") + + embed_extractions_loader = EmbedExtractionsLoaderFactory.get_instance( + module_name="embed_extractions", + module_config=ingest_config.get( + "embed_extractions_module", {"api_key": api_key, "embedding_nim_endpoint": embedding_nim_endpoint} + ), + ) + embed_extractions_stage = pipe.add_stage( + LinearModulesStage( + morpheus_pipeline_config, + embed_extractions_loader, + input_type=ControlMessage, + output_type=ControlMessage, + input_port_name="input", + output_port_name="output", + ) + ) + return embed_extractions_stage + + +def add_image_storage_stage(pipe, morpheus_pipeline_config): + image_storage_stage = pipe.add_stage(ImageStorageStage(morpheus_pipeline_config)) + + return image_storage_stage + + +def add_sink_stage(pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port): + sink_module_loader = RedisTaskSinkLoaderFactory.get_instance( + module_name="redis_task_sink", + module_config=ingest_config.get( + "redis_task_sink", + { + "redis_client": { + "host": message_provider_host, + "port": message_provider_port, + } + }, + ), + ) + sink_stage = pipe.add_stage( + LinearModulesStage( + morpheus_pipeline_config, + sink_module_loader, + input_type=typing.Any, + output_type=ControlMessage, + input_port_name="input", + output_port_name="output", + ) + ) + + return sink_stage + + +def add_otel_tracer_stage(pipe, morpheus_pipeline_config, ingest_config): + endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + + otel_tracer_loader = OpenTelemetryTracerLoaderFactory.get_instance( + module_name="otel_tracer", + module_config=ingest_config.get( + "otel_tracer_module", + { + "otel_endpoint": endpoint, + }, + ), + ) + otel_tracer_stage = pipe.add_stage( + LinearModulesStage( + morpheus_pipeline_config, + otel_tracer_loader, + input_type=ControlMessage, + output_type=ControlMessage, + input_port_name="input", + output_port_name="output", + ) + ) + return otel_tracer_stage + + +def add_otel_meter_stage(pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port): + endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + + otel_meter_loader = OpenTelemetryMeterLoaderFactory.get_instance( + module_name="otel_meter", + module_config=ingest_config.get( + "otel_meter_module", + { + "redis_client": { + "host": message_provider_host, + "port": message_provider_port, + }, + "otel_endpoint": endpoint, + }, + ), + ) + otel_meter_stage = pipe.add_stage( + LinearModulesStage( + morpheus_pipeline_config, + otel_meter_loader, + input_type=ControlMessage, + output_type=ControlMessage, + input_port_name="input", + output_port_name="output", + ) + ) + + return otel_meter_stage + + +def add_completed_job_counter_stage(pipe, morpheus_pipeline_config, ingest_config): + completed_job_counter_loader = JobCounterLoaderFactory.get_instance( + module_name="completed_job_counter", + module_config=ingest_config.get( + "completed_job_counter_module", + { + "name": "completed_jobs", + }, + ), + ) + completed_job_counter_stage = pipe.add_stage( + LinearModulesStage( + morpheus_pipeline_config, + completed_job_counter_loader, + input_type=ControlMessage, + output_type=ControlMessage, + input_port_name="input", + output_port_name="output", + ) + ) + return completed_job_counter_stage + + +def add_vdb_task_sink_stage(pipe, morpheus_pipeline_config, ingest_config): + milvus_endpoint = os.getenv("MILVUS_ENDPOINT", "http://milvus:19530") + + vdb_task_sink_loader = VDBTaskSinkLoaderFactory.get_instance( + module_name="vdb_task_sink", + module_config=ingest_config.get( + "vdb_task_sink_module", + { + "service_kwargs": { + "uri": milvus_endpoint, + } + }, + ), + ) + vdb_task_sink_stage = pipe.add_stage( + LinearModulesStage( + morpheus_pipeline_config, + vdb_task_sink_loader, + input_type=ControlMessage, + output_type=ControlMessage, + input_port_name="input", + output_port_name="output", + ) + ) + return vdb_task_sink_stage diff --git a/src/pipeline.py b/src/pipeline.py index af27e203..4d8289d7 100644 --- a/src/pipeline.py +++ b/src/pipeline.py @@ -4,44 +4,22 @@ import json -import logging -import math -import os -import typing from datetime import datetime import click from morpheus.config import Config from morpheus.config import CppConfig from morpheus.config import PipelineModes -from morpheus.messages import ControlMessage from morpheus.pipeline.pipeline import Pipeline -from morpheus.stages.general.linear_modules_source import LinearModuleSourceStage -from morpheus.stages.general.linear_modules_stage import LinearModulesStage from morpheus.utils.logger import configure_logging from pydantic import ValidationError -from nv_ingest.modules.injectors.metadata_injector import MetadataInjectorLoaderFactory -from nv_ingest.modules.sinks.redis_task_sink import RedisTaskSinkLoaderFactory -from nv_ingest.modules.sinks.vdb_task_sink import VDBTaskSinkLoaderFactory -from nv_ingest.modules.sources.redis_task_source import RedisTaskSourceLoaderFactory -from nv_ingest.modules.telemetry.job_counter import JobCounterLoaderFactory -from nv_ingest.modules.telemetry.otel_meter import OpenTelemetryMeterLoaderFactory -from nv_ingest.modules.telemetry.otel_tracer import OpenTelemetryTracerLoaderFactory -from nv_ingest.modules.transforms.embed_extractions import EmbedExtractionsLoaderFactory -from nv_ingest.modules.transforms.nemo_doc_splitter import NemoDocSplitterLoaderFactory from nv_ingest.schemas.ingest_pipeline_config_schema import IngestPipelineConfigSchema -from nv_ingest.stages.docx_extractor_stage import generate_docx_extractor_stage -from nv_ingest.stages.filters import generate_dedup_stage -from nv_ingest.stages.filters import generate_image_filter_stage -from nv_ingest.stages.pdf_extractor_stage import generate_pdf_extractor_stage -from nv_ingest.stages.pptx_extractor_stage import generate_pptx_extractor_stage -from nv_ingest.stages.storages.image_storage_stage import ImageStorageStage -from nv_ingest.stages.transforms.image_caption_extraction import generate_caption_extraction_stage from nv_ingest.util.converters.containers import merge_dict from nv_ingest.util.logging.configuration import LogLevel from nv_ingest.util.logging.configuration import configure_logging as configure_local_logging from nv_ingest.util.schema.schema_validator import validate_schema +from nv_ingest.util.pipeline.stage_builders import * logger = logging.getLogger(__name__) local_log_level = os.getenv("INGEST_LOG_LEVEL", "INFO") @@ -50,467 +28,66 @@ configure_local_logging(logger, local_log_level) -def validate_positive(ctx, param, value): - if value <= 0: - raise click.BadParameter("must be a positive integer") - return value - - -def get_message_provider_config(): - message_provider_host = os.environ.get("MESSAGE_CLIENT_HOST", "localhost") - message_provider_port = os.environ.get("MESSAGE_CLIENT_PORT", "6379") - - logger.info(f"MESSAGE_CLIENT_HOST: {message_provider_host}") - logger.info(f"MESSAGE_CLIENT_PORT: {message_provider_port}") - - return message_provider_host, message_provider_port - - -def get_caption_classifier_service(): - triton_service_caption_classifier = os.environ.get( - "CAPTION_CLASSIFIER_GRPC_TRITON", - "", - ) - triton_service_caption_classifier_name = os.environ.get( - "CAPTION_CLASSIFIER_MODEL_NAME", - "", - ) - - logger.info(f"CAPTION_CLASSIFIER_GRPC_TRITON: {triton_service_caption_classifier}") - - return triton_service_caption_classifier, triton_service_caption_classifier_name - - -def get_table_detection_service(env_var_prefix): - prefix = env_var_prefix.upper() - grpc_endpoint = os.environ.get( - f"{prefix}_GRPC_ENDPOINT", - "", - ) - http_endpoint = os.environ.get( - f"{prefix}_HTTP_ENDPOINT", - "", - ) - auth_token = os.environ.get( - "NVIDIA_BUILD_API_KEY", - "", - ) or os.environ.get( - "NGC_API_KEY", - "", - ) - infer_protocol = os.environ.get( - f"{prefix}_INFER_PROTOCOL", - "http" if http_endpoint else "grpc" if grpc_endpoint else "", - ) - - logger.info(f"{prefix}_GRPC_TRITON: {grpc_endpoint}") - logger.info(f"{prefix}_HTTP_TRITON: {http_endpoint}") - logger.info(f"{prefix}_INFER_PROTOCOL: {infer_protocol}") - - return grpc_endpoint, http_endpoint, auth_token, infer_protocol - - -def get_default_cpu_count(): - default_cpu_count = os.environ.get("NV_INGEST_MAX_UTIL", int(max(1, math.floor(len(os.sched_getaffinity(0)))))) - - return default_cpu_count - - -def add_source_stage(pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port): - source_module_loader = RedisTaskSourceLoaderFactory.get_instance( - module_name="redis_listener", - module_config=ingest_config.get( - "redis_task_source", - { - "redis_client": { - "host": message_provider_host, - "port": message_provider_port, - } - }, - ), - ) - source_stage = pipe.add_stage( - LinearModuleSourceStage( - morpheus_pipeline_config, - source_module_loader, - output_type=ControlMessage, - output_port_name="output", - ) - ) - - return source_stage - - -def add_submitted_job_counter_stage(pipe, morpheus_pipeline_config, ingest_config): - submitted_job_counter_loader = JobCounterLoaderFactory.get_instance( - module_name="submitted_job_counter", - module_config=ingest_config.get( - "submitted_job_counter_module", - { - "name": "submitted_jobs", - }, - ), - ) - submitted_job_counter_stage = pipe.add_stage( - LinearModulesStage( - morpheus_pipeline_config, - submitted_job_counter_loader, - input_type=ControlMessage, - output_type=ControlMessage, - input_port_name="input", - output_port_name="output", - ) - ) - - return submitted_job_counter_stage - - -def add_metadata_injector_stage(pipe, morpheus_pipeline_config): - metadata_injector_loader = MetadataInjectorLoaderFactory.get_instance( - module_name="metadata_injection", module_config={} - ) - metadata_injector_stage = pipe.add_stage( - LinearModulesStage( - morpheus_pipeline_config, - metadata_injector_loader, - input_type=ControlMessage, - output_type=ControlMessage, - input_port_name="input", - output_port_name="output", - ) - ) - - return metadata_injector_stage - - -def add_pdf_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): - yolox_grpc, yolox_http, yolox_auth, yolox_protocol = get_table_detection_service("yolox") - paddle_grpc, paddle_http, paddle_auth, paddle_protocol = get_table_detection_service("paddle") - deplot_grpc, deplot_http, deplot_auth, deplot_protocol = get_table_detection_service("deplot") - cached_grpc, cached_http, cached_auth, cached_protocol = get_table_detection_service("cached") - pdf_content_extractor_config = ingest_config.get( - "pdf_content_extraction_module", - { - "pdfium_config": { - "cached_endpoints": (cached_grpc, cached_http), - "deplot_endpoints": (deplot_grpc, deplot_http), - "paddle_endpoints": (paddle_grpc, paddle_http), - "yolox_endpoints": (yolox_grpc, yolox_http), - "cached_infer_protocol": cached_protocol, - "deplot_infer_protocol": deplot_protocol, - "paddle_infer_protocol": paddle_protocol, - "yolox_infer_protocol": yolox_protocol, - "auth_token": yolox_auth, # All auth tokens are the same for the moment - } - }, - ) - pdf_extractor_stage = pipe.add_stage( - generate_pdf_extractor_stage( - morpheus_pipeline_config, - pdf_content_extractor_config, - pe_count=8, - task="extract", - task_desc="pdf_content_extractor", - ) - ) - - return pdf_extractor_stage - - -def add_docx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count): - docx_extractor_stage = pipe.add_stage( - generate_docx_extractor_stage( - morpheus_pipeline_config, - pe_count=1, - task="extract", - task_desc="docx_content_extractor", - ) - ) - return docx_extractor_stage - - -def add_pptx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count): - pptx_extractor_stage = pipe.add_stage( - generate_pptx_extractor_stage( - morpheus_pipeline_config, - pe_count=1, - task="extract", - task_desc="pptx_content_extractor", - ) - ) - return pptx_extractor_stage - - -def add_image_dedup_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): - image_dedup_config = ingest_config.get("dedup_module", {}) - image_dedup_stage = pipe.add_stage( - generate_dedup_stage( - morpheus_pipeline_config, - image_dedup_config, - pe_count=2, - task="dedup", - task_desc="dedup_images", - ) - ) - return image_dedup_stage - - -def add_image_filter_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): - image_filter_config = ingest_config.get("image_filter", {}) - image_filter_stage = pipe.add_stage( - generate_image_filter_stage( - morpheus_pipeline_config, - image_filter_config, - pe_count=2, - task="filter", - task_desc="filter_images", - ) - ) - return image_filter_stage - - -def add_nemo_splitter_stage(pipe, morpheus_pipeline_config, ingest_config): - nemo_splitter_loader = NemoDocSplitterLoaderFactory.get_instance( - module_name="nemo_doc_splitter", - module_config=ingest_config.get("text_splitting_module", {}), - ) - nemo_splitter_stage = pipe.add_stage( - LinearModulesStage( - morpheus_pipeline_config, - nemo_splitter_loader, - input_type=ControlMessage, - output_type=ControlMessage, - input_port_name="input", - output_port_name="output", - ) - ) - - return nemo_splitter_stage - - -def add_image_caption_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): - endpoint_url, model_name = get_caption_classifier_service() - image_caption_config = ingest_config.get( - "image_caption_extraction_module", - { - "caption_classifier_model_name": model_name, - "endpoint_url": endpoint_url, - }, - ) - image_caption_stage = pipe.add_stage( - generate_caption_extraction_stage( - morpheus_pipeline_config, - image_caption_config, - pe_count=2, - task="caption", - task_desc="caption_ext", - ) - ) - - return image_caption_stage - - -def add_embed_extractions_stage(pipe, morpheus_pipeline_config, ingest_config): - api_key = os.getenv("NGC_API_KEY", "ngc_api_key") - embedding_nim_endpoint = os.getenv("EMBEDDING_NIM_ENDPOINT", "http://embedding:8000/v1") - - embed_extractions_loader = EmbedExtractionsLoaderFactory.get_instance( - module_name="embed_extractions", - module_config=ingest_config.get( - "embed_extractions_module", {"api_key": api_key, "embedding_nim_endpoint": embedding_nim_endpoint} - ), - ) - embed_extractions_stage = pipe.add_stage( - LinearModulesStage( - morpheus_pipeline_config, - embed_extractions_loader, - input_type=ControlMessage, - output_type=ControlMessage, - input_port_name="input", - output_port_name="output", - ) - ) - return embed_extractions_stage - - -def add_image_storage_stage(pipe, morpheus_pipeline_config): - image_storage_stage = pipe.add_stage(ImageStorageStage(morpheus_pipeline_config)) - - return image_storage_stage - - -def add_sink_stage(pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port): - sink_module_loader = RedisTaskSinkLoaderFactory.get_instance( - module_name="redis_task_sink", - module_config=ingest_config.get( - "redis_task_sink", - { - "redis_client": { - "host": message_provider_host, - "port": message_provider_port, - } - }, - ), - ) - sink_stage = pipe.add_stage( - LinearModulesStage( - morpheus_pipeline_config, - sink_module_loader, - input_type=typing.Any, - output_type=ControlMessage, - input_port_name="input", - output_port_name="output", - ) - ) - - return sink_stage - - -def add_otel_tracer_stage(pipe, morpheus_pipeline_config, ingest_config): - endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") - - otel_tracer_loader = OpenTelemetryTracerLoaderFactory.get_instance( - module_name="otel_tracer", - module_config=ingest_config.get( - "otel_tracer_module", - { - "otel_endpoint": endpoint, - }, - ), - ) - otel_tracer_stage = pipe.add_stage( - LinearModulesStage( - morpheus_pipeline_config, - otel_tracer_loader, - input_type=ControlMessage, - output_type=ControlMessage, - input_port_name="input", - output_port_name="output", - ) - ) - return otel_tracer_stage - - -def add_otel_meter_stage(pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port): - endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") - - otel_meter_loader = OpenTelemetryMeterLoaderFactory.get_instance( - module_name="otel_meter", - module_config=ingest_config.get( - "otel_meter_module", - { - "redis_client": { - "host": message_provider_host, - "port": message_provider_port, - }, - "otel_endpoint": endpoint, - }, - ), - ) - otel_meter_stage = pipe.add_stage( - LinearModulesStage( - morpheus_pipeline_config, - otel_meter_loader, - input_type=ControlMessage, - output_type=ControlMessage, - input_port_name="input", - output_port_name="output", - ) - ) - - return otel_meter_stage - - -def add_completed_job_counter_stage(pipe, morpheus_pipeline_config, ingest_config): - completed_job_counter_loader = JobCounterLoaderFactory.get_instance( - module_name="completed_job_counter", - module_config=ingest_config.get( - "completed_job_counter_module", - { - "name": "completed_jobs", - }, - ), - ) - completed_job_counter_stage = pipe.add_stage( - LinearModulesStage( - morpheus_pipeline_config, - completed_job_counter_loader, - input_type=ControlMessage, - output_type=ControlMessage, - input_port_name="input", - output_port_name="output", - ) - ) - return completed_job_counter_stage - - -def add_vdb_task_sink_stage(pipe, morpheus_pipeline_config, ingest_config): - milvus_endpoint = os.getenv("MILVUS_ENDPOINT", "http://milvus:19530") - - vdb_task_sink_loader = VDBTaskSinkLoaderFactory.get_instance( - module_name="vdb_task_sink", - module_config=ingest_config.get( - "vdb_task_sink_module", - { - "service_kwargs": { - "uri": milvus_endpoint, - } - }, - ), - ) - vdb_task_sink_stage = pipe.add_stage( - LinearModulesStage( - morpheus_pipeline_config, - vdb_task_sink_loader, - input_type=ControlMessage, - output_type=ControlMessage, - input_port_name="input", - output_port_name="output", - ) - ) - return vdb_task_sink_stage - - def setup_ingestion_pipeline( - pipe: Pipeline, morpheus_pipeline_config: Config, ingest_config: typing.Dict[str, typing.Any] + pipe: Pipeline, morpheus_pipeline_config: Config, ingest_config: typing.Dict[str, typing.Any] ): message_provider_host, message_provider_port = get_message_provider_config() default_cpu_count = get_default_cpu_count() - # Pre-processing stages + ######################################################################################################## + ## Insertion and Pre-processing stages + ######################################################################################################## source_stage = add_source_stage( pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port ) submitted_job_counter_stage = add_submitted_job_counter_stage(pipe, morpheus_pipeline_config, ingest_config) metadata_injector_stage = add_metadata_injector_stage(pipe, morpheus_pipeline_config) + ######################################################################################################## - # Primitive extraction + ######################################################################################################## + ## Primitive extraction + ######################################################################################################## pdf_extractor_stage = add_pdf_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) docx_extractor_stage = add_docx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count) pptx_extractor_stage = add_pptx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count) + ######################################################################################################## - # Post-processing + ######################################################################################################## + ## Post-processing + ######################################################################################################## image_dedup_stage = add_image_dedup_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) image_filter_stage = add_image_filter_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) + table_extraction_stage = add_table_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) + chart_extraction_stage = add_chart_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) + ######################################################################################################## - # Transforms and data synthesis + ######################################################################################################## + ## Transforms and data synthesis + ######################################################################################################## nemo_splitter_stage = add_nemo_splitter_stage(pipe, morpheus_pipeline_config, ingest_config) embed_extractions_stage = add_embed_extractions_stage(pipe, morpheus_pipeline_config, ingest_config) + ######################################################################################################## - # Storage and output + ######################################################################################################## + ## Storage and output + ######################################################################################################## image_storage_stage = add_image_storage_stage(pipe, morpheus_pipeline_config) sink_stage = add_sink_stage( pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port ) vdb_task_sink_stage = add_vdb_task_sink_stage(pipe, morpheus_pipeline_config, ingest_config) + ######################################################################################################## - # Telemetry (Note: everything after the sync stage is out of the hot path, please keep it that way) + ####################################################################################################### + ## Telemetry (Note: everything after the sync stage is out of the hot path, please keep it that way) ## + ####################################################################################################### otel_tracer_stage = add_otel_tracer_stage(pipe, morpheus_pipeline_config, ingest_config) otel_meter_stage = add_otel_meter_stage( pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port ) completed_job_counter_stage = add_completed_job_counter_stage(pipe, morpheus_pipeline_config, ingest_config) + ######################################################################################################## # Add edges pipe.add_edge(source_stage, submitted_job_counter_stage) @@ -520,7 +97,9 @@ def setup_ingestion_pipeline( pipe.add_edge(docx_extractor_stage, pptx_extractor_stage) pipe.add_edge(pptx_extractor_stage, image_dedup_stage) pipe.add_edge(image_dedup_stage, image_filter_stage) - pipe.add_edge(image_filter_stage, nemo_splitter_stage) + pipe.add_edge(image_filter_stage, table_extraction_stage) + pipe.add_edge(table_extraction_stage, chart_extraction_stage) + pipe.add_edge(chart_extraction_stage, nemo_splitter_stage) pipe.add_edge(nemo_splitter_stage, embed_extractions_stage) pipe.add_edge(embed_extractions_stage, image_storage_stage) pipe.add_edge(image_storage_stage, vdb_task_sink_stage) @@ -557,7 +136,8 @@ def pipeline(morpheus_pipeline_config, ingest_config) -> float: @click.command() @click.option( - "--ingest_config_path", type=str, envvar="NV_INGEST_CONFIG_PATH", help="Path to the JSON configuration file." + "--ingest_config_path", type=str, envvar="NV_INGEST_CONFIG_PATH", help="Path to the JSON configuration file.", + hidden=True ) @click.option("--use_cpp", is_flag=True, help="Use C++ backend.") @click.option("--pipeline_batch_size", default=256, type=int, help="Batch size for the pipeline.") @@ -586,16 +166,16 @@ def pipeline(morpheus_pipeline_config, ingest_config) -> float: help="Log level.", ) def cli( - ingest_config_path, - caption_batch_size, - use_cpp, - pipeline_batch_size, - enable_monitor, - feature_length, - num_threads, - model_max_batch_size, - mode, - log_level, + ingest_config_path, + caption_batch_size, + use_cpp, + pipeline_batch_size, + enable_monitor, + feature_length, + num_threads, + model_max_batch_size, + mode, + log_level, ): """ Command line interface for configuring and running the pipeline with specified options. @@ -617,7 +197,7 @@ def cli( log_level = "INFO" log_level = log_level_mapping.get(log_level.upper(), logging.INFO) - logging.basicConfig(level=log_level) + logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s") configure_logging(log_level=log_level) CppConfig.set_should_use_cpp(use_cpp) diff --git a/test-requirements.txt b/test-requirements.txt index 8c82bf4b..fe60055f 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,3 +1,4 @@ +autoflake==2.3.1 black==23.11.0 flake8==7.0.0 isort==5.13.2 @@ -5,6 +6,5 @@ pre-commit==3.5.0 pytest==7.4.3 pytest-cov==4.1.0 pytest-mock -yapf==0.40.2 pytest-mock==3.14.0 -autoflake==2.3.1 +yapf==0.40.2 diff --git a/tests/functional/test_ingest_pipeline.py b/tests/functional/test_ingest_pipeline.py index da486206..82c6fb2a 100644 --- a/tests/functional/test_ingest_pipeline.py +++ b/tests/functional/test_ingest_pipeline.py @@ -8,7 +8,6 @@ import pytest from nv_ingest_client.client import NvIngestClient -from nv_ingest_client.message_clients.redis.redis_client import RedisClient # type: ignore from nv_ingest_client.primitives import JobSpec from nv_ingest_client.primitives.tasks import EmbedTask from nv_ingest_client.primitives.tasks import ExtractTask @@ -47,85 +46,3 @@ def remove_keys(data, keys_to_remove): return [remove_keys(item, keys_to_remove) for item in data] else: return data - - -@pytest.mark.skip(reason="Test environment is not running nv-ingest and redis services.") -def test_ingest_pipeline(): - client = NvIngestClient( - message_client_allocator=RedisClient, - message_client_hostname=_DEFAULT_REDIS_HOST, - message_client_port=_DEFAULT_REDIS_PORT, - message_client_kwargs=None, - msg_counter_id="nv-ingest-message-id", - worker_pool_size=1, - ) - - file_content, file_type = extract_file_content(_VALIDATION_PDF) - - job_spec = JobSpec( - document_type=file_type, - payload=file_content, - source_id=_VALIDATION_PDF, - source_name=_VALIDATION_PDF, - extended_options={ - "tracing_options": { - "trace": True, - "ts_send": time.time_ns(), - } - }, - ) - - extract_task = ExtractTask( - document_type=file_type, - extract_text=True, - extract_images=True, - extract_tables=True, - text_depth=_DEFAULT_EXTRACT_PAGE_DEPTH, - extract_tables_method=_DEFAULT_EXTRACT_TABLES_METHOD, - ) - - split_task = SplitTask( - split_by=_DEFAULT_SPLIT_BY, - split_length=_DEFAULT_SPLIT_LENGTH, - split_overlap=_DEFAULT_SPLIT_OVERLAP, - max_character_length=_DEFAULT_SPLIT_MAX_CHARACTER_LENGTH, - sentence_window_size=_DEFAULT_SPLIT_SENTENCE_WINDOW_SIZE, - ) - - embed_task = EmbedTask( - text=True, - tables=True, - ) - - job_spec.add_task(extract_task) - job_spec.add_task(split_task) - job_spec.add_task(embed_task) - job_id = client.add_job(job_spec) - - client.submit_job(job_id, _DEFAULT_TASK_QUEUE) - generated_metadata = client.fetch_job_result(job_id, timeout=_DEFAULT_JOB_TIMEOUT)[0][0] - - with open(_VALIDATION_JSON, "r") as f: - expected_metadata = json.load(f)[0][0] - - keys_to_remove = ["date_created", "last_modified", "table_content"] - generated_metadata_cleaned = remove_keys(generated_metadata, keys_to_remove) - expected_metadata_cleaned = remove_keys(expected_metadata, keys_to_remove) - - for extraction_idx in range(len(generated_metadata_cleaned)): - content_type = generated_metadata_cleaned[extraction_idx]["metadata"]["content_metadata"]["type"] - - if content_type == "text": - assert generated_metadata_cleaned[extraction_idx] == expected_metadata_cleaned[extraction_idx] - - elif content_type == "image": - assert generated_metadata_cleaned[extraction_idx] == expected_metadata_cleaned[extraction_idx] - - elif content_type == "structured": - generated_embedding = generated_metadata_cleaned[extraction_idx]["metadata"]["embedding"] - expected_embedding = expected_metadata_cleaned[extraction_idx]["metadata"]["embedding"] - assert cosine_similarity([generated_embedding], [expected_embedding])[0] > 0.98 - - cleaned_generated_table_metadata = remove_keys(generated_metadata_cleaned, ["embedding", "table_content"]) - cleaned_expected_table_metadata = remove_keys(expected_metadata_cleaned, ["embedding", "table_content"]) - assert cleaned_generated_table_metadata == cleaned_expected_table_metadata diff --git a/tests/nv_ingest/modules/filters/test_image_dedup.py b/tests/nv_ingest/modules/filters/test_image_dedup.py index 125f4ff0..52a245cc 100644 --- a/tests/nv_ingest/modules/filters/test_image_dedup.py +++ b/tests/nv_ingest/modules/filters/test_image_dedup.py @@ -90,7 +90,7 @@ def test_apply_dedup(should_filter, expected0, expected1, expected2): payload_list = [] for _ in range(3): - payload_list.append(valid_image_dedup_payload("test", 1, 1)) + payload_list.append(valid_image_dedup_payload(f"test", 1, 1)) extracted_df = pd.DataFrame(payload_list, columns=["document_type", "metadata"]) extracted_gdf = cudf.from_pandas(extracted_df) diff --git a/tests/nv_ingest/modules/sources/test_redis_task_source.py b/tests/nv_ingest/modules/sources/test_redis_task_source.py index 2ccb84dd..5c67bd86 100644 --- a/tests/nv_ingest/modules/sources/test_redis_task_source.py +++ b/tests/nv_ingest/modules/sources/test_redis_task_source.py @@ -80,18 +80,19 @@ def test_process_message(job_payload, add_trace_tagging, trace_id, ts_send, ts_f payload = json.loads(job_payload) # Update tracing options based on parameters + job_id = "abc12345678910213123" + payload["job_id"] = job_id payload["tracing_options"] = {"trace": add_trace_tagging, "ts_send": int(ts_send.timestamp() * 1e9)} if trace_id is not None: payload["tracing_options"]["trace_id"] = trace_id - modified_payload = json.dumps(payload) - result = process_message(modified_payload, ts_fetched) + + result = process_message(payload, ts_fetched) # Basic type check for the returned object assert isinstance(result, ControlMessage) # Check for correct handling of tracing options - assert result.get_metadata("response_channel") == f"response_{payload['job_id']}" - assert result.get_metadata("job_id") == payload["job_id"] + assert result.get_metadata("response_channel") == f"response_{job_id}" if add_trace_tagging: assert result.get_metadata("config::add_trace_tagging") is True assert result.get_timestamp(f"trace::entry::{MODULE_NAME}") is not None diff --git a/tests/nv_ingest/schemas/test_chart_extractor_schema.py b/tests/nv_ingest/schemas/test_chart_extractor_schema.py new file mode 100644 index 00000000..5deb16b5 --- /dev/null +++ b/tests/nv_ingest/schemas/test_chart_extractor_schema.py @@ -0,0 +1,112 @@ +import pytest +from pydantic import ValidationError +from nv_ingest.schemas.chart_extractor_schema import ChartExtractorConfigSchema, ChartExtractorSchema # Adjust the import as per your file structure + +# Test cases for ChartExtractorConfigSchema +def test_valid_config_with_grpc_only(): + config = ChartExtractorConfigSchema( + auth_token="valid_token", + cached_endpoints=("grpc://cached_service", None), + deplot_endpoints=("grpc://deplot_service", None), + paddle_endpoints=("grpc://paddle_service", None) + ) + assert config.auth_token == "valid_token" + assert config.cached_endpoints == ("grpc://cached_service", None) + assert config.deplot_endpoints == ("grpc://deplot_service", None) + assert config.paddle_endpoints == ("grpc://paddle_service", None) + +def test_valid_config_with_http_only(): + config = ChartExtractorConfigSchema( + auth_token="valid_token", + cached_endpoints=(None, "http://cached_service"), + deplot_endpoints=(None, "http://deplot_service"), + paddle_endpoints=(None, "http://paddle_service") + ) + assert config.auth_token == "valid_token" + assert config.cached_endpoints == (None, "http://cached_service") + assert config.deplot_endpoints == (None, "http://deplot_service") + assert config.paddle_endpoints == (None, "http://paddle_service") + +def test_invalid_config_with_empty_services(): + with pytest.raises(ValidationError) as excinfo: + ChartExtractorConfigSchema( + cached_endpoints=(None, None), + deplot_endpoints=(None, None), + paddle_endpoints=(None, None) + ) + assert "Both gRPC and HTTP services cannot be empty" in str(excinfo.value) + +def test_valid_config_with_both_grpc_and_http(): + config = ChartExtractorConfigSchema( + auth_token="another_token", + cached_endpoints=("grpc://cached_service", "http://cached_service"), + deplot_endpoints=("grpc://deplot_service", "http://deplot_service"), + paddle_endpoints=("grpc://paddle_service", "http://paddle_service") + ) + assert config.auth_token == "another_token" + assert config.cached_endpoints == ("grpc://cached_service", "http://cached_service") + assert config.deplot_endpoints == ("grpc://deplot_service", "http://deplot_service") + assert config.paddle_endpoints == ("grpc://paddle_service", "http://paddle_service") + +def test_invalid_auth_token_none(): + config = ChartExtractorConfigSchema( + cached_endpoints=("grpc://cached_service", None), + deplot_endpoints=("grpc://deplot_service", None), + paddle_endpoints=("grpc://paddle_service", None) + ) + assert config.auth_token is None + +def test_invalid_endpoint_format(): + with pytest.raises(ValidationError): + ChartExtractorConfigSchema( + cached_endpoints=("invalid_endpoint", None), + deplot_endpoints=(None, "invalid_endpoint") + ) + +# Test cases for ChartExtractorSchema +def test_chart_extractor_schema_defaults(): + config = ChartExtractorSchema() + assert config.max_queue_size == 1 + assert config.n_workers == 2 + assert config.raise_on_failure is False + assert config.stage_config is None + +def test_chart_extractor_schema_with_custom_values(): + stage_config = ChartExtractorConfigSchema( + cached_endpoints=("grpc://cached_service", "http://cached_service"), + deplot_endpoints=("grpc://deplot_service", None), + paddle_endpoints=(None, "http://paddle_service") + ) + config = ChartExtractorSchema( + max_queue_size=10, + n_workers=5, + raise_on_failure=True, + stage_config=stage_config + ) + assert config.max_queue_size == 10 + assert config.n_workers == 5 + assert config.raise_on_failure is True + assert config.stage_config == stage_config + +def test_chart_extractor_schema_without_stage_config(): + config = ChartExtractorSchema( + max_queue_size=3, + n_workers=1, + raise_on_failure=False + ) + assert config.max_queue_size == 3 + assert config.n_workers == 1 + assert config.raise_on_failure is False + assert config.stage_config is None + +def test_invalid_chart_extractor_schema_negative_queue_size(): + with pytest.raises(ValidationError): + ChartExtractorSchema( + max_queue_size=-1 + ) + +def test_invalid_chart_extractor_schema_zero_workers(): + with pytest.raises(ValidationError): + ChartExtractorSchema( + n_workers=0 + ) \ No newline at end of file diff --git a/tests/nv_ingest/schemas/test_ingest_job_schema.py b/tests/nv_ingest/schemas/test_ingest_job_schema.py index bb570393..4f42867d 100644 --- a/tests/nv_ingest/schemas/test_ingest_job_schema.py +++ b/tests/nv_ingest/schemas/test_ingest_job_schema.py @@ -202,6 +202,18 @@ def test_multiple_task_types(): "params": {"filter": True}, }, }, + { + "type": "table_data_extract", + "task_properties":{ + "params": {}, + } + }, + { + "type": "chart_data_extract", + "task_properties":{ + "params": {}, + } + } ], } diff --git a/tests/nv_ingest/schemas/test_metadata_schema.py b/tests/nv_ingest/schemas/test_metadata_schema.py new file mode 100644 index 00000000..2f1e76f5 --- /dev/null +++ b/tests/nv_ingest/schemas/test_metadata_schema.py @@ -0,0 +1,188 @@ +import pytest +from pydantic import ValidationError +from datetime import datetime +from nv_ingest.schemas.metadata_schema import ( # Adjust the import as per your file structure + SourceMetadataSchema, + NearbyObjectsSchema, + ContentHierarchySchema, + ContentMetadataSchema, + TextMetadataSchema, + ImageMetadataSchema, + TableMetadataSchema, + ChartMetadataSchema, + ErrorMetadataSchema, + InfoMessageMetadataSchema, + TableFormatEnum, +) + + +# Test cases for SourceMetadataSchema +def test_source_metadata_schema_defaults(): + config = SourceMetadataSchema( + source_name="Test Source", + source_id="1234", + source_type="TestType" + ) + assert config.source_location == "" + assert config.collection_id == "" + assert config.partition_id == -1 + assert config.access_level == -1 + + +def test_source_metadata_schema_invalid_date(): + with pytest.raises(ValidationError): + SourceMetadataSchema( + source_name="Test Source", + source_id="1234", + source_type="TestType", + date_created="invalid_date" + ) + + +# Test cases for NearbyObjectsSchema +def test_nearby_objects_schema_defaults(): + config = NearbyObjectsSchema() + assert config.text.content == [] + assert config.images.content == [] + assert config.structured.content == [] + + +# Test cases for ContentHierarchySchema +def test_content_hierarchy_schema_defaults(): + config = ContentHierarchySchema() + assert config.page_count == -1 + assert config.page == -1 + assert config.block == -1 + assert config.line == -1 + assert config.span == -1 + + +def test_content_hierarchy_schema_with_nearby_objects(): + config = ContentHierarchySchema( + nearby_objects=NearbyObjectsSchema( + text={"content": ["sample text"]}, + images={"content": ["sample image"]} + ) + ) + assert config.nearby_objects.text.content == ["sample text"] + assert config.nearby_objects.images.content == ["sample image"] + + +# Test cases for ContentMetadataSchema +def test_content_metadata_schema_defaults(): + config = ContentMetadataSchema(type="text") + print(config) + assert config.description == "" + assert config.page_number == -1 + + +def test_content_metadata_schema_invalid_type(): + with pytest.raises(ValidationError): + ContentMetadataSchema(type="InvalidType") + + +# Test cases for TextMetadataSchema +def test_text_metadata_schema_defaults(): + config = TextMetadataSchema(text_type="document") + assert config.summary == "" + assert config.keywords == "" + assert config.language == "en" + assert config.text_location == (0, 0, 0, 0) + + +def test_text_metadata_schema_with_keywords(): + config = TextMetadataSchema(text_type="body", keywords=["keyword1", "keyword2"]) + assert config.keywords == ["keyword1", "keyword2"] + + +# Test cases for ImageMetadataSchema +def test_image_metadata_schema_defaults(): + config = ImageMetadataSchema(image_type="image") + assert config.caption == "" + assert config.width == 0 + assert config.height == 0 + + +def test_image_metadata_schema_invalid_type(): + with pytest.raises(ValidationError): + ImageMetadataSchema(image_type=3.14) # Using a float value + +def test_image_metadata_schema_invalid_type(): + with pytest.raises(ValidationError): + ImageMetadataSchema(image_type=3.14) + + +# Test cases for TableMetadataSchema +@pytest.mark.parametrize("table_format", ["html", "markdown", "latex", "image"]) +def test_table_metadata_schema_defaults(table_format): + config = TableMetadataSchema(table_format=table_format) + assert config.caption == "" + assert config.table_content == "" + + +def test_table_metadata_schema_with_location(): + config = TableMetadataSchema( + table_format="latex", + table_location=(1, 2, 3, 4) + ) + assert config.table_location == (1, 2, 3, 4) + + +@pytest.mark.parametrize("schema_class", [TableMetadataSchema, ChartMetadataSchema]) +@pytest.mark.parametrize("table_format", + [TableFormatEnum.HTML, TableFormatEnum.MARKDOWN, TableFormatEnum.LATEX, TableFormatEnum.IMAGE]) +def test_schema_valid_table_format(schema_class, table_format): + config = schema_class(table_format=table_format) + assert config.caption == "" + assert config.table_content == "" + + +def test_table_metadata_schema_invalid_table_format(): + with pytest.raises(ValidationError): + TableMetadataSchema(table_format="invalid_format") + + +# Test cases for ChartMetadataSchema +def test_chart_metadata_schema_defaults(): + config = ChartMetadataSchema(table_format="html") + assert config.caption == "" + assert config.table_content == "" + + +# Test cases for ErrorMetadataSchema +def test_error_metadata_schema_defaults(): + config = ErrorMetadataSchema( + task="embed", + status="error", + error_msg="An error occurred." + ) + assert config.source_id == "" + + +def test_error_metadata_schema_invalid_status(): + with pytest.raises(ValidationError): + ErrorMetadataSchema( + task="TaskType1", + status="InvalidStatus", + error_msg="An error occurred." + ) + + +# Test cases for InfoMessageMetadataSchema +def test_info_message_metadata_schema_defaults(): + config = InfoMessageMetadataSchema( + task="transform", + status="success", + message="This is an info message.", + filter=False + ) + assert config.filter is False + + +def test_info_message_metadata_schema_invalid_task(): + with pytest.raises(ValidationError): + InfoMessageMetadataSchema( + task="InvalidTaskType", + status="InfoStatus", + message="This is an info message." + ) diff --git a/tests/nv_ingest/schemas/test_table_extractor_schema.py b/tests/nv_ingest/schemas/test_table_extractor_schema.py new file mode 100644 index 00000000..72d7c057 --- /dev/null +++ b/tests/nv_ingest/schemas/test_table_extractor_schema.py @@ -0,0 +1,129 @@ +import pytest +from pydantic import ValidationError +from nv_ingest.schemas.table_extractor_schema import TableExtractorConfigSchema, \ + TableExtractorSchema + + +# Test cases for TableExtractorConfigSchema +def test_valid_config_with_grpc_only(): + config = TableExtractorConfigSchema( + auth_token="valid_token", + paddle_endpoints=("grpc://paddle_service", None) + ) + assert config.auth_token == "valid_token" + assert config.paddle_endpoints == ("grpc://paddle_service", None) + + +def test_valid_config_with_http_only(): + config = TableExtractorConfigSchema( + auth_token="valid_token", + paddle_endpoints=(None, "http://paddle_service") + ) + assert config.auth_token == "valid_token" + assert config.paddle_endpoints == (None, "http://paddle_service") + + +def test_valid_config_with_both_services(): + config = TableExtractorConfigSchema( + auth_token="valid_token", + paddle_endpoints=("grpc://paddle_service", "http://paddle_service") + ) + assert config.auth_token == "valid_token" + assert config.paddle_endpoints == ("grpc://paddle_service", "http://paddle_service") + + +def test_invalid_config_empty_endpoints(): + with pytest.raises(ValidationError) as exc_info: + TableExtractorConfigSchema( + paddle_endpoints=(None, None) + ) + assert "Both gRPC and HTTP services cannot be empty for paddle_endpoints" in str(exc_info.value) + + +def test_invalid_extra_fields(): + with pytest.raises(ValidationError) as exc_info: + TableExtractorConfigSchema( + auth_token="valid_token", + paddle_endpoints=("grpc://paddle_service", None), + extra_field="invalid" + ) + assert "extra fields not permitted" in str(exc_info.value) + + +def test_cleaning_empty_strings_in_endpoints(): + config = TableExtractorConfigSchema( + paddle_endpoints=(" ", "http://paddle_service") + ) + assert config.paddle_endpoints == (None, "http://paddle_service") + + config = TableExtractorConfigSchema( + paddle_endpoints=("grpc://paddle_service", "") + ) + assert config.paddle_endpoints == ("grpc://paddle_service", None) + + +def test_auth_token_is_none_by_default(): + config = TableExtractorConfigSchema( + paddle_endpoints=("grpc://paddle_service", "http://paddle_service") + ) + assert config.auth_token is None + + +# Test cases for TableExtractorSchema +def test_table_extractor_schema_defaults(): + config = TableExtractorSchema() + assert config.max_queue_size == 1 + assert config.n_workers == 2 + assert config.raise_on_failure is False + assert config.stage_config is None + + +def test_table_extractor_schema_with_custom_values(): + stage_config = TableExtractorConfigSchema( + paddle_endpoints=("grpc://paddle_service", "http://paddle_service") + ) + config = TableExtractorSchema( + max_queue_size=15, + n_workers=12, + raise_on_failure=True, + stage_config=stage_config + ) + assert config.max_queue_size == 15 + assert config.n_workers == 12 + assert config.raise_on_failure is True + assert config.stage_config == stage_config + + +def test_table_extractor_schema_without_stage_config(): + config = TableExtractorSchema( + max_queue_size=20, + n_workers=5, + raise_on_failure=True + ) + assert config.max_queue_size == 20 + assert config.n_workers == 5 + assert config.raise_on_failure is True + assert config.stage_config is None + + +def test_invalid_table_extractor_schema_negative_queue_size(): + with pytest.raises(ValidationError): + TableExtractorSchema( + max_queue_size=-5 + ) + + +def test_invalid_table_extractor_schema_zero_workers(): + with pytest.raises(ValidationError): + TableExtractorSchema( + n_workers=0 + ) + + +def test_invalid_extra_fields_in_table_extractor_schema(): + with pytest.raises(ValidationError): + TableExtractorSchema( + max_queue_size=10, + n_workers=5, + extra_field="invalid" + ) diff --git a/tests/nv_ingest/util/image_processing/test_transforms.py b/tests/nv_ingest/util/image_processing/test_transforms.py index b9b7e276..ad86c191 100644 --- a/tests/nv_ingest/util/image_processing/test_transforms.py +++ b/tests/nv_ingest/util/image_processing/test_transforms.py @@ -1,6 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest import numpy as np +from PIL import Image +import base64 +from io import BytesIO +from unittest import mock + +from nv_ingest.util.image_processing.transforms import numpy_to_base64, base64_to_numpy, check_numpy_image_size + + +# Helper function to create a base64-encoded string from an image +def create_base64_image(width, height, color="white"): + img = Image.new('RGB', (width, height), color=color) + buffered = BytesIO() + img.save(buffered, format="PNG") + return base64.b64encode(buffered.getvalue()).decode('utf-8') + + +# Fixture for a valid base64-encoded image string +@pytest.fixture +def valid_base64_image(): + return create_base64_image(64, 64) + -from nv_ingest.util.image_processing.transforms import numpy_to_base64 +# Fixture for a corrupted base64 string +@pytest.fixture +def corrupted_base64_image(): + return "not_a_valid_base64_string" + + +# Fixture for a base64 string that decodes but is not a valid image +@pytest.fixture +def non_image_base64(): + return base64.b64encode(b"This is not an image").decode('utf-8') def test_numpy_to_base64_valid_rgba_image(): @@ -25,3 +60,50 @@ def test_numpy_to_base64_grayscale_redundant_axis(): assert isinstance(result, str) assert len(result) > 0 + + +# Tests for base64_to_numpy +def test_base64_to_numpy_valid(valid_base64_image): + img_array = base64_to_numpy(valid_base64_image) + assert isinstance(img_array, np.ndarray) + assert img_array.shape[0] == 64 # Height + assert img_array.shape[1] == 64 # Width + + +def test_base64_to_numpy_invalid_string(corrupted_base64_image): + with pytest.raises(ValueError, match="Invalid base64 string"): + base64_to_numpy(corrupted_base64_image) + + +def test_base64_to_numpy_non_image(non_image_base64): + with pytest.raises(ValueError, match="Unable to decode image from base64 string"): + base64_to_numpy(non_image_base64) + + +def test_base64_to_numpy_import_error(monkeypatch, valid_base64_image): + # Simulate ImportError for PIL by patching import_module + with mock.patch("PIL.Image.open", side_effect=ImportError("PIL library not available")): + with pytest.raises(ImportError): + base64_to_numpy(valid_base64_image) + + +# Tests for check_numpy_image_size +def test_check_numpy_image_size_valid(): + img = np.zeros((100, 100, 3), dtype=np.uint8) + assert check_numpy_image_size(img, 50, 50) is True + + +def test_check_numpy_image_size_too_small_height(): + img = np.zeros((40, 100, 3), dtype=np.uint8) # Height less than min + assert check_numpy_image_size(img, 50, 50) is False + + +def test_check_numpy_image_size_too_small_width(): + img = np.zeros((100, 40, 3), dtype=np.uint8) # Width less than min + assert check_numpy_image_size(img, 50, 50) is False + + +def test_check_numpy_image_size_invalid_dimensions(): + img = np.zeros((100,), dtype=np.uint8) # 1D array + with pytest.raises(ValueError, match="The input array does not have sufficient dimensions for an image."): + check_numpy_image_size(img, 50, 50) diff --git a/tests/stages/__init__.py b/tests/stages/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/stages/nims/__init__.py b/tests/stages/nims/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/stages/nims/test_chart_extraction.py b/tests/stages/nims/test_chart_extraction.py new file mode 100644 index 00000000..8b1d3c5b --- /dev/null +++ b/tests/stages/nims/test_chart_extraction.py @@ -0,0 +1,279 @@ +import pytest +import pandas as pd +from unittest.mock import Mock, patch +from nv_ingest.stages.nim.chart_extraction import _update_metadata, \ + _extract_chart_data # Adjust the import as per your module +import requests + +MODULE_UNDER_TEST = "nv_ingest.stages.nim.chart_extraction" # Replace with your actual module name + + +# Sample data for testing +@pytest.fixture +def base64_encoded_image(): + # Create a simple image and encode it to base64 + from PIL import Image + from io import BytesIO + import base64 + + img = Image.new('RGB', (64, 64), color='white') + buffered = BytesIO() + img.save(buffered, format="PNG") + img_bytes = buffered.getvalue() + base64_str = base64.b64encode(img_bytes).decode('utf-8') + return base64_str + + +@pytest.fixture +def sample_dataframe(base64_encoded_image): + data = { + "metadata": [{ + "content": base64_encoded_image, + "content_metadata": { + "type": "structured", + "subtype": "chart" + }, + "table_metadata": { + "table_content": "original_content" + } + }] + } + df = pd.DataFrame(data) + return df + + +@pytest.fixture +def dataframe_missing_metadata(): + data = { + "other_data": ["no metadata here"] + } + df = pd.DataFrame(data) + return df + + +@pytest.fixture +def dataframe_non_chart(base64_encoded_image): + data = { + "metadata": [{ + "content": base64_encoded_image, + "content_metadata": { + "type": "text", # Not "structured" + "subtype": "paragraph" # Not "chart" + }, + "table_metadata": { + "table_content": "original_content" + } + }] + } + df = pd.DataFrame(data) + return df + + +# Common mock fixtures +@pytest.fixture +def mock_clients_and_requests(): + # Dummy clients as dictionaries with 'endpoint_url' and 'headers' + deplot_client = { + 'endpoint_url': 'http://deplot_endpoint_url', + 'headers': {'Authorization': 'Bearer mock_token'} + } + cached_client = { + 'endpoint_url': 'http://cached_endpoint_url', + 'headers': {'Authorization': 'Bearer mock_token'} + } + + # Mock response for requests.post (successful inference) + mock_response_deplot = Mock() + mock_response_deplot.raise_for_status = Mock() # Does nothing + mock_response_deplot.json.return_value = { + 'object': 'list', + 'data': [{ + 'index': 0, + 'content': 'deplot_result_content', + 'object': 'string' + }], + 'model': 'deplot', + 'usage': None + } + + mock_response_cached = Mock() + mock_response_cached.raise_for_status = Mock() # Does nothing + mock_response_cached.json.return_value = { + 'object': 'list', + 'data': [{ + 'index': 0, + 'content': 'cached_result_content', + 'object': 'string' + }], + 'model': 'cached', + 'usage': None + } + + # Patching create_inference_client and requests.post + with patch(f'{MODULE_UNDER_TEST}.create_inference_client') as mock_create_client, \ + patch('requests.post') as mock_requests_post: + # Mock create_inference_client to return dummy clients + def side_effect_create_inference_client(endpoints, auth_token, protocol): + if 'deplot' in endpoints[0]: + return deplot_client + elif 'cached' in endpoints[0]: + return cached_client + else: + return None + + mock_create_client.side_effect = side_effect_create_inference_client + + # Mock requests.post to return different responses based on URL + def side_effect_requests_post(url, *args, **kwargs): + if 'deplot' in url: + return mock_response_deplot + elif 'cached' in url: + return mock_response_cached + else: + return Mock() + + mock_requests_post.side_effect = side_effect_requests_post + + yield deplot_client, cached_client, mock_create_client, mock_requests_post + + +@pytest.fixture +def mock_clients_and_requests_failure(): + # Dummy clients as dictionaries with 'endpoint_url' and 'headers' + deplot_client = { + 'endpoint_url': 'http://deplot_endpoint_url', + 'headers': {'Authorization': 'Bearer mock_token'} + } + cached_client = { + 'endpoint_url': 'http://cached_endpoint_url', + 'headers': {'Authorization': 'Bearer mock_token'} + } + + # Mock response for requests.post to raise an HTTPError + mock_response_failure = Mock() + mock_response_failure.raise_for_status.side_effect = requests.exceptions.HTTPError("Inference error") + mock_response_failure.json.return_value = {} + + # Patching create_inference_client and requests.post + with patch(f'{MODULE_UNDER_TEST}.create_inference_client') as mock_create_client, \ + patch('requests.post', return_value=mock_response_failure) as mock_requests_post: + # Mock create_inference_client to return dummy clients + def side_effect_create_inference_client(endpoints, auth_token, protocol): + if 'deplot' in endpoints[0]: + return deplot_client + elif 'cached' in endpoints[0]: + return cached_client + else: + return None + + mock_create_client.side_effect = side_effect_create_inference_client + + yield deplot_client, cached_client, mock_create_client, mock_requests_post + + +# Tests for _update_metadata +def test_update_metadata_missing_metadata(dataframe_missing_metadata, mock_clients_and_requests): + deplot_client, cached_client, _, _ = mock_clients_and_requests + + row = dataframe_missing_metadata.iloc[0] + trace_info = {} + with pytest.raises(ValueError, match="Row does not contain 'metadata'."): + _update_metadata(row, cached_client, deplot_client, trace_info) + + +def test_update_metadata_non_chart_content(dataframe_non_chart, mock_clients_and_requests): + deplot_client, cached_client, _, _ = mock_clients_and_requests + + row = dataframe_non_chart.iloc[0] + trace_info = {} + result = _update_metadata(row, cached_client, deplot_client, trace_info) + # The metadata should remain unchanged + assert result == row["metadata"] + + +@pytest.mark.xfail +def test_update_metadata_successful_update(sample_dataframe, mock_clients_and_requests): + deplot_client, cached_client, _, _ = mock_clients_and_requests + + row = sample_dataframe.iloc[0] + trace_info = {} + result = _update_metadata(row, cached_client, deplot_client, trace_info) + # The table_content should be updated with combined result + expected_content = 'Combined content: cached_result_content + deplot_result_content' + assert result["table_metadata"]["table_content"] == expected_content + + +@pytest.mark.xfail +def test_update_metadata_inference_failure(sample_dataframe, mock_clients_and_requests_failure): + deplot_client, cached_client, _, mock_requests_post = mock_clients_and_requests_failure + + row = sample_dataframe.iloc[0] + trace_info = {} + + with pytest.raises(RuntimeError, match="An error occurred during inference: Inference error"): + _update_metadata(row, cached_client, deplot_client, trace_info) + + # Verify that requests.post was called and raised an exception + assert mock_requests_post.call_count >= 1 # At least one call failed + + +@pytest.mark.xfail +def test_extract_chart_data_successful(sample_dataframe, mock_clients_and_requests): + deplot_client, cached_client, mock_create_client, mock_requests_post = mock_clients_and_requests + + validated_config = Mock() + validated_config.stage_config.deplot_endpoints = ("http://deplot_endpoint", None) + validated_config.stage_config.cached_endpoints = ("http://cached_endpoint", None) + validated_config.stage_config.auth_token = "mock_token" + validated_config.stage_config.deplot_infer_protocol = "mock_protocol" + validated_config.stage_config.cached_infer_protocol = "mock_protocol" + + trace_info = {} + + updated_df, trace_info_out = _extract_chart_data(sample_dataframe, {}, validated_config, trace_info) + + # Expected content from the combined results + expected_content = 'Combined content: cached_result_content + deplot_result_content' + assert updated_df.loc[0, 'metadata']['table_metadata']['table_content'] == expected_content + assert trace_info_out == trace_info + + # Verify that the mocked methods were called + assert mock_create_client.call_count == 2 # deplot and cached clients created + assert mock_requests_post.call_count == 2 # deplot and cached inference called + + +def test_extract_chart_data_missing_metadata(dataframe_missing_metadata, mock_clients_and_requests): + deplot_client, cached_client, _, _ = mock_clients_and_requests + + validated_config = Mock() + validated_config.stage_config.deplot_endpoints = ("http://deplot_endpoint", None) + validated_config.stage_config.cached_endpoints = ("http://cached_endpoint", None) + validated_config.stage_config.auth_token = "mock_token" + validated_config.stage_config.deplot_infer_protocol = "mock_protocol" + validated_config.stage_config.cached_infer_protocol = "mock_protocol" + + trace_info = {} + + with pytest.raises(ValueError, match="Row does not contain 'metadata'."): + _extract_chart_data(dataframe_missing_metadata, {}, validated_config, trace_info) + + +@pytest.mark.xfail +def test_extract_chart_data_inference_failure(sample_dataframe, mock_clients_and_requests_failure): + deplot_client, cached_client, mock_create_client, mock_requests_post = mock_clients_and_requests_failure + + validated_config = Mock() + validated_config.stage_config.deplot_endpoints = ("http://deplot_endpoint", None) + validated_config.stage_config.cached_endpoints = ("http://cached_endpoint", None) + validated_config.stage_config.auth_token = "mock_token" + validated_config.stage_config.deplot_infer_protocol = "mock_protocol" + validated_config.stage_config.cached_infer_protocol = "mock_protocol" + + trace_info = {} + + with pytest.raises(RuntimeError, match="An error occurred during inference: Inference error"): + _extract_chart_data(sample_dataframe, {}, validated_config, trace_info) + + # Verify that the mocked methods were called + assert mock_create_client.call_count == 2 + assert mock_requests_post.call_count >= 1 # At least one call failed diff --git a/tests/stages/nims/test_table_extraction.py b/tests/stages/nims/test_table_extraction.py new file mode 100644 index 00000000..6af24097 --- /dev/null +++ b/tests/stages/nims/test_table_extraction.py @@ -0,0 +1,339 @@ +import pytest +import pandas as pd +import base64 +import requests +from unittest.mock import Mock, patch +from io import BytesIO +from PIL import Image +from nv_ingest.stages.nim.table_extraction import _update_metadata, _extract_table_data + +# Constants for minimum image size +PADDLE_MIN_WIDTH = 32 +PADDLE_MIN_HEIGHT = 32 + +MODULE_UNDER_TEST = "nv_ingest.stages.nim.table_extraction" + + +# Fixture for common mock setup +@pytest.fixture +def mock_paddle_client_and_requests(): + # Dummy client as a dictionary with 'endpoint_url' and 'headers' + paddle_client = { + 'endpoint_url': 'http://mock_endpoint_url', + 'headers': {'Authorization': 'Bearer mock_token'} + } + + # Mock response for requests.post + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + 'object': 'list', + 'data': [{ + 'index': 0, + 'content': ('Chart 1 This chart shows some gadgets, and some very fictitious costs ' + 'Gadgets and their cost $160.00 $140.00 $120.00 $100.00 $80.00 $60.00 ' + '$40.00 $20.00 $- Hammer Powerdrill Bluetooth speaker Minifridge Premium ' + 'desk fan Cost'), + 'object': 'string' + }], + 'model': 'paddleocr', + 'usage': None + } + + # Patching create_inference_client and requests.post + with patch(f'{MODULE_UNDER_TEST}.create_inference_client', return_value=paddle_client) as mock_create_client, \ + patch('requests.post', return_value=mock_response) as mock_requests_post: + yield paddle_client, mock_create_client, mock_requests_post + + +# Fixture for common mock setup (inference failure) +@pytest.fixture +def mock_paddle_client_and_requests_failure(): + # Dummy client as a dictionary with 'endpoint_url' and 'headers' + paddle_client = { + 'endpoint_url': 'http://mock_endpoint_url', + 'headers': {'Authorization': 'Bearer mock_token'} + } + + # Mock response for requests.post to raise an HTTPError + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Inference error") + mock_response.json.return_value = {} + + # Patching create_inference_client and requests.post + with patch(f'{MODULE_UNDER_TEST}.create_inference_client', return_value=paddle_client) as mock_create_client, \ + patch('requests.post', return_value=mock_response) as mock_requests_post: + yield paddle_client, mock_create_client, mock_requests_post + + +# Fixture to create a sample image and encode it in base64 +@pytest.fixture +def base64_encoded_image(): + # Create a simple image using PIL + img = Image.new('RGB', (64, 64), color='white') + buffered = BytesIO() + img.save(buffered, format="PNG") + img_bytes = buffered.getvalue() + # Encode the image to base64 + base64_str = base64.b64encode(img_bytes).decode('utf-8') + return base64_str + + +# Fixture for a small image (below minimum size) +@pytest.fixture +def base64_encoded_small_image(): + img = Image.new('RGB', (16, 16), color='white') # Smaller than minimum size + buffered = BytesIO() + img.save(buffered, format="PNG") + img_bytes = buffered.getvalue() + base64_str = base64.b64encode(img_bytes).decode('utf-8') + return base64_str + + +# Fixture for a sample DataFrame +@pytest.fixture +def sample_dataframe(base64_encoded_image): + data = { + "metadata": [{ + "content": base64_encoded_image, + "content_metadata": { + "type": "structured", + "subtype": "table" + }, + "table_metadata": { + "table_content": "" + } + }] + } + df = pd.DataFrame(data) + return df + + +# Fixture for DataFrame with missing metadata +@pytest.fixture +def dataframe_missing_metadata(): + data = { + "other_data": ["no metadata here"] + } + df = pd.DataFrame(data) + return df + + +# Fixture for DataFrame where content_metadata doesn't meet conditions +@pytest.fixture +def dataframe_non_table(base64_encoded_image): + data = { + "metadata": [{ + "content": base64_encoded_image, + "content_metadata": { + "type": "text", # Not "structured" + "subtype": "paragraph" # Not "table" + }, + "table_metadata": { + "table_content": "" + } + }] + } + df = pd.DataFrame(data) + return df + + +# Dummy paddle client that simulates the external service +class DummyPaddleClient: + def infer(self, *args, **kwargs): + return "{'object': 'list', 'data': [{'index': 0, 'content': 'Chart 1 This chart shows some gadgets, and some very fictitious costs Gadgets and their cost $160.00 $140.00 $120.00 $100.00 $80.00 $60.00 $40.00 $20.00 $- Hammer Powerdrill Bluetooth speaker Minifridge Premium desk fan Cost', 'object': 'string'}], 'model': 'paddleocr', 'usage': None}" + + def close(self): + pass + + +# Tests for _update_metadata +def test_update_metadata_missing_metadata(): + row = pd.Series({ + "other_data": "not metadata" + }) + paddle_client = DummyPaddleClient() + trace_info = {} + with pytest.raises(ValueError, match="Row does not contain 'metadata'."): + _update_metadata(row, paddle_client, trace_info) + + +def test_update_metadata_non_table_content(dataframe_non_table): + row = dataframe_non_table.iloc[0] + paddle_client = DummyPaddleClient() + trace_info = {} + result = _update_metadata(row, paddle_client, trace_info) + # The metadata should remain unchanged + assert result == row["metadata"] + + +def test_update_metadata_image_too_small(base64_encoded_small_image): + row = pd.Series({ + "metadata": { + "content": base64_encoded_small_image, + "content_metadata": { + "type": "structured", + "subtype": "table" + }, + "table_metadata": { + "table_content": "" + } + } + }) + paddle_client = DummyPaddleClient() + trace_info = {} + result = _update_metadata(row, paddle_client, trace_info) + # Since the image is too small, table_content should remain unchanged + assert result["table_metadata"]["table_content"] == "" + + +def test_update_metadata_successful_update(sample_dataframe, mock_paddle_client_and_requests): + paddle_client, mock_create_client, mock_requests_post = mock_paddle_client_and_requests + + row = sample_dataframe.iloc[0] + trace_info = {} + result = _update_metadata(row, paddle_client, trace_info) + + # Expected content from the mocked response + expected_content = ('Chart 1 This chart shows some gadgets, and some very fictitious costs ' + 'Gadgets and their cost $160.00 $140.00 $120.00 $100.00 $80.00 $60.00 ' + '$40.00 $20.00 $- Hammer Powerdrill Bluetooth speaker Minifridge Premium ' + 'desk fan Cost') + + # The table_content should be updated with expected_content + assert result["table_metadata"]["table_content"] == expected_content + + # Verify that requests.post was called + mock_requests_post.assert_called_once() + + +def test_update_metadata_inference_failure(sample_dataframe, mock_paddle_client_and_requests_failure): + paddle_client, mock_create_client, mock_requests_post = mock_paddle_client_and_requests_failure + + row = sample_dataframe.iloc[0] + trace_info = {} + + with pytest.raises(RuntimeError, match="HTTP request failed: Inference error"): + _update_metadata(row, paddle_client, trace_info) + + # Verify that requests.post was called and raised an exception + mock_requests_post.assert_called_once() + + +# Tests for _extract_table_data +def test_extract_table_data_successful(sample_dataframe, mock_paddle_client_and_requests): + paddle_client, mock_create_client, mock_requests_post = mock_paddle_client_and_requests + + validated_config = Mock() + validated_config.stage_config.paddle_endpoints = "mock_endpoint" + validated_config.stage_config.auth_token = "mock_token" + validated_config.stage_config.paddle_infer_protocol = "mock_protocol" + + trace_info = {} + + updated_df, trace_info_out = _extract_table_data(sample_dataframe, {}, validated_config, trace_info) + + # Expected content from the mocked response + expected_content = ('Chart 1 This chart shows some gadgets, and some very fictitious costs ' + 'Gadgets and their cost $160.00 $140.00 $120.00 $100.00 $80.00 $60.00 ' + '$40.00 $20.00 $- Hammer Powerdrill Bluetooth speaker Minifridge Premium ' + 'desk fan Cost') + assert updated_df.loc[0, 'metadata']['table_metadata']['table_content'] == expected_content + assert trace_info_out == trace_info + + # Verify that the mocked methods were called + mock_create_client.assert_called_once() + mock_requests_post.assert_called_once() + + +def test_extract_table_data_missing_metadata(dataframe_missing_metadata, mock_paddle_client_and_requests): + paddle_client, mock_create_client, mock_requests_post = mock_paddle_client_and_requests + + validated_config = Mock() + validated_config.stage_config.paddle_endpoints = "mock_endpoint" + validated_config.stage_config.auth_token = "mock_token" + validated_config.stage_config.paddle_infer_protocol = "mock_protocol" + + trace_info = {} + + with pytest.raises(ValueError, match="Row does not contain 'metadata'."): + _extract_table_data(dataframe_missing_metadata, {}, validated_config, trace_info) + + # Verify that the mocked methods were called + mock_create_client.assert_called_once() + # Since metadata is missing, requests.post should not be called + mock_requests_post.assert_not_called() + + +def test_extract_table_data_inference_failure(sample_dataframe, mock_paddle_client_and_requests_failure): + paddle_client, mock_create_client, mock_requests_post = mock_paddle_client_and_requests_failure + + validated_config = Mock() + validated_config.stage_config.paddle_endpoints = "mock_endpoint" + validated_config.stage_config.auth_token = "mock_token" + validated_config.stage_config.paddle_infer_protocol = "mock_protocol" + + trace_info = {} + + with pytest.raises(RuntimeError, match="HTTP request failed: Inference error"): + _extract_table_data(sample_dataframe, {}, validated_config, trace_info) + + # Verify that create_inference_client was called + mock_create_client.assert_called_once() + # Verify that requests.post was called and raised an exception + mock_requests_post.assert_called_once() + + +def test_extract_table_data_image_too_small(base64_encoded_small_image): + data = { + "metadata": [{ + "content": base64_encoded_small_image, + "content_metadata": { + "type": "structured", + "subtype": "table" + }, + "table_metadata": { + "table_content": "" + } + }] + } + df = pd.DataFrame(data) + + validated_config = Mock() + validated_config.stage_config.paddle_endpoints = "mock_endpoint" + validated_config.stage_config.auth_token = "mock_token" + validated_config.stage_config.paddle_infer_protocol = "mock_protocol" + + # Dummy client as a dictionary with 'endpoint_url' and 'headers' + paddle_client = { + 'endpoint_url': 'http://mock_endpoint_url', + 'headers': {'Authorization': 'Bearer mock_token'} + } + trace_info = {} + + def mock_create_inference_client(endpoints, auth_token, protocol): + return paddle_client + + # Mock response to simulate requests.post behavior + mock_response = Mock() + mock_response.raise_for_status = Mock() # Does nothing + mock_response.json.return_value = { + 'object': 'list', + 'data': [{ + 'index': 0, + 'content': ('Chart 1 This chart shows some gadgets, and some very fictitious costs ' + 'Gadgets and their cost $160.00 $140.00 $120.00 $100.00 $80.00 $60.00 ' + '$40.00 $20.00 $- Hammer Powerdrill Bluetooth speaker Minifridge Premium ' + 'desk fan Cost'), + 'object': 'string' + }], + 'model': 'paddleocr', + 'usage': None + } + + with patch(f'{MODULE_UNDER_TEST}.create_inference_client', side_effect=mock_create_inference_client), \ + patch('requests.post', return_value=mock_response): + updated_df, _ = _extract_table_data(df, {}, validated_config, trace_info) + + # The table_content should remain unchanged because the image is too small + assert updated_df.loc[0, 'metadata']['table_metadata']['table_content'] == ""