Source code for nv_ingest_client.primitives.tasks.embed

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=too-few-public-methods
# pylint: disable=too-many-arguments

import logging
from typing import Any
from typing import Dict
from typing import Optional

from nv_ingest_api.internal.schemas.meta.ingest_job_schema import IngestTaskEmbedSchema

from .task_base import Task

logger = logging.getLogger(__name__)


[docs] class EmbedTask(Task): """ Object for document embedding tasks. This class encapsulates the configuration and runtime state for an embedding task, including details like the endpoint URL, model name, and API key. """ def __init__( self, endpoint_url: Optional[str] = None, model_name: Optional[str] = None, api_key: Optional[str] = None, text: Optional[bool] = None, tables: Optional[bool] = None, filter_errors: bool = False, text_elements_modality: Optional[str] = None, image_elements_modality: Optional[str] = None, structured_elements_modality: Optional[str] = None, audio_elements_modality: Optional[str] = None, ) -> None: """ Initialize the EmbedTask configuration. Parameters ---------- endpoint_url : Optional[str], optional URL of the embedding endpoint. Defaults to None. model_name : Optional[str], optional Name of the embedding model. Defaults to None. api_key : Optional[str], optional API key for the embedding service. Defaults to None. text : Optional[bool], optional Deprecated. This parameter is ignored if provided. tables : Optional[bool], optional Deprecated. This parameter is ignored if provided. filter_errors : bool, optional Flag indicating whether errors should be filtered. Defaults to False. """ super().__init__() if text is not None: logger.warning( "'text' parameter is deprecated and will be ignored. Future versions will remove this argument." ) if tables is not None: logger.warning( "'tables' parameter is deprecated and will be ignored. Future versions will remove this argument." ) # Use the API schema for validation validated_data = IngestTaskEmbedSchema( endpoint_url=endpoint_url, model_name=model_name, api_key=api_key, filter_errors=filter_errors, text_elements_modality=text_elements_modality, image_elements_modality=image_elements_modality, structured_elements_modality=structured_elements_modality, audio_elements_modality=audio_elements_modality, ) self._endpoint_url = validated_data.endpoint_url self._model_name = validated_data.model_name self._api_key = validated_data.api_key self._filter_errors = validated_data.filter_errors self._text_elements_modality = validated_data.text_elements_modality self._image_elements_modality = validated_data.image_elements_modality self._structured_elements_modality = validated_data.structured_elements_modality self._audio_elements_modality = validated_data.audio_elements_modality def __str__(self) -> str: """ Return the string representation of the EmbedTask. The string includes the endpoint URL, model name, a redacted API key, and the error filtering flag. Returns ------- str A string representation of the EmbedTask configuration. """ info: str = "Embed Task:\n" if self._endpoint_url: info += f" endpoint_url: {self._endpoint_url}\n" if self._model_name: info += f" model_name: {self._model_name}\n" if self._api_key: info += " api_key: [redacted]\n" info += f" filter_errors: {self._filter_errors}\n" if self._text_elements_modality: info += f" text_elements_modality: {self._text_elements_modality}\n" if self._image_elements_modality: info += f" image_elements_modality: {self._image_elements_modality}\n" if self._structured_elements_modality: info += f" structured_elements_modality: {self._structured_elements_modality}\n" if self._audio_elements_modality: info += f" audio_elements_modality: {self._audio_elements_modality}\n" return info
[docs] def to_dict(self) -> Dict[str, Any]: """ Convert the EmbedTask configuration to a dictionary for submission. Returns ------- Dict[str, Any] A dictionary containing the task type and properties, suitable for submission (e.g., to a Redis database). """ task_properties: Dict[str, Any] = {"filter_errors": self._filter_errors} if self._endpoint_url: task_properties["endpoint_url"] = self._endpoint_url if self._model_name: task_properties["model_name"] = self._model_name if self._api_key: task_properties["api_key"] = self._api_key if self._text_elements_modality: task_properties["text_elements_modality"] = self._text_elements_modality if self._image_elements_modality: task_properties["image_elements_modality"] = self._image_elements_modality if self._structured_elements_modality: task_properties["structured_elements_modality"] = self._structured_elements_modality if self._audio_elements_modality: task_properties["audio_elements_modality"] = self._audio_elements_modality return {"type": "embed", "task_properties": task_properties}