# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from pathlib import Path
from typing import Dict, List, Union, Optional
import pandas as pd
from tqdm import tqdm
from sdp.processors.base_processor import (
BaseParallelProcessor,
BaseProcessor,
DataEntry,
)
from sdp.utils.common import load_manifest
[docs]
class CombineSources(BaseParallelProcessor):
"""Can be used to create a single field from two alternative sources.
E.g.::
_target_: sdp.processors.CombineSources
sources:
- field: text_pc
origin_label: original
- field: text_pc_pred
origin_label: synthetic
- field: text
origin_label: no_pc
target: text
will populate the ``text`` field with data from ``text_pc`` field if it's
present and not equal to ``n/a`` (can be customized). If ``text_pc`` is
not available, it will populate ``text`` from ``text_pc_pred`` field,
following the same rules. If both are not available, it will fall back to
the ``text`` field itself. In all cases it will specify which source was
used in the ``text_origin`` field by using the label from the
``origin_label`` field.. If non of the sources is available,
it will populate both the target and the origin fields with ``n/a``.
Args:
sources (list[dict]): list of the sources to use in order of preference.
Each element in the list should be in the following format::
{
field: <which field to take the data from>
origin_label: <what to write in the "<target>_origin"
}
target (str): target field that we are populating.
na_indicator (str): if any source field has text equal to the
``na_indicator`` it will be considered as not available. If none
of the sources are present, this will also be used as the value
for the target and origin fields. Defaults to ``n/a``.
Returns:
The same data as in the input manifest enhanced with the following fields::
<target>: <populated with data from either <source1> or <source2> \
or with <na_indicator> if none are available>
<target>_origin: <label that marks where the data came from>
"""
def __init__(
self,
sources: List[Dict[str, str]],
target: str,
na_indicator: str = "n/a",
**kwargs,
):
super().__init__(**kwargs)
self.sources = sources
self.target = target
self.na_indicator = na_indicator
def process_dataset_entry(self, data_entry: Dict):
for source_dict in self.sources:
if data_entry.get(source_dict["field"], self.na_indicator) != self.na_indicator:
data_entry[self.target] = data_entry[source_dict["field"]]
data_entry[f"{self.target}_origin"] = source_dict["origin_label"]
break # breaking out on the first present label
else: # going here if no break was triggered
data_entry[self.target] = self.na_indicator
data_entry[f"{self.target}_origin"] = self.na_indicator
return [DataEntry(data=data_entry)]
[docs]
class AddConstantFields(BaseParallelProcessor):
"""This processor adds constant fields to all manifest entries.
E.g., can be useful to add fixed ``label: <language>`` field for downstream
language identification model training.
Args:
fields: dictionary with any additional information to add. E.g.::
fields = {
"label": "en",
"metadata": "mcv-11.0-2022-09-21",
}
Returns:
The same data as in the input manifest with added fields
as specified in the ``fields`` input dictionary.
"""
def __init__(
self,
fields: Dict,
**kwargs,
):
super().__init__(**kwargs)
self.fields = fields
def process_dataset_entry(self, data_entry: Dict):
data_entry.update(self.fields)
return [DataEntry(data=data_entry)]
[docs]
class DuplicateFields(BaseParallelProcessor):
"""This processor duplicates fields in all manifest entries.
It is useful for when you want to do downstream processing of a variant
of the entry. E.g. make a copy of "text" called "text_no_pc", and
remove punctuation from "text_no_pc" in downstream processors.
Args:
duplicate_fields (dict): dictionary where keys are the original
fields to be copied and their values are the new names of
the duplicate fields.
Returns:
The same data as in the input manifest with duplicated fields
as specified in the ``duplicate_fields`` input dictionary.
Example:
.. code-block:: yaml
- _target_: sdp.processors.modify_manifest.common.DuplicateFields
input_manifest_file: ${workspace_dir}/test1.json
output_manifest_file: ${workspace_dir}/test2.json
duplicate_fields: {"text":"answer"}
"""
def __init__(
self,
duplicate_fields: Dict,
**kwargs,
):
super().__init__(**kwargs)
self.duplicate_fields = duplicate_fields
def process_dataset_entry(self, data_entry: Dict):
for field_src, field_tgt in self.duplicate_fields.items():
if not field_src in data_entry:
raise ValueError(f"Expected field {field_src} in data_entry {data_entry} but there isn't one.")
data_entry[field_tgt] = data_entry[field_src]
return [DataEntry(data=data_entry)]
[docs]
class RenameFields(BaseParallelProcessor):
"""This processor renames fields in all manifest entries.
Args:
rename_fields: dictionary where keys are the fields to be
renamed and their values are the new names of the fields.
Returns:
The same data as in the input manifest with renamed fields
as specified in the ``rename_fields`` input dictionary.
"""
def __init__(
self,
rename_fields: Dict,
**kwargs,
):
super().__init__(**kwargs)
self.rename_fields = rename_fields
def process_dataset_entry(self, data_entry: Dict):
for field_src, field_tgt in self.rename_fields.items():
if not field_src in data_entry:
raise ValueError(f"Expected field {field_src} in data_entry {data_entry} but there isn't one.")
data_entry[field_tgt] = data_entry[field_src]
del data_entry[field_src]
return [DataEntry(data=data_entry)]
[docs]
class SplitOnFixedDuration(BaseParallelProcessor):
"""This processor splits audio into a fixed length segments.
It does not actually create different audio files, but simply adds
corresponding ``offset`` and ``duration`` fields. These fields can
be automatically processed by NeMo to split audio on the fly during
training.
Args:
segment_duration (float): fixed desired duration of each segment.
drop_last (bool): whether to drop the last segment if total duration is
not divisible by desired segment duration. If False, the last
segment will be of a different length which is ``< segment_duration``.
Defaults to True.
drop_text (bool): whether to drop text from entries as it is most likely
inaccurate after the split on duration. Defaults to True.
Returns:
The same data as in the input manifest but all audio that's longer
than the ``segment_duration`` will be duplicated multiple times with
additional ``offset`` and ``duration`` fields. If ``drop_text=True``
will also drop ``text`` field from all entries.
"""
def __init__(
self,
segment_duration: float,
drop_last: bool = True,
drop_text: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.segment_duration = segment_duration
self.drop_last = drop_last
self.drop_text = drop_text
def process_dataset_entry(self, data_entry: Dict):
total_duration = data_entry["duration"]
total_segments = int(total_duration // self.segment_duration)
output = [None] * total_segments
for segment_idx in range(total_segments):
modified_entry = data_entry.copy() # shallow copy should be good enough
modified_entry["duration"] = self.segment_duration
modified_entry["offset"] = segment_idx * self.segment_duration
if self.drop_text:
modified_entry.pop("text", None)
output[segment_idx] = DataEntry(data=modified_entry)
remainder = total_duration - self.segment_duration * total_segments
if not self.drop_last and remainder > 0:
modified_entry = data_entry.copy()
modified_entry["duration"] = remainder
modified_entry["offset"] = self.segment_duration * total_segments
if self.drop_text:
modified_entry.pop("text", None)
output.append(DataEntry(data=modified_entry))
return output
[docs]
class ChangeToRelativePath(BaseParallelProcessor):
"""This processor changes the audio filepaths to be relative.
Args:
base_dir: typically a folder where manifest file is going to be
stored. All passes will be relative to that folder.
Returns:
The same data as in the input manifest with ``audio_filepath`` key
changed to contain relative path to the ``base_dir``.
"""
def __init__(
self,
base_dir: str,
**kwargs,
):
super().__init__(**kwargs)
self.base_dir = base_dir
def process_dataset_entry(self, data_entry: Dict):
data_entry["audio_filepath"] = os.path.relpath(data_entry["audio_filepath"], self.base_dir)
return [DataEntry(data=data_entry)]
[docs]
class SortManifest(BaseProcessor):
"""Processor which will sort the manifest by some specified attribute.
Args:
attribute_sort_by (str): the attribute by which the manifest will be sorted.
descending (bool): if set to False, attribute will be in ascending order.
If True, attribute will be in descending order. Defaults to True.
Returns:
The same entries as in the input manifest, but sorted based
on the provided parameters.
"""
def __init__(
self,
attribute_sort_by: str,
descending: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.attribute_sort_by = attribute_sort_by
self.descending = descending
def process(self):
with open(self.input_manifest_file, "rt", encoding="utf8") as fin:
dataset_entries = [json.loads(line) for line in fin.readlines()]
dataset_entries = sorted(dataset_entries, key=lambda x: x[self.attribute_sort_by], reverse=self.descending)
with open(self.output_manifest_file, "wt", encoding="utf8") as fout:
for line in dataset_entries:
fout.write(json.dumps(line, ensure_ascii=False) + "\n")
[docs]
class KeepOnlySpecifiedFields(BaseProcessor):
"""Saves a copy of a manifest but only with a subset of the fields.
Typically will be the final processor to save only relevant fields
in the desired location.
Args:
fields_to_keep (list[str]): list of the fields in the input manifest
that we want to retain. The output file will only contain these
fields.
Returns:
The same data as in input manifest, but re-saved in the new location
with only ``fields_to_keep`` fields retained.
"""
def __init__(self, fields_to_keep: List[str], **kwargs):
super().__init__(**kwargs)
self.fields_to_keep = fields_to_keep
def process(self):
with open(self.input_manifest_file, "rt", encoding="utf8") as fin, open(
self.output_manifest_file, "wt", encoding="utf8"
) as fout:
for line in tqdm(fin):
line = json.loads(line)
new_line = {field: line[field] for field in self.fields_to_keep}
fout.write(json.dumps(new_line, ensure_ascii=False) + "\n")
[docs]
class ApplyInnerJoin(BaseProcessor):
"""Applies inner join to two manifests, i.e. creates a manifest from records that have matching values in both manifests.
For more information, please refer to the Pandas merge function documentation:
https://pandas.pydata.org/docs/reference/api/pandas.merge.html#pandas.merge
Args:
column_id (Union[str, List[str], None]): Field names to join on. These must be found in both manifests.
If `column_id` is None then this defaults to the intersection of the columns in both manifests.
Defaults to None.
left_manifest_file (Optional[str]): path to the left manifest. Defaults to `input_manifest_file`.
right_manifest_file (str): path to the right manifest.
Returns:
Inner join of two manifests.
"""
def __init__(
self,
left_manifest_file: Optional[str],
right_manifest_file: str,
column_id: Union[str, List[str], None] = None,
**kwargs,
):
super().__init__(**kwargs)
self.left_manifest_file = left_manifest_file if left_manifest_file != None else self.input_manifest_file
self.right_manifest_file = right_manifest_file
self.column_id = column_id
def process(self):
m1 = pd.DataFrame.from_records(load_manifest(Path(self.left_manifest_file)))
m2 = pd.DataFrame.from_records(load_manifest(Path(self.right_manifest_file)))
m3 = pd.merge(m1, m2, on=self.column_id, how="inner")
with open(self.output_manifest_file, "wt", encoding="utf8") as fout:
for _, line in m3.iterrows():
fout.write(json.dumps(dict(line), ensure_ascii=False) + "\n")