# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import re
import os
import json
from operator import eq, ge, gt, le, lt, ne
from typing import List, Union
from sdp.logging import logger
from sdp.processors.base_processor import BaseParallelProcessor, DataEntry
from sdp.utils.edit_spaces import add_start_end_spaces, remove_extra_spaces
from sdp.utils.get_diff import get_diff, get_diff_with_subs_grouped
from sdp.utils.metrics_computation import (
get_cer,
get_charrate,
get_wer,
get_wmr,
get_wordrate,
)
[docs]
class PreserveByValue(BaseParallelProcessor):
"""
Processor for preserving dataset entries based on a specified condition involving a target value and an input field.
Args:
input_value_key (str): The field in the dataset entries to be evaluated.
target_value (Union[int, str]): The value to compare with the input field.
operator (str): (Optional) The operator to apply for comparison. Options: "lt" (less than), "le" (less than or equal to), "eq" (equal to), "ne" (not equal to), "ge" (greater than or equal to), "gt" (greater than). Defaults to "eq".
**kwargs: Additional keyword arguments to be passed to the base class `BaseParallelProcessor`.
"""
def __init__(
self,
input_value_key: str,
target_value: Union[int, str],
operator: str = "eq",
**kwargs,
):
super().__init__(**kwargs)
self.input_value_key = input_value_key
self.target_value = target_value
if operator == "lt":
self.operator = lt
elif operator == "le":
self.operator = le
elif operator == "eq":
self.operator = eq
elif operator == "ne":
self.operator = ne
elif operator == "ge":
self.operator = ge
elif operator == "gt":
self.operator = gt
else:
raise ValueError(
'Operator must be one from the list: "lt" (less than), "le" (less than or equal to), "eq" (equal to), "ne" (not equal to), "ge" (greater than or equal to), "gt" (greater than)'
)
def process_dataset_entry(self, data_entry):
input_value = data_entry[self.input_value_key]
target = self.target_value
if self.operator(input_value, target):
return [DataEntry(data=data_entry)]
else:
return [DataEntry(data=None)]
[docs]
class DropHighLowCharrate(BaseParallelProcessor):
"""Drops utterances if their character rate is too low or too high.
Character rate = ``(num of characters in self.text_key) / (duration of audio)``.
A too-low or too-high character rate often implies that the ground
truth transcription might be inaccurate.
Args:
high_charrate_threshold (float): upper character rate threshold.
If the character rate of an utterance is higher than this number,
the utterance will be dropped.
low_charrate_threshold (float): lower character rate threshold.
If the character rate of an utterance is lower than this number,
the utterance will be dropped.
text_key (str): a string indicating which key of the data entries
should be used to find the utterance transcript. Defaults to "text".
Returns:
The same data as in the input manifest with some entries dropped.
"""
def __init__(
self,
high_charrate_threshold: float,
low_charrate_threshold: float,
text_key: str = "text",
**kwargs,
):
super().__init__(**kwargs)
self.high_charrate_threshold = high_charrate_threshold
self.low_charrate_threshold = low_charrate_threshold
self.text_key = text_key
[docs]
def process_dataset_entry(self, data_entry) -> List:
"""Drops utterances based on the provided thresholds."""
charrate = get_charrate(data_entry[self.text_key], data_entry["duration"])
if charrate > self.high_charrate_threshold:
return [DataEntry(data=None, metrics=(0, 1))]
elif charrate < self.low_charrate_threshold:
return [DataEntry(data=None, metrics=(1, 0))]
return [DataEntry(data=data_entry, metrics=(0, 0))]
[docs]
def finalize(self, metrics):
"""Will report how many utterances were dropped for each threshold."""
high_drop_counter = 0
low_drop_counter = 0
for dropped_low, dropped_high in metrics:
low_drop_counter += dropped_low
high_drop_counter += dropped_high
logger.info(
"Num of utterances that were dropped due to char rate > %f: %d",
self.high_charrate_threshold,
high_drop_counter,
)
logger.info(
"Num of utterances that were dropped due to char rate < %f: %d",
self.low_charrate_threshold,
low_drop_counter,
)
super().finalize(metrics)
[docs]
class DropHighLowWordrate(BaseParallelProcessor):
"""Drops utterances if their word rate is too low or too high.
Word rate = ``(num of words in self.text_key) / (duration of audio)``.
A too-low or too-high word rate often implies that the ground
truth transcription might be inaccurate.
Args:
high_wordrate_threshold (float): upper word rate threshold.
If the word rate of an utterance is higher than this number,
the utterance will be dropped.
low_wordrate_threshold (float): lower word rate threshold.
If the word rate of an utterance is lower than this number,
the utterance will be dropped.
text_key (str): a string indicating which key of the data entries
should be used to find the utterance transcript. Defaults to "text".
Returns:
The same data as in the input manifest with some entries dropped.
"""
def __init__(
self,
high_wordrate_threshold: float,
low_wordrate_threshold: float,
text_key: str = "text",
**kwargs,
):
super().__init__(**kwargs)
self.high_wordrate_threshold = high_wordrate_threshold
self.low_wordrate_threshold = low_wordrate_threshold
self.text_key = text_key
def process_dataset_entry(self, data_entry) -> List:
wordrate = get_wordrate(data_entry[self.text_key], data_entry["duration"])
if wordrate > self.high_wordrate_threshold:
return [DataEntry(data=None, metrics=(0, 1))]
elif wordrate < self.low_wordrate_threshold:
return [DataEntry(data=None, metrics=(1, 0))]
return [DataEntry(data=data_entry, metrics=(0, 0))]
def finalize(self, metrics):
high_drop_counter = 0
low_drop_counter = 0
for dropped_low, dropped_high in metrics:
low_drop_counter += dropped_low
high_drop_counter += dropped_high
logger.info(
"Num of utterances that were dropped due to word rate > %f: %d",
self.high_wordrate_threshold,
high_drop_counter,
)
logger.info(
"Num of utterances that were dropped due to word rate < %f: %d",
self.low_wordrate_threshold,
low_drop_counter,
)
super().finalize(metrics)
[docs]
class DropHighLowDuration(BaseParallelProcessor):
"""Drops utterances if their duration is too low or too high.
Args:
high_duration_threshold (float): upper duration threshold (in seconds).
If the duration of an utterance's audio is higher than this number,
the utterance will be dropped.
low_duration_threshold (float): lower duration threshold (in seconds).
If the duration of an utterance's audio is lower than this number,
the utterance will be dropped.
duration_key (str): a string indicating which key of the data entries
should be used to find the utterance duration. Defaults to "duration".
Returns:
The same data as in the input manifest with some entries dropped.
"""
def __init__(
self,
high_duration_threshold: float,
low_duration_threshold: float,
duration_key: str = "duration",
**kwargs,
):
super().__init__(**kwargs)
self.high_duration_threshold = high_duration_threshold
self.low_duration_threshold = low_duration_threshold
self.high_drop_counter = 0
self.low_drop_counter = 0
self.duration_key = duration_key
def process_dataset_entry(self, data_entry) -> List:
duration = data_entry[self.duration_key]
if duration > self.high_duration_threshold:
return [DataEntry(data=None, metrics=(0, 1))]
elif duration < self.low_duration_threshold:
return [DataEntry(data=None, metrics=(1, 0))]
return [DataEntry(data=data_entry, metrics=(0, 0))]
def finalize(self, metrics):
high_drop_counter = 0
low_drop_counter = 0
for dropped_low, dropped_high in metrics:
low_drop_counter += dropped_low
high_drop_counter += dropped_high
logger.info(
"Num of utterances that were dropped due to duration > %f: %d",
self.high_duration_threshold,
high_drop_counter,
)
logger.info(
"Num of utterances that were dropped due to duration < %f: %d",
self.low_duration_threshold,
low_drop_counter,
)
super().finalize(metrics)
[docs]
class DropIfNoneOfRegexMatch(BaseParallelProcessor):
"""Drops utterances if ``data[self.text_key]`` does not match any of ``regex_patterns``.
Before applying regex checks, we will add a space
character to the beginning and end of the ``text`` and ``pred_text``
keys for each data entry. After the the regex checks, assuming the utterance isn't dropped,
the extra spaces are removed. This includes the spaces in the beginning
and end of the text, as well as any double spaces ``" "``.
Args:
regex_patterns (list[str]): If ``data_entry[self.text_key]`` does not
match any of the regex patterns in the list, that utterance
will be dropped.
text_key (str): a string indicating which key of the data entries
should be used to find the utterance transcript. Defaults to "text".
Returns:
The same data as in the input manifest with some entries dropped.
"""
def __init__(
self,
regex_patterns: List[str],
text_key: str = "text",
**kwargs,
):
super().__init__(**kwargs)
self.regex_patterns = regex_patterns
self.text_key = text_key
def process_dataset_entry(self, data_entry) -> List:
data_entry[self.text_key] = add_start_end_spaces(data_entry[self.text_key])
for regex_pattern in self.regex_patterns:
if re.search(regex_pattern, data_entry[self.text_key]):
break
else: # will only reach this if none of the regex match
return [DataEntry(data=None, metrics=1)]
# will reach this part of code if at least one of the regexes matches
data_entry[self.text_key] = remove_extra_spaces(data_entry[self.text_key])
return [DataEntry(data=data_entry, metrics=0)]
def finalize(self, metrics):
total_counter = 0
for value in metrics:
if value:
total_counter += value
logger.info("Num of utterances that were dropped due to not containing any of the specified regex patterns")
logger.info(f"{total_counter}")
super().finalize(metrics)
[docs]
class DropNonAlphabet(BaseParallelProcessor):
"""Drops utterances if they contain characters that are not in the ``alphabet``.
Args:
alphabet (str): a string containing all of the characters in our alphabet.
If an utterance contains at least one character that is not in the
``alphabet``, then that utterance will be dropped.
text_key (str): a string indicating which key of the data entries
should be used to find the utterance transcript. Defaults to "text".
.. note::
Don't forget to include spaces in your alphabet, unless you
want to make sure none of the utterances contain spaces.
Returns:
The same data as in the input manifest with some entries dropped.
"""
def __init__(
self,
alphabet: str,
text_key: str = "text",
**kwargs,
):
super().__init__(**kwargs)
self.alphabet = alphabet
self.text_key = text_key
def process_dataset_entry(self, data_entry) -> List:
drop_this_utt = False
non_alphabet_counter = collections.defaultdict(int)
for char in data_entry[self.text_key]:
if char not in self.alphabet:
drop_this_utt = True
non_alphabet_counter[char] += 1
if drop_this_utt:
return [DataEntry(data=None, metrics=non_alphabet_counter)]
return [DataEntry(data=data_entry, metrics=non_alphabet_counter)]
def finalize(self, metrics):
total_counter = collections.defaultdict(int)
for counter in metrics:
for char, value in counter.items():
total_counter[char] += value
logger.info("Num of non-alphabet characters")
for char, count in total_counter.items():
logger.info(f"{char}: {count}")
super().finalize(metrics)
[docs]
class DropASRErrorBeginningEnd(BaseParallelProcessor):
"""Drops utterances if there is a sufficiently long ASR mismatch
at the beginning or end of the utterance.
Args:
beginning_error_char_threshold (int): if there is an insertion or deletion at
the beginning of the utterance that has more characters than this number,
then the utterance will be dropped.
If there is a substitution at the beginning of the utterance, then the
utterance will be dropped if
``abs(len(deletion) - len(insertion)) > beginning_error_char_threshold``.
end_error_char_threshold (int): if there is an insertion or deletion at
the end of the utterance that has more characters than this number,
then the utterance will be dropped.
If there is a substitution at the end of the utterance, then the
utterance will be dropped if
``abs(len(deletion) - len(insertion)) > end_error_char_threshold``.
text_key (str): a string indicating which key of the data entries
should be used to find the utterance transcript. Defaults to "text".
pred_text_key (str): a string indicating which key of the data entries
should be used to access the ASR predictions. Defaults to "pred_text".
Returns:
The same data as in the input manifest with some entries dropped.
"""
def __init__(
self,
beginning_error_char_threshold: int,
end_error_char_threshold: int,
text_key: str = "text",
pred_text_key: str = "pred_text",
**kwargs,
):
super().__init__(**kwargs)
self.beginning_error_char_threshold = beginning_error_char_threshold
self.end_error_char_threshold = end_error_char_threshold
self.text_key = text_key
self.pred_text_key = pred_text_key
def process_dataset_entry(self, data_entry) -> List:
orig_words, pred_words = data_entry[self.text_key], data_entry[self.pred_text_key]
diff = get_diff_with_subs_grouped(orig_words, pred_words)
if len(diff) > 0: # i.e. if there are differences between text and pred_text
first_diff_entry = diff[0]
if first_diff_entry[0] == 1 or first_diff_entry[0] == -1: # i.e. diff is purely an insertion or deletion
if len(first_diff_entry[1]) > self.beginning_error_char_threshold:
return [DataEntry(data=None, metrics=(1, 0))]
elif first_diff_entry[0] != 0: # i.e. diff should be a tuple representing substitution
len_deletion = len(first_diff_entry[0][1])
len_insertion = len(first_diff_entry[1][1])
if abs(len_deletion - len_insertion) > self.beginning_error_char_threshold:
return [DataEntry(data=None, metrics=(1, 0))]
last_diff_entry = diff[-1]
if last_diff_entry[0] == 1 or last_diff_entry[0] == -1: # i.e. diff is purely an insertion or deletion
if len(last_diff_entry[1]) > self.end_error_char_threshold:
return [DataEntry(data=None, metrics=(0, 1))]
elif last_diff_entry[0] != 0: # i.e. diff should be a tuple representing substitution
len_deletion = len(last_diff_entry[0][1])
len_insertion = len(last_diff_entry[1][1])
if abs(len_deletion - len_insertion) > self.end_error_char_threshold:
return [DataEntry(data=None, metrics=(0, 1))]
return [DataEntry(data=data_entry, metrics=(0, 0))]
def finalize(self, metrics):
beginning_drop_counter = 0
end_drop_counter = 0
for dropped_beginning, dropped_end in metrics:
beginning_drop_counter += dropped_beginning
end_drop_counter += dropped_end
logger.info(
"Num of utterances that were dropped due to asr insertions/deletions at the beginning: %d",
beginning_drop_counter,
)
logger.info(
"Num of utterances that were dropped due to asr insertions/deletions at the end: %d",
end_drop_counter,
)
super().finalize(metrics)
# TODO: needs unification with above class in some way
[docs]
class DropASRError(BaseParallelProcessor):
"""Drops utterances if there is a sufficiently long ASR mismatch anywhere in the utterance.
Args:
consecutive_words_threshold (int): will drop if there is a mismatch of
at least this many words in a row.
text_key (str): a string indicating which key of the data entries
should be used to find the utterance transcript. Defaults to "text".
pred_text_key (str): a string indicating which key of the data entries
should be used to access the ASR predictions. Defaults to "pred_text".
Returns:
The same data as in the input manifest with some entries dropped.
"""
def __init__(
self,
consecutive_words_threshold: int,
text_key: str = "text",
pred_text_key: str = "pred_text",
**kwargs,
):
super().__init__(**kwargs)
self.consecutive_words_threshold = consecutive_words_threshold
self.text_key = text_key
self.pred_text_key = pred_text_key
def process_dataset_entry(self, data_entry) -> List:
orig_words, pred_words = data_entry[self.text_key], data_entry[self.pred_text_key]
diffs = get_diff(orig_words, pred_words)
for diff_entry in diffs:
if diff_entry[0] == 0:
continue
if len(diff_entry[1].split()) >= self.consecutive_words_threshold:
return []
return [DataEntry(data=data_entry)]
[docs]
class DropHighCER(BaseParallelProcessor):
"""Drops utterances if there is a sufficiently high character-error-rate (CER).
CER is measured between ``data[self.text_key]`` and ``data[self.pred_text_key]``.
.. note::
We only drop the utterance if ``CER > threshold`` (i.e. strictly greater
than) so that if we set the threshold to 0, we will not remove
utterances with ``CER == 0``.
Args:
cer_threshold (float): CER threshold above which the utterance will be dropped.
text_key (str): a string indicating which key of the data entries
should be used to find the utterance transcript. Defaults to "text".
pred_text_key (str): a string indicating which key of the data entries
should be used to access the ASR predictions. Defaults to "pred_text".
Returns:
The same data as in the input manifest with some entries dropped.
"""
def __init__(
self,
cer_threshold: float,
text_key: str = "text",
pred_text_key: str = "pred_text",
**kwargs,
):
super().__init__(**kwargs)
self.cer_threshold = cer_threshold
self.text_key = text_key
self.pred_text_key = pred_text_key
def process_dataset_entry(self, data_entry) -> List:
cer = get_cer(data_entry[self.text_key], data_entry[self.pred_text_key])
if cer > self.cer_threshold:
return [DataEntry(data=None, metrics=1)]
else:
return [DataEntry(data=data_entry, metrics=0)]
def finalize(self, metrics):
drop_counter = 0
for dropped in metrics:
drop_counter += dropped
logger.info(
"Num of utterances that were dropped due to CER > %d: %d",
self.cer_threshold,
drop_counter,
)
super().finalize(metrics)
[docs]
class DropHighWER(BaseParallelProcessor):
"""Drops utterances if there is a sufficiently high word-error-rate (WER).
WER is measured between ``data[self.text_key]`` and ``data[self.pred_text_key]``.
.. note::
We only drop the utterance if ``WER > threshold`` (i.e. strictly greater
than) so that if we set the threshold to 0, we will not remove
utterances with ``WER == 0``.
Args:
wer_threshold (float): WER threshold above which the utterance will be dropped.
text_key (str): a string indicating which key of the data entries
should be used to find the utterance transcript. Defaults to "text".
pred_text_key (str): a string indicating which key of the data entries
should be used to access the ASR predictions. Defaults to "pred_text".
Returns:
The same data as in the input manifest with some entries dropped.
"""
def __init__(
self,
wer_threshold: float,
text_key: str = "text",
pred_text_key: str = "pred_text",
**kwargs,
):
super().__init__(**kwargs)
self.wer_threshold = wer_threshold
self.text_key = text_key
self.pred_text_key = pred_text_key
def process_dataset_entry(self, data_entry) -> List:
wer = get_wer(data_entry[self.text_key], data_entry[self.pred_text_key])
if wer > self.wer_threshold:
return [DataEntry(data=None, metrics=1)]
else:
return [DataEntry(data=data_entry, metrics=0)]
def finalize(self, metrics):
drop_counter = 0
for dropped in metrics:
drop_counter += dropped
logger.info(
"Num of utterances that were dropped due to WER > %d: %d",
self.wer_threshold,
drop_counter,
)
super().finalize(metrics)
[docs]
class DropLowWordMatchRate(BaseParallelProcessor):
"""Drops utterances if there is a sufficiently low word-match-rate (WMR).
WMR is measured between ``data[self.text_key]`` and ``data[self.pred_text_key]``.
.. note::
We only drop the utterance if ``WMR < threshold`` (i.e. strictly lower
than) so that if we set the threshold to 100, we will not remove
utterances with ``WMR == 100``.
Args:
wmr_threshold (float): WMR threshold below which the utterance will be dropped.
text_key (str): a string indicating which key of the data entries
should be used to find the utterance transcript. Defaults to "text".
pred_text_key (str): a string indicating which key of the data entries
should be used to access the ASR predictions. Defaults to "pred_text".
Returns:
The same data as in the input manifest with some entries dropped.
"""
def __init__(
self,
wmr_threshold: float,
text_key: str = "text",
pred_text_key: str = "pred_text",
**kwargs,
):
super().__init__(**kwargs)
self.wmr_threshold = wmr_threshold
self.text_key = text_key
self.pred_text_key = pred_text_key
def process_dataset_entry(self, data_entry) -> List:
orig_words, pred_words = data_entry[self.text_key], data_entry[self.pred_text_key]
wmr = get_wmr(orig_words, pred_words)
if wmr < self.wmr_threshold:
return [DataEntry(data=None, metrics=1)]
else:
return [DataEntry(data=data_entry, metrics=0)]
def finalize(self, metrics):
drop_counter = 0
for dropped in metrics:
drop_counter += dropped
logger.info(
"Num of utterances that were dropped due to WMR < %d: %d",
self.wmr_threshold,
drop_counter,
)
super().finalize(metrics)
[docs]
class DropIfRegexMatch(BaseParallelProcessor):
"""Drops utterances if text matches a regex pattern.
Before applying regex checks, we will add a space
character to the beginning and end of the ``text`` and ``pred_text``
keys for each data entry. After the the regex checks, assuming the utterance isn't dropped,
the extra spaces are removed. This includes the spaces in the beginning
and end of the text, as well as any double spaces ``" "``.
Args:
regex_patterns (list[str]): a list of strings. The list will be
traversed in order. If ``data_entry.data[self.text_key]`` matches
the regex, the entry will be dropped.
text_key (str): a string indicating which key of the data entries
should be used to find the utterance transcript. Defaults to "text".
Returns:
The same data as in the input manifest with some entries dropped.
"""
def __init__(
self,
regex_patterns: List[str],
text_key: str = "text",
**kwargs,
):
super().__init__(**kwargs)
self.regex_patterns = regex_patterns
self.text_key = text_key
def process_dataset_entry(self, data_entry) -> List:
drop_counter = collections.defaultdict(int)
data_entry[self.text_key] = add_start_end_spaces(data_entry[self.text_key])
for regex_pattern in self.regex_patterns:
if re.search(regex_pattern, data_entry[self.text_key]):
for match in re.finditer(regex_pattern, data_entry[self.text_key]):
drop_counter[regex_pattern] += 1
return [DataEntry(data=None, metrics=drop_counter)]
data_entry[self.text_key] = remove_extra_spaces(data_entry[self.text_key])
return [DataEntry(data=data_entry, metrics=drop_counter)]
def finalize(self, metrics):
total_counter = collections.defaultdict(int)
for counter in metrics:
for attribute, value in counter.items():
total_counter[attribute] += value
logger.info("Regex matches that were dropped in attribute")
for attribute, matches in total_counter.items():
logger.info(f"{attribute}, {matches}")
super().finalize(metrics)
[docs]
class DropOnAttribute(BaseParallelProcessor):
"""Drops utterances if attribute is set to True/False.
Args:
key (str): which key to use for dropping utterances.
drop_if_false (bool): whether to drop if value is False. Defaults
to dropping if True.
Returns:
The same data as in the input manifest with some entries dropped.
"""
def __init__(
self,
key: str,
drop_if_false: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.key = key
self.drop_if_false = drop_if_false
def process_dataset_entry(self, data_entry) -> List:
if data_entry[self.key] is not self.drop_if_false:
return [DataEntry(data=None, metrics=1)]
return [DataEntry(data=data_entry, metrics=0)]
def finalize(self, metrics):
total_counter = 0
for counter in metrics:
total_counter += counter
logger.info("Dropped %d utterances", total_counter)
super().finalize(metrics)
[docs]
class DropIfSubstringInInsertion(BaseParallelProcessor):
"""Drops utterances if a substring matches an ASR insertion.
Insertions are checked between ``data[self.text_key]`` and
``data[self.pred_text_key]``.
.. note::
We check for exact matches, so you need to be mindful of spaces, e.g.
you may wish to do ``substrings_in_insertion = ["nemo "]`` instead
of ``substrings_in_insertion = ["nemo"]``.
Args:
substrings_in_insertion (list[str]): a list of strings which might be
inserted in predicted ASR text. If the insertion matches a
string exactly, the utterance will be dropped.
text_key (str): a string indicating which key of the data entries
should be used to find the utterance transcript. Defaults to "text".
pred_text_key (str): a string indicating which key of the data entries
should be used to access the ASR predictions. Defaults to "pred_text".
Returns:
The same data as in the input manifest with some entries dropped.
"""
def __init__(
self,
substrings_in_insertion: List[str],
text_key: str = "text",
pred_text_key: str = "pred_text",
**kwargs,
):
super().__init__(**kwargs)
self.substrings_in_insertion = substrings_in_insertion
self.text_key = text_key
self.pred_text_key = pred_text_key
def process_dataset_entry(self, data_entry) -> List:
for substring_in_insertion in self.substrings_in_insertion:
if substring_in_insertion in data_entry[self.pred_text_key]:
orig_words, pred_words = data_entry[self.text_key], data_entry[self.pred_text_key]
diff = get_diff_with_subs_grouped(orig_words, pred_words)
for diff_entry in diff:
if diff_entry[0] == 1: # insertion in original string
if substring_in_insertion in diff_entry[1]:
return [DataEntry(data=None, metrics=diff_entry[1])]
return [DataEntry(data=data_entry, metrics="")]
def finalize(self, metrics):
total_counter = collections.defaultdict(int)
for diff_entry in metrics:
if diff_entry:
total_counter[diff_entry] += 1
logger.info("Some of the insertions that cause the utterance to be dropped:")
total_counter_sorted = dict(sorted(total_counter.items(), key=lambda x: x[1], reverse=True))
for insertion, count in total_counter_sorted.items():
logger.info(f"{insertion}, {count}")
super().finalize(metrics)
[docs]
class DropRepeatedFields(BaseParallelProcessor):
"""Drops utterances from the current manifest if their text fields are present in other manifests.
This class processes multiple manifest files and removes entries from the current manifest if the text field
matches any entry in the other manifests. It allows for optional punctuation removal from the text fields
before performing the check.
.. note::
It is better to process Test/Dev/Train and then Other.tsv
Args:
manifests_paths (list[str]): List of paths to the manifest files to check against.
current_manifest_file (str): Path to the current manifest file to be processed.
punctuations (str): (Optional): String of punctuation characters to be removed from the text fields before checking for duplicates. Defaults to None.
text_key (str): The key in the manifest entries that contains the text field. Defaults to "text".
Returns:
The same data as in the input manifest with some entries dropped.
"""
def __init__(self,
manifests_paths: List[str],
current_manifest_file: str,
punctuations: str = None,
text_key: str = "text",
**kwargs
):
super().__init__( **kwargs)
self.manifests_paths = manifests_paths
self.current_manifest_file = current_manifest_file
self.text_key = text_key
self.punctuations = punctuations
self.text_set = set()
self.load_data()
def load_data(self):
if self.current_manifest_file in self.manifests_paths:
self.manifests_paths.remove(self.current_manifest_file)
for path in self.manifests_paths:
if os.path.exists(path):
with open(path, "rt", encoding="utf8") as fin:
for line in fin:
line_dict = json.loads(line)
line_text = line_dict[self.text_key]
if self.punctuations is not None and len(self.punctuations) > 0:
line_text = self.remove_punctuation(line_text)
self.text_set.add(line_text)
def remove_punctuation(self, text):
return re.sub(fr'[{self.punctuations}]', '', text)
def process_dataset_entry(self, data_entry) -> List:
text_for_check = data_entry[self.text_key]
if self.punctuations is not None and len(self.punctuations) > 0:
text_for_check = self.remove_punctuation(text_for_check)
if text_for_check in self.text_set:
return [DataEntry(data=None, metrics=1)]
return [DataEntry(data=data_entry, metrics=0)]
def finalize(self, metrics: List):
total_counter = 0
for counter in metrics:
total_counter += counter
logger.info("Dropped %d utterances", total_counter)
super().finalize(metrics)