add working synthesis code
This commit is contained in:
parent
f6f0a2ca21
commit
8c22397b55
|
@ -79,17 +79,13 @@ class Subset(dataset.Dataset):
|
|||
mode='constant', constant_values=0)
|
||||
|
||||
# Normalize audio.
|
||||
audio = audio / MAX_WAV_VALUE
|
||||
audio = audio.astype(np.float32) / MAX_WAV_VALUE
|
||||
mel = self.get_mel(audio)
|
||||
#print("mel = {}, dtype {}, shape {}".format(mel, mel.dtype, mel.shape))
|
||||
|
||||
return audio, mel
|
||||
|
||||
def _batch_examples(self, batch):
|
||||
audio_batch = []
|
||||
mel_batch = []
|
||||
for audio, mel in batch:
|
||||
audio_batch
|
||||
|
||||
audios = [sample[0] for sample in batch]
|
||||
mels = [sample[1] for sample in batch]
|
||||
|
||||
|
|
|
@ -8,11 +8,11 @@ import paddle.fluid.dygraph as dg
|
|||
from paddle import fluid
|
||||
|
||||
import utils
|
||||
from wavenet import WaveNet
|
||||
from waveflow import WaveFlow
|
||||
|
||||
|
||||
def add_options_to_parser(parser):
|
||||
parser.add_argument('--model', type=str, default='wavenet',
|
||||
parser.add_argument('--model', type=str, default='waveflow',
|
||||
help="general name of the model")
|
||||
parser.add_argument('--name', type=str,
|
||||
help="specific name of the training model")
|
||||
|
@ -30,7 +30,7 @@ def add_options_to_parser(parser):
|
|||
|
||||
parser.add_argument('--output', type=str, default="./syn_audios",
|
||||
help="path to write synthesized audio files")
|
||||
parser.add_argument('--sample', type=int,
|
||||
parser.add_argument('--sample', type=int, default=None,
|
||||
help="which of the valid samples to synthesize audio")
|
||||
|
||||
|
||||
|
@ -54,7 +54,7 @@ def synthesize(config):
|
|||
print("Random Seed: ", seed)
|
||||
|
||||
# Build model.
|
||||
model = WaveNet(config, checkpoint_dir)
|
||||
model = WaveFlow(config, checkpoint_dir)
|
||||
model.build(training=False)
|
||||
|
||||
# Obtain the current iteration.
|
||||
|
|
|
@ -2,7 +2,8 @@ import itertools
|
|||
import os
|
||||
import time
|
||||
|
||||
import librosa
|
||||
#import librosa
|
||||
from scipy.io.wavfile import write
|
||||
import numpy as np
|
||||
import paddle.fluid.dygraph as dg
|
||||
from paddle import fluid
|
||||
|
@ -156,17 +157,38 @@ class WaveFlow():
|
|||
output = "{}/{}/iter-{}".format(config.output, config.name, iteration)
|
||||
os.makedirs(output, exist_ok=True)
|
||||
|
||||
filename = "{}/valid_{}.wav".format(output, sample)
|
||||
print("Synthesize sample {}, save as {}".format(sample, filename))
|
||||
mels_list = [mels for _, mels in self.validloader()]
|
||||
if sample is not None:
|
||||
mels_list = [mels_list[sample]]
|
||||
|
||||
mels_list = [mels for _, mels, _ in self.validloader()]
|
||||
start_time = time.time()
|
||||
syn_audio = self.waveflow.synthesize(mels_list[sample])
|
||||
syn_time = time.time() - start_time
|
||||
print("audio shape {}, synthesis time {}".format(
|
||||
syn_audio.shape, syn_time))
|
||||
librosa.output.write_wav(filename, syn_audio,
|
||||
sr=config.sample_rate)
|
||||
audio_times = []
|
||||
inf_times = []
|
||||
for sample, mel in enumerate(mels_list):
|
||||
filename = "{}/valid_{}.wav".format(output, sample)
|
||||
print("Synthesize sample {}, save as {}".format(sample, filename))
|
||||
|
||||
start_time = time.time()
|
||||
audio = self.waveflow.synthesize(mel)
|
||||
syn_time = time.time() - start_time
|
||||
|
||||
audio_time = audio.shape[0] / 22050
|
||||
print("audio time {}, synthesis time {}, speedup: {}".format(
|
||||
audio_time, syn_time, audio_time / syn_time))
|
||||
|
||||
#librosa.output.write_wav(filename, syn_audio,
|
||||
# sr=config.sample_rate)
|
||||
audio = audio.numpy() * 32768.0
|
||||
audio = audio.astype('int16')
|
||||
write(filename, config.sample_rate, audio)
|
||||
|
||||
audio_times.append(audio_time)
|
||||
inf_times.append(syn_time)
|
||||
|
||||
total_audio = sum(audio_times)
|
||||
total_inf = sum(inf_times)
|
||||
|
||||
print("Total audio: {}, total inf time {}, speedup: {}".format(
|
||||
total_audio, total_inf, total_audio / total_inf))
|
||||
|
||||
def save(self, iteration):
|
||||
utils.save_latest_parameters(self.checkpoint_dir, iteration,
|
||||
|
|
|
@ -75,6 +75,16 @@ class Conditioner(dg.Layer):
|
|||
|
||||
return fluid.layers.squeeze(x, [1])
|
||||
|
||||
def infer(self, x):
|
||||
x = fluid.layers.unsqueeze(x, 1)
|
||||
for layer in self.upsample_conv2d:
|
||||
x = layer(x)
|
||||
# Trim conv artifacts.
|
||||
time_cutoff = layer._filter_size[1] - layer._stride[1]
|
||||
x = fluid.layers.leaky_relu(x[:, :, :, :-time_cutoff], alpha=0.4)
|
||||
|
||||
return fluid.layers.squeeze(x, [1])
|
||||
|
||||
|
||||
class Flow(dg.Layer):
|
||||
def __init__(self, name_scope, config):
|
||||
|
@ -183,6 +193,14 @@ class Flow(dg.Layer):
|
|||
return self.end(output)
|
||||
|
||||
|
||||
def debug(x, msg):
|
||||
y = x.numpy()
|
||||
print(msg + " :\n", y)
|
||||
print("shape: ", y.shape)
|
||||
print("dtype: ", y.dtype)
|
||||
print("")
|
||||
|
||||
|
||||
class WaveFlowModule(dg.Layer):
|
||||
def __init__(self, name_scope, config):
|
||||
super(WaveFlowModule, self).__init__(name_scope)
|
||||
|
@ -217,7 +235,7 @@ class WaveFlowModule(dg.Layer):
|
|||
if mel.shape[2] > pruned_len:
|
||||
mel = mel[:, :, :pruned_len]
|
||||
|
||||
# From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group]
|
||||
# From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group]
|
||||
mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2])
|
||||
# From [bs, time] to [bs, n_group, time/n_group]
|
||||
audio = fluid.layers.transpose(unfold(audio, self.n_group), [0, 2, 1])
|
||||
|
@ -247,8 +265,54 @@ class WaveFlowModule(dg.Layer):
|
|||
|
||||
return z, log_s_list
|
||||
|
||||
def synthesize(self, mels):
|
||||
pass
|
||||
def synthesize(self, mel, sigma=1.0):
|
||||
#debug(mel, "mel")
|
||||
mel = self.conditioner.infer(mel)
|
||||
#debug(mel, "mel after conditioner")
|
||||
|
||||
# From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group]
|
||||
mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2])
|
||||
#debug(mel, "after group")
|
||||
|
||||
audio = fluid.layers.gaussian_random(
|
||||
shape=[mel.shape[0], 1, mel.shape[2], mel.shape[3]], std=sigma)
|
||||
|
||||
#debug(audio, "audio")
|
||||
|
||||
for i in reversed(range(self.n_flows)):
|
||||
# Permute over the height dimension.
|
||||
audio_slices = [audio[:, :, j, :] for j in self.perms[i]]
|
||||
audio = fluid.layers.stack(audio_slices, axis=2)
|
||||
mel_slices = [mel[:, :, j, :] for j in self.perms[i]]
|
||||
mel = fluid.layers.stack(mel_slices, axis=2)
|
||||
|
||||
audio_list = []
|
||||
audio_0 = audio[:, :, :1, :]
|
||||
audio_list.append(audio_0)
|
||||
|
||||
for h in range(1, self.n_group):
|
||||
# inputs: [bs, 1, h, time/n_group]
|
||||
inputs = fluid.layers.concat(audio_list, axis=2)
|
||||
conds = mel[:, :, 1:(h+1), :]
|
||||
outputs = self.flows[i](inputs, conds)
|
||||
|
||||
log_s = outputs[:, :1, (h-1):h, :]
|
||||
b = outputs[:, 1:, (h-1):h, :]
|
||||
audio_h = (audio[:, :, h:(h+1), :] - b) / fluid.layers.exp(log_s)
|
||||
audio_list.append(audio_h)
|
||||
|
||||
audio = fluid.layers.concat(audio_list, axis=2)
|
||||
#print("audio.shape =", audio.shape)
|
||||
|
||||
# Assume batch size = 1
|
||||
# audio: [n_group, time/n_group]
|
||||
audio = fluid.layers.squeeze(audio, [0, 1])
|
||||
# audio: [time]
|
||||
audio = fluid.layers.reshape(
|
||||
fluid.layers.transpose(audio, [1, 0]), [-1])
|
||||
#print("audio.shape =", audio.shape)
|
||||
|
||||
return audio
|
||||
|
||||
def start_new_sequence(self):
|
||||
for layer in self.sublayers():
|
||||
|
|
Loading…
Reference in New Issue