From f375792c515b8a5f61e9cceff9d472e0ae3b9aec Mon Sep 17 00:00:00 2001 From: lfchener Date: Wed, 9 Dec 2020 12:42:41 +0000 Subject: [PATCH] add tacotron2.py and a new frontend for en --- parakeet/__init__.py | 2 +- parakeet/frontend/normalizer/numbers.py | 85 ++++- parakeet/frontend/phonectic.py | 103 +++++- parakeet/models/tacotron2.py | 424 ++++++++++++++++++++++++ 4 files changed, 595 insertions(+), 19 deletions(-) create mode 100644 parakeet/models/tacotron2.py diff --git a/parakeet/__init__.py b/parakeet/__init__.py index 2358408..3bfc0dc 100644 --- a/parakeet/__init__.py +++ b/parakeet/__init__.py @@ -14,4 +14,4 @@ __version__ = "0.2.0" -from parakeet import audio, data, datastes, frontend, models, modules, training, utils +from parakeet import audio, data, datasets, frontend, models, modules, training, utils diff --git a/parakeet/frontend/normalizer/numbers.py b/parakeet/frontend/normalizer/numbers.py index ef7343c..9d2d42c 100644 --- a/parakeet/frontend/normalizer/numbers.py +++ b/parakeet/frontend/normalizer/numbers.py @@ -1,3 +1,84 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # number expansion is not that easy -import num2words -import inflect \ No newline at end of file +import inflect +import re + +_inflect = inflect.engine() +_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') +_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') +_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') +_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') +_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') +_number_re = re.compile(r'[0-9]+') + + +def _remove_commas(m): + return m.group(1).replace(',', '') + + +def _expand_decimal_point(m): + return m.group(1).replace('.', ' point ') + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split('.') + if len(parts) > 2: + return match + ' dollars' # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + return '%s %s' % (dollars, dollar_unit) + elif cents: + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s' % (cents, cent_unit) + else: + return 'zero dollars' + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return 'two thousand' + elif num > 2000 and num < 2010: + return 'two thousand ' + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + ' hundred' + else: + return _inflect.number_to_words( + num, andword='', zero='oh', group=2).replace(', ', ' ') + else: + return _inflect.number_to_words(num, andword='') + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r'\1 pounds', text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/parakeet/frontend/phonectic.py b/parakeet/frontend/phonectic.py index bf4c852..f814681 100644 --- a/parakeet/frontend/phonectic.py +++ b/parakeet/frontend/phonectic.py @@ -1,34 +1,53 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from abc import ABC, abstractmethod from typing import Union from g2p_en import G2p from g2pM import G2pM +import re +import unicodedata +from builtins import str as unicode from parakeet.frontend import Vocab from opencc import OpenCC from parakeet.frontend.punctuation import get_punctuations +from parakeet.frontend.normalizer.numbers import normalize_numbers -__all__ = ["Phonetics", "English", "Chinese"] +__all__ = ["Phonetics", "English", "EnglishCharacter", "Chinese"] class Phonetics(ABC): @abstractmethod def __call__(self, sentence): pass - + @abstractmethod def phoneticize(self, sentence): pass - + @abstractmethod def numericalize(self, phonemes): pass + class English(Phonetics): def __init__(self): self.backend = G2p() self.phonemes = list(self.backend.phonemes) self.punctuations = get_punctuations("en") self.vocab = Vocab(self.phonemes + self.punctuations) - + def phoneticize(self, sentence): start = self.vocab.start_symbol end = self.vocab.end_symbol @@ -36,17 +55,67 @@ class English(Phonetics): + self.backend(sentence) \ + ([] if end is None else [end]) return phonemes - + def numericalize(self, phonemes): - ids = [self.vocab.lookup(item) for item in phonemes if item in self.vocab.stoi] + ids = [ + self.vocab.lookup(item) for item in phonemes + if item in self.vocab.stoi + ] return ids - + def reverse(self, ids): return [self.vocab.reverse(i) for i in ids] - + def __call__(self, sentence): return self.numericalize(self.phoneticize(sentence)) - + + @property + def vocab_size(self): + return len(self.vocab) + + +class EnglishCharacter(Phonetics): + def __init__(self): + self.backend = G2p() + self.phonemes = list(self.backend.graphemes) + self.punctuations = get_punctuations("en") + self.vocab = Vocab(self.phonemes + self.punctuations) + + def _prepocessing(self, text): + # preprocessing + text = unicode(text) + text = normalize_numbers(text) + text = ''.join( + char for char in unicodedata.normalize('NFD', text) + if unicodedata.category(char) != 'Mn') # Strip accents + text = text.lower() + text = re.sub(r"[^ a-z'.,?!\-]", "", text) + text = text.replace("i.e.", "that is") + text = text.replace("e.g.", "for example") + return text + + def phoneticize(self, sentence): + start = self.vocab.start_symbol + end = self.vocab.end_symbol + + chars = ([] if start is None else [start]) \ + + _prepocessing(sentence) \ + + ([] if end is None else [end]) + return chars + + def numericalize(self, chars): + ids = [ + self.vocab.lookup(item) for item in chars + if item in self.vocab.stoi + ] + return ids + + def reverse(self, ids): + return [self.vocab.reverse(i) for i in ids] + + def __call__(self, sentence): + return self.numericalize(self.phoneticize(sentence)) + @property def vocab_size(self): return len(self.vocab) @@ -59,9 +128,11 @@ class Chinese(Phonetics): self.phonemes = self._get_all_syllables() self.punctuations = get_punctuations("cn") self.vocab = Vocab(self.phonemes + self.punctuations) - + def _get_all_syllables(self): - all_syllables = set([syllable for k, v in self.backend.cedict.items() for syllable in v]) + all_syllables = set([ + syllable for k, v in self.backend.cedict.items() for syllable in v + ]) return list(all_syllables) def phoneticize(self, sentence): @@ -73,7 +144,7 @@ class Chinese(Phonetics): + phonemes \ + ([] if end is None else [end]) return self._filter_symbols(phonemes) - + def _filter_symbols(self, phonemes): cleaned_phonemes = [] for item in phonemes: @@ -84,17 +155,17 @@ class Chinese(Phonetics): if char in self.vocab.stoi: cleaned_phonemes.append(char) return cleaned_phonemes - + def numericalize(self, phonemes): ids = [self.vocab.lookup(item) for item in phonemes] return ids - + def __call__(self, sentence): return self.numericalize(self.phoneticize(sentence)) - + @property def vocab_size(self): return len(self.vocab) - + def reverse(self, ids): return [self.vocab.reverse(i) for i in ids] diff --git a/parakeet/models/tacotron2.py b/parakeet/models/tacotron2.py new file mode 100644 index 0000000..194e068 --- /dev/null +++ b/parakeet/models/tacotron2.py @@ -0,0 +1,424 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import paddle +from paddle import nn +from paddle.nn import functional as F +from parakeet.modules.conv import Conv1dBatchNorm +from parakeet.modules.attention import LocationSensitiveAttention +from parakeet.modules import masking + +__all__ = ["Tacotron2", "Tacotron2Loss"] + + +class DecoderPreNet(nn.Layer): + def __init__(self, + d_input: int, + d_hidden: int, + d_output: int, + dropout_rate: int=0.2): + super().__init__() + + self.linear1 = nn.Linear(d_input, d_hidden, bias_attr=False) + self.linear2 = nn.Linear(d_hidden, d_output, bias_attr=False) + + def forward(self, x): + x = F.dropout(F.relu(self.linear1(x)), self.dropout_rate) + output = F.dropout(F.relu(self.linear2(x)), self.dropout_rate) + return output + + +class DecoderPostNet(nn.Layer): + def __init__(self, + d_mels: int=80, + d_hidden: int=512, + kernel_size: int=5, + padding: int=0, + num_layers: int=5, + dropout=0.1): + super().__init__() + self.dropout = dropout + + self.conv_batchnorms = nn.LayerList() + k = math.sqrt(1.0 / (d_mels * kernel_size)) + self.conv_batchnorms.append( + Conv1dBatchNorm( + d_mels, + d_hidden, + kernel_size=kernel_size, + padding=padding, + bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform( + low=-k, high=k)), + data_format='NLC')) + + k = math.sqrt(1.0 / (d_hidden * kernel_size)) + self.conv_batchnorms.extend([ + Conv1dBatchNorm( + d_hidden, + d_hidden, + kernel_size=kernel_size, + padding=padding, + bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform( + low=-k, high=k)), + data_format='NLC') for i in range(1, num_layers - 1) + ]) + + self.conv_batchnorms.append( + Conv1dBatchNorm( + d_hidden, + d_mels, + kernel_size=kernel_size, + padding=padding, + bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform( + low=-k, high=k)), + data_format='NLC')) + + def forward(self, input): + for i in range(len(self.conv_batchnorms) - 1): + input = F.dropout( + F.tanh(self.conv_batchnorms[i](input), self.dropout)) + input = F.dropout(self.conv_batchnorms[-1](input), self.dropout) + return input + + +class Tacotron2Encoder(nn.Layer): + def __init__(self, + d_hidden: int, + conv_layers: int, + kernel_size: int, + p_dropout: float): + super().__init__() + + k = math.sqrt(1.0 / (d_hidden * kernel_size)) + self.conv_batchnorms = paddle.nn.LayerList([ + Conv1dBatchNorm( + d_hidden, + d_hidden, + kernel_size, + stride=1, + padding=int((kernel_size - 1) / 2), + bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform( + low=-k, high=k)), + data_format='NLC') for i in range(conv_layers) + ]) + self.p_dropout = p_dropout + + self.hidden_size = int(d_hidden / 2) + self.lstm = nn.LSTM( + d_hidden, self.hidden_size, direction="bidirectional") + + def forward(self, x, input_lens=None): + for conv_batchnorm in conv_batchnorms: + x = F.dropout(F.relu(conv_batchnorm(x)), + self.p_dropout) #(B, T, C) + + output, _ = self.lstm(inputs=x, sequence_length=input_lens) + return output + + +class Tacotron2Decoder(nn.Layer): + def __init__(self, + d_mels: int, + reduction_factor: int, + d_encoder: int, + d_prenet: int, + d_attention_rnn: int, + d_decoder_rnn: int, + d_attention: int, + attention_filters: int, + attention_kernel_size: int, + p_prenet_dropout: float, + p_attention_dropout: float, + p_decoder_dropout: float): + super().__init__() + self.d_mels = d_mels + self.reduction_factor = reduction_factor + self.d_encoder = d_encoder + self.d_attention_rnn = d_attention_rnn + self.d_decoder_rnn = d_decoder_rnn + self.p_attention_dropout = p_attention_dropout + self.p_decoder_dropout = p_decoder_dropout + + self.prenet = DecoderPreNet( + d_mels * reduction_factor, + d_prenet, + d_prenet, + dropout_rate=p_prenet_dropout) + + self.attention_rnn = nn.LSTMCell(d_prenet + d_encoder, d_attention_rnn) + + self.attention_layer = LocationSensitiveAttention( + d_attention_rnn, d_encoder, d_attention, attention_filters, + attention_kernel_size) + self.decoder_rnn = nn.LSTMCell(d_attention_rnn + d_encoder, + d_decoder_rnn) + self.linear_projection = nn.Linear(d_decoder_rnn + d_encoder, + d_mels * reduction_factor) + self.stop_layer = nn.Linear(d_decoder_rnn + d_encoder, 1) + + def _initialize_decoder_states(self, key): + batch_size = key.shape[0] + MAX_TIME = key.shape[1] + + self.attention_hidden = paddle.zeros( + shape=[batch_size, self.d_attention_rnn], dtype=key.dtype) + self.attention_cell = paddle.zeros( + shape=[batch_size, self.d_attention_rnn], dtype=key.dtype) + + self.decoder_hidden = paddle.zeros( + shape=[batch_size, self.d_decoder_rnn], dtype=key.dtype) + self.decoder_cell = paddle.zeros( + shape=[batch_size, self.d_decoder_rnn], dtype=key.dtype) + + self.attention_weights = paddle.zeros( + shape=[batch_size, MAX_TIME], dtype=key.dtype) + self.attention_weights_cum = paddle.zeros( + shape=[batch_size, MAX_TIME], dtype=key.dtype) + self.attention_context = paddle.zeros( + shape=[batch_size, self.d_encoder], dtype=key.dtype) + + self.key = key #[B, T, C] + self.processed_key = self.attention_layer.key_layer(key) #[B, T, C] + + def _decode(self, query): + cell_input = paddle.concat([query, self.attention_context], axis=-1) + + # The first lstm layer + _, (self.attention_hidden, self.attention_cell) = self.attention_rnn( + cell_input, (self.attention_hidden, self.attention_cell)) + self.attention_hidden = F.dropout(self.attention_hidden, + self.p_attention_dropout) + + # Loaction sensitive attention + attention_weights_cat = paddle.stack( + [self.attention_weights, self.attention_weights_cum], axis=-1) + self.attention_context, self.attention_weights = self.attention_layer( + self.attention_hidden, self.processed_key, self.key, + attention_weights_cat, self.mask) + self.attention_weights_cum += self.attention_weights + + # The second lasm layer + decoder_input = paddle.concat( + [self.attention_hidden, self.attention_context], axis=-1) + _, (self.decoder_hidden, self.decoder_cell) = self.decoder_rnn( + decoder_input, (self.decoder_hidden, self.decoder_cell)) + self.decoder_hidden = F.dropout( + self.decoder_hidden, p=self.p_decoder_dropout) + + # decode output one step + decoder_hidden_attention_context = paddle.concat( + [self.decoder_hidden, self.attention_context], axis=-1) + decoder_output = self.linear_projection( + decoder_hidden_attention_context) + stop_logit = self.stop_layer(decoder_hidden_attention_context) + return decoder_output, stop_logit, self.attention_weights + + def forward(self, key, query, mask): + query = paddle.reshape( + query, + [query.shape[0], query.shape[1] // self.reduction_factor, -1]) + query = paddle.concat( + [ + paddle.zeros( + shape=[ + query.shape[0], 1, + query.shape[-1] * self.reduction_factor + ], + dtype=query.dtype), query + ], + axis=1) + query = self.prenet(query) + + self._initialize_decoder_states(key) + self.mask = mask + + mel_outputs, stop_logits, alignments = [], [], [] + while len(mel_outputs) < query.shape[ + 1] - 1: # Ignore the last time step + query = query[:, len(mel_outputs), :] + mel_output, stop_logit, attention_weights = self._decode(query) + mel_outputs += [mel_output] + stop_logits += [stop_logit] + alignments += [attention_weights] + + alignments = paddle.stack(alignments, axis=1) + stop_logits = paddle.concat(stop_logits, axis=1) + mel_outputs = paddle.stack(mel_outputs, axis=1) + + return mel_outputs, stop_logits, alignments + + def infer(self, key, stop_threshold=0.5, max_decoder_steps=1000): + decoder_input = paddle.zeros( + shape=[key.shape[0], self.d_mels * self.reduction_factor], + dtype=key.dtype) #[B, C] + + self.initialize_decoder_states(key) + self.mask = None + + mel_outputs, stop_logits, alignments = [], [], [] + while True: + decoder_input = self.prenet(decoder_input) + mel_output, stop_logit, alignment = self.decode(decoder_input) + + mel_outputs += [mel_output] + stop_logits += [stop_logit] + alignments += [alignment] + + if F.sigmoid(stop_logit) > stop_threshold: + break + elif len(mel_outputs) == max_decoder_steps: + print("Warning! Reached max decoder steps!!!") + break + + decoder_input = mel_output + + alignments = paddle.stack(alignments, axis=1) + stop_logits = paddle.concat(stop_logits, axis=1) + mel_outputs = paddle.stack(mel_outputs, axis=1) + + return mel_outputs, stop_logits, alignments + + +class Tacotron2(nn.Layer): + """ + Tacotron2 module for end-to-end text-to-speech (E2E-TTS). + + This is a module of Spectrogram prediction network in Tacotron2 described + in `Natural TTS Synthesis + by Conditioning WaveNet on Mel Spectrogram Predictions`_, + which converts the sequence of characters + into the sequence of mel spectrogram. + + .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: + https://arxiv.org/abs/1712.05884 + """ + + def __init__(self, + frontend: parakeet.frontend.Phonetics, + d_mels: int=80, + d_embedding: int=512, + encoder_conv_layers: int=3, + d_encoder: int=512, + encoder_kernel_size: int=5, + d_prenet: int=256, + d_attention_rnn: int=1024, + d_decoder_rnn: int=1024, + attention_filters: int=32, + attention_kernel_size: int=31, + d_attention: int=128, + d_postnet: int=512, + postnet_kernel_size: int=5, + postnet_conv_layers: int=5, + reduction_factor: int=1, + p_encoder_dropout: float=0.5, + p_prenet_dropout: float=0.5, + p_attention_dropout: float=0.1, + p_decoder_dropout: float=0.1, + p_postnet_dropout: float=0.5): + super().__init__() + + std = math.sqrt(2.0 / (frontend.vocab_size + d_embedding)) + val = math.sqrt(3.0) * std # uniform bounds for std + self.embedding = nn.Embedding( + frontend.vocab_size, + d_embedding, + weight_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform( + low=-val, high=val))) + self.encoder = Tacotron2Encoder(d_encoder, encoder_conv_layers, + encoder_kernel_size, p_encoder_dropout) + self.decoder = Tacotron2Decoder( + d_mels, reduction_factor, d_encoder, d_prenet, d_attention_rnn, + d_decoder_rnn, d_attention, attention_filters, + attention_kernel_size, p_prenet_dropout, p_attention_dropout, + p_decoder_dropout) + self.postnet = DecoderPostNet( + d_mels=d_mels, + d_hidden=d_postnet, + kernel_size=postnet_kernel_size, + padding=int((postnet_kernel_size - 1) / 2), + num_layers=postnet_conv_layers, + dropout=p_postnet_dropout) + + def forward(self, text_inputs, mels, text_lens, output_lens=None): + embedded_inputs = self.embedding(text_inputs) + encoder_outputs = self.encoder(embedded_inputs, text_lens) + + mask = paddle.tensor.unsqueeze( + paddle.fluid.layers.sequence_mask( + x=text_lens, dtype=encoder_outputs.dtype), [-1]) + mel_outputs, stop_logits, alignments = self.decoder( + encoder_outputs, mels, mask=mask) + + mel_outputs_postnet = self.postnet(mel_outputs) + mel_outputs_postnet = mel_outputs + mel_outputs_postnet + + if output_lens is not None: + mask = paddle.tensor.unsqueeze( + paddle.fluid.layers.sequence_mask(x=output_lens), + [-1]) #[B, T, 1] + mel_outputs = mel_outputs * mask #[B, T, C] + mel_outputs_postnet = mel_outputs_postnet * mask #[B, T, C] + stop_logits = stop_logits * mask[:, :, 0] + (1 - mask[:, :, 0] + ) * 1e3 #[B, T] + outputs = { + "mel_output": mel_outputs, + "mel_outputs_postnet": mel_outputs_postnet, + "stop_logits": stop_logits, + "alignments": alignments + } + + return outputs + + def infer(self, text_inputs, stop_threshold=0.5, max_decoder_steps=1000): + embedded_inputs = self.embedding(text_inputs) + encoder_outputs = self.encoder(embedded_inputs) + mel_outputs, stop_logits, alignments = self.decoder.inference( + encoder_outputs, + stop_threshold=stop_threshold, + max_decoder_steps=max_decoder_steps) + + mel_outputs_postnet = self.postnet(mel_outputs) + mel_outputs_postnet = mel_outputs + mel_outputs_postnet + + outputs = { + "mel_output": mel_outputs, + "mel_outputs_postnet": mel_outputs_postnet, + "stop_logits": stop_logits, + "alignments": alignments + } + + return outputs + + def predict(self, text): + # TODO(lifuchen): implement predict function to product mel from texts + pass + + +class Tacotron2Loss(nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, mel_outputs, mel_outputs_postnet, stop_logits, + mel_targets, stop_tokens): + mel_loss = paddle.nn.MSELoss()(mel_outputs, mel_targets) + post_mel_loss = paddle.nn.MSELoss()(mel_outputs_postnet, mel_targets) + stop_loss = paddle.nn.BCEWithLogitsLoss()(stop_logits, stop_tokens) + total_loss = mel_loss + post_mel_loss + stop_loss + losses = dict( + loss=total_loss, + mel_loss=mel_loss, + post_mel_loss=post_mel_loss, + stop_loss=stop_loss) + return losses