# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any
from typing import Optional
from typing import Tuple, Union
import numpy as np
import requests
import tritonclient.grpc as grpcclient
from nv_ingest_api.internal.primitives.tracing.tagging import traceable_func
from nv_ingest_api.util.string_processing import generate_url
logger = logging.getLogger(__name__)
[docs]
class NimClient:
"""
A client for interfacing with a model inference server using gRPC or HTTP protocols.
"""
def __init__(
self,
model_interface,
protocol: str,
endpoints: Tuple[str, str],
auth_token: Optional[str] = None,
timeout: float = 120.0,
max_retries: int = 5,
max_429_retries: int = 5,
):
"""
Initialize the NimClient with the specified model interface, protocol, and server endpoints.
Parameters
----------
model_interface : ModelInterface
The model interface implementation to use.
protocol : str
The protocol to use ("grpc" or "http").
endpoints : tuple
A tuple containing the gRPC and HTTP endpoints.
auth_token : str, optional
Authorization token for HTTP requests (default: None).
timeout : float, optional
Timeout for HTTP requests in seconds (default: 30.0).
max_retries : int, optional
The maximum number of retries for non-429 server-side errors (default: 5).
max_429_retries : int, optional
The maximum number of retries specifically for 429 errors (default: 10).
Raises
------
ValueError
If an invalid protocol is specified or if required endpoints are missing.
"""
self.client = None
self.model_interface = model_interface
self.protocol = protocol.lower()
self.auth_token = auth_token
self.timeout = timeout # Timeout for HTTP requests
self.max_retries = max_retries
self.max_429_retries = max_429_retries
self._grpc_endpoint, self._http_endpoint = endpoints
self._max_batch_sizes = {}
self._lock = threading.Lock()
if self.protocol == "grpc":
if not self._grpc_endpoint:
raise ValueError("gRPC endpoint must be provided for gRPC protocol")
logger.debug(f"Creating gRPC client with {self._grpc_endpoint}")
self.client = grpcclient.InferenceServerClient(url=self._grpc_endpoint)
elif self.protocol == "http":
if not self._http_endpoint:
raise ValueError("HTTP endpoint must be provided for HTTP protocol")
logger.debug(f"Creating HTTP client with {self._http_endpoint}")
self.endpoint_url = generate_url(self._http_endpoint)
self.headers = {"accept": "application/json", "content-type": "application/json"}
if self.auth_token:
self.headers["Authorization"] = f"Bearer {self.auth_token}"
else:
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
def _fetch_max_batch_size(self, model_name, model_version: str = "") -> int:
"""Fetch the maximum batch size from the Triton model configuration in a thread-safe manner."""
if model_name == "yolox_ensemble":
model_name = "yolox"
if model_name in self._max_batch_sizes:
return self._max_batch_sizes[model_name]
with self._lock:
# Double check, just in case another thread set the value while we were waiting
if model_name in self._max_batch_sizes:
return self._max_batch_sizes[model_name]
if not self._grpc_endpoint:
self._max_batch_sizes[model_name] = 1
return 1
try:
client = self.client if self.client else grpcclient.InferenceServerClient(url=self._grpc_endpoint)
model_config = client.get_model_config(model_name=model_name, model_version=model_version)
self._max_batch_sizes[model_name] = model_config.config.max_batch_size
logger.debug(f"Max batch size for model '{model_name}': {self._max_batch_sizes[model_name]}")
except Exception as e:
self._max_batch_sizes[model_name] = 1
logger.warning(f"Failed to retrieve max batch size: {e}, defaulting to 1")
return self._max_batch_sizes[model_name]
def _process_batch(self, batch_input, *, batch_data, model_name, **kwargs):
"""
Process a single batch input for inference using its corresponding batch_data.
Parameters
----------
batch_input : Any
The input data for this batch.
batch_data : Any
The corresponding scratch-pad data for this batch as returned by format_input.
model_name : str
The model name for inference.
kwargs : dict
Additional parameters.
Returns
-------
tuple
A tuple (parsed_output, batch_data) for subsequent post-processing.
"""
if self.protocol == "grpc":
logger.debug("Performing gRPC inference for a batch...")
response = self._grpc_infer(batch_input, model_name, **kwargs)
logger.debug("gRPC inference received response for a batch")
elif self.protocol == "http":
logger.debug("Performing HTTP inference for a batch...")
response = self._http_infer(batch_input)
logger.debug("HTTP inference received response for a batch")
else:
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
parsed_output = self.model_interface.parse_output(
response, protocol=self.protocol, data=batch_data, model_name=model_name, **kwargs
)
return parsed_output, batch_data
[docs]
def try_set_max_batch_size(self, model_name, model_version: str = ""):
"""Attempt to set the max batch size for the model if it is not already set, ensuring thread safety."""
self._fetch_max_batch_size(model_name, model_version)
[docs]
@traceable_func(trace_name="{stage_name}::{model_name}")
def infer(self, data: dict, model_name: str, **kwargs) -> Any:
"""
Perform inference using the specified model and input data.
Parameters
----------
data : dict
The input data for inference.
model_name : str
The model name.
kwargs : dict
Additional parameters for inference.
Returns
-------
Any
The processed inference results, coalesced in the same order as the input images.
"""
try:
# 1. Retrieve or default to the model's maximum batch size.
batch_size = self._fetch_max_batch_size(model_name)
max_requested_batch_size = kwargs.pop("max_batch_size", batch_size)
force_requested_batch_size = kwargs.pop("force_max_batch_size", False)
max_batch_size = (
max(1, min(batch_size, max_requested_batch_size))
if not force_requested_batch_size
else max_requested_batch_size
)
# 2. Prepare data for inference.
data = self.model_interface.prepare_data_for_inference(data)
# 3. Format the input based on protocol.
formatted_batches, formatted_batch_data = self.model_interface.format_input(
data,
protocol=self.protocol,
max_batch_size=max_batch_size,
model_name=model_name,
**kwargs,
)
# Check for a custom maximum pool worker count, and remove it from kwargs.
max_pool_workers = kwargs.pop("max_pool_workers", 16)
# 4. Process each batch concurrently using a thread pool.
# We enumerate the batches so that we can later reassemble results in order.
results = [None] * len(formatted_batches)
with ThreadPoolExecutor(max_workers=max_pool_workers) as executor:
future_to_idx = {}
for idx, (batch, batch_data) in enumerate(zip(formatted_batches, formatted_batch_data)):
future = executor.submit(
self._process_batch, batch, batch_data=batch_data, model_name=model_name, **kwargs
)
future_to_idx[future] = idx
for future in as_completed(future_to_idx.keys()):
idx = future_to_idx[future]
results[idx] = future.result()
# 5. Process the parsed outputs for each batch using its corresponding batch_data.
# As the batches are in order, we coalesce their outputs accordingly.
all_results = []
for parsed_output, batch_data in results:
batch_results = self.model_interface.process_inference_results(
parsed_output,
original_image_shapes=batch_data.get("original_image_shapes"),
protocol=self.protocol,
**kwargs,
)
if isinstance(batch_results, list):
all_results.extend(batch_results)
else:
all_results.append(batch_results)
except Exception as err:
error_str = f"Error during NimClient inference [{self.model_interface.name()}, {self.protocol}]: {err}"
logger.error(error_str)
raise RuntimeError(error_str)
return all_results
def _grpc_infer(
self, formatted_input: Union[list, list[np.ndarray]], model_name: str, **kwargs
) -> Union[list, list[np.ndarray]]:
"""
Perform inference using the gRPC protocol.
Parameters
----------
formatted_input : np.ndarray
The input data formatted as a numpy array.
model_name : str
The name of the model to use for inference.
Returns
-------
np.ndarray
The output of the model as a numpy array.
"""
if not isinstance(formatted_input, list):
formatted_input = [formatted_input]
parameters = kwargs.get("parameters", {})
output_names = kwargs.get("output_names", ["output"])
dtypes = kwargs.get("dtypes", ["FP32"])
input_names = kwargs.get("input_names", ["input"])
input_tensors = []
for input_name, input_data, dtype in zip(input_names, formatted_input, dtypes):
input_tensors.append(grpcclient.InferInput(input_name, input_data.shape, datatype=dtype))
for idx, input_data in enumerate(formatted_input):
input_tensors[idx].set_data_from_numpy(input_data)
outputs = [grpcclient.InferRequestedOutput(output_name) for output_name in output_names]
response = self.client.infer(
model_name=model_name, parameters=parameters, inputs=input_tensors, outputs=outputs
)
logger.debug(f"gRPC inference response: {response}")
if len(outputs) == 1:
return response.as_numpy(outputs[0].name())
else:
return [response.as_numpy(output.name()) for output in outputs]
def _http_infer(self, formatted_input: dict) -> dict:
"""
Perform inference using the HTTP protocol, retrying for timeouts or 5xx errors up to 5 times.
Parameters
----------
formatted_input : dict
The input data formatted as a dictionary.
Returns
-------
dict
The output of the model as a dictionary.
Raises
------
TimeoutError
If the HTTP request times out repeatedly, up to the max retries.
requests.RequestException
For other HTTP-related errors that persist after max retries.
"""
base_delay = 2.0
attempt = 0
retries_429 = 0
while attempt < self.max_retries:
try:
response = requests.post(
self.endpoint_url, json=formatted_input, headers=self.headers, timeout=self.timeout
)
status_code = response.status_code
# Check for server-side or rate-limit type errors
# e.g. 5xx => server error, 429 => too many requests
if status_code == 429:
retries_429 += 1
logger.warning(
f"Received HTTP 429 (Too Many Requests) from {self.model_interface.name()}. "
f"Attempt {retries_429} of {self.max_429_retries}."
)
if retries_429 >= self.max_429_retries:
logger.error("Max retries for HTTP 429 exceeded.")
response.raise_for_status()
else:
backoff_time = base_delay * (2**retries_429)
time.sleep(backoff_time)
continue # Retry without incrementing the main attempt counter
if status_code == 503 or (500 <= status_code < 600):
logger.warning(
f"Received HTTP {status_code} ({response.reason}) from "
f"{self.model_interface.name()}. Attempt {attempt + 1} of {self.max_retries}."
)
if attempt == self.max_retries - 1:
# No more retries left
logger.error(f"Max retries exceeded after receiving HTTP {status_code}.")
response.raise_for_status() # raise the appropriate HTTPError
else:
# Exponential backoff
backoff_time = base_delay * (2**attempt)
time.sleep(backoff_time)
attempt += 1
continue
else:
# Not in our "retry" category => just raise_for_status or return
response.raise_for_status()
logger.debug(f"HTTP inference response: {response.json()}")
return response.json()
except requests.Timeout:
# Treat timeouts similarly to 5xx => attempt a retry
logger.warning(
f"HTTP request timed out after {self.timeout} seconds during {self.model_interface.name()} "
f"inference. Attempt {attempt + 1} of {self.max_retries}."
)
if attempt == self.max_retries - 1:
logger.error("Max retries exceeded after repeated timeouts.")
raise TimeoutError(
f"Repeated timeouts for {self.model_interface.name()} after {attempt + 1} attempts."
)
# Exponential backoff
backoff_time = base_delay * (2**attempt)
time.sleep(backoff_time)
attempt += 1
except requests.HTTPError as http_err:
# If we ended up here, it's a non-retryable 4xx or final 5xx after final attempt
logger.error(f"HTTP request failed with status code {response.status_code}: {http_err}")
raise
except requests.RequestException as e:
# ConnectionError or other non-HTTPError
logger.error(f"HTTP request encountered a network issue: {e}")
if attempt == self.max_retries - 1:
raise
# Else retry on next loop iteration
backoff_time = base_delay * (2**attempt)
time.sleep(backoff_time)
attempt += 1
# If we exit the loop without returning, we've exhausted all attempts
logger.error(f"Failed to get a successful response after {self.max_retries} retries.")
raise Exception(f"Failed to get a successful response after {self.max_retries} retries.")
[docs]
def close(self):
if self.protocol == "grpc" and hasattr(self.client, "close"):
self.client.close()