Source code for sdp.processors.huggingface.speech_recognition
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from pathlib import Path
from tqdm import tqdm
from sdp.logging import logger
from sdp.processors.base_processor import BaseProcessor
from sdp.utils.common import load_manifest
from typing import Optional
[docs]
class ASRTransformers(BaseProcessor):
"""
Processor to transcribe using ASR Transformers model from HuggingFace.
Args:
pretrained_model (str): name of pretrained model on HuggingFace.
output_text_key (str): Key to save transcription result.
input_audio_key (str): Key to read audio file. Defaults to "audio_filepath".
input_duration_key (str): Audio duration key. Defaults to "duration".
device (str): Inference device.
batch_size (int): Inference batch size. Defaults to 1.
chunk_length_s (int): Length of the chunks (in seconds) into which the input audio should be divided.
Note: Some models perform the chunking on their own (for instance, Whisper chunks into 30s segments also by maintaining the context of the previous chunks).
torch_dtype (str): Tensor data type. Default to "float32"
max_new_tokens (Optional[int]): The maximum number of new tokens to generate.
If not specified, there is no hard limit on the number of tokens generated, other than model-specific constraints.
"""
def __init__(
self,
pretrained_model: str,
output_text_key: str,
input_audio_key: str = "audio_filepath",
input_duration_key: str = "duration",
device: str = None,
batch_size: int = 1,
chunk_length_s: int = 0,
torch_dtype: str = "float32",
generate_task: str = "transcribe",
generate_language: str = "english",
max_new_tokens: Optional[int] = None,
**kwargs,
):
super().__init__(**kwargs)
try:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
except:
raise ImportError("Need to install transformers: pip install accelerate transformers")
logger.warning("This is an example processor, for demonstration only. Do not use it for production purposes.")
self.pretrained_model = pretrained_model
self.input_audio_key = input_audio_key
self.output_text_key = output_text_key
self.input_duration_key = input_duration_key
self.device = device
self.batch_size = batch_size
self.chunk_length_s = chunk_length_s
self.generate_task = generate_task
self.generate_language = generate_language
self.max_new_tokens = max_new_tokens
if torch_dtype == "float32":
self.torch_dtype = torch.float32
elif torch_dtype == "float16":
self.torch_dtype = torch.float16
else:
raise NotImplementedError(torch_dtype + " is not implemented!")
if self.device is None:
if torch.cuda.is_available():
self.device = "cuda:0"
else:
self.device = "cpu"
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
self.pretrained_model, torch_dtype=self.torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
self.model.to(self.device)
self.model.generation_config.language = self.generate_language
processor = AutoProcessor.from_pretrained(self.pretrained_model)
self.pipe = pipeline(
"automatic-speech-recognition",
model=self.model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=None,
chunk_length_s=self.chunk_length_s,
batch_size=self.batch_size,
return_timestamps=True,
torch_dtype=self.torch_dtype,
device=self.device,
)
def process(self):
json_list = load_manifest(Path(self.input_manifest_file))
json_list_sorted = sorted(json_list, key=lambda d: d[self.input_duration_key], reverse=True)
Path(self.output_manifest_file).parent.mkdir(exist_ok=True, parents=True)
with Path(self.output_manifest_file).open("w") as f:
start_index = 0
for _ in tqdm(range(len(json_list_sorted) // self.batch_size)):
batch = json_list_sorted[start_index : start_index + self.batch_size]
start_index += self.batch_size
audio_files = [item[self.input_audio_key] for item in batch]
results = self.pipe(
audio_files, generate_kwargs={"language": self.generate_language, "task": self.generate_task}
)
for i, item in enumerate(batch):
item[self.output_text_key] = results[i]["text"]
f.write(json.dumps(item, ensure_ascii=False) + "\n")