Source code for sdp.processors.modify_manifest.data_to_data

# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import os
import re
from typing import Dict, List, Optional
import tempfile
import shutil
import requests
import wget
import tarfile
from glob import glob
import yaml

import soundfile
import torchaudio
from docx import Document
from tqdm import tqdm
import json
import librosa
import numpy as np
from pathlib import Path

from sdp.logging import logger
from sdp.processors.base_processor import (
    BaseParallelProcessor,
    BaseProcessor,
    DataEntry,
)
from sdp.utils.common import ffmpeg_convert
from sdp.utils.edit_spaces import add_start_end_spaces, remove_extra_spaces
from sdp.utils.get_diff import get_diff_with_subs_grouped
from sdp.utils.metrics_computation import get_wer
from sdp.utils.apply_operators import evaluate_expression


[docs] class GetAudioDuration(BaseParallelProcessor): """ Processor that computes the duration of the file in ``audio_filepath_key`` (using soundfile) and saves the duration in ``duration_key``. If there is an error computing the duration, the value at ``duration_key`` will be updated with the value -1.0. Args: audio_filepath_key (str): Key to get path to wav file. duration_key (str): Key to put to audio duration. Returns: All the same fields as in the input manifest plus duration_key """ def __init__( self, audio_filepath_key: str, duration_key: str, **kwargs, ): super().__init__(**kwargs) self.audio_filepath_key = audio_filepath_key self.duration_key = duration_key def process_dataset_entry(self, data_entry): audio_filepath = data_entry[self.audio_filepath_key] try: data, samplerate = soundfile.read(audio_filepath) data_entry[self.duration_key] = data.shape[0] / samplerate except Exception as e: logger.warning(str(e) + " file: " + audio_filepath) data_entry[self.duration_key] = -1.0 return [DataEntry(data=data_entry)]
[docs] class ReadTxtLines(BaseParallelProcessor): """ The text file specified in source_filepath will be read, and each line in it will be added as a line in the output manifest, saved in the field text_key. Args: input_file_key (str): The key in the manifest containing the input txt file path . text_key (str): The key to store the read text lines in the manifest. **kwargs: Additional keyword arguments to be passed to the base class `BaseParallelProcessor`. """ def __init__( self, input_file_key: str, text_key: str, **kwargs, ): super().__init__(**kwargs) self.input_file_key = input_file_key self.text_key = text_key def process_dataset_entry(self, data_entry): fname = data_entry[self.input_file_key] data_list = [] with open(fname, "r") as f: for line in f: line = line.strip() if line: data = data_entry.copy() data[self.text_key] = line data_list.append(DataEntry(data=data)) return data_list
[docs] class CountNumWords(BaseParallelProcessor): """ A processor that counts the number of words in the `text_key` field of each dataset entry and stores the result in `num_words_key`. Before counting, the text is optionally cleaned using a custom `alphabet`: - If `alphabet` is provided, all characters not in the alphabet are replaced with whitespace. - Consecutive whitespace characters are collapsed into a single space. - The number of resulting space-separated tokens is counted as the number of words. Args: text_key (str): The key in the input data entry containing the text to be analyzed. num_words_key (str): The key under which the word count will be stored in the output entry. Defaults to "num_words". alphabet (str, optional): A string of allowed characters (e.g., lowercase letters). All characters not in this set will be replaced with whitespace before counting. If not provided, no filtering is applied. **kwargs: Additional arguments passed to the BaseParallelProcessor. Returns: A manifest where each entry is the original data entry with an added field `num_words_key` (default: `"num_words"`), indicating the number of words in the `text_key` field. """ def __init__( self, text_key: str, num_words_key: str = "num_words", alphabet: str = None, **kwargs, ): super().__init__(**kwargs) self.text_key = text_key self.num_words_key = num_words_key self.pattern = None if alphabet: self.pattern = re.compile("[^" + alphabet + "]") def process_dataset_entry(self, data_entry): text = data_entry[self.text_key] cleaned_string = text if self.pattern: cleaned_string = self.pattern.sub("", cleaned_string).strip() cleaned_string = re.sub("\\s+", " ", cleaned_string).strip() words = cleaned_string.split() num_words = len(words) data_entry[self.num_words_key] = num_words return [DataEntry(data=data_entry)]
[docs] class SplitLineBySentence(BaseParallelProcessor): """ Processor for splitting lines of text into sentences based on a specified pattern. One line containing N sentences will be transformed into N lines containing one sentence. Args: text_key (str): The field containing the text lines in the dataset. end_pattern (str): The regular expression pattern to identify sentence boundaries. **kwargs: Additional keyword arguments to be passed to the base class `BaseParallelProcessor`. """ def __init__( self, text_key: str, end_pattern: str, **kwargs, ): super().__init__(**kwargs) self.text_key = text_key self.pattern = re.compile(end_pattern) def process_dataset_entry(self, data_entry): line = data_entry[self.text_key] data_list = [] start = 0 ends = [m.start() for m in self.pattern.finditer(line)] if ends: for end in ends: sent = line[start : end + 1].strip() # if sent and sent[0].isupper(): data = data_entry.copy() data[self.text_key] = sent data_list.append(DataEntry(data=data)) start = end + 1 if start < len(line): pass else: data = data_entry.copy() data[self.text_key] = line.strip() data_list.append(DataEntry(data=data)) return data_list
[docs] class InsIfASRInsertion(BaseParallelProcessor): """Processor that adds substrings to transcription if they are present in ASR predictions. Will insert substrings into ``data[self.text_key]`` if it is present at that location in ``data[self.pred_text_key]``. It is useful if words are systematically missing from ground truth transcriptions. Args: insert_words (list[str]): list of strings that will be inserted into ``data[self.text_key]`` if there is an insertion (containing only that string) in ``data[self.pred_text_key]``. text_key (str): a string indicating which key of the data entries should be used to find the utterance transcript. Defaults to "text". pred_text_key (str): a string indicating which key of the data entries should be used to access the ASR predictions. Defaults to "pred_text". .. note:: Because this processor looks for an exact match in the insertion, we recommend including variations with different spaces in ``insert_words``, e.g. ``[' nemo', 'nemo ', ' nemo ']``. Returns: The same data as in the input manifest with ``<text_key>`` field changed. """ def __init__( self, insert_words: List[str], text_key: str = "text", pred_text_key: str = "pred_text", **kwargs, ): super().__init__(**kwargs) self.insert_words = insert_words self.text_key = text_key self.pred_text_key = pred_text_key def process_dataset_entry(self, data_entry) -> List: insert_word_counter = collections.defaultdict(int) for insert_word in self.insert_words: if not insert_word in data_entry[self.pred_text_key]: break orig_words, pred_words = ( data_entry[self.text_key], data_entry[self.pred_text_key], ) diff = get_diff_with_subs_grouped(orig_words, pred_words) if len(diff) > 0: # ie if there are differences between text and pred_text new_sent = "" for diff_entry in diff: if diff_entry[0] == 0: # no change new_sent += diff_entry[1] elif diff_entry[0] == -1: # deletion in original string new_sent += diff_entry[1] elif diff_entry[0] == 1: # insertion in original string if diff_entry[1] == insert_word: new_sent += insert_word insert_word_counter[insert_word] += 1 elif isinstance(diff_entry, tuple): # i.e. diff is a substitution new_sent += diff_entry[0][1] else: raise ValueError(f"unexpected item in diff_entry: {diff_entry}") new_sent = " ".join(new_sent.split()) # remove any extra spaces data_entry[self.text_key] = new_sent return [DataEntry(data=data_entry, metrics=insert_word_counter)] def finalize(self, metrics): total_counter = collections.defaultdict(int) for counter in metrics: for word, count in counter.items(): total_counter[word] += count logger.info("Num of words that were inserted") for word, count in total_counter.items(): logger.info(f"{word} {count}") super().finalize(metrics)
[docs] class SubIfASRSubstitution(BaseParallelProcessor): """Processor that substitutes substrings to transcription if they are present in ASR predictions. Will convert a substring in ``data[self.text_key]`` to a substring in ``data[self.pred_text_key]`` if both are located in the same place (ie are part of a 'substitution' operation) and if the substrings correspond to key-value pairs in ``sub_words``. This is useful if words are systematically incorrect in ground truth transcriptions. Before starting to look for substitution, this processor adds spaces at the beginning and end of ``data[self.text_key]`` and ``data[self.pred_text_key]``, to ensure that an argument like ``sub_words = {"nmo ": "nemo "}`` would cause a substitution to be made even if the original ``data[self.text_key]`` ends with ``"nmo"`` and ``data[self.pred_text_key]`` ends with ``"nemo"``. Args: sub_words (dict): dictionary where a key is a string that might be in ``data[self.text_key]`` and the value is the string that might be in ``data[self.pred_text_key]``. If both are located in the same place (i.e. are part of a 'substitution' operation) then the key string will be converted to the value string in ``data[self.text_key]``. text_key (str): a string indicating which key of the data entries should be used to find the utterance transcript. Defaults to "text". pred_text_key (str): a string indicating which key of the data entries should be used to access the ASR predictions. Defaults to "pred_text". .. note:: This processor looks for exact string matches of substitutions, so you may need to be careful with spaces in ``sub_words``. E.g. it is recommended to do ``sub_words = {"nmo ": "nemo "}`` instead of ``sub_words = {"nmo" : "nemo"}``. Returns: The same data as in the input manifest with ``<text_key>`` field changed. """ def __init__( self, sub_words: Dict, text_key: str = "text", pred_text_key: str = "pred_text", **kwargs, ): super().__init__(**kwargs) self.sub_words = sub_words self.text_key = text_key self.pred_text_key = pred_text_key def process_dataset_entry(self, data_entry) -> List: sub_word_counter = collections.defaultdict(int) data_entry[self.text_key] = add_start_end_spaces(data_entry[self.text_key]) data_entry[self.pred_text_key] = add_start_end_spaces(data_entry[self.pred_text_key]) for original_word, new_word in self.sub_words.items(): if not original_word in data_entry[self.text_key]: break orig_words, pred_words = ( data_entry[self.text_key], data_entry[self.pred_text_key], ) diff = get_diff_with_subs_grouped(orig_words, pred_words) if len(diff) > 0: # ie if there are differences between text and pred_text new_sent = "" for diff_entry in diff: if diff_entry[0] == 0: # no change new_sent += diff_entry[1] elif diff_entry[0] == -1: # deletion in original string new_sent += diff_entry[1] elif diff_entry[0] == 1: # insertion in original string # don't make changes pass elif isinstance(diff_entry, tuple): # substitution if diff_entry[0][1] == original_word and diff_entry[1][1] == new_word: # ie. substitution is one we want to use to change the original text new_sent += new_word sub_word_counter[original_word] += 1 else: # ie. substitution is one we want to ignore new_sent += diff_entry[0][1] else: raise ValueError(f"unexpected item in diff_entry: {diff_entry}") new_sent = add_start_end_spaces(new_sent) data_entry[self.text_key] = new_sent data_entry[self.text_key] = remove_extra_spaces(data_entry[self.text_key]) data_entry[self.pred_text_key] = remove_extra_spaces(data_entry[self.pred_text_key]) return [DataEntry(data=data_entry, metrics=sub_word_counter)] def finalize(self, metrics): total_counter = collections.defaultdict(int) for counter in metrics: for word, count in counter.items(): total_counter[word] += count logger.info("Num of words that were substituted") for word, count in total_counter.items(): logger.info(f"{word} {count}") super().finalize(metrics)
# TODO: replace with generic regex
[docs] class SubMakeLowercase(BaseParallelProcessor): """Processor to convert text to lowercase. text_key (str): a string indicating which key of the data entries should be used to find the utterance transcript. Defaults to "text". Returns: The same data as in the input manifest with ``<text_key>`` field changed. """ def __init__( self, text_key: str = "text", **kwargs, ): super().__init__(**kwargs) self.text_key = text_key def process_dataset_entry(self, data_entry) -> List: data_entry[self.text_key] = data_entry[self.text_key].lower() return [DataEntry(data=data_entry)] def finalize(self, metrics): logger.info("Made all letters lowercase") super().finalize(metrics)
[docs] class SubRegex(BaseParallelProcessor): """ Applies a sequence of regex substitutions to the specified text field in each data entry. This processor performs regex-based substitutions as defined in either a provided list of regex parameter dictionaries or a YAML configuration file. Each substitution is applied in the order specified. Before substitutions are applied, a space is temporarily added to the beginning and end of the text to improve regex match consistency. After all substitutions, leading/trailing spaces and repeated spaces are removed. Args: regex_params_list (List[Dict], optional): A list of dictionaries specifying the regex substitutions. Each dictionary must include:: - "pattern": A regex pattern to match. - "repl": A replacement string. - "count" (optional): Maximum number of replacements to make. Defaults to 0 (replace all). regex_params_yaml (str, optional): Path to a YAML file that defines the same list of dictionaries as `regex_params_list`. Either `regex_params_list` or `regex_params_yaml` must be provided. If both are provided, `regex_params_yaml` takes precedence. text_key (str): The key in each data entry whose value will be modified. Defaults to "text". **kwargs: Additional arguments passed to the BaseParallelProcessor. Example YAML format for `regex_params_yaml`: ``` # regex_params.yaml - {"pattern": "♩", "repl": " "} - {"pattern": "♭", "repl": " "} - {"pattern": "\\|", "repl": " "} - {"pattern": ":", "repl": " "} - {"pattern": "-", "repl": " "} - {"pattern": "[^ €₽₴$£%?!',.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЬЮЯабвгдежзийклмнопрстуфхцчшщъьюя]", "repl": ""} - {"pattern": "\\s+\\.", "repl": "."} - {"pattern": "\\?+", "repl": "?"} - {"pattern": "\\.+", "repl": "."} ``` Returns: The same data as in the input manifest with ``<text_key>`` field changed. """ def __init__( self, regex_params_list: List[Dict] = None, regex_params_yaml: str = None, text_key: str = "text", **kwargs, ): super().__init__(**kwargs) if not regex_params_list and not regex_params_yaml: raise ValueError(f'One of `regex_params_list` or `regex_params_yaml` should be provided.') self.regex_params_list = regex_params_list if regex_params_yaml: with open(regex_params_yaml, 'r') as regex_params_file: self.regex_params_list = yaml.safe_load(regex_params_file) self.text_key = text_key # verify all dicts in regex_params_list have "pattern" and "repl" keys for regex_params_dict in self.regex_params_list: if not "pattern" in regex_params_dict.keys(): raise ValueError( f"Need to have key 'pattern' in all entries of `regex_params_list`: {self.regex_params_list}" ) if not "repl" in regex_params_dict.keys(): raise ValueError( f"Need to have key 'repl' in all entries of `regex_params_list`: {self.regex_params_list}" )
[docs] def process_dataset_entry(self, data_entry) -> List: """Replaces each found regex match with a given string.""" replace_word_counter = collections.defaultdict(int) text_in = data_entry[self.text_key] text_in = add_start_end_spaces(text_in) for regex_params in self.regex_params_list: text_out = re.sub( pattern=regex_params["pattern"], repl=regex_params["repl"], string=text_in, # note: this count param is the maximum number of pattern occurrences to be replaced. count=regex_params.get("count", 0), ) if text_in != text_out: replace_word_counter[regex_params["pattern"]] += 1 text_in = text_out text_out = remove_extra_spaces(text_out) data_entry[self.text_key] = text_out return [DataEntry(data=data_entry, metrics=replace_word_counter)]
[docs] def finalize(self, metrics): """Reports how many substitutions were made for each pattern.""" total_counter = collections.defaultdict(int) for counter in metrics: for word, count in counter.items(): total_counter[word] += count logger.info("Number of utterances which applied substitutions for the following patterns:") total_counter_sorted = dict(sorted(total_counter.items(), key=lambda x: x[1], reverse=True)) for word, count in total_counter_sorted.items(): logger.info(f"{word} {count}") super().finalize(metrics)
[docs] class NormalizeText(BaseParallelProcessor): """This processor applies text normalization (TN) to the text. I.e. converts text from written form into its verbalized form. E.g., "$123" is converted to "one hundred and twenty-three dollars." Args: input_text_key (str): the text field that will be the input to the Normalizer. Defaults to: text. input_language (str): language specifying the text normalization rules in ISO 639 Set 1 format. E.g., "en", "es", "it", etc. Defaults to: English. input_case (str): input text capitalization, set to `cased` if text contains capital letters. This flag affects normalization rules applied to the text. Note, `lower_cased` won't lower case input. Defaults to: cased. output_text_key (str): the text field that will be the output from the Normalizer. Defaults to: text. Returns: This processor normalizes the text in the `input_text_key` field and saves the normalized text in `output_text_key` field. Raises: `NotImplementedError`: when TN is not implemented for the requested language. """ def __init__( self, input_text_key: str = "text", input_language: str = "en", input_case: str = "cased", output_text_key: str = "text", **kwargs, ): super().__init__(**kwargs) self.input_text_key = input_text_key self.output_text_key = output_text_key self.input_case = input_case self.input_language = input_language def prepare(self): from nemo_text_processing.text_normalization.normalize import Normalizer try: self.normalizer = Normalizer(input_case=self.input_case, lang=self.input_language) except NotImplementedError as e: logger.error("Failed to run text normalization: %s", repr(e)) def process_dataset_entry(self, data_entry): data_entry[self.output_text_key] = self.normalizer.normalize(data_entry[self.input_text_key]) return [DataEntry(data=data_entry)]
[docs] class InverseNormalizeText(BaseParallelProcessor): """This processor applies inverse text normalization (ITN) to the text. I.e. transforms spoken forms of numbers, dates, etc into their written equivalents. E.g., "one hundred and twenty-three dollars." is converted to "$123". Args: input_text_key (str): the text field that will be the input to the InverseNormalizer. Defaults to: text. input_language (str): language specifying the text normalization rules in ISO 639 Set 1 format. E.g., "en", "es", "it", etc. Defaults to: English. input_case (str): input text capitalization, set to `cased` if text contains capital letters. This flag affects normalization rules applied to the text. Note, `lower_cased` won't lower case input. Defaults to: cased. output_text_key (str): the text field that will be the output from the InverseNormalizer. Defaults to: text. Returns: This processor inverse normalizes the text in the `input_text_key` field and saves the inverse normalized text in `output_text_key` field. Raises: `NotImplementedError`: when ITN is not implemented for the requested language. """ def __init__( self, input_text_key: str = "text", input_language: str = "en", input_case: str = "cased", output_text_key: str = "text", verbose: bool = False, **kwargs, ): super().__init__(**kwargs) self.input_text_key = input_text_key self.output_text_key = output_text_key self.input_case = input_case self.input_language = input_language self.verbose = verbose def prepare(self): from nemo_text_processing.inverse_text_normalization.inverse_normalize import InverseNormalizer try: self.inverse_normalizer = InverseNormalizer(input_case=self.input_case, lang=self.input_language) except NotImplementedError as e: logger.error("Failed to run text inverse normalization: %s", repr(e)) def process_dataset_entry(self, data_entry): data_entry[self.output_text_key] = self.inverse_normalizer.inverse_normalize( data_entry[self.input_text_key], verbose=self.verbose ) return [DataEntry(data=data_entry)]
class CopyManifestData(BaseParallelProcessor): """This processor copies files specified in the manifest to a new location. It is useful for creating a consolidated dataset by gathering files from different sources into a single directory. Args: copy_path (str): The destination directory where files will be copied. source_filepath (str): The key in the manifest that contains the path to the file to be copied. Default: "audio_path". Returns: The same data as in the input manifest, but the files referenced in the manifest will have been copied to the specified destination directory. Example: .. code-block:: yaml - _target_: sdp.processors.modify_manifest.data_to_data.CopyManifestData input_manifest_file: ${workspace_dir}/dataset.json output_manifest_file: ${workspace_dir}/dataset_copied.json copy_path: ${workspace_dir}/consolidated_data source_filepath: "audio_filepath" """ def __init__( self, copy_path: str, source_filepath: str = "audio_path", **kwargs, ): super().__init__(**kwargs) self.input_field = source_filepath self.copy_path = copy_path def prepare(self): os.makedirs(self.copy_path, exist_ok=True) def process_dataset_entry(self, data_entry): fname = data_entry[self.input_field] dest_file_path = os.path.join(self.copy_path, os.path.basename(fname)) shutil.copy(fname, dest_file_path) data_entry[self.input_field] = dest_file_path return [DataEntry(data=data_entry)] class ReadDocxLines(BaseParallelProcessor): """ Processor for reading text lines from a docx file and updating the manifest. Args: source_filepath (str): The field containing the file path in the manifest. text_key (str): The field to store the read text lines in the manifest. **kwargs: Additional keyword arguments to be passed to the base class `BaseParallelProcessor`. """ def __init__( self, source_filepath: str, text_key: str, **kwargs, ): super().__init__(**kwargs) self.input_field = source_filepath self.output_field = text_key def process_dataset_entry(self, data_entry): fname = data_entry[self.input_field] # Skip hidden files and directories (e.g., .DS_Store, ._filename) if os.path.basename(fname).startswith('.'): logger.warning(f"Skipping hidden file: {fname}") return [] data_list = [] try: doc = Document(fname) for para in doc.paragraphs: line = para.text.strip() if line: data = data_entry.copy() data[self.output_field] = line data_list.append(DataEntry(data=data)) except Exception as e: logger.error(f"Error reading document {fname}: {e}") return data_list class ExtractFromBrackets(BaseParallelProcessor): """ A class for extracting text contained within specified bracket types from strings, handling nested brackets. Example Input: data_entry = { "text": "This is a [test] string with [multiple [nested] brackets]." } Example Output: [ { "text": "test" }, { "text": "multiple [nested] brackets" } ] Explanation: - It extracts "test" from the first occurrence of brackets. - It extracts "multiple [nested] brackets" from the second occurrence, handling nested brackets correctly. Attributes: brackets (List[str]): A list where each element is a pair of strings representing the opening and closing brackets. text_key (str): The key in the input data from which to extract text, defaults to "text". """ def __init__( self, brackets: List[str], text_key: str = "text", **kwargs, ): super().__init__(**kwargs) self.brackets = brackets self.text_key = text_key def extract_text_within_brackets(self, text, brackets): """ Extracts text within the specified brackets, including handling nested brackets. Args: text (str): The string from which to extract text. brackets (tuple[str, str]): A tuple containing the opening and closing bracket. Returns: List[str]: A list of strings, each representing a segment of text found within the outermost brackets, including any nested brackets content. """ open_bracket, close_bracket = brackets depth = 0 buffer = "" sentences = [] for char in text: if char == open_bracket: if depth > 0: buffer += char # Add to buffer if already inside brackets depth += 1 elif char == close_bracket: depth -= 1 if depth == 0: # Exiting outermost brackets if buffer: sentences.append(buffer) buffer = "" # Reset buffer for next possible extraction elif depth > 0: buffer += char # Still inside nested brackets, continue adding elif depth > 0: buffer += char # Add characters inside brackets to buffer return sentences def process_dataset_entry(self, data_entry) -> List: data: list[dict] = [] sentences = [] text_in = data_entry[self.text_key] for bracket in self.brackets: sentences.extend(self.extract_text_within_brackets(text_in, bracket)) for sentence in sentences: new_entry = data_entry.copy() new_entry[self.text_key] = sentence # new_entry["ORIGINAL TEXT"] = text_in # for testing data.append(new_entry) data_list = [] for data_point in data: data_list.append(DataEntry(data=data_point)) return data_list class GetWER(BaseParallelProcessor): """This processor calculates Word Error Rate (WER) between predicted text and ground truth text. It computes the WER for each entry in the manifest and adds the result as a new field. Args: text_key (str): Key for the ground truth text field in the manifest. Default: "text". pred_text_key (str): Key for the predicted text field in the manifest. Default: "pred_text". Returns: The same data as in the input manifest with an additional 'wer' field containing the calculated Word Error Rate between the specified text fields. """ def __init__( self, text_key: str = "text", pred_text_key: str = "pred_text", **kwargs, ): super().__init__(**kwargs) self.text_key = text_key self.pred_text_key = pred_text_key def process_dataset_entry(self, data_entry) -> List: data_entry['wer'] = get_wer(data_entry[self.text_key], data_entry[self.pred_text_key]) return [DataEntry(data=data_entry)] class MakeSentence(BaseParallelProcessor): """This processor formats text strings into proper sentences. It capitalizes the first character of the text (if enabled) and appends an end symbol if the text does not already end with punctuation. Args: text_key (str): The key in the manifest containing the text to be processed. Default: "text". end_symbol (str): The punctuation symbol to add at the end of the text if it doesn't already have one. Default: ":". make_uppercase (bool): Whether to capitalize the first character of the text. Default: True. Returns: The same data as in the input manifest with the text field modified to have proper sentence formatting. Example: .. code-block:: yaml - _target_: sdp.processors.modify_manifest.data_to_data.MakeSentence input_manifest_file: ${workspace_dir}/dataset.json output_manifest_file: ${workspace_dir}/dataset_formatted.json text_key: "transcript" end_symbol: "." make_uppercase: true """ def __init__( self, text_key: str = "text", end_symbol: str = ":", make_uppercase: bool = True, **kwargs, ): super().__init__(**kwargs) self.make_uppercase = make_uppercase self.text_key = text_key self.end_symbol = end_symbol def process_dataset_entry(self, data_entry) -> List: if self.make_uppercase: data_entry[self.text_key] = data_entry[self.text_key][0].upper() + data_entry[self.text_key][1:] # Append end_symbol only if the text doesn't end with punctuation if data_entry[self.text_key][-1].isalpha(): data_entry[self.text_key] += self.end_symbol return [DataEntry(data=data_entry)] class ASRFileCheck(BaseProcessor): """This processor validates audio files in the manifest and identifies corrupted files. It attempts to load each audio file using the torchaudio library and moves corrupted files to a specified directory. Args: audio_filepath_key (str): The key in the manifest that contains the path to the audio file. Default: "audio_filepath". corrupted_audio_dir (str): The directory where corrupted audio files will be moved. workspace_dir (str, optional): The base directory for resolving relative paths. Default: None. Returns: A manifest with corrupted audio files removed. """ def __init__(self, audio_filepath_key: str = "audio_filepath", corrupted_audio_dir: str = None, workspace_dir: str = None, **kwargs): """ Constructs the necessary attributes for the ASRFileCheck class. Parameters: ---------- audio_filepath_key : str, optional The key in the manifest entries used to retrieve the path to the audio file. Defaults to 'audio_filepath'. corrupted_audio_dir : str The directory where corrupted audio files will be moved. This is required. workspace_dir : str, optional The base directory where audio files are stored. If provided, audio file paths will be resolved relative to this directory. Defaults to None. """ super().__init__(**kwargs) self.audio_filepath_key = audio_filepath_key if corrupted_audio_dir is None: raise ValueError("corrupted_audio_dir parameter is required. Please specify a directory to move corrupted files.") self.corrupted_audio_dir = corrupted_audio_dir self.workspace_dir = workspace_dir self.failed_files = [] def process(self): """ Check each file listed in the manifest to ensure it can be loaded with torchaudio. This method reads through the manifest file, attempts to load each audio file using torchaudio, and moves corrupted files. A new manifest file is created with only the valid entries. Specific errors handled: - FileNotFoundError: File doesn't exist - RuntimeError: File format issues or codec problems - Other exceptions: General issues with file loading """ from sdp.logging import logger # Debug print to show workspace_dir logger.info(f"ASRFileCheck workspace_dir: {self.workspace_dir}") with open(self.input_manifest_file, 'r') as f: lines = f.readlines() entries = [] total_lines = len(lines) # Ensure the corrupted files directory exists os.makedirs(self.corrupted_audio_dir, exist_ok=True) for idx in tqdm(range(total_lines), desc="Checking Audio Files"): line = lines[idx] entry = json.loads(line) audio_path = entry[self.audio_filepath_key] # Debug print first file path if idx == 0: logger.info(f"First audio_path from manifest: {audio_path}") # If workspace_dir is provided, join it with audio_path to get absolute path if self.workspace_dir is not None: full_audio_path = os.path.join(self.workspace_dir, audio_path) else: full_audio_path = audio_path # Debug print first full path if idx == 0: logger.info(f"First full_audio_path: {full_audio_path}") logger.info(f"Path exists: {os.path.exists(full_audio_path)}") try: # Attempt to load the audio file to check if it is corrupted torchaudio.load(full_audio_path) entries.append(entry) # File is good, append to entries list except FileNotFoundError: logger.warning(f"File not found: {full_audio_path}") self.failed_files.append(audio_path) except RuntimeError as e: logger.warning(f"Audio format error in {audio_path}: {e}") self.failed_files.append(audio_path) # Move the corrupted audio file if os.path.exists(full_audio_path): dest_path = os.path.join(self.corrupted_audio_dir, os.path.basename(audio_path)) os.rename(full_audio_path, dest_path) logger.info(f"Moved corrupted file to: {dest_path}") except Exception as e: logger.warning(f"Unknown error loading {audio_path}: {e}") self.failed_files.append(audio_path) # Move the corrupted audio file if os.path.exists(full_audio_path): dest_path = os.path.join(self.corrupted_audio_dir, os.path.basename(audio_path)) os.rename(full_audio_path, dest_path) logger.info(f"Moved corrupted file to: {dest_path}") # Output non-corrupted entries to a new manifest file with open(self.output_manifest_file, 'w', encoding='utf-8') as f_out: for entry in entries: json.dump(entry, f_out, ensure_ascii=False) f_out.write("\n") if self.failed_files: logger.warning(f"Failed to process {len(self.failed_files)} files.") logger.debug(f"Failed files: {self.failed_files}")
[docs] class ListToEntries(BaseParallelProcessor): """ A dataset processor that transforms a single entry containing a list of items into multiple entries, one for each item in the list. This is useful when a manifest field (e.g., "segments") contains a list of sub-entries, and you want to flatten these into individual records for further processing. Args: field_with_list (str): The name of the field in the input entry that contains a list. output_field (str, optional): The name of the output field to assign to items in the list if they are not dictionaries. Required if the list contains primitive types (e.g., strings). **kwargs: Additional arguments passed to the BaseParallelProcessor. Raises: TypeError: If the specified list field is not of type list. ValueError: If the list items are not dictionaries and `output_field` is not provided. Returns: A manifest where each entry corresponds to one item in the original list from the input entry. This effectively transforms a single input entry containing a list of items into multiple standalone entries, each suitable for further dataset processing. .. admonition:: Example 1 (list of dicts) .. code-block:: yaml - _target_: sdp.processors.ListToEntries input_manifest_file: ${workspace_dir}/input_manifest.json output_manifest_file: ${workspace_dir}/output_manifest.json field_with_list: "segments" Input:: { "audio_filepath": "sample.wav", "segments": [ {"start": 0.0, "end": 1.5, "text": "Hello"}, {"start": 1.6, "end": 3.0, "text": "World"} ] } Output:: [ { "audio_filepath": "sample.wav", "start": 0.0, "end": 1.5, "text": "Hello" }, { "audio_filepath": "sample.wav", "start": 1.6, "end": 3.0, "text": "World" } ] .. admonition:: Example 2 (list of primitives) .. code-block:: yaml - _target_: sdp.processors.ListToEntries input_manifest_file: ${workspace_dir}/input_manifest.json output_manifest_file: ${workspace_dir}/output_manifest.json field_with_list: "text_chunks" output_field: "text" Input:: { "audio_filepath": "sample.wav", "text_chunks": [ "Hello", "World" ] } Output:: [ { "audio_filepath": "sample.wav", "text": "Hello" }, { "audio_filepath": "sample.wav", "text": "World" } ] """ def __init__(self, field_with_list: str, output_field: str = None, **kwargs): super().__init__(**kwargs) self.field_with_list = field_with_list self.output_field = output_field def process_dataset_entry(self, data_entry): _entries = [] # Check that the target field is actually a list if not isinstance(data_entry[self.field_with_list], list): raise TypeError(f'Values of {self.field_with_list} field should be list type only: {data_entry}') # Remove the list field from the entry and get the list of items items_list = data_entry.pop(self.field_with_list) # If items are not dicts, output_field must be specified to store the item if not isinstance(items_list[0], dict) and not self.output_field: raise ValueError(f'Type of items in items list `{self.field_with_list}` is not dict ({type(items_list[0])}). In this case `output_field` should be provided.') # Expand the list into multiple entries for item in items_list: _entry = data_entry.copy() # If item is a dict, merge its keys; otherwise, store it in `output_field` if isinstance(item, dict): _entry.update(item) else: _entry[self.output_field] = item _entry = DataEntry(_entry) _entries.append(_entry) return _entries
[docs] class LambdaExpression(BaseParallelProcessor): """ A dataset processor that evaluates a Python expression on each data entry and either stores the result in a new field or uses it as a filtering condition. This processor is useful for dynamic field computation or conditional filtering of entries based on configurable expressions. It leverages ``evaluate_expression``, which safely evaluates expressions using the abstract syntax tree (AST). Filtering behavior: If ``filter=True``, the expression is evaluated for each entry. Only entries for which the expression evaluates to ``True`` are kept; all others are filtered out (removed from the output). If ``filter=False``, the result of the expression is stored in the field specified by ``new_field`` for each entry (no filtering occurs). Examples:: # Example 1: Filtering entries where the duration is greater than 5.0 seconds LambdaExpression( new_field="keep", # This field is ignored when filter=True expression="entry['duration'] > 5.0", lambda_param_name="entry", filter=True ) # Only entries with duration > 5.0 will be kept in the output manifest. # Example 2: Adding a new field with the number of words in the text LambdaExpression( new_field="num_words", expression="len(entry['text'].split())", lambda_param_name="entry", filter=False ) # Each entry will have a new field 'num_words' with the word count of the 'text' field. Supported operations: The expression supports a safe subset of Python operations, including: - Arithmetic: ``+``, ``-``, ``*``, ``/``, ``//``, ``%``, ``**`` - Comparisons: ``==``, ``!=``, ``<``, ``<=``, ``>``, ``>=``, ``is``, ``is not`` - Logical: ``and``, ``or``, ``not`` - Bitwise: ``|``, ``&``, ``^``, ``~``, ``<<``, ``>>`` - Indexing and slicing: ``entry['key']``, ``entry[0]``, ``entry[1:3]`` - Conditional (ternary) expressions: ``a if cond else b`` - List and dict literals: ``[a, b]``, ``{k: v}`` - Attribute access: ``entry.attr`` - Function calls (limited): ``max``, ``min``, ``len``, ``sum``, ``abs``, ``sorted`` For the full list, see the ``OPERATORS`` and ``SAFE_FUNCTIONS`` in :mod:`sdp.utils.apply_operators`. See also: https://docs.python.org/3/library/operator.html Args: new_field (str): The name of the field to store the result of the expression (ignored if filter=True). expression (str): A Python expression to evaluate. It can reference fields of the data entry using the name specified by ``lambda_param_name`` (default: 'entry'). lambda_param_name (str, optional): The name to refer to the current data entry in the expression. Default is "entry". filter (bool, optional): If True, the expression result is treated as a condition. The entry is kept only if the result is ``True``. Default is ``False``. **kwargs: Additional keyword arguments passed to the ``BaseParallelProcessor`` class. Returns: str: A line-delimited JSON manifest, where each line is a processed entry. The result may contain fewer entries than the input if ``filter=True``. """ def __init__( self, new_field: str, expression: str, lambda_param_name: str = "entry", filter: bool = False, **kwargs, ): super().__init__(**kwargs) self.new_field = new_field self.expression = expression self.lambda_param_name = lambda_param_name self.filter = filter def process_dataset_entry(self, data_entry) -> List[DataEntry]: """ Process a single data entry by evaluating the expression. If `filter` is True, the entry is only retained if the expression evaluates to True. Otherwise, the result is stored in `new_field`. """ value = evaluate_expression(self.expression, data_entry, self.lambda_param_name) if self.filter: if value is not True: return [] data_entry[self.new_field] = value return [DataEntry(data=data_entry)] def finalize(self, metrics): super().finalize(metrics)
[docs] class EstimateBandwidth(BaseParallelProcessor): """ Adds estimated bandwidth to each utterance in the input manifest file. Args: audio_dir (str): Root directory where audio files are stored. input_audio_key (str): Manifest key with relative audio paths. output_bandwidth_key (str): Manifest key to store estimated bandwidth in. max_seconds (float): The maximum length of audio to use for bandwidth estimation. By default, uses the first 30 seconds. sample_rate (int): Sample rate to resample audio to before doing bandwidth estimation. Defaults to 44100, upsampling the input audio as needed. n_fft (int): Number of FFT bins to use for bandwidth estimation. Defaults to 512. hop_length (int): Audio frame hop length to use for bandwidth estimation. Defaults to 441, corresponding to 0.01 seconds for 44100 sample rate. top_db (float): top_db treshhold to use for bandwidth estimation. frequency_threshold (float): Bandwidth estimation finds the highest frequency with mean power spectrum that is within 'frequency_threshold' dB of its peak power. Defaults to -50 dB. Returns: This processor estimates the bandwidth of the audio file in the`input_audio_key` field and saves the estimate in the output_bandwidth_key` field. Example: .. code-block:: yaml - _target_: sdp.processors.EstimateBandwidth input_manifest_file: ${workspace_dir}/manifest.json output_manifest_file: ${workspace_dir}/manifest_bandwidth.json audio_dir: ${workspace_dir}/audio_22khz max_workers: 8 """ def __init__( self, audio_dir: str, input_audio_key: str = "audio_filepath", output_bandwidth_key: str = "bandwidth", max_seconds: float = 30.0, sample_rate: int = 44100, n_fft: int = 512, hop_length: int = 441, top_db: float = 100.0, frequency_threshold: float = -50.0, **kwargs, ): super().__init__(**kwargs) self.audio_directory = Path(audio_dir) self.input_audio_key = input_audio_key self.output_bandwidth_key = output_bandwidth_key self.max_seconds = max_seconds self.sample_rate = sample_rate self.n_fft = n_fft self.hop_length = hop_length self.top_db = top_db self.frequency_threshold = frequency_threshold def _estimate_bandwidth(self, audio, sample_rate): spec = librosa.stft(y=audio, n_fft=self.n_fft, hop_length=self.hop_length, window="blackmanharris") power_spec = np.abs(spec) ** 2 power_spec = np.mean(power_spec, axis=1) power_spec = librosa.power_to_db(power_spec, ref=self.n_fft, top_db=self.top_db) bandwidth = 0 peak = np.max(power_spec) freq_width = sample_rate / self.n_fft for idx in range(len(power_spec) - 1, -1, -1): if power_spec[idx] - peak > self.frequency_threshold: bandwidth = idx * freq_width break return bandwidth def process_dataset_entry(self, data_entry): audio_filename = data_entry[self.input_audio_key] audio_file = self.audio_directory / audio_filename audio, sr = librosa.load(path=audio_file, sr=self.sample_rate, duration=self.max_seconds) bandwidth = self._estimate_bandwidth(audio=audio, sample_rate=sr) data_entry[self.output_bandwidth_key] = int(bandwidth) return [DataEntry(data=data_entry)]
[docs] class CharacterHistogramLangValidator(BaseParallelProcessor): """ A processor that filters text based on character histogram similarity to trusted data in the target language. This processor computes the ratio of characters in a given text that are found in a reference character histogram for a specific language. If this ratio is below a certain threshold, the text is likely mislabeled or noisy. Histograms are sourced from the NLLB paper (https://arxiv.org/pdf/2207.04672), see page 30 for methodology. This technique is a lightweight language ID filter, designed to catch mismatches between text content and claimed language. Reference implementation: https://github.com/facebookresearch/fairseq/blob/main/examples/m2m_100/process_data/clean_histogram.py Args: text_field (str): Key in the data entry containing the text to evaluate. lang_field (str, optional): Key in the data entry that identifies the language. Required if `lang` is not provided. lang (str, optional): Language code to use for all entries (overrides `lang_field`). Required if `lang_field` is not provided. threshold (float): Threshold ratio to determine if text matches the histogram. Used only externally (not enforced in this processor). cache_dir (str, optional): Directory where histograms are downloaded and cached. threshold_char (str): Character used to truncate the histogram file (default is ']'). output_score_field (str): Key name under which the computed character match ratio will be stored. **kwargs: Additional keyword arguments passed to `BaseParallelProcessor`. Raises: ValueError: If both `lang` and `lang_field` are provided, or if neither is provided. Also raised if histogram for specified language is missing. Returns: A manifest where each entry includes the additional field `output_score_field` with the character match ratio. Example:: { "text": "hello world", "lang": "en", "hist_token_ratio": 0.95 } """ HISTOGRAMS_URL = 'https://dl.fbaipublicfiles.com/m2m_100/histograms.tar.gz' def __init__(self, text_field: str, lang_field: str = None, lang: str = None, threshold: float = 0.8, cache_dir: str = None, threshold_char: str = "]", output_score_field: str = "hist_token_ratio", **kwargs): super().__init__(**kwargs) self.text_field = text_field # Ensure exactly one of `lang` or `lang_field` is provided if lang_field is None and lang is None: raise ValueError("One of the arguments `lang` or `lang_field` must be provided.") if lang_field is not None and lang is not None: raise ValueError( f"Both `lang` ({lang}) and `lang_field` ({lang_field}) are provided, which makes the source of language ambiguous. Please provide only one of them." ) self.lang_field = lang_field self.lang = lang self.threshold = threshold self.cache_dir = cache_dir self.threshold_char = threshold_char self.output_score_field = output_score_field self.histograms = dict() def _read_hist(self, lang: str): """ Read and parse the histogram file for a given language, stopping at the threshold character. """ hist_file = os.path.join(self.cache_dir, lang) chars = [] with open(hist_file) as hist: for line in hist: char = line[0] chars.append(char) if char == self.threshold_char: break self.histograms[lang] = set(chars) def _download_histograms(self): """ Download and extract histogram files into the cache directory. """ logger.info('Downloading histograms collection..') response = requests.get(self.HISTOGRAMS_URL) if response.status_code != 200: raise requests.exceptions.RequestException( f"Failed to download model file. Status code: {response.status_code}" ) if self.cache_dir is None: self.cache_dir = tempfile.mkdtemp() os.makedirs(self.cache_dir, exist_ok=True) histograms_tarfile = wget.download(self.HISTOGRAMS_URL, out=self.cache_dir) with tarfile.open(histograms_tarfile, "r:gz") as tar: tar.extractall(path=self.cache_dir) # Flatten subdirectories into the main cache_dir histograms_filepaths = glob(f'{self.cache_dir}/checkpoint/edunov/cc60_multilingual/clean_hists/*') for histogram_filepath in histograms_filepaths: shutil.move(histogram_filepath, os.path.join(self.cache_dir, os.path.basename(histogram_filepath))) os.remove(histograms_tarfile) shutil.rmtree(f'{self.cache_dir}/checkpoint/edunov/cc60_multilingual/clean_hists/') logger.info(f'Histograms have been downloaded to {self.cache_dir}.') def prepare(self): """ Ensure histograms are available and read them into memory. """ if (self.cache_dir is None or not os.path.exists(self.cache_dir) or not os.path.isdir(self.cache_dir) or len(os.listdir(self.cache_dir)) == 0): self._download_histograms() logger.info('Reading histograms...') available_langs = os.listdir(self.cache_dir) if self.lang is not None: if self.lang in available_langs: self._read_hist(self.lang) else: raise ValueError(f"Invalid value for `lang`: {self.lang}. Please provide one of the following: {available_langs}") logger.info(f'Histogram for `{self.lang}` has been read.') else: for lang in tqdm(available_langs): self._read_hist(lang) logger.info(f'Histograms have been read.') def process_dataset_entry(self, data_entry): """ Compute and attach the character histogram match ratio for a given text entry. Args: data_entry (dict): A dictionary containing at least `text_field` and either `lang_field` or a preset `lang`. Returns: List[DataEntry]: A list with one updated `DataEntry` including the character match ratio field. """ # Determine language for this entry lang = self.lang if self.lang is not None else data_entry[self.lang_field] if lang not in self.histograms: raise ValueError(f'lang `{lang}` is not supported.') # Compute how many characters match the histogram text = data_entry[self.text_field].strip() cnt = len([c for c in text if c in self.histograms[lang]]) token_ratio = cnt / len(text) if len(text) > 0 else 0.0 # Store the ratio in the data entry data_entry[self.output_score_field] = token_ratio return [DataEntry(data=data_entry)]