Source code for sdp.processors.datasets.mls.restore_pc

# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import json
import os
import re
import string
import sys
from glob import glob
from pathlib import Path
from typing import Optional

import regex
from joblib import Parallel, delayed
from tqdm import tqdm

from sdp.logging import logger
from sdp.processors.base_processor import BaseProcessor
from sdp.utils.common import download_file, extract_archive

sys.setrecursionlimit(1000000)

NA = "n/a"
MLS_TEXT_URL = "https://dl.fbaipublicfiles.com/mls/lv_text.tar.gz"


def abbreviations(text):
    text = (
        text.replace("Cap'n", "Captain")
        .replace("cap'n", "captain")
        .replace("o'shot", "o shot")
        .replace("o' shot", "o shot")
        .replace("on'y", "only")
        .replace("on' y", "only")
        .replace(" 'a ", " a ")
        .replace(" 'em ", " em ")
        .replace("gen'leman", "gentleman")
    )
    return text


def process(text):
    text = (
        text.replace("www.gutenberg.org", "www dot gutenberg dot org")
        .replace(".txt", "dot txt")
        .replace(".zip", "dot zip")
    )

    text = (
        text.replace("’", "'")
        .replace("_", " ")
        .replace("\n", " ")
        .replace("\t", " ")
        .replace("…", "...")
        .replace("»", '"')
        .replace("«", '"')
        .replace("\\", "")
        .replace("”", '"')
        .replace("„", '"')
        .replace("´", "'")
        .replace("-- --", "--")
        .replace("--", " -- ")
        .replace(". . .", "...")
        .replace("’", "'")
        .replace("“", '"')
        .replace("“", '"')
        .replace("‘", "'")
        .replace("_", " ")
        .replace("*", " ")
        .replace("—", "-")
        .replace("- -", "--")
        .replace("•", " ")
        .replace("^", " ")
        .replace(">", " ")
        .replace("■", " ")
        .replace("/", " ")
        .replace("––––", "...")
        .replace("W⸺", "W")
        .replace("`", "'")
        .replace("<", " ")
        .replace("{", " ")
        .replace("Good-night", "Good night")
        .replace("good-night", "good night")
        .replace("good-bye", "goodbye")
        .replace("Good-bye", "Goodbye")
        .replace(" !", "!")
        .replace(" ?", "?")
        .replace(" ,", ",")
        .replace(" .", ".")
        .replace(" ;", ";")
        .replace(" :", ":")
        .replace("!!", "!")
        .replace("--", "-")
        .replace("“", '"')
        .replace(", , ", ", ")
        .replace("=", " ")
        .replace("l,000", "1,000")
        .replace("–", "-")
    )
    # remove dash in between the words
    text = re.sub(r"([A-Za-z0-9]+)(-)([A-Za-z0-9]+)", r"\g<1> \g<3>", text)
    text = re.sub(r"([A-Za-z0-9]+)(\.)([A-Za-z]+)", r"\g<1>\g<2> \g<3>", text)
    text = re.sub(r"([A-Za-z]+)(\.)([A-Za-z0-9]+)", r"\g<1>\g<2> \g<3>", text)

    # # remove text inside square brackets
    # text = re.sub(r"(\[.*?\])", " ", text)

    def __fix_space(text):
        # remove commas between digits
        text = re.sub(r"([0-9]+)(,)(\d\d\d)", r"\g<1>\g<3>", text)
        text = re.sub(r"([A-Za-z]+)(,)([A-Za-z0-9]+)", r"\g<1>\g<2> \g<3>", text)
        return text

    for _ in range(3):
        text = __fix_space(text)

    text = re.sub(r" +", " ", text)

    # make sure the text starts with an alpha
    start_idx = 0
    while not text[start_idx].isalpha():
        start_idx += 1

    end_text = "END OF THIS PROJECT GUTENBERG"
    end_idx = len(text)
    if end_text in text:
        end_idx = text.find(end_text)

    end_text = "End of the Project Gutenberg"
    if end_text in text:
        end_idx = text.find(end_text)

    return text[start_idx:end_idx]


def read_text(text_f):
    with open(text_f, "r") as f:
        text = f.read()
    return text


def remove_punctuation(text: str, remove_spaces=True, do_lower=True, exclude=None, remove_accents=False):
    all_punct_marks = string.punctuation + "¿¡⸘"

    if exclude is not None:
        for p in exclude:
            all_punct_marks = all_punct_marks.replace(p, "")

        # a weird bug where commas is getting deleted when dash is present in the list of punct marks
        all_punct_marks = all_punct_marks.replace("-", "")
    text = re.sub("[" + all_punct_marks + "]", " ", text)

    if exclude and "-" not in exclude:
        text = text.replace("-", " ")

    text = re.sub(r" +", " ", text)
    if remove_spaces:
        text = text.replace(" ", "").replace("\u00A0", "").strip()

    if do_lower:
        text = text.lower()

    if remove_accents:
        text = text.replace("á", "a")
        text = text.replace("é", "e")
        text = text.replace("í", "i")
        text = text.replace("ó", "o")
        text = text.replace("ú", "u")
        text = text.replace("à", "a")
        text = text.replace("è", "e")
        text = text.replace("ù", "u")
        text = text.replace("â", "a")
        text = text.replace("ê", "e")
        text = text.replace("î", "i")
        text = text.replace("ô", "o")
        text = text.replace("û", "u")

    return text.strip()


def recover_lines(manifest, processed_text, output_dir, restored_text_field):
    manifest_recovered = f"{output_dir}/{os.path.basename(manifest)}"
    if os.path.exists(manifest_recovered):
        return

    lines = []
    with open(manifest, "r") as f:
        for line in f:
            line = json.loads(line)
            lines.append(line["text"])

    logger.debug(f"processing {manifest}")
    logger.debug(f"processing - {len(lines)} lines")

    last_found_start_idx = 0
    recovered_lines = {}

    for idx, cur_line in enumerate(lines):
        stop_search_for_line = False
        cur_word_idx = 0
        cur_line = abbreviations(cur_line)
        cur_line = cur_line.split()
        end_match_found = False

        while not stop_search_for_line:
            cur_word = cur_line[cur_word_idx]

            pattern = cur_word
            max_start_match_len = min(4, len(cur_line))
            for i in range(1, max_start_match_len):
                pattern += f"[^A-Za-z]+{cur_line[i]}"

            pattern = re.compile(pattern)

            for i, m in enumerate(pattern.finditer(processed_text[last_found_start_idx:].lower())):
                if end_match_found:
                    break
                match_idx = m.start() + last_found_start_idx
                processed_text_list = processed_text[match_idx:].split()
                raw_text_pointer = (
                    len(cur_line) - 3
                )  # added in case some dash separated words and split into multiple words in the cur_line
                stop_end_search = False
                right_offset = 20
                while not end_match_found and raw_text_pointer <= len(processed_text_list) and not stop_end_search:
                    if cur_line[-1].replace("'", "") == remove_punctuation(
                        processed_text_list[raw_text_pointer - 1],
                        remove_spaces=True,
                        do_lower=True,
                        remove_accents=False,
                    ):
                        # processed text could contain apostrophes that are parts of quotes, let's remove them from the processed text as well
                        if "'" not in cur_line[-1] and "'" in processed_text_list[raw_text_pointer - 1]:
                            processed_text_list[raw_text_pointer - 1] = processed_text_list[
                                raw_text_pointer - 1
                            ].replace("'", "")
                        recovered_line = " ".join(processed_text_list[:raw_text_pointer])
                        if not is_valid(" ".join(cur_line), recovered_line):
                            raw_text_pointer += 1
                        else:
                            recovered_lines[idx] = recovered_line
                            end_match_found = True
                            raw_text_pointer += 1
                            stop_search_for_line = True
                            last_found_start_idx = raw_text_pointer

                    else:
                        raw_text_pointer += 1
                        if raw_text_pointer > (len(cur_line) + right_offset):
                            stop_end_search = True

            if not end_match_found:
                stop_search_for_line = True

    logger.debug(
        f"recovered {len(recovered_lines)} lines out of {len(lines)} -- {round(len(recovered_lines)/len(lines)*100, 2)}% -- {os.path.basename(manifest)}"
    )

    with open(manifest_recovered, "w") as f_out, open(manifest, "r") as f_in:
        for idx, line in enumerate(f_in):
            line = json.loads(line)
            if idx in recovered_lines:
                line[restored_text_field] = recovered_lines[idx]
            else:
                line[restored_text_field] = NA
            f_out.write(json.dumps(line, ensure_ascii=False) + "\n")


def split_text_into_sentences(text: str):
    """
    Split text into sentences.

    Args:
        text: text

    Returns list of sentences
    """
    # TODO: should this be filled up and exposed as a parameter?
    lower_case_unicode = ""
    upper_case_unicode = ""

    # end of quoted speech - to be able to split sentences by full stop
    text = re.sub(r"([\.\?\!])([\"\'])", r"\g<2>\g<1> ", text)

    # remove extra space
    text = re.sub(r" +", " ", text)

    # remove space in the middle of the lower case abbreviation to avoid splitting into separate sentences
    matches = re.findall(rf"[a-z{lower_case_unicode}]\.\s[a-z{lower_case_unicode}]\.", text)
    for match in matches:
        text = text.replace(match, match.replace(". ", "."))

    # Read and split transcript by utterance (roughly, sentences)
    split_pattern = (
        rf"(?<!\w\.\w.)(?<![A-Z{upper_case_unicode}][a-z{lower_case_unicode}]+\.)"
        rf"(?<![A-Z{upper_case_unicode}]\.)(?<=\.|\?|\!|\.”|\?”\!”)\s(?![0-9]+[a-z]*\.)"
    )
    sentences = regex.split(split_pattern, text)
    return sentences


def normalize_text(text_f: str, normalizer: Optional['Normalizer'] = None):
    """
    Pre-process and normalized text_f file.

    Args:
        text_f: path to .txt file to normalize
        normalizer:
    """
    raw_text = read_text(text_f)
    processed_text = abbreviations(process(raw_text))
    if normalizer is not None:
        processed_text_list = normalizer.split_text_into_sentences(processed_text)
    else:
        processed_text_list = split_text_into_sentences(processed_text)
    processed_text_list_merged = []
    last_segment = ""
    max_len = 7500
    for i, text in enumerate(processed_text_list):
        if len(last_segment) < max_len:
            last_segment += " " + text
        else:
            processed_text_list_merged.append(last_segment.strip())
            last_segment = ""

        if i == len(processed_text_list) - 1 and len(last_segment) > 0:
            processed_text_list_merged.append(last_segment.strip())

    for i, text in enumerate(tqdm(processed_text_list_merged)):
        if normalizer is not None:
            processed_text_list_merged[i] = normalizer.normalize(
                text=text, punct_post_process=True, punct_pre_process=True
            )
        else:
            processed_text_list_merged[i] = re.sub(r"\d", r"", processed_text_list_merged[i])
    processed_text = " ".join(processed_text_list_merged)
    return processed_text


import diff_match_patch as dmp_module

dmp = dmp_module.diff_match_patch()
dmp.Diff_Timeout = 0


def is_valid(line, recovered_line):
    """Checks that the restore line matches the original line in everything but casing and punctuation marks"""
    line = abbreviations(line)
    line_no_punc = remove_punctuation(line, remove_spaces=True, do_lower=True, remove_accents=True)
    recovered_line_no_punc = remove_punctuation(recovered_line, remove_spaces=True, do_lower=True, remove_accents=True)

    is_same = line_no_punc == recovered_line_no_punc

    return is_same


def process_book(book_manifest, texts_dir, submanifests_dir, output_dir, restored_text_field, normalizer):
    book_id = os.path.basename(book_manifest).split(".")[0]
    text_f = f"{texts_dir}/{book_id}.txt"
    manifests = glob(f"{submanifests_dir}/{book_id}_*.json")
    logger.info(f"{book_id} -- {len(manifests)} manifests")

    # only continue (i.e. do not make early 'return') if there are {book_id}_{spk_id}.json files in submanifests_dir
    # that are not in output dir - else return early
    for book_id_spk_id in [os.path.basename(x).strip(".json") for x in manifests]:
        if not os.path.exists(os.path.join(output_dir, f"{book_id_spk_id}.json")):
            logger.info(f"Did not find {book_id_spk_id} in {output_dir} => will process this book")
            break
    else:
        return

    try:
        processed_text = normalize_text(text_f, normalizer)
        # re-run abbreviations since new are being added
        processed_text = abbreviations(processed_text)
        [
            recover_lines(
                manifest=manifest,
                processed_text=processed_text,
                output_dir=output_dir,
                restored_text_field=restored_text_field,
            )
            for manifest in manifests
        ]
    except:
        logger.info(f"{text_f} failed")
        return


[docs] class RestorePCForMLS(BaseProcessor): """Recovers original text from the MLS Librivox texts. This processor can be used to restore punctuation and capitalization for the MLS data. Uses the original data in https://dl.fbaipublicfiles.com/mls/lv_text.tar.gz. Saves recovered text in ``restored_text_field`` field. If text was not recovered, ``restored_text_field`` will be equal to ``n/a``. Args: language_long (str): the full name of the language, used for choosing the folder of the contents of "https://dl.fbaipublicfiles.com/mls/lv_text.tar.gz". E.g., "english", "spanish", "italian", etc. language_short (str or None): the short name of the language, used for specifying the normalizer we want to use. E.g., "en", "es", "it", etc. If set to None, we will not try to normalize the provided Librivox text. lv_text_dir (str): the directory where the contents of https://dl.fbaipublicfiles.com/mls/lv_text.tar.gz will be saved. submanifests_dir (str): the directory where submanifests (one for each combo of speaker + book) will be stored. restored_submanifests_dir (str): the directory where restored submanifests (one for each combo of speaker + book) will be stored. restored_text_field (str): the field where the recovered text will be stored. n_jobs (int): number of jobs to use for parallel processing. Defaults to -1. show_conversion_breakdown (bool): whether to show how much of each submanifest was restored. Defaults to True. Returns: All the same data as in the input manifest with an additional key:: <restored_text_field>: <restored text or n/a if match was not found>`` """ def __init__( self, language_long: str, language_short: Optional[str], lv_text_dir: str, submanifests_dir: str, restored_submanifests_dir: str, restored_text_field: str, n_jobs: int = -1, show_conversion_breakdown: bool = True, **kwargs, ): super().__init__(**kwargs) self.language_long = language_long self.language_short = language_short self.lv_text_dir = Path(lv_text_dir) self.submanifests_dir = Path(submanifests_dir) self.restored_submanifests_dir = Path(restored_submanifests_dir) self.restored_text_field = restored_text_field self.n_jobs = n_jobs self.show_conversion_breakdown = show_conversion_breakdown def process(self): """Main processing happens here. * Download & extract lv_text. * Create submanifests. * Restore P&C to submanifests. * Group back submanifests into a single manifest """ from nemo_text_processing.text_normalization.normalize import Normalizer os.makedirs(self.lv_text_dir, exist_ok=True) # Download & extract lv_text. download_file(MLS_TEXT_URL, str(self.lv_text_dir)) lv_text_data_folder = extract_archive( str(self.lv_text_dir / os.path.basename(MLS_TEXT_URL)), str(self.lv_text_dir) ) # Create submanifests os.makedirs(self.submanifests_dir, exist_ok=True) data = {} with open(self.input_manifest_file, "r") as f: for line in tqdm(f): item = json.loads(line) name = Path(item["audio_filepath"]).stem reader_id, lv_book_id, sample_id = name.split("_") key = f"{lv_book_id}_{reader_id}" if key not in data: data[key] = {} data[key][sample_id] = line for key, v in data.items(): with open(f"{self.submanifests_dir}/{key}.json", "w") as f_out: for sample_id in sorted(v.keys()): line = v[sample_id] f_out.write(line) # Restore P&C to submanifests. os.makedirs(str(self.restored_submanifests_dir), exist_ok=True) if self.language_short: try: normalizer = Normalizer( input_case="cased", lang=self.language_short, cache_dir="CACHE_DIR", overwrite_cache=False, post_process=True, ) except NotImplementedError: # some languages don't support text normalization logger.info( f"Could not find NeMo Normalizer for language {self.language_short}, so" " will not normalize the Librivox text before attempting to restore punctuation" " and capitalization." ) normalizer = None else: logger.info( f"`language_short` was not specified, so will not normalize the Librivox" " text before attempting to restore punctuation and capitalization." ) normalizer = None # TODO: rename to maybe books_ids_in_datasplit books_ids_in_submanifests = set([x.split("_")[0] for x in data.keys()]) Parallel(n_jobs=self.n_jobs)( delayed(process_book)( book_id, str(Path(lv_text_data_folder) / self.language_long), str(self.submanifests_dir), str(self.restored_submanifests_dir), self.restored_text_field, normalizer, ) for book_id in tqdm(books_ids_in_submanifests) ) # get stats --- keep track of book/spk ids in our datasplit book_id_spk_ids_in_datasplit = set() # set of tuples (book_id, spk_id), ... original_manifest_duration = 0 with open(self.input_manifest_file, "r") as f: for line in f: line = json.loads(line) book_id, spk_id = Path(line["audio_filepath"]).stem.split("_")[:2] book_id_spk_ids_in_datasplit.add((book_id, spk_id)) original_manifest_duration += line["duration"] logger.info( f"duration ORIGINAL total (for current datasplit): {round(original_manifest_duration / 60 / 60, 2)} hrs" ) # make dicts to record durations of manifests filename_to_sub_manifest_durs = collections.defaultdict(float) filename_to_restored_sub_manifest_durs = collections.defaultdict(float) # duration in submanifests for book_id, spk_id in book_id_spk_ids_in_datasplit: manifest = os.path.join(self.submanifests_dir, f"{spk_id}_{book_id}.json") with open(manifest, "r") as f: for line in f: line = json.loads(line) filename_to_sub_manifest_durs[f"{spk_id}_{book_id}.json"] += line["duration"] # duration in restored_submanifests for book_id, spk_id in book_id_spk_ids_in_datasplit: manifest = os.path.join(self.restored_submanifests_dir, f"{spk_id}_{book_id}.json") if os.path.exists(manifest): with open(manifest, "r") as f: for line in f: line = json.loads(line) if line[self.restored_text_field] != NA: filename_to_restored_sub_manifest_durs[f"{spk_id}_{book_id}.json"] += line["duration"] else: filename_to_restored_sub_manifest_durs[f"{spk_id}_{book_id}.json"] = 0 if self.show_conversion_breakdown: for filename in filename_to_sub_manifest_durs.keys(): orig_dur = filename_to_sub_manifest_durs[filename] restored_dur = filename_to_restored_sub_manifest_durs[filename] pc_restored = 100 * restored_dur / orig_dur logger.info( f"{filename}: {orig_dur/60:.2f} mins -> {restored_dur/60:.2f} mins\t({pc_restored:.2f}% restored)" ) sub_manifest_duration = sum(list(filename_to_sub_manifest_durs.values())) restored_manifest_duration = sum(list(filename_to_restored_sub_manifest_durs.values())) logger.info("duration in submanifests (for current datasplit): %.2f hrs", sub_manifest_duration / 60 / 60) logger.info( "duration restored (for current datasplit): %.2f hrs (%.2f%%), lost: %.2f hrs", restored_manifest_duration / 60 / 60, restored_manifest_duration / sub_manifest_duration * 100, (sub_manifest_duration - restored_manifest_duration) / 60 / 60, ) logger.info( "Combining restored manifest for current datasplit into single manifest at %s", self.output_manifest_file ) # duration in restored_submanifests with open(self.output_manifest_file, 'w') as fout: for book_id, spk_id in book_id_spk_ids_in_datasplit: manifest = os.path.join(self.restored_submanifests_dir, f"{spk_id}_{book_id}.json") if os.path.exists(manifest): with open(manifest, "r") as fin: for line in fin: fout.write(line)