Source code for sdp.processors.nemo.pc_inference
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from pathlib import Path
from typing import Dict, List, Optional, Union
from sdp.processors.base_processor import BaseProcessor
def load_manifest(manifest: Path) -> List[Dict[str, Union[str, float]]]:
result = []
with manifest.open() as f:
for i, line in enumerate(f):
data = json.loads(line)
result.append(data)
return result
[docs]
class PCInference(BaseProcessor):
"""Adds predictions of a text-based punctuation and capitalization (P&C) model.
Operates on the text in the ``input_text_field``, and saves predictions in
the ``output_text_field``.
Args:
input_text_field (str): the text field that will be the input to the P&C model.
output_text_field (str): the text field where the output of the PC model
will be saved.
batch_size (int): the batch sized used by the P&C model.
device (str): the device used by the P&C model. Can be skipped to auto-select.
pretrained_name (str): the pretrained_name of the P&C model.
model_path (str): the model path to the P&C model.
.. note::
Either ``pretrained_name`` or ``model_path`` have to be specified.
Returns:
The same data as in the input manifest with an additional field
<output_text_field> containing P&C model's predictions.
"""
def __init__(
self,
input_text_field: str,
output_text_field: str,
batch_size: int,
device: Optional[str] = None,
pretrained_name: Optional[str] = None,
model_path: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
self.pretrained_name = pretrained_name
self.model_path = model_path
self.input_text_field = input_text_field
self.output_text_field = output_text_field
self.device = device
self.batch_size = batch_size
# verify self.pretrained_name/model_path
if self.pretrained_name is None and self.model_path is None:
raise ValueError("pretrained_name and model_path cannot both be None")
if self.pretrained_name is not None and self.model_path is not None:
raise ValueError("pretrained_name and model_path cannot both be specified")
def process(self):
import torch # importing after nemo to make sure users first install nemo, instead of torch, then nemo
from nemo.collections.nlp.models import PunctuationCapitalizationModel
if self.pretrained_name:
model = PunctuationCapitalizationModel.from_pretrained(self.pretrained_name)
else:
model = PunctuationCapitalizationModel.restore_from(self.model_path)
if self.device is None:
if torch.cuda.is_available():
model = model.cuda()
else:
model = model.cpu()
else:
model = model.to(self.device)
manifest = load_manifest(Path(self.input_manifest_file))
texts = []
for item in manifest:
texts.append(item[self.input_text_field])
processed_texts = model.add_punctuation_capitalization(
texts,
batch_size=self.batch_size,
)
Path(self.output_manifest_file).parent.mkdir(exist_ok=True, parents=True)
with Path(self.output_manifest_file).open('w') as f:
for item, t in zip(manifest, processed_texts):
item[self.output_text_field] = t
f.write(json.dumps(item, ensure_ascii=False) + '\n')