# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import copy
import re
from datetime import datetime
from collections import defaultdict
from typing import Any, Dict, Generator, List, Optional, Union
import logging
import pandas as pd
from nv_ingest_api.internal.primitives.control_message_task import ControlMessageTask
logger = logging.getLogger(__name__)
[docs]
def remove_task_by_type(ctrl_msg, task: str):
"""
Remove a task from the control message by matching its type.
This function iterates over the tasks in the control message, and if it finds a task
whose type matches the provided task string, it removes that task (using its unique id)
and returns the task's properties.
Parameters
----------
ctrl_msg : IngestControlMessage
The control message from which to remove the task.
task : str
The task type to remove.
Returns
-------
dict
The properties of the removed task.
Raises
------
ValueError
If no task with the given type is found.
"""
task_obj = None
for t in ctrl_msg.get_tasks():
if t.type == task:
task_obj = t
break
if task_obj is None:
err_msg = f"process_control_message: Task '{task}' not found in control message."
logger.error(err_msg)
raise ValueError(err_msg)
removed_task = ctrl_msg.remove_task(task_obj.id)
return removed_task.properties
[docs]
def remove_all_tasks_by_type(ctrl_msg, task: str):
"""
Remove all tasks from the control message by matching their type.
This function iterates over the tasks in the control message, finds all tasks
whose type matches the provided task string, removes them, and returns their
properties as a list.
Parameters
----------
ctrl_msg : IngestControlMessage
The control message from which to remove the tasks.
task : str
The task type to remove.
Returns
-------
list[dict]
A list of dictionaries of properties for all removed tasks.
Raises
------
ValueError
If no tasks with the given type are found.
"""
matching_tasks = []
# Find all tasks with matching type
for t in ctrl_msg.get_tasks():
if t.type == task:
matching_tasks.append(t)
if not matching_tasks:
err_msg = f"process_control_message: No tasks of type '{task}' found in control message."
logger.error(err_msg)
raise ValueError(err_msg)
# Remove all matching tasks and collect their properties
removed_task_properties = []
for task_obj in matching_tasks:
removed_task = ctrl_msg.remove_task(task_obj.id)
removed_task_properties.append(removed_task.properties)
return removed_task_properties
[docs]
class IngestControlMessage:
"""
A control message class for ingesting tasks and managing associated metadata,
timestamps, configuration, and payload.
"""
def __init__(self):
"""
Initialize a new IngestControlMessage instance.
"""
self._tasks: Dict[str, List[ControlMessageTask]] = defaultdict(list)
self._metadata: Dict[str, Any] = {}
self._timestamps: Dict[str, datetime] = {}
self._payload: Optional[pd.DataFrame] = None
self._config: Dict[str, Any] = {}
[docs]
def add_task(self, task: ControlMessageTask):
"""
Add a task to the control message. Multiple tasks with the same ID are supported.
"""
self._tasks[task.id].append(task)
[docs]
def get_tasks(self) -> Generator[ControlMessageTask, None, None]:
"""
Return all tasks as a generator.
"""
for task_list in self._tasks.values():
yield from task_list
[docs]
def has_task(self, task_id: str) -> bool:
"""
Check if any tasks with the given ID exist.
"""
return task_id in self._tasks and len(self._tasks[task_id]) > 0
[docs]
def remove_task(self, task_id: str) -> ControlMessageTask:
"""
Remove the first task with the given ID. Warns if no task exists.
"""
if task_id in self._tasks and self._tasks[task_id]:
task = self._tasks[task_id].pop(0)
# Clean up empty lists
if not self._tasks[task_id]:
del self._tasks[task_id]
return task
else:
raise RuntimeError(f"Attempted to remove non-existent task with id: {task_id}")
[docs]
def config(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
"""
Get or update the control message configuration.
If 'config' is provided, it must be a dictionary. The configuration is updated with the
provided values. If no argument is provided, returns a copy of the current configuration.
Raises
------
ValueError
If the provided configuration is not a dictionary.
"""
if config is None:
return self._config.copy()
if not isinstance(config, dict):
raise ValueError("Configuration must be provided as a dictionary.")
self._config.update(config)
return self._config.copy()
[docs]
def copy(self) -> "IngestControlMessage":
"""
Create a deep copy of this control message.
"""
return copy.deepcopy(self)
[docs]
def filter_timestamp(self, regex_filter: str) -> Dict[str, datetime]:
"""
Retrieve timestamps whose keys match the regex filter.
"""
pattern = re.compile(regex_filter)
timestamps_snapshot = self._timestamps.copy()
return {key: ts for key, ts in timestamps_snapshot.items() if pattern.search(key)}
[docs]
def get_timestamp(self, key: str, fail_if_nonexist: bool = False) -> datetime:
"""
Retrieve a timestamp for a given key.
Raises
------
KeyError
If the key is not found and 'fail_if_nonexist' is True.
"""
if key in self._timestamps:
return self._timestamps[key]
if fail_if_nonexist:
raise KeyError(f"Timestamp for key '{key}' does not exist.")
return None
[docs]
def get_timestamps(self) -> Dict[str, datetime]:
"""
Retrieve all timestamps.
"""
return self._timestamps.copy()
[docs]
def set_timestamp(self, key: str, timestamp: Any) -> None:
"""
Set a timestamp for a given key. Accepts either a datetime object or an ISO format string.
Raises
------
ValueError
If the provided timestamp is neither a datetime object nor a valid ISO format string.
"""
if isinstance(timestamp, datetime):
self._timestamps[key] = timestamp
elif isinstance(timestamp, str):
try:
dt = datetime.fromisoformat(timestamp)
self._timestamps[key] = dt
except ValueError as e:
raise ValueError(f"Invalid timestamp format: {timestamp}") from e
else:
raise ValueError("timestamp must be a datetime object or ISO format string")
[docs]
def payload(self, payload: pd.DataFrame = None) -> pd.DataFrame:
"""
Get or set the payload DataFrame.
Raises
------
ValueError
If the provided payload is not a pandas DataFrame.
"""
if payload is None:
return self._payload
if not isinstance(payload, pd.DataFrame):
raise ValueError("Payload must be a pandas DataFrame")
self._payload = payload
return self._payload