Source code for sdp.processors.inference.asr.utils.rttm

import os
from typing import Dict

import soundfile as sf

from sdp.logging import logger
from sdp.processors.base_processor import BaseParallelProcessor, DataEntry


[docs] class GetRttmSegments(BaseParallelProcessor): """This processor extracts audio segments based on RTTM (Rich Transcription Time Marked) files. The class reads an RTTM file specified by the `rttm_key` in the input data entry and generates a list of audio segment start times. It ensures that segments longer than a specified duration threshold are split into smaller segments. The resulting segments are stored in the output data entry under the `output_file_key`. Args: rttm_key (str): The key in the manifest that contains the path to the RTTM file. output_file_key (str, optional): The key in the data entry where the list of audio segment start times will be stored. Defaults to "audio_segments". duration_key (str, optional): The key in the data entry that contains the total duration of the audio file. Defaults to "duration". duration_threshold (float, optional): The maximum duration for a segment before it is split. Segments longer than this threshold will be divided into smaller segments. Defaults to 20.0 seconds. Returns: A list containing a single `DataEntry` object with the updated data entry, which includes the `output_file_key` containing the sorted list of audio segment start times. """ def __init__( self, rttm_key: str, output_file_key: str = "audio_segments", duration_key: str = "duration", duration_threshold: float = 20.0, **kwargs, ): super().__init__(**kwargs) self.rttm_key = rttm_key self.duration_threshold = duration_threshold self.duration_key = duration_key self.output_file_key = output_file_key def split_long_segment(self, slices, duration, last_slice): duration0 = self.duration_threshold while duration0 < duration: slices.append(last_slice + duration0) duration0 += self.duration_threshold if duration0 > duration: duration0 = duration slices.append(last_slice + duration0) return slices, last_slice + duration0 def process_dataset_entry(self, data_entry: Dict): file_duration = data_entry[self.duration_key] rttm_file = data_entry[self.rttm_key] starts = [] with open(rttm_file, "r") as f: for line in f: starts.append(float(line.split(" ")[3])) starts.append(file_duration) slices = [0] last_slice, last_start, last_duration, duration = 0, 0, 0, 0 for start in starts: duration = start - last_slice if duration <= self.duration_threshold: pass elif duration > self.duration_threshold and last_duration < self.duration_threshold: slices.append(last_start) last_slice = last_start last_start = start last_duration = duration duration = start - last_slice if duration <= self.duration_threshold: slices.append(start) last_slice = start else: slices, last_slice = self.split_long_segment(slices, duration, last_slice) else: slices.append(start) last_slice = start last_start = start last_duration = duration data_entry[self.output_file_key] = sorted(list(set(slices))) return [DataEntry(data=data_entry)]
[docs] class SplitAudioFile(BaseParallelProcessor): """This processor splits audio files into segments based on provided timestamps. The class reads an audio file specified by the `input_file_key` and splits it into segments based on the timestamps provided in the `segments_key` field of the input data entry. The split audio segments are saved as individual WAV files in the specified `splited_audio_dir` directory. The `output_file_key` field of the data entry is updated with the path to the corresponding split audio file, and the `duration_key` field is updated with the duration of the split audio segment. Args: splited_audio_dir (str): The directory where the split audio files will be saved. segments_key (str, optional): The key in the manifest that contains the list of timestamps for splitting the audio. Defaults to "audio_segments". duration_key (str, optional): The key in the manifest where the duration of the split audio segment will be stored. Defaults to "duration". input_file_key (str, optional): The key in the manifest that contains the path to the input audio file. Defaults to "source_filepath". output_file_key (str, optional): The key in the manifest where the path to the split audio file will be stored. Defaults to "audio_filepath". Returns: A list of data entries, where each entry represents a split audio segment with the corresponding file path and duration updated in the data entry. """ def __init__( self, splited_audio_dir: str, segments_key: str = "audio_segments", duration_key: str = "duration", input_file_key: str = "source_filepath", output_file_key: str = "audio_filepath", **kwargs, ): super().__init__(**kwargs) self.splited_audio_dir = splited_audio_dir self.segments_key = segments_key self.duration_key = duration_key self.input_file_key = input_file_key self.output_file_key = output_file_key def write_segment(self, data, samplerate, start_sec, end_sec, input_file): wav_save_file = os.path.join( self.splited_audio_dir, os.path.splitext(os.path.split(input_file)[1])[0], str(int(start_sec * 100)) + "-" + str(int(end_sec * 100)) + ".wav", ) if not os.path.isfile(wav_save_file): data_sample = data[int(start_sec * samplerate) : int(end_sec * samplerate)] duration = len(data_sample) / samplerate os.makedirs(os.path.split(wav_save_file)[0], exist_ok=True) sf.write(wav_save_file, data_sample, samplerate) return wav_save_file, duration else: try: data, samplerate = sf.read(wav_save_file) duration = data.shape[0] / samplerate except Exception as e: logger.warning(str(e) + " file: " + wav_save_file) duration = -1.0 return wav_save_file, duration def process_dataset_entry(self, data_entry: Dict): slices = data_entry[self.segments_key] input_file = data_entry[self.input_file_key] input_data, samplerate = sf.read(input_file) data_entries = [] for i in range(len(slices[:-1])): wav_save_file, duration = self.write_segment(input_data, samplerate, slices[i], slices[i + 1], input_file) data_entry[self.output_file_key] = wav_save_file data_entry[self.duration_key] = duration data_entries.append(DataEntry(data=data_entry.copy())) return data_entries