hide fastspeech, deepvoice3, clarinet temporarily till they are updated
This commit is contained in:
parent
3ca037453e
commit
e87bfb7d05
|
@ -1,158 +0,0 @@
|
|||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle import distribution as D
|
||||
|
||||
from parakeet.models.wavenet import WaveNet, UpsampleNet, crop
|
||||
|
||||
__all__ = ["Clarinet"]
|
||||
|
||||
class ParallelWaveNet(nn.LayerList):
|
||||
def __init__(self, n_loops, n_layers, residual_channels, condition_dim,
|
||||
filter_size):
|
||||
"""ParallelWaveNet, an inverse autoregressive flow model, it contains several flows(WaveNets).
|
||||
|
||||
Args:
|
||||
n_loops (List[int]): `n_loop` for each flow.
|
||||
n_layers (List[int]): `n_layer` for each flow.
|
||||
residual_channels (int): `residual_channels` for every flow.
|
||||
condition_dim (int): `condition_dim` for every flow.
|
||||
filter_size (int): `filter_size` for every flow.
|
||||
"""
|
||||
super(ParallelWaveNet, self).__init__()
|
||||
for n_loop, n_layer in zip(n_loops, n_layers):
|
||||
# teacher's log_scale_min does not matter herem, -100 is a dummy value
|
||||
self.append(
|
||||
WaveNet(n_loop, n_layer, residual_channels, 3, condition_dim,
|
||||
filter_size, "mog", -100.0))
|
||||
|
||||
def forward(self, z, condition=None):
|
||||
"""Transform a random noise sampled from a standard Gaussian distribution into sample from the target distribution. And output the mean and log standard deviation of the output distribution.
|
||||
|
||||
Args:
|
||||
z (Variable): shape(B, T), random noise sampled from a standard gaussian disribution.
|
||||
condition (Variable, optional): shape(B, F, T), dtype float, the upsampled condition. Defaults to None.
|
||||
|
||||
Returns:
|
||||
(z, out_mu, out_log_std)
|
||||
z (Variable): shape(B, T), dtype float, transformed noise, it is the synthesized waveform.
|
||||
out_mu (Variable): shape(B, T), dtype float, means of the output distributions.
|
||||
out_log_std (Variable): shape(B, T), dtype float, log standard deviations of the output distributions.
|
||||
"""
|
||||
for i, flow in enumerate(self):
|
||||
theta = flow(z, condition) # w, mu, log_std [0: T]
|
||||
w, mu, log_std = paddle.chunk(theta, 3, axis=-1) # (B, T, 1) for each
|
||||
mu = paddle.squeeze(mu, -1) #[0: T]
|
||||
log_std = paddle.squeeze(log_std, -1) #[0: T]
|
||||
z = z * paddle.exp(log_std) + mu #[0: T]
|
||||
|
||||
if i == 0:
|
||||
out_mu = mu
|
||||
out_log_std = log_std
|
||||
else:
|
||||
out_mu = out_mu * paddle.exp(log_std) + mu
|
||||
out_log_std += log_std
|
||||
|
||||
return z, out_mu, out_log_std
|
||||
|
||||
|
||||
# Gaussian IAF model
|
||||
class Clarinet(nn.Layer):
|
||||
def __init__(self, encoder, teacher, student, stft,
|
||||
min_log_scale=-6.0, lmd=4.0):
|
||||
"""Clarinet model. Conditional Parallel WaveNet.
|
||||
|
||||
Args:
|
||||
encoder (UpsampleNet): an UpsampleNet to upsample mel spectrogram.
|
||||
teacher (WaveNet): a WaveNet, the teacher.
|
||||
student (ParallelWaveNet): a ParallelWaveNet model, the student.
|
||||
stft (STFT): a STFT model to perform differentiable stft transform.
|
||||
min_log_scale (float, optional): used only for computing loss, the minimal value of log standard deviation of the output distribution of both the teacher and the student . Defaults to -6.0.
|
||||
lmd (float, optional): weight for stft loss. Defaults to 4.0.
|
||||
"""
|
||||
super(Clarinet, self).__init__()
|
||||
self.encoder = encoder
|
||||
self.teacher = teacher
|
||||
self.student = student
|
||||
self.stft = stft
|
||||
|
||||
self.lmd = lmd
|
||||
self.min_log_scale = min_log_scale
|
||||
|
||||
def forward(self, audio, mel, audio_start, clip_kl=True):
|
||||
"""Compute loss of Clarinet model.
|
||||
|
||||
Args:
|
||||
audio (Variable): shape(B, T_audio), dtype flaot32, ground truth waveform.
|
||||
mel (Variable): shape(B, F, T_mel), dtype flaot32, condition(mel spectrogram here).
|
||||
audio_start (Variable): shape(B, ), dtype int64, audio starts positions.
|
||||
clip_kl (bool, optional): whether to clip kl_loss by maximum=100. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Dict(str, Variable)
|
||||
loss (Variable): shape(1, ), dtype flaot32, total loss.
|
||||
kl (Variable): shape(1, ), dtype flaot32, kl divergence between the teacher's output distribution and student's output distribution.
|
||||
regularization (Variable): shape(1, ), dtype flaot32, a regularization term of the KL divergence.
|
||||
spectrogram_frame_loss (Variable): shape(1, ), dytpe: float, stft loss, the L1-distance of the magnitudes of the spectrograms of the ground truth waveform and synthesized waveform.
|
||||
"""
|
||||
batch_size, audio_length = audio.shape # audio clip's length
|
||||
|
||||
z = paddle.randn(audio.shape)
|
||||
condition = self.encoder(mel) # (B, C, T)
|
||||
condition_slice = crop(condition, audio_start, audio_length)
|
||||
|
||||
x, s_means, s_scales = self.student(z, condition_slice) # all [0: T]
|
||||
s_means = s_means[:, 1:] # (B, T-1), time steps [1: T]
|
||||
s_scales = s_scales[:, 1:] # (B, T-1), time steps [1: T]
|
||||
s_clipped_scales = paddle.clip(s_scales, self.min_log_scale, 100.)
|
||||
|
||||
# teacher outputs single gaussian
|
||||
y = self.teacher(x[:, :-1], condition_slice[:, :, 1:])
|
||||
_, t_means, t_scales = paddle.chunk(y, 3, axis=-1) # time steps [1: T]
|
||||
t_means = paddle.squeeze(t_means, [-1]) # (B, T-1), time steps [1: T]
|
||||
t_scales = paddle.squeeze(t_scales, [-1]) # (B, T-1), time steps [1: T]
|
||||
t_clipped_scales = paddle.clip(t_scales, self.min_log_scale, 100.)
|
||||
|
||||
s_distribution = D.Normal(s_means, paddle.exp(s_clipped_scales))
|
||||
t_distribution = D.Normal(t_means, paddle.exp(t_clipped_scales))
|
||||
|
||||
# kl divergence loss, so we only need to sample once? no MC
|
||||
kl = s_distribution.kl_divergence(t_distribution)
|
||||
if clip_kl:
|
||||
kl = paddle.clip(kl, -100., 10.)
|
||||
# context size dropped
|
||||
kl = paddle.reduce_mean(kl[:, self.teacher.context_size:])
|
||||
# major diff here
|
||||
regularization = F.mse_loss(t_scales[:, self.teacher.context_size:],
|
||||
s_scales[:, self.teacher.context_size:])
|
||||
|
||||
# introduce information from real target
|
||||
spectrogram_frame_loss = F.mse_loss(
|
||||
self.stft.magnitude(audio), self.stft.magnitude(x))
|
||||
loss = kl + self.lmd * regularization + spectrogram_frame_loss
|
||||
loss_dict = {
|
||||
"loss": loss,
|
||||
"kl_divergence": kl,
|
||||
"regularization": regularization,
|
||||
"stft_loss": spectrogram_frame_loss
|
||||
}
|
||||
return loss_dict
|
||||
|
||||
@paddle.no_grad()
|
||||
def synthesis(self, mel):
|
||||
"""Synthesize waveform using the encoder and the student network.
|
||||
|
||||
Args:
|
||||
mel (Variable): shape(B, F, T_mel), the condition(mel spectrogram here).
|
||||
|
||||
Returns:
|
||||
Variable: shape(B, T_audio), the synthesized waveform. (T_audio = T_mel * upscale_factor, where upscale_factor is the `upscale_factor` of the encoder.)
|
||||
"""
|
||||
condition = self.encoder(mel)
|
||||
samples_shape = (condition.shape[0], condition.shape[-1])
|
||||
z = paddle.randn(samples_shape)
|
||||
x, s_means, s_scales = self.student(z, condition)
|
||||
return x
|
||||
|
||||
|
||||
# TODO(chenfeiyu): ClariNetLoss
|
|
@ -1,465 +0,0 @@
|
|||
import math
|
||||
import numpy as np
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle.nn import initializer as I
|
||||
|
||||
from parakeet.modules import positional_encoding as pe
|
||||
|
||||
__all__ = ["SpectraNet"]
|
||||
|
||||
class ConvBlock(nn.Layer):
|
||||
def __init__(self, in_channel, kernel_size, causal=False, has_bias=False,
|
||||
bias_dim=None, keep_prob=1.):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.causal = causal
|
||||
self.keep_prob = keep_prob
|
||||
self.in_channel = in_channel
|
||||
self.has_bias = has_bias
|
||||
|
||||
std = math.sqrt(4 * keep_prob / (kernel_size * in_channel))
|
||||
padding = "valid" if causal else "same"
|
||||
conv = nn.Conv1D(in_channel, 2 * in_channel, (kernel_size, ),
|
||||
padding=padding,
|
||||
data_format="NLC",
|
||||
weight_attr=I.Normal(scale=std))
|
||||
self.conv = nn.utils.weight_norm(conv)
|
||||
if has_bias:
|
||||
std = math.sqrt(1 / bias_dim)
|
||||
self.bias_affine = nn.Linear(bias_dim, 2 * in_channel,
|
||||
weight_attr=I.Normal(scale=std))
|
||||
|
||||
def forward(self, input, bias=None, padding=None):
|
||||
"""
|
||||
input: input feature (B, T, C)
|
||||
padding: only used when using causal conv, we pad mannually
|
||||
"""
|
||||
input_dropped = F.dropout(input, 1. - self.keep_prob, training=self.training)
|
||||
if self.causal:
|
||||
assert padding is not None
|
||||
input_dropped = paddle.concat([padding, input_dropped], axis=1)
|
||||
hidden = self.conv(input_dropped)
|
||||
|
||||
if self.has_bias:
|
||||
assert bias is not None
|
||||
transformed_bias = F.softsign(self.bias_affine(bias))
|
||||
hidden_embedded = hidden + paddle.unsqueeze(transformed_bias, 1)
|
||||
else:
|
||||
hidden_embedded = hidden
|
||||
|
||||
# glu
|
||||
content, gate = paddle.chunk(hidden, 2, axis=-1)
|
||||
content = hidden_embedded[:, :, :self.in_channel]
|
||||
hidden = F.sigmoid(gate) * content
|
||||
|
||||
# # residual
|
||||
hidden = paddle.scale(input + hidden, math.sqrt(0.5))
|
||||
return hidden
|
||||
|
||||
|
||||
class AffineBlock1(nn.Layer):
|
||||
def __init__(self, in_channel, out_channel, has_bias=False, bias_dim=0):
|
||||
super(AffineBlock1, self).__init__()
|
||||
std = math.sqrt(1.0 / in_channel)
|
||||
affine = nn.Linear(in_channel, out_channel, weight_attr=I.Normal(scale=std))
|
||||
self.affine = nn.utils.weight_norm(affine, dim=-1)
|
||||
if has_bias:
|
||||
std = math.sqrt(1 / bias_dim)
|
||||
self.bias_affine = nn.Linear(bias_dim, out_channel,
|
||||
weight_attr=I.Normal(scale=std))
|
||||
|
||||
self.has_bias = has_bias
|
||||
self.bias_dim = bias_dim
|
||||
|
||||
def forward(self, input, bias=None):
|
||||
"""
|
||||
input -> (affine + weight_norm) ->hidden
|
||||
bias -> (affine) -> softsign -> transformed_bis
|
||||
hidden += transformed_bias
|
||||
"""
|
||||
hidden = self.affine(input)
|
||||
if self.has_bias:
|
||||
assert bias is not None
|
||||
transformed_bias = F.softsign(self.bias_affine(bias))
|
||||
hidden += paddle.unsqueeze(transformed_bias, 1)
|
||||
return hidden
|
||||
|
||||
|
||||
class AffineBlock2(nn.Layer):
|
||||
def __init__(self, in_channel, out_channel,
|
||||
has_bias=False, bias_dim=0, dropout=False, keep_prob=1.):
|
||||
super(AffineBlock2, self).__init__()
|
||||
if has_bias:
|
||||
std = math.sqrt(1 / bias_dim)
|
||||
self.bias_affine = nn.Linear(bias_dim, in_channel, weight_attr=I.Normal(scale=std))
|
||||
std = math.sqrt(1.0 / in_channel)
|
||||
affine = nn.Linear(in_channel, out_channel, weight_attr=I.Normal(scale=std))
|
||||
self.affine = nn.utils.weight_norm(affine, dim=-1)
|
||||
|
||||
self.has_bias = has_bias
|
||||
self.bias_dim = bias_dim
|
||||
self.dropout = dropout
|
||||
self.keep_prob = keep_prob
|
||||
|
||||
def forward(self, input, bias=None):
|
||||
"""
|
||||
input -> (dropout) ->hidden
|
||||
bias -> (affine) -> softsign -> transformed_bis
|
||||
hidden += transformed_bias
|
||||
hidden -> (affine + weight_norm) -> relu -> hidden
|
||||
"""
|
||||
hidden = input
|
||||
if self.dropout:
|
||||
hidden = F.dropout(hidden, 1. - self.keep_prob, training=self.training)
|
||||
if self.has_bias:
|
||||
assert bias is not None
|
||||
transformed_bias = F.softsign(self.bias_affine(bias))
|
||||
hidden += paddle.unsqueeze(transformed_bias, 1)
|
||||
hidden = F.relu(self.affine(hidden))
|
||||
return hidden
|
||||
|
||||
|
||||
class Encoder(nn.Layer):
|
||||
def __init__(self, layers, in_channels, encoder_dim, kernel_size,
|
||||
has_bias=False, bias_dim=0, keep_prob=1.):
|
||||
super(Encoder, self).__init__()
|
||||
self.pre_affine = AffineBlock1(in_channels, encoder_dim, has_bias, bias_dim)
|
||||
self.convs = nn.LayerList([
|
||||
ConvBlock(encoder_dim, kernel_size, False, has_bias, bias_dim, keep_prob) \
|
||||
for _ in range(layers)])
|
||||
self.post_affine = AffineBlock1(encoder_dim, in_channels, has_bias, bias_dim)
|
||||
|
||||
def forward(self, char_embed, speaker_embed=None):
|
||||
hidden = self.pre_affine(char_embed, speaker_embed)
|
||||
for layer in self.convs:
|
||||
hidden = layer(hidden, speaker_embed)
|
||||
hidden = self.post_affine(hidden, speaker_embed)
|
||||
keys = hidden
|
||||
values = paddle.scale(char_embed + hidden, math.sqrt(0.5))
|
||||
return keys, values
|
||||
|
||||
|
||||
class AttentionBlock(nn.Layer):
|
||||
def __init__(self, attention_dim, input_dim, position_encoding_weight=1.,
|
||||
position_rate=1., reduction_factor=1, has_bias=False, bias_dim=0,
|
||||
keep_prob=1.):
|
||||
super(AttentionBlock, self).__init__()
|
||||
# positional encoding
|
||||
omega_default = position_rate / reduction_factor
|
||||
self.omega_default = omega_default
|
||||
# multispeaker case
|
||||
if has_bias:
|
||||
std = math.sqrt(1.0 / bias_dim)
|
||||
self.q_pos_affine = nn.Linear(bias_dim, 1, weight_attr=I.Normal(scale=std))
|
||||
self.k_pos_affine = nn.Linear(bias_dim, 1, weight_attr=I.Normal(scale=std))
|
||||
self.omega_initial = self.create_parameter(shape=[1],
|
||||
attr=I.Constant(value=omega_default))
|
||||
|
||||
# mind the fact that q, k, v have the same feature dimension
|
||||
# so we can init k_affine and q_affine's weight as the same matrix
|
||||
# to get a better init attention
|
||||
dtype = self.omega_initial.numpy().dtype
|
||||
init_weight = np.random.normal(size=(input_dim, attention_dim),
|
||||
scale=np.sqrt(1. / input_dim)).astype(dtype)
|
||||
# TODO(chenfeiyu): to report an issue, there is no such initializer
|
||||
#initializer = paddle.fluid.initializer.NumpyArrayInitializer(init_weight)
|
||||
# 3 affine transformation to project q, k, v into attention_dim
|
||||
q_affine = nn.Linear(input_dim, attention_dim)
|
||||
self.q_affine = nn.utils.weight_norm(q_affine, dim=-1)
|
||||
k_affine = nn.Linear(input_dim, attention_dim)
|
||||
self.k_affine = nn.utils.weight_norm(k_affine, dim=-1)
|
||||
|
||||
# better to use this, since NumpyInitializer does not support float64
|
||||
self.q_affine.weight.set_value(init_weight)
|
||||
self.k_affine.weight.set_value(init_weight)
|
||||
|
||||
std = np.sqrt(1.0 / input_dim)
|
||||
v_affine = nn.Linear(input_dim, attention_dim, weight_attr=I.Normal(scale=std))
|
||||
self.v_affine = nn.utils.weight_norm(v_affine, dim=-1)
|
||||
|
||||
std = np.sqrt(1.0 / attention_dim)
|
||||
out_affine = nn.Linear(attention_dim, input_dim, weight_attr=I.Normal(scale=std))
|
||||
self.out_affine = nn.utils.weight_norm(out_affine, dim=-1)
|
||||
|
||||
self.keep_prob = keep_prob
|
||||
self.has_bias = has_bias
|
||||
self.bias_dim = bias_dim
|
||||
self.attention_dim = attention_dim
|
||||
self.position_encoding_weight = position_encoding_weight
|
||||
|
||||
def forward(self, q, k, v, lengths, speaker_embed, start_index,
|
||||
force_monotonic=False, prev_coeffs=None, window=None):
|
||||
dtype = self.omega_initial.dtype
|
||||
# add position encoding as an inductive bias
|
||||
if self.has_bias: # multi-speaker model
|
||||
omega_q = 2 * F.sigmoid(
|
||||
paddle.squeeze(self.q_pos_affine(speaker_embed), -1))
|
||||
omega_k = 2 * self.omega_initial * F.sigmoid(paddle.squeeze(
|
||||
self.k_pos_affine(speaker_embed), -1))
|
||||
else: # single-speaker case
|
||||
batch_size = q.shape[0]
|
||||
omega_q = paddle.ones((batch_size, ), dtype=dtype)
|
||||
omega_k = paddle.ones((batch_size, ), dtype=dtype) * self.omega_default
|
||||
q += self.position_encoding_weight * pe.scalable_positional_encoding(start_index, q.shape[1], q.shape[-1], omega_q)
|
||||
k += self.position_encoding_weight * pe.scalable_positional_encoding(0, k.shape[1], k.shape[-1], omega_k)
|
||||
|
||||
|
||||
q, k, v = self.q_affine(q), self.k_affine(k), self.v_affine(v)
|
||||
activations = paddle.matmul(q, k, transpose_y=True)
|
||||
activations /= math.sqrt(self.attention_dim)
|
||||
|
||||
if self.training:
|
||||
# mask the <pad> parts from the encoder
|
||||
mask = paddle.fluid.layers.sequence_mask(lengths, dtype=dtype)
|
||||
attn_bias = paddle.scale(1. - mask, -1000)
|
||||
activations += paddle.unsqueeze(attn_bias, 1)
|
||||
elif force_monotonic:
|
||||
assert window is not None
|
||||
backward_step, forward_step = window
|
||||
T_enc = k.shape[1]
|
||||
batch_size, T_dec, _ = q.shape
|
||||
|
||||
# actually T_dec = 1 here
|
||||
alpha = paddle.fill_constant((batch_size, T_dec), value=0, dtype="int64") \
|
||||
if prev_coeffs is None \
|
||||
else paddle.argmax(prev_coeffs, axis=-1)
|
||||
backward = paddle.fluid.layers.sequence_mask(alpha - backward_step, maxlen=T_enc, dtype="bool")
|
||||
forward = paddle.fluid.layers.sequence_mask(alpha + forward_step, maxlen=T_enc, dtype="bool")
|
||||
mask = paddle.cast(paddle.logical_xor(backward, forward), activations.dtype)
|
||||
# print("mask's shape:", mask.shape)
|
||||
attn_bias = paddle.scale(1. - mask, -1000)
|
||||
activations += attn_bias
|
||||
|
||||
# softmax
|
||||
coefficients = F.softmax(activations, axis=-1)
|
||||
# context vector
|
||||
coefficients = F.dropout(coefficients, 1. - self.keep_prob, training=self.training)
|
||||
contexts = paddle.matmul(coefficients, v)
|
||||
# context normalization
|
||||
enc_lengths = paddle.cast(paddle.unsqueeze(lengths, axis=[1, 2]), contexts.dtype)
|
||||
contexts *= paddle.sqrt(enc_lengths)
|
||||
# out affine
|
||||
contexts = self.out_affine(contexts)
|
||||
return contexts, coefficients
|
||||
|
||||
|
||||
class Decoder(nn.Layer):
|
||||
def __init__(self, in_channels, reduction_factor, prenet_sizes,
|
||||
layers, kernel_size, attention_dim,
|
||||
position_encoding_weight=1., omega=1.,
|
||||
has_bias=False, bias_dim=0, keep_prob=1.):
|
||||
super(Decoder, self).__init__()
|
||||
# prenet-mind the difference of AffineBlock2 and AffineBlock1
|
||||
c_in = in_channels
|
||||
self.prenet = nn.LayerList()
|
||||
for i, c_out in enumerate(prenet_sizes):
|
||||
affine = AffineBlock2(c_in, c_out, has_bias, bias_dim, dropout=(i!=0), keep_prob=keep_prob)
|
||||
self.prenet.append(affine)
|
||||
c_in = c_out
|
||||
|
||||
# causal convolutions + multihop attention
|
||||
decoder_dim = prenet_sizes[-1]
|
||||
self.causal_convs = nn.LayerList()
|
||||
self.attention_blocks = nn.LayerList()
|
||||
for i in range(layers):
|
||||
conv = ConvBlock(decoder_dim, kernel_size, True, has_bias, bias_dim, keep_prob)
|
||||
attn = AttentionBlock(attention_dim, decoder_dim, position_encoding_weight, omega, reduction_factor, has_bias, bias_dim, keep_prob)
|
||||
self.causal_convs.append(conv)
|
||||
self.attention_blocks.append(attn)
|
||||
|
||||
# output mel spectrogram
|
||||
output_dim = reduction_factor * in_channels # r * mel_dim
|
||||
std = math.sqrt(1.0 / decoder_dim)
|
||||
out_affine = nn.Linear(decoder_dim, output_dim, weight_attr=I.Normal(scale=std))
|
||||
self.out_affine = nn.utils.weight_norm(out_affine, dim=-1)
|
||||
if has_bias:
|
||||
std = math.sqrt(1 / bias_dim)
|
||||
self.out_sp_affine = nn.Linear(bias_dim, output_dim, weight_attr=I.Normal(scale=std))
|
||||
|
||||
self.has_bias = has_bias
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.decoder_dim = decoder_dim
|
||||
self.reduction_factor = reduction_factor
|
||||
self.out_channels = output_dim
|
||||
|
||||
def forward(self, inputs, keys, values, lengths, start_index, speaker_embed=None,
|
||||
state=None, force_monotonic_attention=None, coeffs=None, window=(0, 4)):
|
||||
hidden = inputs
|
||||
for layer in self.prenet:
|
||||
hidden = layer(hidden, speaker_embed)
|
||||
|
||||
attentions = [] # every layer of (B, T_dec, T_enc) attention
|
||||
final_state = [] # layers * (B, (k-1)d, C_dec)
|
||||
batch_size = inputs.shape[0]
|
||||
causal_padding_shape = (batch_size, self.kernel_size - 1, self.decoder_dim)
|
||||
|
||||
for i in range(len(self.causal_convs)):
|
||||
if state is None:
|
||||
padding = paddle.zeros(causal_padding_shape, dtype=inputs.dtype)
|
||||
else:
|
||||
padding = state[i]
|
||||
new_state = paddle.concat([padding, hidden], axis=1) # => to be used next step
|
||||
# causal conv, (B, T, C)
|
||||
hidden = self.causal_convs[i](hidden, speaker_embed, padding=padding)
|
||||
# attn
|
||||
prev_coeffs = None if coeffs is None else coeffs[i]
|
||||
force_monotonic = False if force_monotonic_attention is None else force_monotonic_attention[i]
|
||||
context, attention = self.attention_blocks[i](
|
||||
hidden, keys, values, lengths, speaker_embed,
|
||||
start_index, force_monotonic, prev_coeffs, window)
|
||||
# residual connextion (B, T_dec, C_dec)
|
||||
hidden = paddle.scale(hidden + context, math.sqrt(0.5))
|
||||
|
||||
attentions.append(attention) # layers * (B, T_dec, T_enc)
|
||||
# new state: shift a step, layers * (B, T, C)
|
||||
new_state = new_state[:, -(self.kernel_size - 1):, :]
|
||||
final_state.append(new_state)
|
||||
|
||||
# predict mel spectrogram (B, 1, T_dec, r * C_in)
|
||||
decoded = self.out_affine(hidden)
|
||||
if self.has_bias:
|
||||
decoded *= F.sigmoid(paddle.unsqueeze(self.out_sp_affine(speaker_embed), 1))
|
||||
return decoded, hidden, attentions, final_state
|
||||
|
||||
|
||||
class PostNet(nn.Layer):
|
||||
def __init__(self, layers, in_channels, postnet_dim, kernel_size, out_channels, upsample_factor, has_bias=False, bias_dim=0, keep_prob=1.):
|
||||
super(PostNet, self).__init__()
|
||||
self.pre_affine = AffineBlock1(in_channels, postnet_dim, has_bias, bias_dim)
|
||||
self.convs = nn.LayerList([
|
||||
ConvBlock(postnet_dim, kernel_size, False, has_bias, bias_dim, keep_prob) for _ in range(layers)
|
||||
])
|
||||
std = math.sqrt(1.0 / postnet_dim)
|
||||
post_affine = nn.Linear(postnet_dim, out_channels, weight_attr=I.Normal(scale=std))
|
||||
self.post_affine = nn.utils.weight_norm(post_affine, dim=-1)
|
||||
self.upsample_factor = upsample_factor
|
||||
|
||||
def forward(self, hidden, speaker_embed=None):
|
||||
hidden = self.pre_affine(hidden, speaker_embed)
|
||||
batch_size, time_steps, channels = hidden.shape # pylint: disable=unused-variable
|
||||
# NOTE: paddle.expand can only expand dimension whose size is 1
|
||||
hidden = paddle.expand(paddle.unsqueeze(hidden, 2), [-1, -1, self.upsample_factor, -1])
|
||||
hidden = paddle.reshape(hidden, [batch_size, -1, channels])
|
||||
for layer in self.convs:
|
||||
hidden = layer(hidden, speaker_embed)
|
||||
spec = self.post_affine(hidden)
|
||||
return spec
|
||||
|
||||
|
||||
class SpectraNet(nn.Layer):
|
||||
def __init__(self, char_embedding, speaker_embedding, encoder, decoder, postnet):
|
||||
super(SpectraNet, self).__init__()
|
||||
self.char_embedding = char_embedding
|
||||
self.speaker_embedding = speaker_embedding
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.postnet = postnet
|
||||
|
||||
def forward(self, text, text_lengths, speakers=None, mel=None, frame_lengths=None,
|
||||
force_monotonic_attention=None, window=None):
|
||||
# encode
|
||||
text_embed = self.char_embedding(text)# no stress embedding here
|
||||
speaker_embed = F.softsign(self.speaker_embedding(speakers)) if self.speaker_embedding is not None else None
|
||||
keys, values = self.encoder(text_embed, speaker_embed)
|
||||
|
||||
if mel is not None:
|
||||
return self.teacher_forced_train(keys, values, text_lengths, speaker_embed, mel)
|
||||
else:
|
||||
return self.inference(keys, values, text_lengths, speaker_embed, force_monotonic_attention, window)
|
||||
|
||||
def teacher_forced_train(self, keys, values, text_lengths, speaker_embed, mel):
|
||||
# build decoder inputs by shifting over by one frame and add all zero <start> frame
|
||||
# the mel input is downsampled by a reduction factor
|
||||
batch_size = mel.shape[0]
|
||||
mel_input = paddle.reshape(mel, (batch_size, -1, self.decoder.reduction_factor, self.decoder.in_channels))
|
||||
zero_frame = paddle.zeros((batch_size, 1, self.decoder.in_channels), dtype=mel.dtype)
|
||||
# downsample mel input as a regularization
|
||||
mel_input = paddle.concat([zero_frame, mel_input[:, :-1, -1, :]], axis=1)
|
||||
|
||||
# decoder
|
||||
decoded, hidden, attentions, final_state = self.decoder(mel_input, keys, values, text_lengths, 0, speaker_embed)
|
||||
attentions = paddle.stack(attentions) # (N, B, T_dec, T_encs)
|
||||
# unfold frames
|
||||
decoded = paddle.reshape(decoded, (batch_size, -1, self.decoder.in_channels))
|
||||
# postnet
|
||||
refined = self.postnet(hidden, speaker_embed)
|
||||
return decoded, refined, attentions, final_state
|
||||
|
||||
def spec_loss(self, decoded, input, num_frames=None):
|
||||
if num_frames is None:
|
||||
l1_loss = paddle.mean(paddle.abs(decoded - input))
|
||||
else:
|
||||
# mask the <pad> part of the decoder
|
||||
num_channels = decoded.shape[-1]
|
||||
l1_loss = paddle.abs(decoded - input)
|
||||
mask = paddle.fluid.layers.sequence_mask(num_frames, dtype=decoded.dtype)
|
||||
l1_loss *= paddle.unsqueeze(mask, axis=-1)
|
||||
l1_loss = paddle.sum(l1_loss) / paddle.scale(paddle.sum(mask), num_channels)
|
||||
return l1_loss
|
||||
|
||||
@paddle.no_grad()
|
||||
def inference(self, keys, values, text_lengths, speaker_embed,
|
||||
force_monotonic_attention, window):
|
||||
MAX_STEP = 500
|
||||
|
||||
# layer index of the first monotonic attention
|
||||
num_monotonic_attention_layers = sum(force_monotonic_attention)
|
||||
first_mono_attention_layer = 0
|
||||
if num_monotonic_attention_layers > 0:
|
||||
for i, item in enumerate(force_monotonic_attention):
|
||||
if item:
|
||||
first_mono_attention_layer = i
|
||||
break
|
||||
|
||||
# stop cond (if would be more complicated to support minibatch autoregressive decoding)
|
||||
# so we only supports batch_size == 0 in inference
|
||||
def should_continue(i, mel_input, outputs, hidden, attention, state, coeffs):
|
||||
T_enc = coeffs.shape[-1]
|
||||
attn_peak = paddle.argmax(coeffs[first_mono_attention_layer, 0, 0]) \
|
||||
if num_monotonic_attention_layers > 0 \
|
||||
else paddle.fill_constant([1], "int64", value=0)
|
||||
return i < MAX_STEP and paddle.reshape(attn_peak, [1]) < T_enc - 1
|
||||
|
||||
def loop_body(i, mel_input, outputs, hiddens, attentions, state=None, coeffs=None):
|
||||
# state is None coeffs is None for the first step
|
||||
decoded, hidden, new_coeffs, new_state = self.decoder(
|
||||
mel_input, keys, values, text_lengths, i, speaker_embed,
|
||||
state, force_monotonic_attention, coeffs, window)
|
||||
new_coeffs = paddle.stack(new_coeffs) # (N, B, T_dec=1, T_enc)
|
||||
|
||||
attentions.append(new_coeffs) # (N, B, T_dec=1, T_enc)
|
||||
outputs.append(decoded) # (B, T_dec=1, rC_mel)
|
||||
hiddens.append(hidden) # (B, T_dec=1, C_dec)
|
||||
|
||||
# slice the last frame out of r generated frames to be used as the input for the next step
|
||||
batch_size = mel_input.shape[0]
|
||||
frames = paddle.reshape(decoded, [batch_size, -1, self.decoder.reduction_factor, self.decoder.in_channels])
|
||||
input_frame = frames[:, :, -1, :]
|
||||
return (i + 1, input_frame, outputs, hiddens, attentions, new_state, new_coeffs)
|
||||
|
||||
i = 0
|
||||
batch_size = keys.shape[0]
|
||||
input_frame = paddle.zeros((batch_size, 1, self.decoder.in_channels), dtype=keys.dtype)
|
||||
outputs = []
|
||||
hiddens = []
|
||||
attentions = []
|
||||
loop_state = loop_body(i, input_frame, outputs, hiddens, attentions)
|
||||
|
||||
while should_continue(*loop_state):
|
||||
loop_state = loop_body(*loop_state)
|
||||
|
||||
outputs, hiddens, attention = loop_state[2], loop_state[3], loop_state[4]
|
||||
# concat decoder timesteps
|
||||
outputs = paddle.concat(outputs, axis=1)
|
||||
hiddens = paddle.concat(hiddens, axis=1)
|
||||
attention = paddle.concat(attention, axis=2)
|
||||
|
||||
# unfold frames
|
||||
outputs = paddle.reshape(outputs, (batch_size, -1, self.decoder.in_channels))
|
||||
|
||||
refined = self.postnet(hiddens, speaker_embed)
|
||||
return outputs, refined, attention
|
|
@ -1,13 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -1,113 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid as fluid
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
from parakeet.models.fastspeech.fft_block import FFTBlock
|
||||
|
||||
|
||||
class Decoder(dg.Layer):
|
||||
def __init__(self,
|
||||
len_max_seq,
|
||||
n_layers,
|
||||
n_head,
|
||||
d_k,
|
||||
d_q,
|
||||
d_model,
|
||||
d_inner,
|
||||
fft_conv1d_kernel,
|
||||
fft_conv1d_padding,
|
||||
dropout=0.1):
|
||||
"""Decoder layer of FastSpeech.
|
||||
|
||||
Args:
|
||||
len_max_seq (int): the max mel len of sequence.
|
||||
n_layers (int): the layers number of FFTBlock.
|
||||
n_head (int): the head number of multihead attention.
|
||||
d_k (int): the dim of key in multihead attention.
|
||||
d_q (int): the dim of query in multihead attention.
|
||||
d_model (int): the dim of hidden layer in multihead attention.
|
||||
d_inner (int): the dim of hidden layer in ffn.
|
||||
fft_conv1d_kernel (int): the conv kernel size in FFTBlock.
|
||||
fft_conv1d_padding (int): the conv padding size in FFTBlock.
|
||||
dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1.
|
||||
"""
|
||||
super(Decoder, self).__init__()
|
||||
|
||||
n_position = len_max_seq + 1
|
||||
self.n_head = n_head
|
||||
self.pos_inp = get_sinusoid_encoding_table(
|
||||
n_position, d_model, padding_idx=0)
|
||||
self.position_enc = dg.Embedding(
|
||||
size=[n_position, d_model],
|
||||
padding_idx=0,
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.NumpyArrayInitializer(
|
||||
self.pos_inp),
|
||||
trainable=False))
|
||||
self.layer_stack = [
|
||||
FFTBlock(
|
||||
d_model,
|
||||
d_inner,
|
||||
n_head,
|
||||
d_k,
|
||||
d_q,
|
||||
fft_conv1d_kernel,
|
||||
fft_conv1d_padding,
|
||||
dropout=dropout) for _ in range(n_layers)
|
||||
]
|
||||
for i, layer in enumerate(self.layer_stack):
|
||||
self.add_sublayer('fft_{}'.format(i), layer)
|
||||
|
||||
def forward(self, enc_seq, enc_pos):
|
||||
"""
|
||||
Compute decoder outputs.
|
||||
|
||||
Args:
|
||||
enc_seq (Variable): shape(B, T_mel, C), dtype float32,
|
||||
the output of length regulator, where T_mel means the timesteps of input spectrum.
|
||||
enc_pos (Variable): shape(B, T_mel), dtype int64,
|
||||
the spectrum position.
|
||||
|
||||
Returns:
|
||||
dec_output (Variable): shape(B, T_mel, C), the decoder output.
|
||||
dec_slf_attn_list (list[Variable]): len(n_layers), the decoder self attention list.
|
||||
"""
|
||||
dec_slf_attn_list = []
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
slf_attn_mask = get_dec_attn_key_pad_mask(enc_pos, self.n_head,
|
||||
enc_seq.dtype)
|
||||
|
||||
else:
|
||||
len_q = enc_seq.shape[1]
|
||||
slf_attn_mask = layers.triu(
|
||||
layers.ones(
|
||||
shape=[len_q, len_q], dtype=enc_seq.dtype),
|
||||
diagonal=1)
|
||||
slf_attn_mask = layers.cast(
|
||||
slf_attn_mask != 0, dtype=enc_seq.dtype) * -1e30
|
||||
|
||||
non_pad_mask = get_non_pad_mask(enc_pos, 1, enc_seq.dtype)
|
||||
|
||||
# -- Forward
|
||||
dec_output = enc_seq + self.position_enc(enc_pos)
|
||||
|
||||
for dec_layer in self.layer_stack:
|
||||
dec_output, dec_slf_attn = dec_layer(
|
||||
dec_output,
|
||||
non_pad_mask=non_pad_mask,
|
||||
slf_attn_mask=slf_attn_mask)
|
||||
dec_slf_attn_list += [dec_slf_attn]
|
||||
|
||||
return dec_output, dec_slf_attn_list
|
|
@ -1,109 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid as fluid
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
from parakeet.models.fastspeech.fft_block import FFTBlock
|
||||
|
||||
|
||||
class Encoder(dg.Layer):
|
||||
def __init__(self,
|
||||
n_src_vocab,
|
||||
len_max_seq,
|
||||
n_layers,
|
||||
n_head,
|
||||
d_k,
|
||||
d_q,
|
||||
d_model,
|
||||
d_inner,
|
||||
fft_conv1d_kernel,
|
||||
fft_conv1d_padding,
|
||||
dropout=0.1):
|
||||
"""Encoder layer of FastSpeech.
|
||||
|
||||
Args:
|
||||
n_src_vocab (int): the number of source vocabulary.
|
||||
len_max_seq (int): the max mel len of sequence.
|
||||
n_layers (int): the layers number of FFTBlock.
|
||||
n_head (int): the head number of multihead attention.
|
||||
d_k (int): the dim of key in multihead attention.
|
||||
d_q (int): the dim of query in multihead attention.
|
||||
d_model (int): the dim of hidden layer in multihead attention.
|
||||
d_inner (int): the dim of hidden layer in ffn.
|
||||
fft_conv1d_kernel (int): the conv kernel size in FFTBlock.
|
||||
fft_conv1d_padding (int): the conv padding size in FFTBlock.
|
||||
dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1.
|
||||
"""
|
||||
super(Encoder, self).__init__()
|
||||
n_position = len_max_seq + 1
|
||||
self.n_head = n_head
|
||||
|
||||
self.src_word_emb = dg.Embedding(
|
||||
size=[n_src_vocab, d_model],
|
||||
padding_idx=0,
|
||||
param_attr=fluid.initializer.Normal(
|
||||
loc=0.0, scale=1.0))
|
||||
self.pos_inp = get_sinusoid_encoding_table(
|
||||
n_position, d_model, padding_idx=0)
|
||||
self.position_enc = dg.Embedding(
|
||||
size=[n_position, d_model],
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.NumpyArrayInitializer(
|
||||
self.pos_inp),
|
||||
trainable=False))
|
||||
self.layer_stack = [
|
||||
FFTBlock(
|
||||
d_model,
|
||||
d_inner,
|
||||
n_head,
|
||||
d_k,
|
||||
d_q,
|
||||
fft_conv1d_kernel,
|
||||
fft_conv1d_padding,
|
||||
dropout=dropout) for _ in range(n_layers)
|
||||
]
|
||||
for i, layer in enumerate(self.layer_stack):
|
||||
self.add_sublayer('fft_{}'.format(i), layer)
|
||||
|
||||
def forward(self, character, text_pos):
|
||||
"""
|
||||
Encode text sequence.
|
||||
|
||||
Args:
|
||||
character (Variable): shape(B, T_text), dtype float32, the input text characters,
|
||||
where T_text means the timesteps of input characters,
|
||||
text_pos (Variable): shape(B, T_text), dtype int64, the input text position.
|
||||
|
||||
Returns:
|
||||
enc_output (Variable): shape(B, T_text, C), the encoder output.
|
||||
enc_slf_attn_list (list[Variable]): len(n_layers), the encoder self attention list.
|
||||
"""
|
||||
enc_slf_attn_list = []
|
||||
|
||||
# -- Forward
|
||||
enc_output = self.src_word_emb(character) + self.position_enc(
|
||||
text_pos) #(N, T, C)
|
||||
|
||||
slf_attn_mask = get_attn_key_pad_mask(text_pos, self.n_head,
|
||||
enc_output.dtype)
|
||||
non_pad_mask = get_non_pad_mask(text_pos, 1, enc_output.dtype)
|
||||
|
||||
for enc_layer in self.layer_stack:
|
||||
enc_output, enc_slf_attn = enc_layer(
|
||||
enc_output,
|
||||
non_pad_mask=non_pad_mask,
|
||||
slf_attn_mask=slf_attn_mask)
|
||||
enc_slf_attn_list += [enc_slf_attn]
|
||||
|
||||
return enc_output, enc_slf_attn_list
|
|
@ -1,133 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
import numpy as np
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid as fluid
|
||||
from parakeet.g2p.text.symbols import symbols
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
from parakeet.models.transformer_tts.post_convnet import PostConvNet
|
||||
from parakeet.models.fastspeech.length_regulator import LengthRegulator
|
||||
from parakeet.models.fastspeech.encoder import Encoder
|
||||
from parakeet.models.fastspeech.decoder import Decoder
|
||||
|
||||
|
||||
class FastSpeech(dg.Layer):
|
||||
def __init__(self, cfg, num_mels=80):
|
||||
"""FastSpeech model.
|
||||
|
||||
Args:
|
||||
cfg: the yaml configs used in FastSpeech model.
|
||||
num_mels (int, optional): the number of mel bands when calculating mel spectrograms. Defaults to 80.
|
||||
|
||||
"""
|
||||
super(FastSpeech, self).__init__()
|
||||
|
||||
self.encoder = Encoder(
|
||||
n_src_vocab=len(symbols) + 1,
|
||||
len_max_seq=cfg['max_seq_len'],
|
||||
n_layers=cfg['encoder_n_layer'],
|
||||
n_head=cfg['encoder_head'],
|
||||
d_k=cfg['hidden_size'] // cfg['encoder_head'],
|
||||
d_q=cfg['hidden_size'] // cfg['encoder_head'],
|
||||
d_model=cfg['hidden_size'],
|
||||
d_inner=cfg['encoder_conv1d_filter_size'],
|
||||
fft_conv1d_kernel=cfg['fft_conv1d_filter'],
|
||||
fft_conv1d_padding=cfg['fft_conv1d_padding'],
|
||||
dropout=0.1)
|
||||
self.length_regulator = LengthRegulator(
|
||||
input_size=cfg['hidden_size'],
|
||||
out_channels=cfg['duration_predictor_output_size'],
|
||||
filter_size=cfg['duration_predictor_filter_size'],
|
||||
dropout=cfg['dropout'])
|
||||
self.decoder = Decoder(
|
||||
len_max_seq=cfg['max_seq_len'],
|
||||
n_layers=cfg['decoder_n_layer'],
|
||||
n_head=cfg['decoder_head'],
|
||||
d_k=cfg['hidden_size'] // cfg['decoder_head'],
|
||||
d_q=cfg['hidden_size'] // cfg['decoder_head'],
|
||||
d_model=cfg['hidden_size'],
|
||||
d_inner=cfg['decoder_conv1d_filter_size'],
|
||||
fft_conv1d_kernel=cfg['fft_conv1d_filter'],
|
||||
fft_conv1d_padding=cfg['fft_conv1d_padding'],
|
||||
dropout=0.1)
|
||||
self.weight = fluid.ParamAttr(
|
||||
initializer=fluid.initializer.XavierInitializer())
|
||||
k = math.sqrt(1.0 / cfg['hidden_size'])
|
||||
self.bias = fluid.ParamAttr(initializer=fluid.initializer.Uniform(
|
||||
low=-k, high=k))
|
||||
self.mel_linear = dg.Linear(
|
||||
cfg['hidden_size'],
|
||||
num_mels * cfg['outputs_per_step'],
|
||||
param_attr=self.weight,
|
||||
bias_attr=self.bias, )
|
||||
self.postnet = PostConvNet(
|
||||
n_mels=num_mels,
|
||||
num_hidden=512,
|
||||
filter_size=5,
|
||||
padding=int(5 / 2),
|
||||
num_conv=5,
|
||||
outputs_per_step=cfg['outputs_per_step'],
|
||||
use_cudnn=True,
|
||||
dropout=0.1,
|
||||
batchnorm_last=True)
|
||||
|
||||
def forward(self,
|
||||
character,
|
||||
text_pos,
|
||||
mel_pos=None,
|
||||
length_target=None,
|
||||
alpha=1.0):
|
||||
"""
|
||||
Compute mel output from text character.
|
||||
|
||||
Args:
|
||||
character (Variable): shape(B, T_text), dtype float32, the input text characters,
|
||||
where T_text means the timesteps of input characters,
|
||||
text_pos (Variable): shape(B, T_text), dtype int64, the input text position.
|
||||
mel_pos (Variable, optional): shape(B, T_mel), dtype int64, the spectrum position,
|
||||
where T_mel means the timesteps of input spectrum,
|
||||
length_target (Variable, optional): shape(B, T_text), dtype int64,
|
||||
the duration of phoneme compute from pretrained transformerTTS. Defaults to None.
|
||||
alpha (float32, optional): The hyperparameter to determine the length of the expanded sequence
|
||||
mel, thereby controlling the voice speed. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
mel_output (Variable): shape(B, T_mel, C), the mel output before postnet.
|
||||
mel_output_postnet (Variable): shape(B, T_mel, C), the mel output after postnet.
|
||||
duration_predictor_output (Variable): shape(B, T_text), the duration of phoneme compute with duration predictor.
|
||||
enc_slf_attn_list (List[Variable]): len(enc_n_layers), the encoder self attention list.
|
||||
dec_slf_attn_list (List[Variable]): len(dec_n_layers), the decoder self attention list.
|
||||
"""
|
||||
|
||||
encoder_output, enc_slf_attn_list = self.encoder(character, text_pos)
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
length_regulator_output, duration_predictor_output = self.length_regulator(
|
||||
encoder_output, target=length_target, alpha=alpha)
|
||||
decoder_output, dec_slf_attn_list = self.decoder(
|
||||
length_regulator_output, mel_pos)
|
||||
|
||||
mel_output = self.mel_linear(decoder_output)
|
||||
mel_output_postnet = self.postnet(mel_output) + mel_output
|
||||
|
||||
return mel_output, mel_output_postnet, duration_predictor_output, enc_slf_attn_list, dec_slf_attn_list
|
||||
else:
|
||||
length_regulator_output, decoder_pos = self.length_regulator(
|
||||
encoder_output, alpha=alpha)
|
||||
decoder_output, _ = self.decoder(length_regulator_output,
|
||||
decoder_pos)
|
||||
mel_output = self.mel_linear(decoder_output)
|
||||
mel_output_postnet = self.postnet(mel_output) + mel_output
|
||||
|
||||
return mel_output, mel_output_postnet
|
|
@ -1,84 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import math
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.fluid as fluid
|
||||
from parakeet.modules.multihead_attention import MultiheadAttention
|
||||
from parakeet.modules.ffn import PositionwiseFeedForward
|
||||
|
||||
|
||||
class FFTBlock(dg.Layer):
|
||||
def __init__(self,
|
||||
d_model,
|
||||
d_inner,
|
||||
n_head,
|
||||
d_k,
|
||||
d_q,
|
||||
filter_size,
|
||||
padding,
|
||||
dropout=0.2):
|
||||
"""Feed forward structure based on self-attention.
|
||||
|
||||
Args:
|
||||
d_model (int): the dim of hidden layer in multihead attention.
|
||||
d_inner (int): the dim of hidden layer in ffn.
|
||||
n_head (int): the head number of multihead attention.
|
||||
d_k (int): the dim of key in multihead attention.
|
||||
d_q (int): the dim of query in multihead attention.
|
||||
filter_size (int): the conv kernel size.
|
||||
padding (int): the conv padding size.
|
||||
dropout (float, optional): dropout probability. Defaults to 0.2.
|
||||
"""
|
||||
super(FFTBlock, self).__init__()
|
||||
self.slf_attn = MultiheadAttention(
|
||||
d_model,
|
||||
d_k,
|
||||
d_q,
|
||||
num_head=n_head,
|
||||
is_bias=True,
|
||||
dropout=dropout,
|
||||
is_concat=False)
|
||||
self.pos_ffn = PositionwiseFeedForward(
|
||||
d_model,
|
||||
d_inner,
|
||||
filter_size=filter_size,
|
||||
padding=padding,
|
||||
dropout=dropout)
|
||||
|
||||
def forward(self, enc_input, non_pad_mask, slf_attn_mask=None):
|
||||
"""
|
||||
Feed forward block of FastSpeech
|
||||
|
||||
Args:
|
||||
enc_input (Variable): shape(B, T, C), dtype float32, the embedding characters input,
|
||||
where T means the timesteps of input.
|
||||
non_pad_mask (Variable): shape(B, T, 1), dtype int64, the mask of sequence.
|
||||
slf_attn_mask (Variable, optional): shape(B, len_q, len_k), dtype int64, the mask of self attention,
|
||||
where len_q means the sequence length of query and len_k means the sequence length of key. Defaults to None.
|
||||
|
||||
Returns:
|
||||
output (Variable): shape(B, T, C), the output after self-attention & ffn.
|
||||
slf_attn (Variable): shape(B * n_head, T, T), the self attention.
|
||||
"""
|
||||
output, slf_attn = self.slf_attn(
|
||||
enc_input, enc_input, enc_input, mask=slf_attn_mask)
|
||||
|
||||
output *= non_pad_mask
|
||||
|
||||
output = self.pos_ffn(output)
|
||||
output *= non_pad_mask
|
||||
|
||||
return output, slf_attn
|
|
@ -1,181 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import math
|
||||
import parakeet.models.fastspeech.utils
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.fluid as fluid
|
||||
from parakeet.modules.customized import Conv1D
|
||||
|
||||
|
||||
class LengthRegulator(dg.Layer):
|
||||
def __init__(self, input_size, out_channels, filter_size, dropout=0.1):
|
||||
"""Length Regulator block in FastSpeech.
|
||||
|
||||
Args:
|
||||
input_size (int): the channel number of input.
|
||||
out_channels (int): the output channel number.
|
||||
filter_size (int): the filter size of duration predictor.
|
||||
dropout (float, optional): dropout probability. Defaults to 0.1.
|
||||
"""
|
||||
super(LengthRegulator, self).__init__()
|
||||
self.duration_predictor = DurationPredictor(
|
||||
input_size=input_size,
|
||||
out_channels=out_channels,
|
||||
filter_size=filter_size,
|
||||
dropout=dropout)
|
||||
|
||||
def LR(self, x, duration_predictor_output):
|
||||
output = []
|
||||
batch_size = x.shape[0]
|
||||
for i in range(batch_size):
|
||||
output.append(
|
||||
self.expand(x[i:i + 1], duration_predictor_output[i:i + 1]))
|
||||
output = self.pad(output)
|
||||
return output
|
||||
|
||||
def pad(self, input_ele):
|
||||
max_len = max([input_ele[i].shape[0] for i in range(len(input_ele))])
|
||||
out_list = []
|
||||
for i in range(len(input_ele)):
|
||||
pad_len = max_len - input_ele[i].shape[0]
|
||||
one_batch_padded = layers.pad(input_ele[i], [0, pad_len, 0, 0],
|
||||
pad_value=0.0)
|
||||
out_list.append(one_batch_padded)
|
||||
out_padded = layers.stack(out_list)
|
||||
return out_padded
|
||||
|
||||
def expand(self, batch, predicted):
|
||||
out = []
|
||||
time_steps = batch.shape[1]
|
||||
fertilities = predicted.numpy()
|
||||
batch = layers.squeeze(batch, [0])
|
||||
|
||||
for i in range(time_steps):
|
||||
if fertilities[0, i] == 0:
|
||||
continue
|
||||
out.append(
|
||||
layers.expand(batch[i:i + 1, :], [int(fertilities[0, i]), 1]))
|
||||
out = layers.concat(out, axis=0)
|
||||
return out
|
||||
|
||||
def forward(self, x, alpha=1.0, target=None):
|
||||
"""
|
||||
Compute length of mel from encoder output use TransformerTTS attention
|
||||
|
||||
Args:
|
||||
x (Variable): shape(B, T, C), dtype float32, the encoder output.
|
||||
alpha (float32, optional): the hyperparameter to determine the length of
|
||||
the expanded sequence mel, thereby controlling the voice speed. Defaults to 1.0.
|
||||
target (Variable, optional): shape(B, T_text), dtype int64, the duration of phoneme compute from pretrained transformerTTS.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
output (Variable): shape(B, T, C), the output after exppand.
|
||||
duration_predictor_output (Variable): shape(B, T, C), the output of duration predictor.
|
||||
"""
|
||||
duration_predictor_output = self.duration_predictor(x)
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
output = self.LR(x, target)
|
||||
return output, duration_predictor_output
|
||||
else:
|
||||
duration_predictor_output = duration_predictor_output * alpha
|
||||
duration_predictor_output = layers.ceil(duration_predictor_output)
|
||||
output = self.LR(x, duration_predictor_output)
|
||||
mel_pos = dg.to_variable(np.arange(1, output.shape[1] + 1)).astype(
|
||||
np.int64)
|
||||
mel_pos = layers.unsqueeze(mel_pos, [0])
|
||||
return output, mel_pos
|
||||
|
||||
|
||||
class DurationPredictor(dg.Layer):
|
||||
def __init__(self, input_size, out_channels, filter_size, dropout=0.1):
|
||||
"""Duration Predictor block in FastSpeech.
|
||||
|
||||
Args:
|
||||
input_size (int): the channel number of input.
|
||||
out_channels (int): the output channel number.
|
||||
filter_size (int): the filter size.
|
||||
dropout (float, optional): dropout probability. Defaults to 0.1.
|
||||
"""
|
||||
super(DurationPredictor, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.out_channels = out_channels
|
||||
self.filter_size = filter_size
|
||||
self.dropout = dropout
|
||||
|
||||
k = math.sqrt(1.0 / self.input_size)
|
||||
self.conv1 = Conv1D(
|
||||
num_channels=self.input_size,
|
||||
num_filters=self.out_channels,
|
||||
filter_size=self.filter_size,
|
||||
padding=1,
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.XavierInitializer()),
|
||||
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
|
||||
low=-k, high=k)))
|
||||
#data_format='NTC')
|
||||
k = math.sqrt(1.0 / self.out_channels)
|
||||
self.conv2 = Conv1D(
|
||||
num_channels=self.out_channels,
|
||||
num_filters=self.out_channels,
|
||||
filter_size=self.filter_size,
|
||||
padding=1,
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.XavierInitializer()),
|
||||
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
|
||||
low=-k, high=k)))
|
||||
#data_format='NTC')
|
||||
self.layer_norm1 = dg.LayerNorm(self.out_channels)
|
||||
self.layer_norm2 = dg.LayerNorm(self.out_channels)
|
||||
|
||||
self.weight = fluid.ParamAttr(
|
||||
initializer=fluid.initializer.XavierInitializer())
|
||||
k = math.sqrt(1.0 / self.out_channels)
|
||||
self.bias = fluid.ParamAttr(initializer=fluid.initializer.Uniform(
|
||||
low=-k, high=k))
|
||||
|
||||
self.linear = dg.Linear(
|
||||
self.out_channels, 1, param_attr=self.weight, bias_attr=self.bias)
|
||||
|
||||
def forward(self, encoder_output):
|
||||
"""
|
||||
Predict the duration of each character.
|
||||
|
||||
Args:
|
||||
encoder_output (Variable): shape(B, T, C), dtype float32, the encoder output.
|
||||
|
||||
Returns:
|
||||
out (Variable): shape(B, T, C), the output of duration predictor.
|
||||
"""
|
||||
# encoder_output.shape(N, T, C)
|
||||
out = layers.transpose(encoder_output, [0, 2, 1])
|
||||
out = self.conv1(out)
|
||||
out = layers.transpose(out, [0, 2, 1])
|
||||
out = layers.dropout(
|
||||
layers.relu(self.layer_norm1(out)),
|
||||
self.dropout,
|
||||
dropout_implementation='upscale_in_train')
|
||||
out = layers.transpose(out, [0, 2, 1])
|
||||
out = self.conv2(out)
|
||||
out = layers.transpose(out, [0, 2, 1])
|
||||
out = layers.dropout(
|
||||
layers.relu(self.layer_norm2(out)),
|
||||
self.dropout,
|
||||
dropout_implementation='upscale_in_train')
|
||||
out = layers.relu(self.linear(out))
|
||||
out = layers.squeeze(out, axes=[-1])
|
||||
|
||||
return out
|
|
@ -1,46 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_alignment(attn_probs, mel_lens, n_head):
|
||||
max_F = 0
|
||||
assert attn_probs[0].shape[0] % n_head == 0
|
||||
batch_size = int(attn_probs[0].shape[0] // n_head)
|
||||
for i in range(len(attn_probs)):
|
||||
multi_attn = attn_probs[i].numpy()
|
||||
for j in range(n_head):
|
||||
attn = multi_attn[j * batch_size:(j + 1) * batch_size]
|
||||
F = score_F(attn)
|
||||
if max_F < F:
|
||||
max_F = F
|
||||
max_attn = attn
|
||||
alignment = compute_duration(max_attn, mel_lens)
|
||||
return alignment, max_attn
|
||||
|
||||
|
||||
def score_F(attn):
|
||||
max = np.max(attn, axis=-1)
|
||||
mean = np.mean(max)
|
||||
return mean
|
||||
|
||||
|
||||
def compute_duration(attn, mel_lens):
|
||||
alignment = np.zeros([attn.shape[2]])
|
||||
#for i in range(attn.shape[0]):
|
||||
for j in range(mel_lens):
|
||||
max_index = np.argmax(attn[0, j])
|
||||
alignment[max_index] += 1
|
||||
|
||||
return alignment
|
Loading…
Reference in New Issue