Source code for sdp.processors.nemo.asr_inference
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
from pathlib import Path
from sdp.processors.base_processor import BaseProcessor
# Note that we do not re-use base parallel implementation, since the ASR
# inference is already run in batches.
# TODO: actually, it might still be beneficial to have another level of
# parallelization, but that needs to be tested.
[docs]
class ASRInference(BaseProcessor):
"""This processor performs ASR inference on each utterance of the input manifest.
ASR predictions will be saved in the ``pred_text`` key.
Args:
pretrained_model (str): the name of the pretrained NeMo ASR model
which will be used to do inference.
batch_size (int): the batch size to use for ASR inference. Defaults to 32.
Returns:
The same data as in the input manifest with an additional field
``pred_text`` containing ASR model's predictions.
"""
def __init__(
self,
pretrained_model: str,
batch_size: int = 32,
**kwargs,
):
super().__init__(**kwargs)
self.script_path = Path(__file__).parents[1] / "nemo" / "transcribe_speech.py"
self.pretrained_model = pretrained_model
self.batch_size = batch_size
def process(self):
"""This will add "pred_text" key into the output manifest."""
os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True)
subprocess.run(
f"python {self.script_path} "
f"pretrained_name={self.pretrained_model} "
f"dataset_manifest={self.input_manifest_file} "
f"output_filename={self.output_manifest_file} "
f"batch_size={self.batch_size} ",
shell=True,
check=True,
)