# Copyright (c) 2018 NVIDIA Corporation
import numpy as np
from scipy.io.wavfile import write
from .encoder_decoder import EncoderDecoderModel
[docs]def save_audio(signal, logdir, step, sampling_rate, mode):
signal = np.float32(signal)
file_name = '{}/sample_step{}_{}.wav'.format(logdir, step, mode)
if logdir[0] != '/':
file_name = "./" + file_name
write(file_name, sampling_rate, signal)
[docs]class Text2SpeechWavenet(EncoderDecoderModel):
[docs] @staticmethod
def get_required_params():
return dict(
EncoderDecoderModel.get_required_params(), **{}
)
def __init__(self, params, mode="train", hvd=None):
super(Text2SpeechWavenet, self).__init__(params, mode=mode, hvd=hvd)
[docs] def maybe_print_logs(self, input_values, output_values, training_step):
save_audio(
output_values[1][-1],
self.params["logdir"],
training_step,
sampling_rate=22050,
mode="train"
)
return {}
[docs] def evaluate(self, input_values, output_values):
return output_values[1][-1]
[docs] def finalize_evaluation(self, results_per_batch, training_step=None):
save_audio(
results_per_batch[0],
self.params["logdir"],
training_step,
sampling_rate=22050,
mode="eval"
)
return {}
[docs] def infer(self, input_values, output_values):
return output_values[1][-1]
[docs] def finalize_inference(self, results_per_batch, output_file):
return {}