Source code for nv_ingest_api.internal.primitives.nim.nim_client

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import logging
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any
from typing import Optional
from typing import Tuple, Union

import numpy as np
import requests
import tritonclient.grpc as grpcclient

from nv_ingest_api.internal.primitives.tracing.tagging import traceable_func
from nv_ingest_api.util.string_processing import generate_url

logger = logging.getLogger(__name__)


[docs] class NimClient: """ A client for interfacing with a model inference server using gRPC or HTTP protocols. """ def __init__( self, model_interface, protocol: str, endpoints: Tuple[str, str], auth_token: Optional[str] = None, timeout: float = 120.0, max_retries: int = 5, max_429_retries: int = 5, ): """ Initialize the NimClient with the specified model interface, protocol, and server endpoints. Parameters ---------- model_interface : ModelInterface The model interface implementation to use. protocol : str The protocol to use ("grpc" or "http"). endpoints : tuple A tuple containing the gRPC and HTTP endpoints. auth_token : str, optional Authorization token for HTTP requests (default: None). timeout : float, optional Timeout for HTTP requests in seconds (default: 30.0). max_retries : int, optional The maximum number of retries for non-429 server-side errors (default: 5). max_429_retries : int, optional The maximum number of retries specifically for 429 errors (default: 10). Raises ------ ValueError If an invalid protocol is specified or if required endpoints are missing. """ self.client = None self.model_interface = model_interface self.protocol = protocol.lower() self.auth_token = auth_token self.timeout = timeout # Timeout for HTTP requests self.max_retries = max_retries self.max_429_retries = max_429_retries self._grpc_endpoint, self._http_endpoint = endpoints self._max_batch_sizes = {} self._lock = threading.Lock() if self.protocol == "grpc": if not self._grpc_endpoint: raise ValueError("gRPC endpoint must be provided for gRPC protocol") logger.debug(f"Creating gRPC client with {self._grpc_endpoint}") self.client = grpcclient.InferenceServerClient(url=self._grpc_endpoint) elif self.protocol == "http": if not self._http_endpoint: raise ValueError("HTTP endpoint must be provided for HTTP protocol") logger.debug(f"Creating HTTP client with {self._http_endpoint}") self.endpoint_url = generate_url(self._http_endpoint) self.headers = {"accept": "application/json", "content-type": "application/json"} if self.auth_token: self.headers["Authorization"] = f"Bearer {self.auth_token}" else: raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") def _fetch_max_batch_size(self, model_name, model_version: str = "") -> int: """Fetch the maximum batch size from the Triton model configuration in a thread-safe manner.""" if model_name == "yolox_ensemble": model_name = "yolox" if model_name in self._max_batch_sizes: return self._max_batch_sizes[model_name] with self._lock: # Double check, just in case another thread set the value while we were waiting if model_name in self._max_batch_sizes: return self._max_batch_sizes[model_name] if not self._grpc_endpoint: self._max_batch_sizes[model_name] = 1 return 1 try: client = self.client if self.client else grpcclient.InferenceServerClient(url=self._grpc_endpoint) model_config = client.get_model_config(model_name=model_name, model_version=model_version) self._max_batch_sizes[model_name] = model_config.config.max_batch_size logger.debug(f"Max batch size for model '{model_name}': {self._max_batch_sizes[model_name]}") except Exception as e: self._max_batch_sizes[model_name] = 1 logger.warning(f"Failed to retrieve max batch size: {e}, defaulting to 1") return self._max_batch_sizes[model_name] def _process_batch(self, batch_input, *, batch_data, model_name, **kwargs): """ Process a single batch input for inference using its corresponding batch_data. Parameters ---------- batch_input : Any The input data for this batch. batch_data : Any The corresponding scratch-pad data for this batch as returned by format_input. model_name : str The model name for inference. kwargs : dict Additional parameters. Returns ------- tuple A tuple (parsed_output, batch_data) for subsequent post-processing. """ if self.protocol == "grpc": logger.debug("Performing gRPC inference for a batch...") response = self._grpc_infer(batch_input, model_name, **kwargs) logger.debug("gRPC inference received response for a batch") elif self.protocol == "http": logger.debug("Performing HTTP inference for a batch...") response = self._http_infer(batch_input) logger.debug("HTTP inference received response for a batch") else: raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") parsed_output = self.model_interface.parse_output( response, protocol=self.protocol, data=batch_data, model_name=model_name, **kwargs ) return parsed_output, batch_data
[docs] def try_set_max_batch_size(self, model_name, model_version: str = ""): """Attempt to set the max batch size for the model if it is not already set, ensuring thread safety.""" self._fetch_max_batch_size(model_name, model_version)
[docs] @traceable_func(trace_name="{stage_name}::{model_name}") def infer(self, data: dict, model_name: str, **kwargs) -> Any: """ Perform inference using the specified model and input data. Parameters ---------- data : dict The input data for inference. model_name : str The model name. kwargs : dict Additional parameters for inference. Returns ------- Any The processed inference results, coalesced in the same order as the input images. """ try: # 1. Retrieve or default to the model's maximum batch size. batch_size = self._fetch_max_batch_size(model_name) max_requested_batch_size = kwargs.pop("max_batch_size", batch_size) force_requested_batch_size = kwargs.pop("force_max_batch_size", False) max_batch_size = ( max(1, min(batch_size, max_requested_batch_size)) if not force_requested_batch_size else max_requested_batch_size ) # 2. Prepare data for inference. data = self.model_interface.prepare_data_for_inference(data) # 3. Format the input based on protocol. formatted_batches, formatted_batch_data = self.model_interface.format_input( data, protocol=self.protocol, max_batch_size=max_batch_size, model_name=model_name, **kwargs, ) # Check for a custom maximum pool worker count, and remove it from kwargs. max_pool_workers = kwargs.pop("max_pool_workers", 16) # 4. Process each batch concurrently using a thread pool. # We enumerate the batches so that we can later reassemble results in order. results = [None] * len(formatted_batches) with ThreadPoolExecutor(max_workers=max_pool_workers) as executor: future_to_idx = {} for idx, (batch, batch_data) in enumerate(zip(formatted_batches, formatted_batch_data)): future = executor.submit( self._process_batch, batch, batch_data=batch_data, model_name=model_name, **kwargs ) future_to_idx[future] = idx for future in as_completed(future_to_idx.keys()): idx = future_to_idx[future] results[idx] = future.result() # 5. Process the parsed outputs for each batch using its corresponding batch_data. # As the batches are in order, we coalesce their outputs accordingly. all_results = [] for parsed_output, batch_data in results: batch_results = self.model_interface.process_inference_results( parsed_output, original_image_shapes=batch_data.get("original_image_shapes"), protocol=self.protocol, **kwargs, ) if isinstance(batch_results, list): all_results.extend(batch_results) else: all_results.append(batch_results) except Exception as err: error_str = f"Error during NimClient inference [{self.model_interface.name()}, {self.protocol}]: {err}" logger.error(error_str) raise RuntimeError(error_str) return all_results
def _grpc_infer( self, formatted_input: Union[list, list[np.ndarray]], model_name: str, **kwargs ) -> Union[list, list[np.ndarray]]: """ Perform inference using the gRPC protocol. Parameters ---------- formatted_input : np.ndarray The input data formatted as a numpy array. model_name : str The name of the model to use for inference. Returns ------- np.ndarray The output of the model as a numpy array. """ if not isinstance(formatted_input, list): formatted_input = [formatted_input] parameters = kwargs.get("parameters", {}) output_names = kwargs.get("output_names", ["output"]) dtypes = kwargs.get("dtypes", ["FP32"]) input_names = kwargs.get("input_names", ["input"]) input_tensors = [] for input_name, input_data, dtype in zip(input_names, formatted_input, dtypes): input_tensors.append(grpcclient.InferInput(input_name, input_data.shape, datatype=dtype)) for idx, input_data in enumerate(formatted_input): input_tensors[idx].set_data_from_numpy(input_data) outputs = [grpcclient.InferRequestedOutput(output_name) for output_name in output_names] response = self.client.infer( model_name=model_name, parameters=parameters, inputs=input_tensors, outputs=outputs ) logger.debug(f"gRPC inference response: {response}") if len(outputs) == 1: return response.as_numpy(outputs[0].name()) else: return [response.as_numpy(output.name()) for output in outputs] def _http_infer(self, formatted_input: dict) -> dict: """ Perform inference using the HTTP protocol, retrying for timeouts or 5xx errors up to 5 times. Parameters ---------- formatted_input : dict The input data formatted as a dictionary. Returns ------- dict The output of the model as a dictionary. Raises ------ TimeoutError If the HTTP request times out repeatedly, up to the max retries. requests.RequestException For other HTTP-related errors that persist after max retries. """ base_delay = 2.0 attempt = 0 retries_429 = 0 while attempt < self.max_retries: try: response = requests.post( self.endpoint_url, json=formatted_input, headers=self.headers, timeout=self.timeout ) status_code = response.status_code # Check for server-side or rate-limit type errors # e.g. 5xx => server error, 429 => too many requests if status_code == 429: retries_429 += 1 logger.warning( f"Received HTTP 429 (Too Many Requests) from {self.model_interface.name()}. " f"Attempt {retries_429} of {self.max_429_retries}." ) if retries_429 >= self.max_429_retries: logger.error("Max retries for HTTP 429 exceeded.") response.raise_for_status() else: backoff_time = base_delay * (2**retries_429) time.sleep(backoff_time) continue # Retry without incrementing the main attempt counter if status_code == 503 or (500 <= status_code < 600): logger.warning( f"Received HTTP {status_code} ({response.reason}) from " f"{self.model_interface.name()}. Attempt {attempt + 1} of {self.max_retries}." ) if attempt == self.max_retries - 1: # No more retries left logger.error(f"Max retries exceeded after receiving HTTP {status_code}.") response.raise_for_status() # raise the appropriate HTTPError else: # Exponential backoff backoff_time = base_delay * (2**attempt) time.sleep(backoff_time) attempt += 1 continue else: # Not in our "retry" category => just raise_for_status or return response.raise_for_status() logger.debug(f"HTTP inference response: {response.json()}") return response.json() except requests.Timeout: # Treat timeouts similarly to 5xx => attempt a retry logger.warning( f"HTTP request timed out after {self.timeout} seconds during {self.model_interface.name()} " f"inference. Attempt {attempt + 1} of {self.max_retries}." ) if attempt == self.max_retries - 1: logger.error("Max retries exceeded after repeated timeouts.") raise TimeoutError( f"Repeated timeouts for {self.model_interface.name()} after {attempt + 1} attempts." ) # Exponential backoff backoff_time = base_delay * (2**attempt) time.sleep(backoff_time) attempt += 1 except requests.HTTPError as http_err: # If we ended up here, it's a non-retryable 4xx or final 5xx after final attempt logger.error(f"HTTP request failed with status code {response.status_code}: {http_err}") raise except requests.RequestException as e: # ConnectionError or other non-HTTPError logger.error(f"HTTP request encountered a network issue: {e}") if attempt == self.max_retries - 1: raise # Else retry on next loop iteration backoff_time = base_delay * (2**attempt) time.sleep(backoff_time) attempt += 1 # If we exit the loop without returning, we've exhausted all attempts logger.error(f"Failed to get a successful response after {self.max_retries} retries.") raise Exception(f"Failed to get a successful response after {self.max_retries} retries.")
[docs] def close(self): if self.protocol == "grpc" and hasattr(self.client, "close"): self.client.close()