Source code for nv_ingest_client.cli.util.click

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import json
import logging
import os
import random
from enum import Enum
from pprint import pprint
from typing import Union, List, Any, Dict

import click

from nv_ingest_api.internal.enums.common import PipelinePhase
from nv_ingest_api.util.introspection.function_inspect import infer_udf_function_name
from nv_ingest_client.util.processing import check_schema
from nv_ingest_client.primitives.tasks import CaptionTask
from nv_ingest_client.primitives.tasks import DedupTask
from nv_ingest_client.primitives.tasks import EmbedTask
from nv_ingest_client.primitives.tasks import ExtractTask
from nv_ingest_client.primitives.tasks import FilterTask
from nv_ingest_client.primitives.tasks import InfographicExtractionTask
from nv_ingest_client.primitives.tasks import SplitTask
from nv_ingest_client.primitives.tasks import StoreEmbedTask
from nv_ingest_client.primitives.tasks import StoreTask
from nv_ingest_client.primitives.tasks import UDFTask
from nv_ingest_api.internal.schemas.meta.ingest_job_schema import IngestTaskCaptionSchema
from nv_ingest_api.internal.schemas.meta.ingest_job_schema import IngestTaskDedupSchema
from nv_ingest_api.internal.schemas.meta.ingest_job_schema import IngestTaskEmbedSchema
from nv_ingest_api.internal.schemas.meta.ingest_job_schema import IngestTaskExtractSchema
from nv_ingest_api.internal.schemas.meta.ingest_job_schema import IngestTaskFilterSchema
from nv_ingest_api.internal.schemas.meta.ingest_job_schema import IngestTaskInfographicExtraction
from nv_ingest_api.internal.schemas.meta.ingest_job_schema import IngestTaskSplitSchema
from nv_ingest_api.internal.schemas.meta.ingest_job_schema import IngestTaskStoreEmbedSchema
from nv_ingest_api.internal.schemas.meta.ingest_job_schema import IngestTaskStoreSchema
from nv_ingest_api.internal.schemas.meta.ingest_job_schema import IngestTaskUDFSchema
from nv_ingest_client.util.util import generate_matching_files

logger = logging.getLogger(__name__)


[docs] class LogLevel(str, Enum): """ Enum for specifying logging levels. Attributes ---------- DEBUG : str Debug logging level. INFO : str Informational logging level. WARNING : str Warning logging level. ERROR : str Error logging level. CRITICAL : str Critical logging level. """ DEBUG = "DEBUG" INFO = "INFO" WARNING = "WARNING" ERROR = "ERROR" CRITICAL = "CRITICAL"
[docs] class ClientType(str, Enum): """ Enum for specifying client types. Attributes ---------- REST : str Represents a REST client. REDIS : str Represents a Redis client. KAFKA : str Represents a Kafka client. """ REST = "REST" REDIS = "REDIS" KAFKA = "KAFKA"
[docs] def debug_print_click_options(ctx: click.Context) -> None: """ Retrieves all options from the Click context and pretty prints them. Parameters ---------- ctx : click.Context The Click context object from which to retrieve the command options. """ click_options: Dict[str, Any] = {} for param in ctx.command.params: if isinstance(param, click.Option): value = ctx.params[param.name] click_options[param.name] = value pprint(click_options)
[docs] def click_validate_file_exists( ctx: click.Context, param: click.Parameter, value: Union[str, List[str], None] ) -> List[str]: """ Validates that the given file(s) exist. Parameters ---------- ctx : click.Context The Click context. param : click.Parameter The parameter associated with the file option. value : Union[str, List[str], None] A file path or a list of file paths. Returns ------- List[str] A list of validated file paths. Raises ------ click.BadParameter If any file path does not exist. """ if not value: return [] if isinstance(value, str): value = [value] else: value = list(value) for filepath in value: if not os.path.exists(filepath): raise click.BadParameter(f"File does not exist: {filepath}") return value
# Define a union type for all supported task types. TaskType = Union[ CaptionTask, DedupTask, EmbedTask, ExtractTask, FilterTask, InfographicExtractionTask, SplitTask, StoreEmbedTask, StoreTask, UDFTask, ]
[docs] def parse_task_options(task_id: str, options_str: str) -> Dict[str, Any]: """ Parse the task options string as JSON. Parameters ---------- task_id : str The identifier of the task for which options are being parsed. options_str : str The string containing JSON options. Returns ------- Dict[str, Any] The parsed options as a dictionary. Raises ------ ValueError If the JSON string is not well formatted. The error message will indicate the task, the error details (e.g., expected property format), and show the input that was provided. """ try: options = json.loads(options_str) # Convert string boolean values to actual booleans for extract tasks if task_id == "extract": boolean_fields = [ "extract_text", "extract_images", "extract_tables", "extract_charts", "extract_infographics", "extract_page_as_image", ] for field in boolean_fields: if field in options: value = options[field] if isinstance(value, str): if value.lower() in ("true", "1", "yes", "on"): options[field] = True elif value.lower() in ("false", "0", "no", "off"): options[field] = False else: raise ValueError( f"Invalid boolean value for {field}: '{value}'. Use true/false, 1/0, yes/no, or on/off." ) return options except json.JSONDecodeError as e: error_message = ( f"Invalid JSON format for task '{task_id}': {e.msg} at line {e.lineno} column {e.colno} (char {e.pos}). " f"Input was: {options_str}" ) raise ValueError(error_message)
[docs] def click_validate_task(ctx: click.Context, param: click.Parameter, value: List[str]) -> Dict[str, TaskType]: """ Validates and processes task definitions provided as strings. Each task definition should be in the format "<task_id>:<json_options>". If the separator ':' is missing, an empty JSON options dictionary is assumed. The function uses a schema check (via check_schema) for validation and instantiates the corresponding task. Parameters ---------- ctx : click.Context The Click context. param : click.Parameter The parameter associated with the task option. value : List[str] A list of task strings to validate. Returns ------- Dict[str, TaskType] A dictionary mapping task IDs to their corresponding task objects. Raises ------ click.BadParameter If any task fails validation (including malformed JSON) or if duplicate tasks are detected. """ validated_tasks: Dict[str, TaskType] = {} validation_errors: List[str] = [] for task_str in value: task_split = task_str.split(":", 1) if len(task_split) != 2: task_id, json_options = task_str, "{}" else: task_id, json_options = task_split try: options: Dict[str, Any] = parse_task_options(task_id, json_options) if task_id == "split": task_options = check_schema(IngestTaskSplitSchema, options, task_id, json_options) new_task_id = f"{task_id}" new_task = [(new_task_id, SplitTask(**task_options.model_dump()))] elif task_id == "extract": # Map CLI parameters to API schema structure method = options.pop("extract_method", None) if method is None: method = "pdfium" # Default fallback # Build params dict for API schema params = {k: v for k, v in options.items() if k != "document_type"} # Validate with API schema api_options = { "document_type": options.get("document_type"), "method": method, "params": params, } task_options = check_schema(IngestTaskExtractSchema, api_options, task_id, json_options) new_task_id = f"{task_id}_{task_options.document_type.value}" # Create ExtractTask with original CLI parameters extract_task_params = { "document_type": task_options.document_type, "extract_method": task_options.method, **task_options.params, } # Start with the main extract task new_task = [(new_task_id, ExtractTask(**extract_task_params))] # Add ChartExtractionTask if extract_charts is True if task_options.params.get("extract_charts", False): from nv_ingest_client.primitives.tasks import ChartExtractionTask chart_task_id = "chart_data_extract" chart_params = {"params": {}} # ChartExtractionTask takes params dict new_task.append((chart_task_id, ChartExtractionTask(chart_params))) # Add TableExtractionTask if extract_tables is True if task_options.params.get("extract_tables", False): from nv_ingest_client.primitives.tasks import TableExtractionTask table_task_id = "table_data_extract" new_task.append((table_task_id, TableExtractionTask())) elif task_id == "store": task_options = check_schema(IngestTaskStoreSchema, options, task_id, json_options) new_task_id = f"{task_id}" new_task = [(new_task_id, StoreTask(**task_options.model_dump()))] elif task_id == "store_embedding": task_options = check_schema(IngestTaskStoreEmbedSchema, options, task_id, json_options) new_task_id = f"{task_id}" new_task = [(new_task_id, StoreEmbedTask(**task_options.model_dump()))] elif task_id == "caption": task_options = check_schema(IngestTaskCaptionSchema, options, task_id, json_options) new_task_id = f"{task_id}" # Extract individual parameters from API schema for CaptionTask constructor caption_params = { "api_key": task_options.api_key, "endpoint_url": task_options.endpoint_url, "prompt": task_options.prompt, "model_name": task_options.model_name, } new_task = [(new_task_id, CaptionTask(**caption_params))] elif task_id == "dedup": task_options = check_schema(IngestTaskDedupSchema, options, task_id, json_options) new_task_id = f"{task_id}" # Extract individual parameters from API schema for DedupTask constructor dedup_params = { "content_type": task_options.content_type, "filter": task_options.params.filter, } new_task = [(new_task_id, DedupTask(**dedup_params))] elif task_id == "filter": task_options = check_schema(IngestTaskFilterSchema, options, task_id, json_options) new_task_id = f"{task_id}" # Extract individual parameters from API schema for FilterTask constructor filter_params = { "content_type": task_options.content_type, "min_size": task_options.params.min_size, "max_aspect_ratio": task_options.params.max_aspect_ratio, "min_aspect_ratio": task_options.params.min_aspect_ratio, "filter": task_options.params.filter, } new_task = [(new_task_id, FilterTask(**filter_params))] elif task_id == "embed": task_options = check_schema(IngestTaskEmbedSchema, options, task_id, json_options) new_task_id = f"{task_id}" new_task = [(new_task_id, EmbedTask(**task_options.model_dump()))] elif task_id == "infographic": task_options = check_schema(IngestTaskInfographicExtraction, options, task_id, json_options) new_task_id = f"{task_id}" new_task = [(new_task_id, InfographicExtractionTask(**task_options.model_dump()))] elif task_id == "udf": # Validate mutual exclusivity of target_stage and phase has_target_stage = "target_stage" in options and options["target_stage"] is not None has_phase = "phase" in options and options["phase"] is not None if has_target_stage and has_phase: raise ValueError( "UDF task cannot specify both 'target_stage' and 'phase'. Please specify only one." ) elif not has_target_stage and not has_phase: raise ValueError("UDF task must specify either 'target_stage' or 'phase'.") # Pre-process UDF task options to convert phase names to integers if "phase" in options and isinstance(options["phase"], str): # Convert phase string to integer using the same logic as UDFTask phase_str = options["phase"].upper() phase_aliases = { "PRE_PROCESSING": PipelinePhase.PRE_PROCESSING, "PREPROCESSING": PipelinePhase.PRE_PROCESSING, "PRE": PipelinePhase.PRE_PROCESSING, "EXTRACTION": PipelinePhase.EXTRACTION, "EXTRACT": PipelinePhase.EXTRACTION, "POST_PROCESSING": PipelinePhase.POST_PROCESSING, "POSTPROCESSING": PipelinePhase.POST_PROCESSING, "POST": PipelinePhase.POST_PROCESSING, "MUTATION": PipelinePhase.MUTATION, "MUTATE": PipelinePhase.MUTATION, "TRANSFORM": PipelinePhase.TRANSFORM, "RESPONSE": PipelinePhase.RESPONSE, "RESP": PipelinePhase.RESPONSE, } if phase_str in phase_aliases: options["phase"] = phase_aliases[phase_str].value else: raise ValueError(f"Invalid phase name: {options['phase']}") # Try to infer udf_function_name if not provided if "udf_function_name" not in options or not options["udf_function_name"]: udf_function = options.get("udf_function", "") if udf_function: inferred_name = infer_udf_function_name(udf_function) if inferred_name: options["udf_function_name"] = inferred_name logger.info(f"Inferred UDF function name: {inferred_name}") else: raise ValueError( f"Could not infer UDF function name from '{udf_function}'. " "Please specify 'udf_function_name' explicitly." ) task_options = check_schema(IngestTaskUDFSchema, options, task_id, json_options) new_task_id = f"{task_id}" new_task = [(new_task_id, UDFTask(**task_options.model_dump()))] else: raise ValueError(f"Unsupported task type: {task_id}") # Check for duplicate tasks - now allowing multiple tasks of the same type if new_task_id in validated_tasks: logger.debug(f"Multiple tasks detected for {new_task_id}, storing as list") logger.debug("Adding task: %s", new_task_id) for task_tuple in new_task: if task_tuple[0] in validated_tasks: # Convert single task to list if needed, then append existing_task = validated_tasks[task_tuple[0]] if not isinstance(existing_task, list): validated_tasks[task_tuple[0]] = [existing_task] validated_tasks[task_tuple[0]].append(task_tuple[1]) else: validated_tasks[task_tuple[0]] = task_tuple[1] except ValueError as e: validation_errors.append(str(e)) if validation_errors: error_message = "\n".join(validation_errors) raise click.BadParameter(error_message) return validated_tasks
[docs] def click_validate_batch_size(ctx: click.Context, param: click.Parameter, value: int) -> int: """ Validates that the batch size is at least 1. Parameters ---------- ctx : click.Context The Click context. param : click.Parameter The parameter associated with the batch size option. value : int The batch size value provided. Returns ------- int The validated batch size. Raises ------ click.BadParameter If the batch size is less than 1. """ if value < 1: raise click.BadParameter("Batch size must be >= 1.") return value
[docs] def pre_process_dataset(dataset_json: str, shuffle_dataset: bool) -> List[str]: """ Loads a dataset from a JSON file and optionally shuffles the list of files. Parameters ---------- dataset_json : str The path to the dataset JSON file. shuffle_dataset : bool Whether to shuffle the dataset before processing. Returns ------- List[str] The list of file paths from the dataset. If 'shuffle_dataset' is True, the list will be shuffled. Raises ------ click.BadParameter If the dataset file is not found or if its contents are not valid JSON. """ try: with open(dataset_json, "r") as f: file_source = json.load(f) except FileNotFoundError: raise click.BadParameter(f"Dataset JSON file not found: {dataset_json}") except json.JSONDecodeError: raise click.BadParameter(f"Invalid JSON format in file: {dataset_json}") file_source = file_source.get("sampled_files", []) if shuffle_dataset: random.shuffle(file_source) return file_source
[docs] def click_match_and_validate_files(ctx: click.Context, param: click.Parameter, value: List[str]) -> List[str]: """ Matches and validates files based on the provided file source patterns. Parameters ---------- ctx : click.Context The Click context. param : click.Parameter The parameter associated with the file matching option. value : List[str] A list of file source patterns to match against. Returns ------- List[str] A list of matching file paths. If no files match, an empty list is returned. """ if not value: return [] matching_files = list(generate_matching_files(value)) if not matching_files: logger.warning("No files found matching the specified patterns.") return [] return matching_files