Source code for sdp.processors.toloka.accept_if

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from collections import defaultdict
from typing import Optional

from sdp.logging import logger
from sdp.processors.base_processor import BaseParallelProcessor, DataEntry

try:
    import toloka.client
    import toloka.client.project.template_builder
    TOLOKA_AVAILABLE = True
except ImportError:
    TOLOKA_AVAILABLE = False
    toloka = None
    

from tqdm import tqdm


[docs] class AcceptIfWERLess(BaseParallelProcessor): """This processor accepts Toloka assignments if the Word Error Rate (WER) is below a threshold. It evaluates the WER between ground truth and predicted text for each assignment and accepts those that meet the specified threshold criteria. Args: input_data_file (str): Path to the input data file containing API configurations. input_pool_file (str): Path to the input pool file containing pool configurations. threshold (float): The WER threshold below which assignments are accepted. Default: 75. config_file (str): Path to the configuration file. Default: None. API_KEY (str): The API key for authenticating with Toloka's API. Default: None. platform (str): The Toloka platform to use. Default: None. pool_id (str): The ID of the Toloka pool. Default: None. Returns: A manifest with accepted assignments from Toloka based on the WER threshold. Example: .. code-block:: yaml - _target_: sdp.processors.toloka.accept_if.AcceptIfWERLess input_manifest_file: ${workspace_dir}/result_manifest_pred_clean.json output_manifest_file: ${workspace_dir}/result_manifest_pred_review.json input_data_file: ${workspace_dir}/data_file.json input_pool_file: ${workspace_dir}/taskpool.json threshold: 50 """ def __init__( self, input_data_file: str, input_pool_file: str, threshold: float = 75, config_file: str = None, API_KEY: str = None, platform: str = None, pool_id: str = None, **kwargs, ): super().__init__(**kwargs) self.input_data_file = input_data_file self.input_pool_file = input_pool_file self.threshold = threshold self.config_file = config_file self.API_KEY = API_KEY or os.getenv('TOLOKA_API_KEY') self.platform = platform or os.getenv('TOLOKA_PLATFORM') self.pool_id = pool_id if self.config_file: self.load_config() self.toloka_available = TOLOKA_AVAILABLE def load_config(self): """ Loads configuration data from the specified config file. This method attempts to read configuration details such as API key, platform, and pool ID from a JSON file. If the file is missing or improperly formatted, an appropriate error is logged. """ try: with open(self.config_file, 'r') as file: config = json.load(file) self.API_KEY = config.get('API_KEY', self.API_KEY) self.platform = config.get('platform', self.platform) self.pool_id = config.get('pool_id', self.pool_id) except FileNotFoundError: logger.error("Configuration file not found.") except json.JSONDecodeError: logger.error("Error decoding JSON from the configuration file.") def prepare(self): """ Prepares the class by loading API configuration, pool configuration, and initializing Toloka client. This method loads necessary configurations and initializes the Toloka client to interact with Toloka's API. """ if self.toloka_available != True: logger.warning("Toloka is currently not supported. AcceptIf processor functionality will be limited.") if not self.API_KEY or not self.platform or not self.pool_id: try: with open(self.input_data_file, 'r') as file: data = json.loads(file.readline()) self.API_KEY = data.get("API_KEY", self.API_KEY) self.platform = data.get("platform", self.platform) except FileNotFoundError: logger.error("Data file not found.") except json.JSONDecodeError: logger.error("Error decoding JSON from the data file.") try: with open(self.input_pool_file, 'r') as file: data = json.loads(file.readline()) self.pool_id = data.get("pool_id", self.pool_id) except FileNotFoundError: logger.error("Pool file not found.") except json.JSONDecodeError: logger.error("Error decoding JSON from the pool file.") self.toloka_client = toloka.client.TolokaClient(self.API_KEY, self.platform) def process(self): """ Accepts Toloka assignments if their Word Error Rate (WER) is below the specified threshold. This method reads assignments from the manifest file, evaluates the WER, and accepts assignments that meet the acceptance criteria. """ big_dict = defaultdict(int) self.prepare() with open(self.input_manifest_file, 'r') as file: for line in file: data_entry = json.loads(line) if data_entry["wer"] < self.threshold: if str(data_entry["status"]) == "Status.SUBMITTED": big_dict[data_entry["assignment_id"]] += 1 accepted = 0 for assignment_id, count in tqdm(big_dict.items()): if count >= 3: # should be >= 3 and <= 5 self.toloka_client.accept_assignment(assignment_id=assignment_id, public_comment='Well done!') accepted += 1 logger.info(f"Number of accepted task suits: {accepted} of {len(big_dict)}")