transformer_tts, miscellaneous fixes

This commit is contained in:
chenfeiyu 2020-12-01 18:13:30 +08:00
parent 598d813908
commit 9cb5c03069
6 changed files with 124 additions and 72 deletions

View File

@ -44,6 +44,7 @@ class English(Phonetics):
def __call__(self, sentence): def __call__(self, sentence):
return self.numericalize(self.phoneticize(sentence)) return self.numericalize(self.phoneticize(sentence))
@property
def vocab_size(self): def vocab_size(self):
return len(self.vocab) return len(self.vocab)
@ -88,6 +89,7 @@ class Chinese(Phonetics):
def __call__(self, sentence): def __call__(self, sentence):
return self.numericalize(self.phoneticize(sentence)) return self.numericalize(self.phoneticize(sentence))
@property
def vocab_size(self): def vocab_size(self):
return len(self.vocab) return len(self.vocab)

View File

@ -1,9 +1,11 @@
import math import math
from tqdm import trange
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
from paddle.nn import initializer as I from paddle.nn import initializer as I
import parakeet
from parakeet.modules.attention import _split_heads, _concat_heads, drop_head, scaled_dot_product_attention from parakeet.modules.attention import _split_heads, _concat_heads, drop_head, scaled_dot_product_attention
from parakeet.modules.transformer import PositionwiseFFN from parakeet.modules.transformer import PositionwiseFFN
from parakeet.modules import masking from parakeet.modules import masking
@ -111,8 +113,6 @@ class TransformerEncoderLayer(nn.Layer):
def _forward_mha(self, x, mask, drop_n_heads): def _forward_mha(self, x, mask, drop_n_heads):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
if mask is not None:
mask = paddle.unsqueeze(mask, 1)
x_in = x x_in = x
x = self.layer_norm1(x) x = self.layer_norm1(x)
context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads) context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads)
@ -131,21 +131,13 @@ class TransformerEncoderLayer(nn.Layer):
""" """
Args: Args:
x (Tensor): shape(batch_size, time_steps, d_model), the decoder input. x (Tensor): shape(batch_size, time_steps, d_model), the decoder input.
mask (Tensor): shape(batch_size, time_steps), the padding mask. mask (Tensor): shape(batch_size, 1, time_steps), the padding mask.
Returns: Returns:
(x, attn_weights) (x, attn_weights)
x (Tensor): shape(batch_size, time_steps, d_model), the decoded. x (Tensor): shape(batch_size, time_steps, d_model), the decoded.
attn_weights (Tensor), shape(batch_size, n_heads, time_steps, time_steps), self attention. attn_weights (Tensor), shape(batch_size, n_heads, time_steps, time_steps), self attention.
""" """
# # pre norm
# x_in = x
# x = self.layer_norm1(x)
# context_vector, attn_weights = self.self_mha(x, x, x, paddle.unsqueeze(mask, 1), drop_n_heads)
# x = x_in + context_vector # here, the order can be tuned
# # pre norm
# x = x + self.ffn(self.layer_norm2(x))
x, attn_weights = self._forward_mha(x, mask, drop_n_heads) x, attn_weights = self._forward_mha(x, mask, drop_n_heads)
x = self._forward_ffn(x) x = self._forward_ffn(x)
return x, attn_weights return x, attn_weights
@ -188,7 +180,7 @@ class TransformerDecoderLayer(nn.Layer):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
q_in = q q_in = q
q = self.layer_norm2(q) q = self.layer_norm2(q)
context_vector, attn_weights = self.cross_mha(q, k, v, paddle.unsqueeze(mask, 1), drop_n_heads) context_vector, attn_weights = self.cross_mha(q, k, v, mask, drop_n_heads)
context_vector = q_in + F.dropout(context_vector, self.dropout, training=self.training) context_vector = q_in + F.dropout(context_vector, self.dropout, training=self.training)
return context_vector, attn_weights return context_vector, attn_weights
@ -206,7 +198,7 @@ class TransformerDecoderLayer(nn.Layer):
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoder input. q (Tensor): shape(batch_size, time_steps_q, d_model), the decoder input.
k (Tensor): shape(batch_size, time_steps_k, d_model), keys. k (Tensor): shape(batch_size, time_steps_k, d_model), keys.
v (Tensor): shape(batch_size, time_steps_k, d_model), values v (Tensor): shape(batch_size, time_steps_k, d_model), values
encoder_mask (Tensor): shape(batch_size, time_steps_k) encoder padding mask. encoder_mask (Tensor): shape(batch_size, 1, time_steps_k) encoder padding mask.
decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q) or broadcastable shape, decoder padding mask. decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q) or broadcastable shape, decoder padding mask.
Returns: Returns:
@ -215,20 +207,6 @@ class TransformerDecoderLayer(nn.Layer):
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention. self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention. cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
""" """
# # pre norm
# q_in = q
# q = self.layer_norm1(q)
# context_vector, self_attn_weights = self.self_mha(q, q, q, decoder_mask, drop_n_heads)
# q = q_in + context_vector
# # pre norm
# q_in = q
# q = self.layer_norm2(q)
# context_vector, cross_attn_weights = self.cross_mha(q, k, v, paddle.unsqueeze(encoder_mask, 1), drop_n_heads)
# q = q_in + context_vector
# # pre norm
# q = q + self.ffn(self.layer_norm3(q))
q, self_attn_weights = self._forward_self_mha(q, decoder_mask, drop_n_heads) q, self_attn_weights = self._forward_self_mha(q, decoder_mask, drop_n_heads)
q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask, drop_n_heads) q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask, drop_n_heads)
q = self._forward_ffn(q) q = self._forward_ffn(q)
@ -245,7 +223,7 @@ class TransformerEncoder(nn.LayerList):
""" """
Args: Args:
x (Tensor): shape(batch_size, time_steps, feature_size), the input tensor. x (Tensor): shape(batch_size, time_steps, feature_size), the input tensor.
mask (Tensor): shape(batch_size, time_steps), the mask. mask (Tensor): shape(batch_size, 1, time_steps), the mask.
drop_n_heads (int, optional): how many heads to drop. Defaults to 0. drop_n_heads (int, optional): how many heads to drop. Defaults to 0.
Returns: Returns:
@ -273,7 +251,7 @@ class TransformerDecoder(nn.LayerList):
q (Tensor): shape(batch_size, time_steps_q, d_model) q (Tensor): shape(batch_size, time_steps_q, d_model)
k (Tensor): shape(batch_size, time_steps_k, d_encoder) k (Tensor): shape(batch_size, time_steps_k, d_encoder)
v (Tensor): shape(batch_size, time_steps_k, k_encoder) v (Tensor): shape(batch_size, time_steps_k, k_encoder)
encoder_mask (Tensor): shape(batch_size, time_steps_k) encoder_mask (Tensor): shape(batch_size, 1, time_steps_k)
decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q) decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q)
drop_n_heads (int, optional): [description]. Defaults to 0. drop_n_heads (int, optional): [description]. Defaults to 0.
@ -290,21 +268,21 @@ class TransformerDecoder(nn.LayerList):
class MLPPreNet(nn.Layer): class MLPPreNet(nn.Layer):
def __init__(self, d_input, d_hidden, d_output): def __init__(self, d_input, d_hidden, d_output, dropout):
# (lin + relu + dropout) * n + last projection # (lin + relu + dropout) * n + last projection
super(MLPPreNet, self).__init__() super(MLPPreNet, self).__init__()
self.lin1 = nn.Linear(d_input, d_hidden) self.lin1 = nn.Linear(d_input, d_hidden)
self.lin2 = nn.Linear(d_hidden, d_hidden) self.lin2 = nn.Linear(d_hidden, d_hidden)
# self.lin3 = nn.Linear(d_output, d_output) self.lin3 = nn.Linear(d_output, d_output)
self.dropout = dropout
def forward(self, x, dropout): def forward(self, x, dropout):
# the original code said also use dropout in inference l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training)
l1 = F.dropout(F.relu(self.lin1(x)), dropout, training=self.training) l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training)
l2 = F.dropout(F.relu(self.lin2(l1)), dropout, training=self.training) l3 = self.lin3(l2)
#l3 = self.lin3(l2) return l3
return l2
# NOTE: not used in
class CNNPreNet(nn.Layer): class CNNPreNet(nn.Layer):
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers, def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers,
dropout=0.): dropout=0.):
@ -347,7 +325,6 @@ class CNNPostNet(nn.Layer):
# NOTE: it can also be a non-causal conv # NOTE: it can also be a non-causal conv
def forward(self, x): def forward(self, x):
# why not use pre norms
x_in = x x_in = x
for i, layer in enumerate(self.convs): for i, layer in enumerate(self.convs):
x = layer(x) x = layer(x)
@ -358,33 +335,60 @@ class CNNPostNet(nn.Layer):
class TransformerTTS(nn.Layer): class TransformerTTS(nn.Layer):
def __init__(self, vocab_size, padding_idx, d_encoder, d_decoder, d_mel, n_heads, d_ffn, def __init__(self,
encoder_layers, decoder_layers, d_prenet, d_postnet, postnet_layers, frontend: parakeet.frontend.Phonetics,
postnet_kernel_size, max_reduction_factor, dropout): d_encoder: int,
d_decoder: int,
d_mel: int,
n_heads: int,
d_ffn: int,
encoder_layers: int,
decoder_layers: int,
d_prenet: int,
d_postnet: int,
postnet_layers: int,
postnet_kernel_size: int,
max_reduction_factor: int,
decoder_prenet_dropout: float,
dropout: float):
super(TransformerTTS, self).__init__() super(TransformerTTS, self).__init__()
# text frontend (text normalization and g2p)
self.frontend = frontend
# encoder # encoder
self.encoder_prenet = nn.Embedding(vocab_size, d_encoder, padding_idx, weight_attr=I.Uniform(-0.05, 0.05)) self.encoder_prenet = nn.Embedding(
# self.encoder_prenet = CNNPreNet(d_encoder, d_encoder, d_encoder, 5, 3, dropout) frontend.vocab_size, d_encoder,
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) # it may be extended later padding_idx=frontend.vocab.padding_index,
self.encoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.)) weight_attr=I.Uniform(-0.05, 0.05))
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn, encoder_layers, dropout) # position encoding matrix may be extended later
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder)
self.encoder_pe_scalar = self.create_parameter(
[1], attr=I.Constant(1.))
self.encoder = TransformerEncoder(
d_encoder, n_heads, d_ffn, encoder_layers, dropout)
# decoder # decoder
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder) self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout)
self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder) # it may be extended later self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder)
self.decoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.)) self.decoder_pe_scalar = self.create_parameter(
self.decoder = TransformerDecoder(d_decoder, n_heads, d_ffn, decoder_layers, dropout, d_encoder=d_encoder) [1], attr=I.Constant(1.))
self.decoder = TransformerDecoder(
d_decoder, n_heads, d_ffn, decoder_layers, dropout,
d_encoder=d_encoder)
self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel) self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel)
self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers) self.decoder_postnet = CNNPostNet(
d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
self.stop_conditioner = nn.Linear(d_mel, 3) self.stop_conditioner = nn.Linear(d_mel, 3)
# specs # specs
self.padding_idx = padding_idx self.padding_idx = frontend.vocab.padding_index
self.d_encoder = d_encoder self.d_encoder = d_encoder
self.d_decoder = d_decoder self.d_decoder = d_decoder
self.d_mel = d_mel self.d_mel = d_mel
self.max_r = max_reduction_factor self.max_r = max_reduction_factor
self.dropout = dropout self.dropout = dropout
self.decoder_prenet_dropout = decoder_prenet_dropout
# start and end: though it is only used in predict # start and end: though it is only used in predict
# it can also be used in training # it can also be used in training
@ -395,13 +399,19 @@ class TransformerTTS(nn.Layer):
# mutables # mutables
self.r = max_reduction_factor # set it every call self.r = max_reduction_factor # set it every call
self.decoder_prenet_dropout = 0.0
self.drop_n_heads = 0 self.drop_n_heads = 0
def forward(self, text, mel, stop): def forward(self, text, mel):
encoded, encoder_attention_weights, encoder_mask = self.encode(text) encoded, encoder_attention_weights, encoder_mask = self.encode(text)
mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(encoded, mel, encoder_mask) mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(encoded, mel, encoder_mask)
return mel_output, mel_intermediate, encoder_attention_weights, cross_attention_weights, stop_logits outputs = {
"mel_output": mel_output,
"mel_intermediate": mel_intermediate,
"encoder_attention_weights": encoder_attention_weights,
"cross_attention_weights": cross_attention_weights,
"stop_logits": stop_logits,
}
return outputs
def encode(self, text): def encode(self, text):
T_enc = text.shape[-1] T_enc = text.shape[-1]
@ -414,7 +424,8 @@ class TransformerTTS(nn.Layer):
x = F.dropout(x, self.dropout, training=self.training) x = F.dropout(x, self.dropout, training=self.training)
# TODO(chenfeiyu): unsqueeze a decoder_time_steps=1 for the mask # TODO(chenfeiyu): unsqueeze a decoder_time_steps=1 for the mask
encoder_padding_mask = masking.id_mask(text, self.padding_idx, dtype=x.dtype) encoder_padding_mask = paddle.unsqueeze(
masking.id_mask(text, self.padding_idx, dtype=x.dtype), 1)
x, attention_weights = self.encoder(x, encoder_padding_mask, self.drop_n_heads) x, attention_weights = self.encoder(x, encoder_padding_mask, self.drop_n_heads)
return x, attention_weights, encoder_padding_mask return x, attention_weights, encoder_padding_mask
@ -453,21 +464,26 @@ class TransformerTTS(nn.Layer):
return mel_output, mel_intermediate, cross_attention_weights, stop_logits return mel_output, mel_intermediate, cross_attention_weights, stop_logits
def predict(self, input, max_length=1000, verbose=True): def predict(self, input, raw_input=True, max_length=1000, verbose=True):
"""[summary] """Predict log scale magnitude mel spectrogram from text input.
Args: Args:
input (Tensor): shape (T), dtype int, input text sequencce. input (Tensor): shape (T), dtype int, input text sequencce.
max_length (int, optional): max decoder steps. Defaults to 1000. max_length (int, optional): max decoder steps. Defaults to 1000.
verbose (bool, optional): display progress bar. Defaults to True. verbose (bool, optional): display progress bar. Defaults to True.
""" """
text_input = paddle.unsqueeze(input, 0) # (1, T) if raw_input:
text_ids = paddle.to_tensor(self.frontend(input))
text_input = paddle.unsqueeze(text_ids, 0) # (1, T)
else:
text_input = input
decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C) decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C) decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
# encoder the text sequence # encoder the text sequence
encoder_output, encoder_attentions, encoder_padding_mask = self.encode(text_input) encoder_output, encoder_attentions, encoder_padding_mask = self.encode(text_input)
for _ in range(int(max_length // self.r) + 1): for _ in trange(int(max_length // self.r) + 1):
mel_output, _, cross_attention_weights, stop_logits = self.decode( mel_output, _, cross_attention_weights, stop_logits = self.decode(
encoder_output, decoder_input, encoder_padding_mask) encoder_output, decoder_input, encoder_padding_mask)
@ -477,18 +493,22 @@ class TransformerTTS(nn.Layer):
decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1) decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1)
# stop condition? # stop condition?
if paddle.argmax(stop_logits[:, -1, :]) == self.stop_prob_index: if paddle.any(paddle.argmax(stop_logits[0, :, :], axis=-1) == self.stop_prob_index):
if verbose: if verbose:
print("Hits stop condition.") print("Hits stop condition.")
break break
mel_output = decoder_output[:, 1:, :]
return decoder_output[:, 1:, :], encoder_attentions, cross_attention_weights outputs = {
"mel_output": mel_output,
"encoder_attention_weights": encoder_attentions,
"cross_attention_weights": cross_attention_weights,
}
return outputs
def set_constants(self, reduction_factor, drop_n_heads, decoder_prenet_dropout): def set_constants(self, reduction_factor, drop_n_heads):
# TODO(chenfeiyu): make a good design for these hyperparameter settings
self.r = reduction_factor self.r = reduction_factor
self.drop_n_heads = drop_n_heads self.drop_n_heads = drop_n_heads
self.decoder_prenet_dropout = decoder_prenet_dropout
class TransformerTTSLoss(nn.Layer): class TransformerTTSLoss(nn.Layer):
@ -505,12 +525,14 @@ class TransformerTTSLoss(nn.Layer):
mel_len = mask.shape[-1] mel_len = mask.shape[-1]
last_position = F.one_hot(mask.sum(-1).astype("int64") - 1, num_classes=mel_len) last_position = F.one_hot(mask.sum(-1).astype("int64") - 1, num_classes=mel_len)
mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(mask.dtype) mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(mask.dtype)
stop_loss = L.masked_softmax_with_cross_entropy(stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1)) stop_loss = L.masked_softmax_with_cross_entropy(
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
loss = mel_loss1 + mel_loss2 + stop_loss loss = mel_loss1 + mel_loss2 + stop_loss
details = dict( losses = dict(
loss=loss, # total loss
mel_loss1=mel_loss1, # ouput mel loss mel_loss1=mel_loss1, # ouput mel loss
mel_loss2=mel_loss2, # intermediate mel loss mel_loss2=mel_loss2, # intermediate mel loss
stop_loss=stop_loss # stop prob loss stop_loss=stop_loss # stop prob loss
) )
return loss, details return losses

View File

@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import io, layer_tools, scheduler, display

View File

@ -26,6 +26,14 @@ def pack_attention_images(attention_weights, rotate=False):
img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)]) img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)])
return img return img
def add_attention_plots(writer, tag, attention_weights, global_step):
attns = [attn[0].numpy() for attn in attention_weights]
for i, attn in enumerate(attns):
img = pack_attention_images(attn)
writer.add_image(f"{tag}/{i}",
cm.plasma(img),
global_step=global_step,
dataformats="HWC")
def min_max_normalize(v): def min_max_normalize(v):
return (v - v.min()) / (v.max() - v.min()) return (v - v.min()) / (v.max() - v.min())

View File

@ -132,13 +132,13 @@ def load_parameters(model,
k].dtype: k].dtype:
model_dict[k] = v.astype(state_dict[k].numpy().dtype) model_dict[k] = v.astype(state_dict[k].numpy().dtype)
model.set_dict(model_dict) model.set_state_dict(model_dict)
print("[checkpoint] Rank {}: loaded model from {}.pdparams".format( print("[checkpoint] Rank {}: loaded model from {}.pdparams".format(
local_rank, checkpoint_path)) local_rank, checkpoint_path))
if optimizer and optimizer_dict: if optimizer and optimizer_dict:
optimizer.set_dict(optimizer_dict) optimizer.set_state_dict(optimizer_dict)
print("[checkpoint] Rank {}: loaded optimizer state from {}.pdopt". print("[checkpoint] Rank {}: loaded optimizer state from {}.pdopt".
format(local_rank, checkpoint_path)) format(local_rank, checkpoint_path))

View File

@ -0,0 +1,18 @@
import paddle
from paddle import distributed as dist
from functools import wraps
def rank_zero_only(func):
local_rank = dist.get_rank()
@wraps(func)
def wrapper(*args, **kwargs):
if local_rank != 0:
return
result = func(*args, **kwargs)
return result
return wrapper