Source code for sdp.processors.inference.asr.nemo.lid_inference

import json
from pathlib import Path

import numpy as np
from tqdm import tqdm

from sdp.logging import logger
from sdp.processors.base_processor import BaseProcessor
from sdp.utils.common import load_manifest


[docs] class AudioLid(BaseProcessor): """ Processor for language identification (LID) of audio files using a pre-trained LID model. Args: input_audio_key (str): The key in the dataset containing the path to the audio files for language identification. pretrained_model (str): The name of the pre-trained ASR model for language identification. output_lang_key (str): The key to store the identified language for each audio file. device (str): The device to run the ASR model on (e.g., 'cuda', 'cpu'). If None, it automatically selects the available GPU if present; otherwise, it uses the CPU. segment_duration (float): Random sample duration in seconds. Delault is np.inf. num_segments (int): Number of segments of file to use for majority vote. Delault is 1. random_seed (int): Seed for generating the starting position of the segment. Delault is None. **kwargs: Additional keyword arguments to be passed to the base class `BaseProcessor`. """ def __init__( self, input_audio_key: str, pretrained_model: str, output_lang_key: str, device: str, segment_duration: float = np.inf, num_segments: int = 1, random_seed: int = None, **kwargs, ): super().__init__(**kwargs) self.input_audio_key = input_audio_key self.pretrained_model = pretrained_model self.output_lang_key = output_lang_key self.segment_duration = segment_duration self.num_segments = num_segments self.random_seed = random_seed self.device = device def process(self): import nemo.collections.asr as nemo_asr import torch # importing after nemo to make sure users first install nemo, instead of torch, then nemo model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name=self.pretrained_model) if self.device is None: if torch.cuda.is_available(): model = model.cuda() else: model = model.cpu() else: model = model.to(self.device) manifest = load_manifest(Path(self.input_manifest_file)) Path(self.output_manifest_file).parent.mkdir(exist_ok=True, parents=True) with Path(self.output_manifest_file).open('w') as f: for item in tqdm(manifest): audio_file = item[self.input_audio_key] try: lang = model.get_label(audio_file, self.segment_duration, self.num_segments) except Exception as e: logger.warning("AudioLid " + audio_file + " " + str(e)) lang = None if lang: item[self.output_lang_key] = lang f.write(json.dumps(item, ensure_ascii=False) + '\n')