From 9fe6ad11f0dad288b7a312c43fe2a94d035bfb6a Mon Sep 17 00:00:00 2001 From: lifuchen Date: Tue, 17 Dec 2019 06:23:34 +0000 Subject: [PATCH] Training with multi-GPU --- parakeet/audio/__init__.py | 1 + parakeet/audio/audio.py | 261 ++++++++++++++++++ parakeet/data/datacargo.py | 17 +- .../transformerTTS/config/train_postnet.yaml | 2 +- .../config/train_transformer.yaml | 2 +- parakeet/models/transformerTTS/data.py | 29 ++ parakeet/models/transformerTTS/module.py | 5 +- parakeet/models/transformerTTS/network.py | 18 +- .../models/transformerTTS/train_postnet.py | 86 ++---- .../transformerTTS/train_transformer.py | 122 +++----- 10 files changed, 393 insertions(+), 150 deletions(-) create mode 100644 parakeet/audio/__init__.py create mode 100644 parakeet/audio/audio.py create mode 100644 parakeet/models/transformerTTS/data.py diff --git a/parakeet/audio/__init__.py b/parakeet/audio/__init__.py new file mode 100644 index 0000000..6212dee --- /dev/null +++ b/parakeet/audio/__init__.py @@ -0,0 +1 @@ +from .audio import AudioProcessor \ No newline at end of file diff --git a/parakeet/audio/audio.py b/parakeet/audio/audio.py new file mode 100644 index 0000000..b29dbf2 --- /dev/null +++ b/parakeet/audio/audio.py @@ -0,0 +1,261 @@ +import librosa +import soundfile as sf +import numpy as np +import scipy.io +import scipy.signal + +class AudioProcessor(object): + def __init__(self, + sample_rate=None, # int, sampling rate + num_mels=None, # int, bands of mel spectrogram + min_level_db=None, # float, minimum level db + ref_level_db=None, # float, reference level dbn + n_fft=None, # int: number of samples in a frame for stft + win_length=None, # int: the same meaning with n_fft + hop_length=None, # int: number of samples between neighboring frame + power=None, # float:power to raise before griffin-lim + preemphasis=None, # float: preemphasis coefficident + signal_norm=None, # + symmetric_norm=False, # bool, apply clip norm in [-max_norm, max_form] + max_norm=None, # float, max norm + mel_fmin=None, # int: mel spectrogram's minimum frequency + mel_fmax=None, # int: mel spectrogram's maximum frequency + clip_norm=True, # bool: clip spectrogram's norm + griffin_lim_iters=None, # int: + do_trim_silence=False, # bool: trim silience + sound_norm=False, + **kwargs): + self.sample_rate = sample_rate + self.num_mels = num_mels + self.min_level_db = min_level_db + self.ref_level_db = ref_level_db + + # stft related + self.n_fft = n_fft + self.win_length = win_length or n_fft + # hop length defaults to 1/4 window_length + self.hop_length = hop_length or 0.25 * self.win_length + + self.power = power + self.preemphasis = float(preemphasis) + + self.griffin_lim_iters = griffin_lim_iters + self.signal_norm = signal_norm + self.symmetric_norm = symmetric_norm + + # mel transform related + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + + self.max_norm = 1.0 if max_norm is None else float(max_norm) + self.clip_norm = clip_norm + self.do_trim_silence = do_trim_silence + + self.sound_norm = sound_norm + self.num_freq, self.frame_length_ms, self.frame_shift_ms = self._stft_parameters() + + def _stft_parameters(self): + """compute frame length and hop length in ms""" + frame_length_ms = self.win_length * 1. / self.sample_rate + frame_shift_ms = self.hop_length * 1. / self.sample_rate + num_freq = 1 + self.n_fft // 2 + return num_freq, frame_length_ms, frame_shift_ms + + def __repr__(self): + """object repr""" + cls_name_str = self.__class__.__name__ + members = vars(self) + dict_str = "\n".join([" {}: {},".format(k, v) for k, v in members.items()]) + repr_str = "{}(\n{})\n".format(cls_name_str, dict_str) + return repr_str + + def save_wav(self, path, wav): + """save audio with scipy.io.wavfile in 16bit integers""" + wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) + scipy.io.wavfile.write(path, self.sample_rate, wav_norm.as_type(np.int16)) + + def load_wav(self, path, sr=None): + """load wav -> trim_silence -> rescale""" + + x, sr = librosa.load(path, sr=None) + assert self.sample_rate == sr, "audio sample rate: {}Hz != processor sample rate: {}Hz".format(sr, self.sample_rate) + if self.do_trim_silence: + try: + x = self.trim_silence(x) + except ValueError: + print(" [!] File cannot be trimmed for silence - {}".format(path)) + if self.sound_norm: + x = x / x.max() * 0.9 # why 0.9 ? + return x + + def trim_silence(self, wav): + """Trim soilent parts with a threshold and 0.01s margin""" + margin = int(self.sample_rate * 0.01) + wav = wav[margin: -margin] + trimed_wav = librosa.effects.trim(wav, top_db=60, frame_length=self.win_length, hop_length=self.hop_length)[0] + return trimed_wav + + def apply_preemphasis(self, x): + if self.preemphasis == 0.: + raise RuntimeError(" !! Preemphasis coefficient should be positive. ") + return scipy.signal.lfilter([1., -self.preemphasis], [1.], x) + + def apply_inv_preemphasis(self, x): + if self.preemphasis == 0.: + raise RuntimeError(" !! Preemphasis coefficient should be positive. ") + return scipy.signal.lfilter([1.], [1., -self.preemphasis], x) + + def _amplitude_to_db(self, x): + amplitude_min = np.exp(self.min_level_db / 20 * np.log(10)) + return 20 * np.log10(np.maximum(amplitude_min, x)) + + @staticmethod + def _db_to_amplitude(x): + return np.power(10., 0.05 * x) + + def _linear_to_mel(self, spectrogram): + _mel_basis = self._build_mel_basis() + return np.dot(_mel_basis, spectrogram) + + def _mel_to_linear(self, mel_spectrogram): + inv_mel_basis = np.linalg.pinv(self._build_mel_basis()) + return np.maximum(1e-10, np.dot(inv_mel_basis, mel_spectrogram)) + + def _build_mel_basis(self): + """return mel basis for mel scale""" + if self.mel_fmax is not None: + assert self.mel_fmax <= self.sample_rate // 2 + return librosa.filters.mel( + self.sample_rate, + self.n_fft, + n_mels=self.num_mels, + fmin=self.mel_fmin, + fmax=self.mel_fmax) + + def _normalize(self, S): + """put values in [0, self.max_norm] or [-self.max_norm, self,max_norm]""" + if self.signal_norm: + S_norm = (S - self.min_level_db) / (-self.min_level_db) + if self.symmetric_norm: + S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm + if self.clip_norm: + S_norm = np.clip(S_norm, -self.max_norm, self.max_norm) + return S_norm + else: + S_norm = self.max_norm * S_norm + if self.clip_norm: + S_norm = np.clip(S_norm, 0, self.max_norm) + return S_norm + else: + return S + + def _denormalize(self, S): + """denormalize values""" + S_denorm = S + if self.signal_norm: + if self.symmetric_norm: + if self.clip_norm: + S_denorm = np.clip(S_denorm, -self.max_norm, self.max_norm) + S_denorm = (S_denorm + self.max_norm) * (-self.min_level_db) / (2 * self.max_norm) + self.min_level_db + return S_denorm + else: + if self.clip_norm: + S_denorm = np.clip(S_denorm, 0, self.max_norm) + S_denorm = S_denorm * (-self.min_level_db)/ self.max_norm + self.min_level_db + return S_denorm + else: + return S + + def _stft(self, y): + return librosa.stft( + y=y, + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length) + + def _istft(self, S): + return librosa.istft(S, hop_length=self.hop_length, win_length=self.win_length) + + def spectrogram(self, y): + """compute linear spectrogram(amplitude) + preemphasis -> stft -> mag -> amplitude_to_db -> minus_ref_level_db -> normalize + """ + if self.preemphasis: + D = self._stft(self.apply_preemphasis(y)) + else: + D = self._stft(y) + S = self._amplitude_to_db(np.abs(D)) - self.ref_level_db + return self._normalize(S) + + def melspectrogram(self, y): + """compute linear spectrogram(amplitude) + preemphasis -> stft -> mag -> mel_scale -> amplitude_to_db -> minus_ref_level_db -> normalize + """ + if self.preemphasis: + D = self._stft(self.apply_preemphasis(y)) + else: + D = self._stft(y) + S = self._amplitude_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db + return self._normalize(S) + + def inv_spectrogram(self, spectrogram): + """convert spectrogram back to waveform using griffin_lim in librosa""" + S = self._denormalize(spectrogram) + S = self._db_to_amplitude(S + self.ref_level_db) + if self.preemphasis: + return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) + return self._griffin_lim(S ** self.power) + + def inv_melspectrogram(self, mel_spectrogram): + S = self._denormalize(mel_spectrogram) + S = self._db_to_amplitude(S + self.ref_level_db) + S = self._linear_to_mel(np.abs(S)) + if self.preemphasis: + return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) + return self._griffin_lim(S ** self.power) + + def out_linear_to_mel(self, linear_spec): + """convert output linear spec to mel spec""" + S = self._denormalize(linear_spec) + S = self._db_to_amplitude(S + self.ref_level_db) + S = self._linear_to_mel(np.abs(S)) + S = self._amplitude_to_db(S) - self.ref_level_db + mel = self._normalize(S) + return mel + + def _griffin_lim(self, S): + angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) + S_complex = np.abs(S).astype(np.complex) + y = self._istft(S_complex * angles) + for _ in range(self.griffin_lim_iters): + angles = np.exp(1j * np.angle(self._stft(y))) + y = self._istft(S_complex * angles) + return y + + @staticmethod + def mulaw_encode(wav, qc): + mu = 2 ** qc - 1 + # wav_abs = np.minimum(np.abs(wav), 1.0) + signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1. + mu) + # Quantize signal to the specified number of levels. + signal = (signal + 1) / 2 * mu + 0.5 + return np.floor(signal,) + + @staticmethod + def mulaw_decode(wav, qc): + """Recovers waveform from quantized values.""" + mu = 2 ** qc - 1 + x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) + return x + + @staticmethod + def encode_16bits(x): + return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16) + + @staticmethod + def quantize(x, bits): + return (x + 1.) * (2**bits - 1) / 2 + + @staticmethod + def dequantize(x, bits): + return 2 * x / (2**bits - 1) - 1 diff --git a/parakeet/data/datacargo.py b/parakeet/data/datacargo.py index 1d7d8d5..e087a4f 100644 --- a/parakeet/data/datacargo.py +++ b/parakeet/data/datacargo.py @@ -2,7 +2,8 @@ from .sampler import SequentialSampler, RandomSampler, BatchSampler class DataCargo(object): def __init__(self, dataset, batch_size=1, sampler=None, - shuffle=False, batch_sampler=None, drop_last=False): + shuffle=False, batch_sampler=None, collate_fn=None, + drop_last=False): self.dataset = dataset if batch_sampler is not None: @@ -21,13 +22,20 @@ class DataCargo(object): sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset) - # auto_collation without custom batch_sampler batch_sampler = BatchSampler(sampler, batch_size, drop_last) + else: + batch_sampler = BatchSampler(sampler, batch_size, drop_last) + + self.batch_sampler = batch_sampler + + if collate_fn is None: + collate_fn = dataset._batch_examples + self.collate_fn = collate_fn self.batch_size = batch_size self.drop_last = drop_last self.sampler = sampler - self.batch_sampler = batch_sampler + def __iter__(self): return DataIterator(self) @@ -57,6 +65,7 @@ class DataIterator(object): self._index_sampler = loader._index_sampler self._sampler_iter = iter(self._index_sampler) + self.collate_fn = loader.collate_fn def __iter__(self): return self @@ -64,7 +73,7 @@ class DataIterator(object): def __next__(self): index = self._next_index() # may raise StopIteration, TODO(chenfeiyu): use dynamic batch size minibatch = [self._dataset[i] for i in index] # we can abstract it, too to use dynamic batch size - minibatch = self._dataset._batch_examples(minibatch) # list[Example] -> Batch + minibatch = self.collate_fn(minibatch) return minibatch def _next_index(self): diff --git a/parakeet/models/transformerTTS/config/train_postnet.yaml b/parakeet/models/transformerTTS/config/train_postnet.yaml index 90ac94e..5753ab1 100644 --- a/parakeet/models/transformerTTS/config/train_postnet.yaml +++ b/parakeet/models/transformerTTS/config/train_postnet.yaml @@ -20,7 +20,7 @@ epochs: 10000 lr: 0.001 save_step: 500 use_gpu: True -use_data_parallel: False +use_data_parallel: True data_path: ../../../dataset/LJSpeech-1.1 save_path: ./checkpoint diff --git a/parakeet/models/transformerTTS/config/train_transformer.yaml b/parakeet/models/transformerTTS/config/train_transformer.yaml index 17db190..3e56a4f 100644 --- a/parakeet/models/transformerTTS/config/train_transformer.yaml +++ b/parakeet/models/transformerTTS/config/train_transformer.yaml @@ -21,7 +21,7 @@ lr: 0.001 save_step: 500 image_step: 2000 use_gpu: True -use_data_parallel: False +use_data_parallel: True data_path: ../../../dataset/LJSpeech-1.1 save_path: ./checkpoint diff --git a/parakeet/models/transformerTTS/data.py b/parakeet/models/transformerTTS/data.py new file mode 100644 index 0000000..f432640 --- /dev/null +++ b/parakeet/models/transformerTTS/data.py @@ -0,0 +1,29 @@ +from pathlib import Path +import numpy as np +from paddle import fluid +from parakeet.data.sampler import DistributedSampler +from parakeet.data.datacargo import DataCargo +from preprocess import batch_examples, LJSpeech, batch_examples_postnet + +class LJSpeechLoader: + def __init__(self, config, nranks, rank, is_postnet=False): + place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace() + + LJSPEECH_ROOT = Path(config.data_path) + dataset = LJSpeech(LJSPEECH_ROOT) + sampler = DistributedSampler(len(dataset), nranks, rank) + + assert config.batch_size % nranks == 0 + each_bs = config.batch_size // nranks + if is_postnet: + dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=True, collate_fn=batch_examples_postnet, drop_last=True) + else: + dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=True, collate_fn=batch_examples, drop_last=True) + + self.reader = fluid.io.DataLoader.from_generator( + capacity=32, + iterable=True, + use_double_buffer=True, + return_list=True) + self.reader.set_batch_generator(dataloader, place) + diff --git a/parakeet/models/transformerTTS/module.py b/parakeet/models/transformerTTS/module.py index 76bdffb..f83bff5 100644 --- a/parakeet/models/transformerTTS/module.py +++ b/parakeet/models/transformerTTS/module.py @@ -130,7 +130,7 @@ class EncoderPrenet(dg.Layer): self.projection = FC(self.full_name(), num_hidden, num_hidden) def forward(self, x): - x = self.embedding(fluid.layers.unsqueeze(x, axes=[-1])) #(batch_size, seq_len, embending_size) + x = self.embedding(x) #(batch_size, seq_len, embending_size) x = layers.transpose(x,[0,2,1]) x = layers.dropout(layers.relu(self.batch_norm1(self.conv1(x))), 0.2) x = layers.dropout(layers.relu(self.batch_norm2(self.conv2(x))), 0.2) @@ -211,8 +211,9 @@ class ScaledDotProductAttention(dg.Layer): # Mask key to ignore padding if mask is not None: attention = attention * mask - mask = (mask == 0).astype(float) * (-2 ** 32 + 1) + mask = (mask == 0).astype(np.float32) * (-2 ** 32 + 1) attention = attention + mask + attention = layers.softmax(attention) # Mask query to ignore padding diff --git a/parakeet/models/transformerTTS/network.py b/parakeet/models/transformerTTS/network.py index ff25ad2..3d356dc 100644 --- a/parakeet/models/transformerTTS/network.py +++ b/parakeet/models/transformerTTS/network.py @@ -7,9 +7,9 @@ class Encoder(dg.Layer): def __init__(self, name_scope, embedding_size, num_hidden, config): super(Encoder, self).__init__(name_scope) self.num_hidden = num_hidden - param = fluid.ParamAttr(name='alpha') - self.alpha = self.create_parameter(param, shape=(1, ), dtype='float32', - default_initializer = fluid.initializer.ConstantInitializer(value=1.0)) + param = fluid.ParamAttr(name='alpha', + initializer=fluid.initializer.Constant(value=1.0)) + self.alpha = self.create_parameter(param, shape=(1, ), dtype='float32') self.pos_inp = get_sinusoid_encoding_table(1024, self.num_hidden, padding_idx=0) self.pos_emb = dg.Embedding(name_scope=self.full_name(), size=[1024, num_hidden], @@ -31,8 +31,8 @@ class Encoder(dg.Layer): def forward(self, x, positional): if fluid.framework._dygraph_tracer()._train_mode: - query_mask = (positional != 0).astype(float) - mask = (positional != 0).astype(float) + query_mask = (positional != 0).astype(np.float32) + mask = (positional != 0).astype(np.float32) mask = fluid.layers.expand(fluid.layers.unsqueeze(mask,[1]), [1,x.shape[1], 1]) else: query_mask, mask = None, None @@ -42,7 +42,7 @@ class Encoder(dg.Layer): # Get positional encoding - positional = self.pos_emb(fluid.layers.unsqueeze(positional, axes=[-1])) + positional = self.pos_emb(positional) x = positional * self.alpha + x #(N, T, C) @@ -102,14 +102,14 @@ class Decoder(dg.Layer): if fluid.framework._dygraph_tracer()._train_mode: #zeros = np.zeros(positional.shape, dtype=np.float32) - m_mask = (positional != 0).astype(float) + m_mask = (positional != 0).astype(np.float32) mask = np.repeat(np.expand_dims(m_mask.numpy() == 0, axis=1), decoder_len, axis=1) mask = mask + np.repeat(np.expand_dims(np.triu(np.ones([decoder_len, decoder_len]), 1), axis=0) ,batch_size, axis=0) mask = fluid.layers.cast(dg.to_variable(mask == 0), np.float32) # (batch_size, decoder_len, decoder_len) - zero_mask = fluid.layers.expand(fluid.layers.unsqueeze((c_mask != 0).astype(float), axes=2), [1,1,decoder_len]) + zero_mask = fluid.layers.expand(fluid.layers.unsqueeze((c_mask != 0).astype(np.float32), axes=2), [1,1,decoder_len]) # (batch_size, decoder_len, seq_len) zero_mask = fluid.layers.transpose(zero_mask, [0,2,1]) @@ -125,7 +125,7 @@ class Decoder(dg.Layer): query = self.linear(query) # Get position embedding - positional = self.pos_emb(fluid.layers.unsqueeze(positional, axes=[-1])) + positional = self.pos_emb(positional) query = positional * self.alpha + query #positional dropout diff --git a/parakeet/models/transformerTTS/train_postnet.py b/parakeet/models/transformerTTS/train_postnet.py index 6e32f9c..8beeece 100644 --- a/parakeet/models/transformerTTS/train_postnet.py +++ b/parakeet/models/transformerTTS/train_postnet.py @@ -1,13 +1,12 @@ from network import * -from preprocess import batch_examples_postnet, LJSpeech from tensorboardX import SummaryWriter import os from tqdm import tqdm -from parakeet.data.datacargo import DataCargo from pathlib import Path import jsonargparse from parse import add_config_options_to_parser from pprint import pprint +from data import LJSpeechLoader class MyDataParallel(dg.parallel.DataParallel): """ @@ -27,21 +26,15 @@ class MyDataParallel(dg.parallel.DataParallel): object.__getattribute__(self, "_sub_layers")["_layers"], key) -def main(): - parser = jsonargparse.ArgumentParser(description="Train postnet model", formatter_class='default_argparse') - add_config_options_to_parser(parser) - cfg = parser.parse_args('-c ./config/train_postnet.yaml'.split()) +def main(cfg): + + local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0 + nranks = dg.parallel.Env().nranks if cfg.use_data_parallel else 1 - local_rank = dg.parallel.Env().local_rank - if local_rank == 0: # Print the whole config setting. pprint(jsonargparse.namespace_to_dict(cfg)) - LJSPEECH_ROOT = Path(cfg.data_path) - dataset = LJSpeech(LJSPEECH_ROOT) - dataloader = DataCargo(dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=batch_examples_postnet, drop_last=True) - global_step = 0 place = (fluid.CUDAPlace(dg.parallel.Env().dev_id) if cfg.use_data_parallel else fluid.CUDAPlace(0) @@ -50,35 +43,10 @@ def main(): if not os.path.exists(cfg.log_dir): os.mkdir(cfg.log_dir) path = os.path.join(cfg.log_dir,'postnet') - writer = SummaryWriter(path) - with dg.guard(place): - # dataloader - input_fields = { - 'names': ['mel', 'mag'], - 'shapes': - [[cfg.batch_size, None, 80], [cfg.batch_size, None, 257]], - 'dtypes': ['float32', 'float32'], - 'lod_levels': [0, 0] - } + writer = SummaryWriter(path) if local_rank == 0 else None - inputs = [ - fluid.data( - name=input_fields['names'][i], - shape=input_fields['shapes'][i], - dtype=input_fields['dtypes'][i], - lod_level=input_fields['lod_levels'][i]) - for i in range(len(input_fields['names'])) - ] - - reader = fluid.io.DataLoader.from_generator( - feed_list=inputs, - capacity=32, - iterable=True, - use_double_buffer=True, - return_list=True) - - + with dg.guard(place): model = ModelPostNet('postnet', cfg) model.train() @@ -94,9 +62,10 @@ def main(): strategy = dg.parallel.prepare_context() model = MyDataParallel(model, strategy) + reader = LJSpeechLoader(cfg, nranks, local_rank, is_postnet=True).reader() + for epoch in range(cfg.epochs): - reader.set_batch_generator(dataloader, place) - pbar = tqdm(reader()) + pbar = tqdm(reader) for i, data in enumerate(pbar): pbar.set_description('Processing at epoch %d'%epoch) mel, mag = data @@ -109,27 +78,30 @@ def main(): loss = layers.mean(layers.abs(layers.elementwise_sub(mag_pred, mag))) if cfg.use_data_parallel: loss = model.scale_loss(loss) - - writer.add_scalars('training_loss',{ - 'loss':loss.numpy(), - }, global_step) - - loss.backward() - if cfg.use_data_parallel: + loss.backward() model.apply_collective_grads() + else: + loss.backward() optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(1)) model.clear_gradients() - - if global_step % cfg.save_step == 0: - if not os.path.exists(cfg.save_path): - os.mkdir(cfg.save_path) - save_path = os.path.join(cfg.save_path,'postnet/%d' % global_step) - dg.save_dygraph(model.state_dict(), save_path) - dg.save_dygraph(optimizer.state_dict(), save_path) - + if local_rank==0: + writer.add_scalars('training_loss',{ + 'loss':loss.numpy(), + }, global_step) + if global_step % cfg.save_step == 0: + if not os.path.exists(cfg.save_path): + os.mkdir(cfg.save_path) + save_path = os.path.join(cfg.save_path,'postnet/%d' % global_step) + dg.save_dygraph(model.state_dict(), save_path) + dg.save_dygraph(optimizer.state_dict(), save_path) + if local_rank==0: + writer.close() if __name__ == '__main__': - main() \ No newline at end of file + parser = jsonargparse.ArgumentParser(description="Train postnet model", formatter_class='default_argparse') + add_config_options_to_parser(parser) + cfg = parser.parse_args('-c ./config/train_postnet.yaml'.split()) + main(cfg) \ No newline at end of file diff --git a/parakeet/models/transformerTTS/train_transformer.py b/parakeet/models/transformerTTS/train_transformer.py index 0cdbf37..065be6d 100644 --- a/parakeet/models/transformerTTS/train_transformer.py +++ b/parakeet/models/transformerTTS/train_transformer.py @@ -1,16 +1,15 @@ -from preprocess import batch_examples, LJSpeech import os from tqdm import tqdm import paddle.fluid.dygraph as dg import paddle.fluid.layers as layers from network import * from tensorboardX import SummaryWriter -from parakeet.data.datacargo import DataCargo from pathlib import Path import jsonargparse from parse import add_config_options_to_parser from pprint import pprint from matplotlib import cm +from data import LJSpeechLoader class MyDataParallel(dg.parallel.DataParallel): """ @@ -30,21 +29,14 @@ class MyDataParallel(dg.parallel.DataParallel): object.__getattribute__(self, "_sub_layers")["_layers"], key) -def main(): - parser = jsonargparse.ArgumentParser(description="Train TransformerTTS model", formatter_class='default_argparse') - add_config_options_to_parser(parser) - cfg = parser.parse_args('-c ./config/train_transformer.yaml'.split()) - - local_rank = dg.parallel.Env().local_rank +def main(cfg): + local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0 + nranks = dg.parallel.Env().nranks if cfg.use_data_parallel else 1 if local_rank == 0: # Print the whole config setting. pprint(jsonargparse.namespace_to_dict(cfg)) - - LJSPEECH_ROOT = Path(cfg.data_path) - dataset = LJSpeech(LJSPEECH_ROOT) - dataloader = DataCargo(dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=batch_examples, drop_last=True) global_step = 0 place = (fluid.CUDAPlace(dg.parallel.Env().dev_id) if cfg.use_data_parallel else fluid.CUDAPlace(0) @@ -57,39 +49,13 @@ def main(): writer = SummaryWriter(path) if local_rank == 0 else None with dg.guard(place): - if cfg.use_data_parallel: - strategy = dg.parallel.prepare_context() - - # dataloader - input_fields = { - 'names': ['character', 'mel', 'mel_input', 'pos_text', 'pos_mel', 'text_len'], - 'shapes': - [[cfg.batch_size, None], [cfg.batch_size, None, 80], [cfg.batch_size, None, 80], [cfg.batch_size, 1], [cfg.batch_size, 1], [cfg.batch_size, 1]], - 'dtypes': ['float32', 'float32', 'float32', 'int64', 'int64', 'int64'], - 'lod_levels': [0, 0, 0, 0, 0, 0] - } - - inputs = [ - fluid.data( - name=input_fields['names'][i], - shape=input_fields['shapes'][i], - dtype=input_fields['dtypes'][i], - lod_level=input_fields['lod_levels'][i]) - for i in range(len(input_fields['names'])) - ] - - reader = fluid.io.DataLoader.from_generator( - feed_list=inputs, - capacity=32, - iterable=True, - use_double_buffer=True, - return_list=True) - model = Model('transtts', cfg) model.train() optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(4000 *( cfg.lr ** 2)), 4000)) - + + reader = LJSpeechLoader(cfg, nranks, local_rank).reader() + if cfg.checkpoint_path is not None: model_dict, opti_dict = fluid.dygraph.load_dygraph(cfg.checkpoint_path) model.set_dict(model_dict) @@ -97,11 +63,11 @@ def main(): print("load checkpoint!!!") if cfg.use_data_parallel: + strategy = dg.parallel.prepare_context() model = MyDataParallel(model, strategy) - + for epoch in range(cfg.epochs): - reader.set_batch_generator(dataloader, place) - pbar = tqdm(reader()) + pbar = tqdm(reader) for i, data in enumerate(pbar): pbar.set_description('Processing at epoch %d'%epoch) character, mel, mel_input, pos_text, pos_mel, text_length = data @@ -114,40 +80,41 @@ def main(): post_mel_loss = layers.mean(layers.abs(layers.elementwise_sub(postnet_pred, mel))) loss = mel_loss + post_mel_loss + if local_rank==0: + writer.add_scalars('training_loss', { + 'mel_loss':mel_loss.numpy(), + 'post_mel_loss':post_mel_loss.numpy(), + }, global_step) + + writer.add_scalars('alphas', { + 'encoder_alpha':model.encoder.alpha.numpy(), + 'decoder_alpha':model.decoder.alpha.numpy(), + }, global_step) + + writer.add_scalar('learning_rate', optimizer._learning_rate.step().numpy(), global_step) + + if global_step % cfg.image_step == 1: + for i, prob in enumerate(attn_probs): + for j in range(4): + x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255) + writer.add_image('Attention_enc_%d_0'%global_step, x, i*4+j, dataformats="HWC") + + for i, prob in enumerate(attn_enc): + for j in range(4): + x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255) + writer.add_image('Attention_enc_%d_0'%global_step, x, i*4+j, dataformats="HWC") + + for i, prob in enumerate(attn_dec): + for j in range(4): + x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255) + writer.add_image('Attention_dec_%d_0'%global_step, x, i*4+j, dataformats="HWC") + if cfg.use_data_parallel: loss = model.scale_loss(loss) - - writer.add_scalars('training_loss', { - 'mel_loss':mel_loss.numpy(), - 'post_mel_loss':post_mel_loss.numpy(), - }, global_step) - - writer.add_scalars('alphas', { - 'encoder_alpha':model.encoder.alpha.numpy(), - 'decoder_alpha':model.decoder.alpha.numpy(), - }, global_step) - - writer.add_scalar('learning_rate', optimizer._learning_rate.step().numpy(), global_step) - - if global_step % cfg.image_step == 1: - for i, prob in enumerate(attn_probs): - for j in range(4): - x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255) - writer.add_image('Attention_enc_%d_0'%global_step, x, i*4+j, dataformats="HWC") - - for i, prob in enumerate(attn_enc): - for j in range(4): - x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255) - writer.add_image('Attention_enc_%d_0'%global_step, x, i*4+j, dataformats="HWC") - - for i, prob in enumerate(attn_dec): - for j in range(4): - x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255) - writer.add_image('Attention_dec_%d_0'%global_step, x, i*4+j, dataformats="HWC") - - loss.backward() - if cfg.use_data_parallel: + loss.backward() model.apply_collective_grads() + else: + loss.backward() optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(1)) model.clear_gradients() @@ -163,4 +130,7 @@ def main(): if __name__ =='__main__': - main() \ No newline at end of file + parser = jsonargparse.ArgumentParser(description="Train TransformerTTS model", formatter_class='default_argparse') + add_config_options_to_parser(parser) + cfg = parser.parse_args('-c ./config/train_transformer.yaml'.split()) + main(cfg) \ No newline at end of file