# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Union
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import numpy as np
import pandas as pd
from nv_ingest_api.internal.schemas.extract.extract_chart_schema import ChartExtractorSchema
from nv_ingest_api.internal.schemas.meta.ingest_job_schema import IngestTaskChartExtraction
from nv_ingest_api.util.image_processing.table_and_chart import join_yolox_graphic_elements_and_ocr_output
from nv_ingest_api.util.image_processing.table_and_chart import process_yolox_graphic_elements
from nv_ingest_api.internal.primitives.nim.model_interface.ocr import OCRModelInterface
from nv_ingest_api.internal.primitives.nim.model_interface.ocr import get_ocr_model_name
from nv_ingest_api.internal.primitives.nim import NimClient
from nv_ingest_api.internal.primitives.nim.model_interface.yolox import YoloxGraphicElementsModelInterface
from nv_ingest_api.util.image_processing.transforms import base64_to_numpy
from nv_ingest_api.util.nim import create_inference_client
PADDLE_MIN_WIDTH = 32
PADDLE_MIN_HEIGHT = 32
logger = logging.getLogger(f"ray.{__name__}")
def _filter_valid_chart_images(
base64_images: List[str],
) -> Tuple[List[str], List[np.ndarray], List[int], List[Tuple[str, Optional[Dict]]]]:
"""
Filter base64-encoded images based on minimum dimensions for chart extraction.
Returns:
- valid_images: Base64 strings meeting size requirements.
- valid_arrays: Corresponding numpy arrays.
- valid_indices: Original indices of valid images.
- results: Initial results list where invalid images are set to (img, None).
"""
results: List[Tuple[str, Optional[Dict]]] = [("", None)] * len(base64_images)
valid_images: List[str] = []
valid_arrays: List[np.ndarray] = []
valid_indices: List[int] = []
for i, img in enumerate(base64_images):
array = base64_to_numpy(img)
height, width = array.shape[0], array.shape[1]
if width >= PADDLE_MIN_WIDTH and height >= PADDLE_MIN_HEIGHT:
valid_images.append(img)
valid_arrays.append(array)
valid_indices.append(i)
else:
# Image is too small; mark as skipped.
results[i] = (img, None)
return valid_images, valid_arrays, valid_indices, results
def _run_chart_inference(
yolox_client: Any,
ocr_client: Any,
ocr_model_name: str,
valid_arrays: List[np.ndarray],
valid_images: List[str],
trace_info: Dict,
) -> Tuple[List[Any], List[Any]]:
"""
Run concurrent inference for chart extraction using YOLOX and Paddle.
Returns a tuple of (yolox_results, ocr_results).
"""
data_yolox = {"images": valid_arrays}
data_ocr = {"base64_images": valid_images}
future_yolox_kwargs = dict(
data=data_yolox,
model_name="yolox_ensemble",
stage_name="chart_extraction",
input_names=["INPUT_IMAGES", "THRESHOLDS"],
dtypes=["BYTES", "FP32"],
output_names=["OUTPUT"],
trace_info=trace_info,
max_batch_size=8,
)
future_ocr_kwargs = dict(
data=data_ocr,
stage_name="chart_extraction",
max_batch_size=1 if ocr_client.protocol == "grpc" else 2,
trace_info=trace_info,
)
if ocr_model_name == "paddle":
future_ocr_kwargs.update(
model_name="paddle",
)
elif ocr_model_name == "scene_text":
future_ocr_kwargs.update(
model_name=ocr_model_name,
input_names=["input", "merge_levels"],
dtypes=["FP32", "BYTES"],
merge_level="paragraph",
)
elif ocr_model_name == "scene_text_ensemble":
future_ocr_kwargs.update(
model_name=ocr_model_name,
input_names=["INPUT_IMAGE_URLS", "MERGE_LEVELS"],
output_names=["OUTPUT"],
dtypes=["BYTES", "BYTES"],
merge_level="paragraph",
)
else:
raise ValueError(f"Unknown OCR model name: {ocr_model_name}")
with ThreadPoolExecutor(max_workers=2) as executor:
future_yolox = executor.submit(yolox_client.infer, **future_yolox_kwargs)
future_ocr = executor.submit(ocr_client.infer, **future_ocr_kwargs)
try:
yolox_results = future_yolox.result()
except Exception as e:
logger.error(f"Error calling yolox_client.infer: {e}", exc_info=True)
raise
try:
ocr_results = future_ocr.result()
except Exception as e:
logger.error(f"Error calling ocr_client.infer: {e}", exc_info=True)
raise
return yolox_results, ocr_results
def _validate_chart_inference_results(
yolox_results: Any, ocr_results: Any, valid_arrays: List[Any], valid_images: List[str]
) -> Tuple[List[Any], List[Any]]:
"""
Ensure inference results are lists and have expected lengths.
Raises:
ValueError if results do not match expected types or lengths.
"""
if not (isinstance(yolox_results, list) and isinstance(ocr_results, list)):
raise ValueError("Expected list results from both yolox_client and ocr_client infer calls.")
if len(yolox_results) != len(valid_arrays):
raise ValueError(f"Expected {len(valid_arrays)} yolox results, got {len(yolox_results)}")
if len(ocr_results) != len(valid_images):
raise ValueError(f"Expected {len(valid_images)} ocr results, got {len(ocr_results)}")
return yolox_results, ocr_results
def _merge_chart_results(
base64_images: List[str],
valid_indices: List[int],
yolox_results: List[Any],
ocr_results: List[Any],
initial_results: List[Tuple[str, Optional[Dict]]],
) -> List[Tuple[str, Optional[Dict]]]:
"""
Merge inference results into the initial results list using the original indices.
For each valid image, processes the results from both inference calls and updates the
corresponding entry in the results list.
"""
for idx, (yolox_res, ocr_res) in enumerate(zip(yolox_results, ocr_results)):
# Unpack ocr result into bounding boxes and text predictions.
bounding_boxes, text_predictions, _ = ocr_res
yolox_elements = join_yolox_graphic_elements_and_ocr_output(yolox_res, bounding_boxes, text_predictions)
chart_content = process_yolox_graphic_elements(yolox_elements)
original_index = valid_indices[idx]
initial_results[original_index] = (base64_images[original_index], chart_content)
return initial_results
def _update_chart_metadata(
base64_images: List[str],
yolox_client: Any,
ocr_client: Any,
ocr_model_name: str,
trace_info: Dict,
worker_pool_size: int = 8, # Not currently used.
) -> List[Tuple[str, Optional[Dict]]]:
"""
Given a list of base64-encoded chart images, concurrently call both YOLOX and Paddle
inference services to extract chart data.
For each base64-encoded image, returns:
(original_image_str, joined_chart_content_dict)
Images that do not meet minimum size requirements are marked as skipped.
"""
logger.debug("Running chart extraction using updated concurrency handling.")
# Initialize results with placeholders and filter valid images.
valid_images, valid_arrays, valid_indices, results = _filter_valid_chart_images(base64_images)
# Run concurrent inference only for valid images.
yolox_results, ocr_results = _run_chart_inference(
yolox_client=yolox_client,
ocr_client=ocr_client,
ocr_model_name=ocr_model_name,
valid_arrays=valid_arrays,
valid_images=valid_images,
trace_info=trace_info,
)
# Validate that the returned inference results are lists of the expected length.
yolox_results, ocr_results = _validate_chart_inference_results(
yolox_results, ocr_results, valid_arrays, valid_images
)
# Merge the inference results into the results list.
return _merge_chart_results(base64_images, valid_indices, yolox_results, ocr_results, results)
def _create_clients(
yolox_endpoints: Tuple[str, str],
yolox_protocol: str,
ocr_endpoints: Tuple[str, str],
ocr_protocol: str,
auth_token: str,
) -> Tuple[NimClient, NimClient]:
yolox_model_interface = YoloxGraphicElementsModelInterface()
ocr_model_interface = OCRModelInterface()
logger.debug(f"Inference protocols: yolox={yolox_protocol}, ocr={ocr_protocol}")
yolox_client = create_inference_client(
endpoints=yolox_endpoints,
model_interface=yolox_model_interface,
auth_token=auth_token,
infer_protocol=yolox_protocol,
)
ocr_client = create_inference_client(
endpoints=ocr_endpoints,
model_interface=ocr_model_interface,
auth_token=auth_token,
infer_protocol=ocr_protocol,
)
return yolox_client, ocr_client