Source code for sdp.processors.toloka.create_task_set
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from typing import List, Optional
from sdp.logging import logger
from sdp.processors.base_processor import BaseParallelProcessor, DataEntry
try:
import toloka.client
import toloka.client.project.template_builder
TOLOKA_AVAILABLE = True
except ImportError:
TOLOKA_AVAILABLE = False
toloka = None
[docs]
class CreateTolokaTaskSet(BaseParallelProcessor):
"""Creates a set of tasks in a Toloka pool based on user-provided configurations and input data.
This class reads data from a manifest file, loads the target pool configuration,
and uses Toloka's API to create and upload tasks into the specified pool.
Args:
input_data_file (str): Path to the input data file containing API configurations.
input_pool_file (str): Path to the input pool file containing pool configurations.
limit (float): Percentage of tasks to load from the manifest file. Default: 100.
Returns:
A set of tasks created and uploaded to the specified Toloka pool.
"""
def __init__(
self,
input_data_file: str,
input_pool_file: str,
limit: float = 100,
**kwargs,
):
super().__init__(**kwargs)
self.input_data_file = input_data_file
self.input_pool_file = input_pool_file
self.limit = limit
self.pool_id = None
self.toloka_available = TOLOKA_AVAILABLE
# Get API key and platform from environment variables
self.API_KEY = os.getenv('TOLOKA_API_KEY')
if not self.API_KEY:
raise ValueError("TOLOKA_API_KEY environment variable is not set")
self.platform = os.getenv('TOLOKA_PLATFORM')
if not self.platform:
raise ValueError("TOLOKA_PLATFORM environment variable is not set")
self.toloka_client = None
def prepare(self):
"""
Prepares the class by loading pool configuration and initializing Toloka client.
This method sets up the necessary components for task creation, including loading the
pool configuration and initializing the Toloka client.
"""
if self.toloka_available != True:
logger.warning("Toloka is currently not supported. CreateTaskSet processor functionality will be limited.")
self.load_pool_config()
self.toloka_client = toloka.client.TolokaClient(self.API_KEY, self.platform)
def load_pool_config(self):
"""
Loads pool configuration data from the input pool file.
This method reads the pool configuration from the specified file and extracts the
pool ID for use in task creation.
Raises:
------
ValueError
If the input pool file does not contain a pool ID.
"""
try:
with open(self.input_pool_file, 'r') as file:
pool_config = json.load(file)
self.pool_id = pool_config.get('pool_id')
if not self.pool_id:
raise ValueError("No pool ID found in the pool configuration file.")
except FileNotFoundError:
raise ValueError(f"Pool configuration file {self.input_pool_file} not found.")
except json.JSONDecodeError:
raise ValueError(f"Error decoding JSON from the pool configuration file {self.input_pool_file}.")
def read_manifest(self) -> List[dict]:
"""
Reads and returns a portion of the manifest data from the input manifest file based on the specified limit.
This method reads the input manifest file, calculates the number of entries to read based on the
specified limit, and returns a list of those entries.
Returns:
-------
List[dict]
A list of manifest data entries that have been read.
"""
logger.info("Reading manifest...")
with open(self.input_manifest_file, "rt") as fin:
total_lines = sum(1 for _ in fin)
lines_to_read = max(1, int(total_lines * (self.limit / 100)))
fin.seek(0)
entries = [json.loads(fin.readline()) for _ in range(lines_to_read)]
return entries
def process(self):
"""
Creates Toloka tasks based on manifest data and adds them to the specified pool.
This method reads the input manifest, creates task objects for Toloka, and submits
them to the specified pool. It also writes the manifest data to an output file after
tasks have been created.
Raises:
------
ValueError
If no pool ID is available or if there are issues with the Toloka API.
"""
logger.info("Processing tasks...")
self.prepare()
if not self.pool_id:
raise ValueError("No pool ID available. Cannot create tasks.")
entries = self.read_manifest()
tasks = [
toloka.client.Task(input_values={'text': data_entry["text"]}, pool_id=self.pool_id)
for data_entry in entries
]
try:
self.toloka_client.create_tasks(tasks, allow_defaults=True)
logger.info(f"Created {len(tasks)} tasks.")
except Exception as e:
logger.error(f"Error creating tasks: {e}")
raise ValueError(f"Failed to create tasks: {e}")
# Write the manifest data to the output file
with open(self.output_manifest_file, "wt", encoding='utf-8') as fout:
for entry in entries:
fout.write(json.dumps(entry, ensure_ascii=False) + "\n")