Skip to content

Commit

Permalink
Refactor Yolox model interface and inference processing (#332)
Browse files Browse the repository at this point in the history
  • Loading branch information
edknv authored Jan 16, 2025
1 parent 5edceb5 commit 9ef8e0b
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 272 deletions.
17 changes: 1 addition & 16 deletions src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from nv_ingest.util.image_processing.transforms import crop_image
from nv_ingest.util.image_processing.transforms import numpy_to_base64
from nv_ingest.util.nim.helpers import create_inference_client
from nv_ingest.util.nim.helpers import get_version
from nv_ingest.util.pdf.metadata_aggregators import Base64Image
from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent
from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata_from_pdf_image
Expand Down Expand Up @@ -64,22 +63,8 @@ def extract_tables_and_charts_using_image_ensemble(
) -> List[Tuple[int, object]]: # List[Tuple[int, CroppedImageWithContent]]
tables_and_charts = []

# Obtain yolox_version
# Assuming that the grpc endpoint is at index 0
yolox_http_endpoint = config.yolox_endpoints[1]
try:
yolox_version = get_version(yolox_http_endpoint)
if not yolox_version:
logger.warning(
"Failed to obtain yolox-page-elements version from the endpoint. Falling back to the latest version."
)
yolox_version = None # Default to the latest version
except Exception:
logger.waring("Failed to get yolox-page-elements version after 30 seconds. Falling back to the latest version.")
yolox_version = None # Default to the latest version

try:
model_interface = yolox_utils.YoloxPageElementsModelInterface(yolox_version=yolox_version)
model_interface = yolox_utils.YoloxPageElementsModelInterface()
yolox_client = create_inference_client(
config.yolox_endpoints, model_interface, config.auth_token, config.yolox_infer_protocol
)
Expand Down
2 changes: 1 addition & 1 deletion src/nv_ingest/util/nim/cached.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, An
else:
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")

def process_inference_results(self, output: Any, **kwargs) -> Any:
def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any:
"""
Process inference results for the Cached model.
Expand Down
2 changes: 1 addition & 1 deletion src/nv_ingest/util/nim/deplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, An
else:
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")

def process_inference_results(self, output: Any, **kwargs) -> Any:
def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any:
"""
Process inference results for the Deplot model.
Expand Down
4 changes: 2 additions & 2 deletions src/nv_ingest/util/nim/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def prepare_data_for_inference(self, data: dict):
"""
raise NotImplementedError("Subclasses should implement this method")

def process_inference_results(self, output_array, **kwargs):
def process_inference_results(self, output_array, protocol: str, **kwargs):
"""
Process the inference results from the model.
Expand Down Expand Up @@ -206,7 +206,7 @@ def infer(self, data: dict, model_name: str, **kwargs) -> Any:
response, protocol=self.protocol, data=prepared_data, **kwargs
)
results = self.model_interface.process_inference_results(
parsed_output, original_image_shapes=data.get("original_image_shapes"), **kwargs
parsed_output, original_image_shapes=data.get("original_image_shapes"), protocol=self.protocol, **kwargs
)
return results

Expand Down
Loading

0 comments on commit 9ef8e0b

Please sign in to comment.