Source code for sdp.processors.toloka.download_responses

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

from sdp.logging import logger
from sdp.processors.base_processor import BaseParallelProcessor, DataEntry

try:
    import toloka.client
    TOLOKA_AVAILABLE = True
except ImportError:
    TOLOKA_AVAILABLE = False
    toloka = None
    



[docs] class GetTolokaResults(BaseParallelProcessor): """Fetches and stores results from a specified Toloka pool based on user-configured conditions. This class connects to Toloka, retrieves task results from a specified pool, filters them by assignment status, and stores the results in the given output directory. Args: input_data_file (str): Path to the input data file containing API configurations. input_pool_file (str): Path to the input pool file containing pool configurations. output_dir (str): Directory where the results will be stored. status (str): Status filter for assignments to retrieve (default: 'ACCEPTED'). config_file (str): Path to a configuration file. Default: None. API_KEY (str): The API key for authenticating with Toloka's API. Default: None. platform (str): The Toloka environment to use ('PRODUCTION' or 'SANDBOX'). Default: None. pool_id (str): The ID of the Toloka pool to retrieve results from. Default: None. Returns: A set of task results from Toloka, stored in the specified output directory. """ def __init__( self, input_data_file: str, input_pool_file: str, output_dir: str, status: str = "ACCEPTED", config_file: str = None, API_KEY: str = None, platform: str = None, pool_id: str = None, **kwargs ): """ Constructs the necessary attributes for the GetTolokaResults class. Parameters: ---------- input_data_file : str The path to the input data file containing API configurations. input_pool_file : str The path to the input pool file containing pool configurations. output_dir : str The directory where the output results will be stored. status : str, optional The status filter for assignments to retrieve. Defaults to 'ACCEPTED'. config_file : str, optional The path to the configuration file. Defaults to None. API_KEY : str, optional The API key used to authenticate with Toloka's API. If not provided, it is retrieved from the environment. platform : str, optional Specifies the Toloka environment (e.g., 'PRODUCTION', 'SANDBOX'). If not provided, it is retrieved from the environment. pool_id : str, optional The ID of the pool from which results will be retrieved. Defaults to None. """ super().__init__(**kwargs) self.input_data_file = input_data_file self.input_pool_file = input_pool_file self.output_dir = output_dir self.status = status self.config_file = config_file self.API_KEY = API_KEY or os.getenv('TOLOKA_API_KEY') self.platform = platform or os.getenv('TOLOKA_PLATFORM') self.pool_id = pool_id if self.config_file: self.load_config() self.toloka_available = TOLOKA_AVAILABLE def load_config(self): """ Loads configuration data from the specified config file. This method attempts to read configuration details such as API key, platform, and pool ID from a JSON file. If the file is missing or improperly formatted, an appropriate error is logged. """ try: with open(self.config_file, 'r') as file: config = json.load(file) self.API_KEY = config.get('API_KEY', self.API_KEY) self.platform = config.get('platform', self.platform) self.pool_id = config.get('pool_id', self.pool_id) except FileNotFoundError: logger.error("Configuration file not found.") except json.JSONDecodeError: logger.error("Error decoding JSON from the configuration file.") def prepare(self): """ Prepares the class by loading API configuration, pool configuration, and initializing Toloka client. This method loads necessary configurations and initializes the Toloka client to interact with Toloka's API. """ if self.toloka_available != True: logger.warning("Toloka is currently not supported. DownloadResponses processor functionality will be limited.") if not self.API_KEY or not self.platform or not self.pool_id: try: with open(self.input_data_file, 'r') as file: data = json.loads(file.readline()) self.API_KEY = data.get("API_KEY", self.API_KEY) self.platform = data.get("platform", self.platform) except FileNotFoundError: logger.error("Data file not found.") except json.JSONDecodeError: logger.error("Error decoding JSON from the data file.") try: with open(self.input_pool_file, 'r') as file: data = json.loads(file.readline()) self.pool_id = data.get("pool_id", self.pool_id) except FileNotFoundError: logger.error("Pool file not found.") except json.JSONDecodeError: logger.error("Error decoding JSON from the pool file.") self.toloka_client = toloka.client.TolokaClient(self.API_KEY, self.platform) return super().prepare() def read_manifest(self): """ Retrieves and yields task information from Toloka based on the specified pool and assignment status. This method retrieves assignments from Toloka for a given pool and yields task information for each assignment that matches the specified status. Yields: ------ dict A dictionary containing task information such as task ID, text, attachment ID, status, etc. """ for assignment in self.toloka_client.get_assignments(pool_id=self.pool_id): if str(assignment.status) == 'Status.' + self.status: # ACCEPTED, ACTIVE, EXPIRED, REJECTED, SKIPPED, SUBMITTED if ( str(assignment.status) == 'Status.ACCEPTED' or str(assignment.status) == 'Status.REJECTED' or str(assignment.status) == 'Status.SUBMITTED' ): for task, solution in zip(assignment.tasks, assignment.solutions): suit_id = assignment.task_suite_id assignment_id = assignment.id user_id = assignment.user_id task_id = task.id text = task.input_values['text'] attachment_id = solution.output_values.get('audio_file', None) status = assignment.status task_info = { 'task_id': task_id, 'text': text, 'attachment_id': attachment_id, 'status': str(status), 'suit_id': suit_id, 'assignment_id': assignment_id, 'user_id': user_id, } yield task_info else: for task in assignment.tasks: suit_id = assignment.task_suite_id assignment_id = assignment.id user_id = assignment.user_id task_id = task.id text = task.input_values['text'] attachment_id = "" status = assignment.status task_info = { 'task_id': task_id, 'text': text, 'attachment_id': attachment_id, 'status': str(status), 'suit_id': suit_id, 'assignment_id': assignment_id, 'user_id': user_id, } yield task_info def process_dataset_entry(self, data_entry): """ Downloads and processes individual task results. This method takes a data entry, retrieves the corresponding attachment, and stores it in the specified output directory. The task information is then returned. Parameters: ---------- data_entry : dict A dictionary containing the data entry information. Returns: ------- list A list containing a DataEntry object with the task information. """ user_id = data_entry["user_id"] task_id = data_entry["task_id"] text = data_entry["text"] attachment_id = data_entry["attachment_id"] status = data_entry["status"] suit_id = data_entry["suit_id"] assignment_id = data_entry["assignment_id"] output_path = os.path.join(self.output_dir, attachment_id + '.wav') os.makedirs(os.path.dirname(output_path), exist_ok=True) if attachment_id != "": with open(output_path, 'wb') as attachment_file: self.toloka_client.download_attachment(attachment_id, out=attachment_file) task_info = { 'task_id': task_id, 'text': text, 'attachment_id': attachment_id, 'status': status, 'audio_filepath': output_path, 'suit_id': suit_id, 'assignment_id': assignment_id, 'user_id': user_id, } return [DataEntry(data=task_info)]