refine code
This commit is contained in:
parent
8c22397b55
commit
0e18d60057
|
@ -0,0 +1,71 @@
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
import jsonargparse
|
||||||
|
import numpy as np
|
||||||
|
import paddle.fluid.dygraph as dg
|
||||||
|
from paddle import fluid
|
||||||
|
|
||||||
|
import utils
|
||||||
|
from waveflow import WaveFlow
|
||||||
|
|
||||||
|
|
||||||
|
def add_options_to_parser(parser):
|
||||||
|
parser.add_argument('--model', type=str, default='waveflow',
|
||||||
|
help="general name of the model")
|
||||||
|
parser.add_argument('--name', type=str,
|
||||||
|
help="specific name of the training model")
|
||||||
|
parser.add_argument('--root', type=str,
|
||||||
|
help="root path of the LJSpeech dataset")
|
||||||
|
|
||||||
|
parser.add_argument('--use_gpu', type=bool, default=True,
|
||||||
|
help="option to use gpu training")
|
||||||
|
|
||||||
|
parser.add_argument('--iteration', type=int, default=None,
|
||||||
|
help=("which iteration of checkpoint to load, "
|
||||||
|
"default to load the latest checkpoint"))
|
||||||
|
parser.add_argument('--checkpoint', type=str, default=None,
|
||||||
|
help="path of the checkpoint to load")
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark(config):
|
||||||
|
pprint(jsonargparse.namespace_to_dict(config))
|
||||||
|
|
||||||
|
# Get checkpoint directory path.
|
||||||
|
run_dir = os.path.join("runs", config.model, config.name)
|
||||||
|
checkpoint_dir = os.path.join(run_dir, "checkpoint")
|
||||||
|
|
||||||
|
# Configurate device.
|
||||||
|
place = fluid.CUDAPlace(0) if config.use_gpu else fluid.CPUPlace()
|
||||||
|
|
||||||
|
with dg.guard(place):
|
||||||
|
# Fix random seed.
|
||||||
|
seed = config.seed
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
fluid.default_startup_program().random_seed = seed
|
||||||
|
fluid.default_main_program().random_seed = seed
|
||||||
|
print("Random Seed: ", seed)
|
||||||
|
|
||||||
|
# Build model.
|
||||||
|
model = WaveFlow(config, checkpoint_dir)
|
||||||
|
model.build(training=False)
|
||||||
|
|
||||||
|
# Run model inference.
|
||||||
|
model.benchmark()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Create parser.
|
||||||
|
parser = jsonargparse.ArgumentParser(
|
||||||
|
description="Synthesize audio using WaveNet model",
|
||||||
|
formatter_class='default_argparse')
|
||||||
|
add_options_to_parser(parser)
|
||||||
|
utils.add_config_options_to_parser(parser)
|
||||||
|
|
||||||
|
# Parse argument from both command line and yaml config file.
|
||||||
|
# For conflicting updates to the same field,
|
||||||
|
# the preceding update will be overwritten by the following one.
|
||||||
|
config = parser.parse_args()
|
||||||
|
benchmark(config)
|
|
@ -1,24 +0,0 @@
|
||||||
valid_size: 16
|
|
||||||
segment_length: 16000
|
|
||||||
sample_rate: 22050
|
|
||||||
fft_window_shift: 256
|
|
||||||
fft_window_size: 1024
|
|
||||||
fft_size: 1024
|
|
||||||
mel_bands: 80
|
|
||||||
mel_fmin: 0.0
|
|
||||||
mel_fmax: 8000.0
|
|
||||||
|
|
||||||
seed: 123
|
|
||||||
learning_rate: 0.0002
|
|
||||||
batch_size: 8
|
|
||||||
test_every: 2000
|
|
||||||
save_every: 5000
|
|
||||||
max_iterations: 2000000
|
|
||||||
|
|
||||||
sigma: 1.0
|
|
||||||
n_flows: 8
|
|
||||||
n_group: 16
|
|
||||||
n_layers: 8
|
|
||||||
n_channels: 64
|
|
||||||
kernel_h: 3
|
|
||||||
kernel_w: 3
|
|
|
@ -4,7 +4,6 @@ import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
|
|
||||||
import utils
|
|
||||||
from parakeet.datasets import ljspeech
|
from parakeet.datasets import ljspeech
|
||||||
from parakeet.data import dataset
|
from parakeet.data import dataset
|
||||||
from parakeet.data.batch import SpecBatcher, WavBatcher
|
from parakeet.data.batch import SpecBatcher, WavBatcher
|
||||||
|
@ -12,8 +11,6 @@ from parakeet.data.datacargo import DataCargo
|
||||||
from parakeet.data.sampler import DistributedSampler, BatchSampler
|
from parakeet.data.sampler import DistributedSampler, BatchSampler
|
||||||
from scipy.io.wavfile import read
|
from scipy.io.wavfile import read
|
||||||
|
|
||||||
MAX_WAV_VALUE = 32768.0
|
|
||||||
|
|
||||||
|
|
||||||
class Dataset(ljspeech.LJSpeech):
|
class Dataset(ljspeech.LJSpeech):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
@ -78,10 +75,9 @@ class Subset(dataset.Dataset):
|
||||||
audio = np.pad(audio, (0, segment_length - audio.shape[0]),
|
audio = np.pad(audio, (0, segment_length - audio.shape[0]),
|
||||||
mode='constant', constant_values=0)
|
mode='constant', constant_values=0)
|
||||||
|
|
||||||
# Normalize audio.
|
# Normalize audio to the [-1, 1] range.
|
||||||
audio = audio.astype(np.float32) / MAX_WAV_VALUE
|
audio = audio.astype(np.float32) / 32768.0
|
||||||
mel = self.get_mel(audio)
|
mel = self.get_mel(audio)
|
||||||
#print("mel = {}, dtype {}, shape {}".format(mel, mel.dtype, mel.shape))
|
|
||||||
|
|
||||||
return audio, mel
|
return audio, mel
|
||||||
|
|
||||||
|
|
|
@ -1,3 +0,0 @@
|
||||||
paddlepaddle-gpu==1.6.1.post97
|
|
||||||
tensorboardX==1.9
|
|
||||||
librosa==0.7.1
|
|
|
@ -14,8 +14,6 @@ import slurm
|
||||||
import utils
|
import utils
|
||||||
from waveflow import WaveFlow
|
from waveflow import WaveFlow
|
||||||
|
|
||||||
MAXIMUM_SAVE_TIME = 10 * 60
|
|
||||||
|
|
||||||
|
|
||||||
def add_options_to_parser(parser):
|
def add_options_to_parser(parser):
|
||||||
parser.add_argument('--model', type=str, default='waveflow',
|
parser.add_argument('--model', type=str, default='waveflow',
|
||||||
|
@ -35,8 +33,6 @@ def add_options_to_parser(parser):
|
||||||
"default to load the latest checkpoint"))
|
"default to load the latest checkpoint"))
|
||||||
parser.add_argument('--checkpoint', type=str, default=None,
|
parser.add_argument('--checkpoint', type=str, default=None,
|
||||||
help="path of the checkpoint to load")
|
help="path of the checkpoint to load")
|
||||||
parser.add_argument('--slurm', type=bool, default=False,
|
|
||||||
help="whether you are using slurm to submit training jobs")
|
|
||||||
|
|
||||||
|
|
||||||
def train(config):
|
def train(config):
|
||||||
|
@ -85,13 +81,6 @@ def train(config):
|
||||||
else:
|
else:
|
||||||
iteration = int(config.checkpoint.split('/')[-1].split('-')[-1])
|
iteration = int(config.checkpoint.split('/')[-1].split('-')[-1])
|
||||||
|
|
||||||
# Get restart command if using slurm.
|
|
||||||
if config.slurm:
|
|
||||||
resume_command, death_time = slurm.restart_command()
|
|
||||||
if rank == 0:
|
|
||||||
print("Restart command:", " ".join(resume_command))
|
|
||||||
done = False
|
|
||||||
|
|
||||||
while iteration < config.max_iterations:
|
while iteration < config.max_iterations:
|
||||||
# Run one single training step.
|
# Run one single training step.
|
||||||
model.train_step(iteration)
|
model.train_step(iteration)
|
||||||
|
@ -102,20 +91,6 @@ def train(config):
|
||||||
# Run validation step.
|
# Run validation step.
|
||||||
model.valid_step(iteration)
|
model.valid_step(iteration)
|
||||||
|
|
||||||
# Check whether reaching the time limit.
|
|
||||||
if config.slurm:
|
|
||||||
done = (death_time is not None and death_time - time.time() <
|
|
||||||
MAXIMUM_SAVE_TIME)
|
|
||||||
|
|
||||||
if rank == 0 and done:
|
|
||||||
print("Saving progress before exiting.")
|
|
||||||
model.save(iteration)
|
|
||||||
|
|
||||||
print("Running restart command:", " ".join(resume_command))
|
|
||||||
# Submit restart command.
|
|
||||||
subprocess.check_call(resume_command)
|
|
||||||
break
|
|
||||||
|
|
||||||
if rank == 0 and iteration % config.save_every == 0:
|
if rank == 0 and iteration % config.save_every == 0:
|
||||||
# Save parameters.
|
# Save parameters.
|
||||||
model.save(iteration)
|
model.save(iteration)
|
||||||
|
|
|
@ -57,27 +57,6 @@ def add_config_options_to_parser(parser):
|
||||||
parser.add_argument('--config', action=jsonargparse.ActionConfigFile)
|
parser.add_argument('--config', action=jsonargparse.ActionConfigFile)
|
||||||
|
|
||||||
|
|
||||||
def pad_to_size(array, length, pad_with=0.0):
|
|
||||||
"""
|
|
||||||
Pad an array on the first (length) axis to a given length.
|
|
||||||
"""
|
|
||||||
padding = length - array.shape[0]
|
|
||||||
assert padding >= 0, "Padding required was less than zero"
|
|
||||||
|
|
||||||
paddings = [(0, 0)] * len(array.shape)
|
|
||||||
paddings[0] = (0, padding)
|
|
||||||
|
|
||||||
return np.pad(array, paddings, mode='constant', constant_values=pad_with)
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_context_size(config):
|
|
||||||
dilations = list(
|
|
||||||
itertools.islice(
|
|
||||||
itertools.cycle(config.dilation_block), config.layers))
|
|
||||||
config.context_size = sum(dilations) + 1
|
|
||||||
print("Context size is", config.context_size)
|
|
||||||
|
|
||||||
|
|
||||||
def load_latest_checkpoint(checkpoint_dir, rank=0):
|
def load_latest_checkpoint(checkpoint_dir, rank=0):
|
||||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||||
# Create checkpoint index file if not exist.
|
# Create checkpoint index file if not exist.
|
||||||
|
|
|
@ -2,11 +2,10 @@ import itertools
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
#import librosa
|
|
||||||
from scipy.io.wavfile import write
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
|
from scipy.io.wavfile import write
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
from data import LJSpeech
|
from data import LJSpeech
|
||||||
|
@ -29,18 +28,6 @@ class WaveFlow():
|
||||||
self.trainloader = dataset.trainloader
|
self.trainloader = dataset.trainloader
|
||||||
self.validloader = dataset.validloader
|
self.validloader = dataset.validloader
|
||||||
|
|
||||||
# if self.rank == 0:
|
|
||||||
# for i, (audios, mels) in enumerate(self.validloader()):
|
|
||||||
# print("audios {}, mels {}".format(audios.dtype, mels.dtype))
|
|
||||||
# print("{}: rank {}, audios {}, mels {}".format(
|
|
||||||
# i, self.rank, audios.shape, mels.shape))
|
|
||||||
#
|
|
||||||
# for i, (audios, mels) in enumerate(self.trainloader):
|
|
||||||
# print("{}: rank {}, audios {}, mels {}".format(
|
|
||||||
# i, self.rank, audios.shape, mels.shape))
|
|
||||||
#
|
|
||||||
# exit()
|
|
||||||
|
|
||||||
waveflow = WaveFlowModule("waveflow", config)
|
waveflow = WaveFlowModule("waveflow", config)
|
||||||
|
|
||||||
# Dry run once to create and initalize all necessary parameters.
|
# Dry run once to create and initalize all necessary parameters.
|
||||||
|
@ -96,8 +83,6 @@ class WaveFlow():
|
||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
current_lr = self.optimizer._learning_rate
|
|
||||||
|
|
||||||
self.optimizer.minimize(loss, parameter_list=self.waveflow.parameters())
|
self.optimizer.minimize(loss, parameter_list=self.waveflow.parameters())
|
||||||
self.waveflow.clear_gradients()
|
self.waveflow.clear_gradients()
|
||||||
|
|
||||||
|
@ -113,7 +98,6 @@ class WaveFlow():
|
||||||
|
|
||||||
tb = self.tb_logger
|
tb = self.tb_logger
|
||||||
tb.add_scalar("Train-Loss-Rank-0", loss_val, iteration)
|
tb.add_scalar("Train-Loss-Rank-0", loss_val, iteration)
|
||||||
tb.add_scalar("Learning-Rate", current_lr, iteration)
|
|
||||||
|
|
||||||
@dg.no_grad
|
@dg.no_grad
|
||||||
def valid_step(self, iteration):
|
def valid_step(self, iteration):
|
||||||
|
@ -161,34 +145,44 @@ class WaveFlow():
|
||||||
if sample is not None:
|
if sample is not None:
|
||||||
mels_list = [mels_list[sample]]
|
mels_list = [mels_list[sample]]
|
||||||
|
|
||||||
audio_times = []
|
|
||||||
inf_times = []
|
|
||||||
for sample, mel in enumerate(mels_list):
|
for sample, mel in enumerate(mels_list):
|
||||||
filename = "{}/valid_{}.wav".format(output, sample)
|
filename = "{}/valid_{}.wav".format(output, sample)
|
||||||
print("Synthesize sample {}, save as {}".format(sample, filename))
|
print("Synthesize sample {}, save as {}".format(sample, filename))
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
audio = self.waveflow.synthesize(mel)
|
audio = self.waveflow.synthesize(mel, sigma=self.config.sigma)
|
||||||
syn_time = time.time() - start_time
|
syn_time = time.time() - start_time
|
||||||
|
|
||||||
audio_time = audio.shape[0] / 22050
|
audio = audio[0]
|
||||||
print("audio time {}, synthesis time {}, speedup: {}".format(
|
audio_time = audio.shape[0] / self.config.sample_rate
|
||||||
audio_time, syn_time, audio_time / syn_time))
|
print("audio time {:.4f}, synthesis time {:.4f}".format(
|
||||||
|
audio_time, syn_time))
|
||||||
|
|
||||||
#librosa.output.write_wav(filename, syn_audio,
|
# Denormalize audio from [-1, 1] to [-32768, 32768] int16 range.
|
||||||
# sr=config.sample_rate)
|
|
||||||
audio = audio.numpy() * 32768.0
|
audio = audio.numpy() * 32768.0
|
||||||
audio = audio.astype('int16')
|
audio = audio.astype('int16')
|
||||||
write(filename, config.sample_rate, audio)
|
write(filename, config.sample_rate, audio)
|
||||||
|
|
||||||
audio_times.append(audio_time)
|
@dg.no_grad
|
||||||
inf_times.append(syn_time)
|
def benchmark(self):
|
||||||
|
self.waveflow.eval()
|
||||||
|
|
||||||
total_audio = sum(audio_times)
|
mels_list = [mels for _, mels in self.validloader()]
|
||||||
total_inf = sum(inf_times)
|
mel = fluid.layers.concat(mels_list, axis=2)
|
||||||
|
mel = mel[:, :, :864]
|
||||||
|
batch_size = 8
|
||||||
|
mel = fluid.layers.expand(mel, [batch_size, 1, 1])
|
||||||
|
|
||||||
print("Total audio: {}, total inf time {}, speedup: {}".format(
|
for i in range(10):
|
||||||
total_audio, total_inf, total_audio / total_inf))
|
start_time = time.time()
|
||||||
|
audio = self.waveflow.synthesize(mel, sigma=self.config.sigma)
|
||||||
|
print("audio.shape = ", audio.shape)
|
||||||
|
syn_time = time.time() - start_time
|
||||||
|
|
||||||
|
audio_time = audio.shape[1] * batch_size / self.config.sample_rate
|
||||||
|
print("audio time {:.4f}, synthesis time {:.4f}".format(
|
||||||
|
audio_time, syn_time))
|
||||||
|
print("{} X real-time".format(audio_time / syn_time))
|
||||||
|
|
||||||
def save(self, iteration):
|
def save(self, iteration):
|
||||||
utils.save_latest_parameters(self.checkpoint_dir, iteration,
|
utils.save_latest_parameters(self.checkpoint_dir, iteration,
|
||||||
|
|
|
@ -23,7 +23,6 @@ def set_param_attr(layer, c_in=1):
|
||||||
|
|
||||||
def unfold(x, n_group):
|
def unfold(x, n_group):
|
||||||
length = x.shape[-1]
|
length = x.shape[-1]
|
||||||
#assert length % n_group == 0
|
|
||||||
new_shape = x.shape[:-1] + [length // n_group, n_group]
|
new_shape = x.shape[:-1] + [length // n_group, n_group]
|
||||||
return fluid.layers.reshape(x, new_shape)
|
return fluid.layers.reshape(x, new_shape)
|
||||||
|
|
||||||
|
@ -192,13 +191,53 @@ class Flow(dg.Layer):
|
||||||
|
|
||||||
return self.end(output)
|
return self.end(output)
|
||||||
|
|
||||||
|
def infer(self, audio, mel, queues):
|
||||||
|
audio = self.start(audio)
|
||||||
|
|
||||||
def debug(x, msg):
|
for i in range(self.n_layers):
|
||||||
y = x.numpy()
|
dilation_h = self.dilation_h_list[i]
|
||||||
print(msg + " :\n", y)
|
dilation_w = 2 ** i
|
||||||
print("shape: ", y.shape)
|
|
||||||
print("dtype: ", y.dtype)
|
state_size = dilation_h * (self.kernel_h - 1)
|
||||||
print("")
|
queue = queues[i]
|
||||||
|
|
||||||
|
if len(queue) == 0:
|
||||||
|
for j in range(state_size):
|
||||||
|
queue.append(fluid.layers.zeros_like(audio))
|
||||||
|
|
||||||
|
state = queue[0:state_size]
|
||||||
|
state = fluid.layers.concat([*state, audio], axis=2)
|
||||||
|
|
||||||
|
queue.pop(0)
|
||||||
|
queue.append(audio)
|
||||||
|
|
||||||
|
# Pad height dim (n_group): causal convolution
|
||||||
|
# Pad width dim (time): dialated non-causal convolution
|
||||||
|
pad_top, pad_bottom = 0, 0
|
||||||
|
pad_left = int((self.kernel_w-1) * dilation_w / 2)
|
||||||
|
pad_right = int((self.kernel_w-1) * dilation_w / 2)
|
||||||
|
state = fluid.layers.pad2d(state,
|
||||||
|
paddings=[pad_top, pad_bottom, pad_left, pad_right])
|
||||||
|
|
||||||
|
hidden = self.in_layers[i](state)
|
||||||
|
cond_hidden = self.cond_layers[i](mel)
|
||||||
|
in_acts = hidden + cond_hidden
|
||||||
|
out_acts = fluid.layers.tanh(in_acts[:, :self.n_channels, :]) * \
|
||||||
|
fluid.layers.sigmoid(in_acts[:, self.n_channels:, :])
|
||||||
|
res_skip_acts = self.res_skip_layers[i](out_acts)
|
||||||
|
|
||||||
|
if i < self.n_layers - 1:
|
||||||
|
audio += res_skip_acts[:, :self.n_channels, :, :]
|
||||||
|
skip_acts = res_skip_acts[:, self.n_channels:, :, :]
|
||||||
|
else:
|
||||||
|
skip_acts = res_skip_acts
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
output = skip_acts
|
||||||
|
else:
|
||||||
|
output += skip_acts
|
||||||
|
|
||||||
|
return self.end(output)
|
||||||
|
|
||||||
|
|
||||||
class WaveFlowModule(dg.Layer):
|
class WaveFlowModule(dg.Layer):
|
||||||
|
@ -206,7 +245,9 @@ class WaveFlowModule(dg.Layer):
|
||||||
super(WaveFlowModule, self).__init__(name_scope)
|
super(WaveFlowModule, self).__init__(name_scope)
|
||||||
self.n_flows = config.n_flows
|
self.n_flows = config.n_flows
|
||||||
self.n_group = config.n_group
|
self.n_group = config.n_group
|
||||||
|
self.n_layers = config.n_layers
|
||||||
assert self.n_group % 2 == 0
|
assert self.n_group % 2 == 0
|
||||||
|
assert self.n_flows % 2 == 0
|
||||||
|
|
||||||
self.conditioner = Conditioner(self.full_name())
|
self.conditioner = Conditioner(self.full_name())
|
||||||
self.flows = []
|
self.flows = []
|
||||||
|
@ -215,14 +256,16 @@ class WaveFlowModule(dg.Layer):
|
||||||
self.flows.append(flow)
|
self.flows.append(flow)
|
||||||
self.add_sublayer("flow_{}".format(i), flow)
|
self.add_sublayer("flow_{}".format(i), flow)
|
||||||
|
|
||||||
self.perms = [[15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
|
self.perms = []
|
||||||
[15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
|
half = self.n_group // 2
|
||||||
[15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
|
for i in range(self.n_flows):
|
||||||
[15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
|
perm = list(range(self.n_group))
|
||||||
[7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8],
|
if i < self.n_flows // 2:
|
||||||
[7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8],
|
perm = perm[::-1]
|
||||||
[7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8],
|
else:
|
||||||
[7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8]]
|
perm[:half] = reversed(perm[:half])
|
||||||
|
perm[half:] = reversed(perm[half:])
|
||||||
|
self.perms.append(perm)
|
||||||
|
|
||||||
def forward(self, audio, mel):
|
def forward(self, audio, mel):
|
||||||
mel = self.conditioner(mel)
|
mel = self.conditioner(mel)
|
||||||
|
@ -266,19 +309,13 @@ class WaveFlowModule(dg.Layer):
|
||||||
return z, log_s_list
|
return z, log_s_list
|
||||||
|
|
||||||
def synthesize(self, mel, sigma=1.0):
|
def synthesize(self, mel, sigma=1.0):
|
||||||
#debug(mel, "mel")
|
|
||||||
mel = self.conditioner.infer(mel)
|
mel = self.conditioner.infer(mel)
|
||||||
#debug(mel, "mel after conditioner")
|
|
||||||
|
|
||||||
# From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group]
|
# From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group]
|
||||||
mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2])
|
mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2])
|
||||||
#debug(mel, "after group")
|
|
||||||
|
|
||||||
audio = fluid.layers.gaussian_random(
|
audio = fluid.layers.gaussian_random(
|
||||||
shape=[mel.shape[0], 1, mel.shape[2], mel.shape[3]], std=sigma)
|
shape=[mel.shape[0], 1, mel.shape[2], mel.shape[3]], std=sigma)
|
||||||
|
|
||||||
#debug(audio, "audio")
|
|
||||||
|
|
||||||
for i in reversed(range(self.n_flows)):
|
for i in reversed(range(self.n_flows)):
|
||||||
# Permute over the height dimension.
|
# Permute over the height dimension.
|
||||||
audio_slices = [audio[:, :, j, :] for j in self.perms[i]]
|
audio_slices = [audio[:, :, j, :] for j in self.perms[i]]
|
||||||
|
@ -287,34 +324,28 @@ class WaveFlowModule(dg.Layer):
|
||||||
mel = fluid.layers.stack(mel_slices, axis=2)
|
mel = fluid.layers.stack(mel_slices, axis=2)
|
||||||
|
|
||||||
audio_list = []
|
audio_list = []
|
||||||
audio_0 = audio[:, :, :1, :]
|
audio_0 = audio[:, :, 0:1, :]
|
||||||
audio_list.append(audio_0)
|
audio_list.append(audio_0)
|
||||||
|
audio_h = audio_0
|
||||||
|
queues = [[] for _ in range(self.n_layers)]
|
||||||
|
|
||||||
for h in range(1, self.n_group):
|
for h in range(1, self.n_group):
|
||||||
# inputs: [bs, 1, h, time/n_group]
|
inputs = audio_h
|
||||||
inputs = fluid.layers.concat(audio_list, axis=2)
|
conds = mel[:, :, h:(h+1), :]
|
||||||
conds = mel[:, :, 1:(h+1), :]
|
outputs = self.flows[i].infer(inputs, conds, queues)
|
||||||
outputs = self.flows[i](inputs, conds)
|
|
||||||
|
|
||||||
log_s = outputs[:, :1, (h-1):h, :]
|
log_s = outputs[:, 0:1, :, :]
|
||||||
b = outputs[:, 1:, (h-1):h, :]
|
b = outputs[:, 1:, :, :]
|
||||||
audio_h = (audio[:, :, h:(h+1), :] - b) / fluid.layers.exp(log_s)
|
audio_h = (audio[:, :, h:(h+1), :] - b) / \
|
||||||
|
fluid.layers.exp(log_s)
|
||||||
audio_list.append(audio_h)
|
audio_list.append(audio_h)
|
||||||
|
|
||||||
audio = fluid.layers.concat(audio_list, axis=2)
|
audio = fluid.layers.concat(audio_list, axis=2)
|
||||||
#print("audio.shape =", audio.shape)
|
|
||||||
|
|
||||||
# Assume batch size = 1
|
# audio: [bs, n_group, time/n_group]
|
||||||
# audio: [n_group, time/n_group]
|
audio = fluid.layers.squeeze(audio, [1])
|
||||||
audio = fluid.layers.squeeze(audio, [0, 1])
|
# audio: [bs, time]
|
||||||
# audio: [time]
|
|
||||||
audio = fluid.layers.reshape(
|
audio = fluid.layers.reshape(
|
||||||
fluid.layers.transpose(audio, [1, 0]), [-1])
|
fluid.layers.transpose(audio, [0, 2, 1]), [audio.shape[0], -1])
|
||||||
#print("audio.shape =", audio.shape)
|
|
||||||
|
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
def start_new_sequence(self):
|
|
||||||
for layer in self.sublayers():
|
|
||||||
if isinstance(layer, conv.Conv1D):
|
|
||||||
layer.start_new_sequence()
|
|
||||||
|
|
Loading…
Reference in New Issue