# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import pandas as pd
from nv_ingest_api.internal.primitives.nim import NimClient
from nv_ingest_api.internal.primitives.nim.model_interface.ocr import OCRModelInterface
from nv_ingest_api.internal.primitives.nim.model_interface.ocr import get_ocr_model_name
from nv_ingest_api.internal.schemas.extract.extract_infographic_schema import (
InfographicExtractorSchema,
)
from nv_ingest_api.util.image_processing.transforms import base64_to_numpy
from nv_ingest_api.util.nim import create_inference_client
from nv_ingest_api.util.image_processing.table_and_chart import reorder_boxes
logger = logging.getLogger(__name__)
PADDLE_MIN_WIDTH = 32
PADDLE_MIN_HEIGHT = 32
def _filter_infographic_images(
base64_images: List[str],
) -> Tuple[List[str], List[int], List[Tuple[str, Optional[Any], Optional[Any]]]]:
"""
Filters base64-encoded images based on minimum size requirements.
Parameters
----------
base64_images : List[str]
List of base64-encoded image strings.
Returns
-------
Tuple[List[str], List[int], List[Tuple[str, Optional[Any], Optional[Any]]]]
- valid_images: List of images that meet the size requirements.
- valid_indices: Original indices of valid images.
- results: Initialized results list, with invalid images marked as (img, None, None).
"""
results: List[Tuple[str, Optional[Any], Optional[Any]]] = [("", None, None)] * len(base64_images)
valid_images: List[str] = []
valid_indices: List[int] = []
for i, img in enumerate(base64_images):
array = base64_to_numpy(img)
height, width = array.shape[0], array.shape[1]
if width >= PADDLE_MIN_WIDTH and height >= PADDLE_MIN_HEIGHT:
valid_images.append(img)
valid_indices.append(i)
else:
# Mark image as skipped if it does not meet size requirements.
results[i] = (img, None, None)
return valid_images, valid_indices, results
def _update_infographic_metadata(
base64_images: List[str],
ocr_client: NimClient,
ocr_model_name: str,
worker_pool_size: int = 8, # Not currently used
trace_info: Optional[Dict] = None,
) -> List[Tuple[str, Optional[Any], Optional[Any]]]:
"""
Filters base64-encoded images and uses OCR to extract infographic data.
For each image that meets the minimum size, calls ocr_client.infer to obtain
(text_predictions, bounding_boxes). Invalid images are marked as skipped.
Parameters
----------
base64_images : List[str]
List of base64-encoded images.
ocr_client : NimClient
Client instance for OCR inference.
worker_pool_size : int, optional
Worker pool size (currently not used), by default 8.
trace_info : Optional[Dict], optional
Optional trace information for debugging.
Returns
-------
List[Tuple[str, Optional[Any], Optional[Any]]]
List of tuples in the same order as base64_images, where each tuple contains:
(base64_image, text_predictions, bounding_boxes).
"""
logger.debug(f"Running infographic extraction using protocol {ocr_client.protocol}")
valid_images, valid_indices, results = _filter_infographic_images(base64_images)
data_ocr = {"base64_images": valid_images}
# worker_pool_size is not used in current implementation.
_ = worker_pool_size
infer_kwargs = dict(
stage_name="infographic_extraction",
max_batch_size=1 if ocr_client.protocol == "grpc" else 2,
trace_info=trace_info,
)
if ocr_model_name == "paddle":
infer_kwargs.update(
model_name="paddle",
)
elif ocr_model_name == "scene_text":
infer_kwargs.update(
model_name=ocr_model_name,
input_names=["input", "merge_levels"],
dtypes=["FP32", "BYTES"],
merge_level="paragraph",
)
elif ocr_model_name == "scene_text_ensemble":
infer_kwargs.update(
model_name=ocr_model_name,
input_names=["INPUT_IMAGE_URLS", "MERGE_LEVELS"],
output_names=["OUTPUT"],
dtypes=["BYTES", "BYTES"],
merge_level="paragraph",
)
else:
raise ValueError(f"Unknown OCR model name: {ocr_model_name}")
try:
ocr_results = ocr_client.infer(data_ocr, **infer_kwargs)
except Exception as e:
logger.error(f"Error calling ocr_client.infer: {e}", exc_info=True)
raise
if len(ocr_results) != len(valid_images):
raise ValueError(f"Expected {len(valid_images)} ocr results, got {len(ocr_results)}")
for idx, ocr_res in enumerate(ocr_results):
original_index = valid_indices[idx]
if ocr_model_name == "paddle":
logger.debug(f"OCR results for image {base64_images[original_index]}: {ocr_res}")
else:
# Each ocr_res is expected to be a tuple (text_predictions, bounding_boxes, conf_scores).
ocr_res = reorder_boxes(*ocr_res)
results[original_index] = (base64_images[original_index], ocr_res[0], ocr_res[1])
return results
def _create_clients(
ocr_endpoints: Tuple[str, str],
ocr_protocol: str,
auth_token: str,
) -> NimClient:
ocr_model_interface = OCRModelInterface()
logger.debug(f"Inference protocols: ocr={ocr_protocol}")
ocr_client = create_inference_client(
endpoints=ocr_endpoints,
model_interface=ocr_model_interface,
auth_token=auth_token,
infer_protocol=ocr_protocol,
)
return ocr_client
def _meets_infographic_criteria(row: pd.Series) -> bool:
"""
Determines if a DataFrame row meets the criteria for infographic extraction.
A row qualifies if:
- It contains a 'metadata' dictionary.
- The 'content_metadata' in metadata has type "structured" and subtype "infographic".
- The 'table_metadata' is not None.
- The 'content' is not None or an empty string.
Parameters
----------
row : pd.Series
A row from the DataFrame.
Returns
-------
bool
True if the row meets all criteria; False otherwise.
"""
metadata = row.get("metadata", {})
if not metadata:
return False
content_md = metadata.get("content_metadata", {})
if (
content_md.get("type") == "structured"
and content_md.get("subtype") == "infographic"
and metadata.get("table_metadata") is not None
and metadata.get("content") not in [None, ""]
):
return True
return False