Training with multi-GPU

This commit is contained in:
lifuchen 2019-12-17 06:23:34 +00:00 committed by chenfeiyu
parent 8a9bbc2634
commit 9fe6ad11f0
10 changed files with 393 additions and 150 deletions

View File

@ -0,0 +1 @@
from .audio import AudioProcessor

261
parakeet/audio/audio.py Normal file
View File

@ -0,0 +1,261 @@
import librosa
import soundfile as sf
import numpy as np
import scipy.io
import scipy.signal
class AudioProcessor(object):
def __init__(self,
sample_rate=None, # int, sampling rate
num_mels=None, # int, bands of mel spectrogram
min_level_db=None, # float, minimum level db
ref_level_db=None, # float, reference level dbn
n_fft=None, # int: number of samples in a frame for stft
win_length=None, # int: the same meaning with n_fft
hop_length=None, # int: number of samples between neighboring frame
power=None, # float:power to raise before griffin-lim
preemphasis=None, # float: preemphasis coefficident
signal_norm=None, #
symmetric_norm=False, # bool, apply clip norm in [-max_norm, max_form]
max_norm=None, # float, max norm
mel_fmin=None, # int: mel spectrogram's minimum frequency
mel_fmax=None, # int: mel spectrogram's maximum frequency
clip_norm=True, # bool: clip spectrogram's norm
griffin_lim_iters=None, # int:
do_trim_silence=False, # bool: trim silience
sound_norm=False,
**kwargs):
self.sample_rate = sample_rate
self.num_mels = num_mels
self.min_level_db = min_level_db
self.ref_level_db = ref_level_db
# stft related
self.n_fft = n_fft
self.win_length = win_length or n_fft
# hop length defaults to 1/4 window_length
self.hop_length = hop_length or 0.25 * self.win_length
self.power = power
self.preemphasis = float(preemphasis)
self.griffin_lim_iters = griffin_lim_iters
self.signal_norm = signal_norm
self.symmetric_norm = symmetric_norm
# mel transform related
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.max_norm = 1.0 if max_norm is None else float(max_norm)
self.clip_norm = clip_norm
self.do_trim_silence = do_trim_silence
self.sound_norm = sound_norm
self.num_freq, self.frame_length_ms, self.frame_shift_ms = self._stft_parameters()
def _stft_parameters(self):
"""compute frame length and hop length in ms"""
frame_length_ms = self.win_length * 1. / self.sample_rate
frame_shift_ms = self.hop_length * 1. / self.sample_rate
num_freq = 1 + self.n_fft // 2
return num_freq, frame_length_ms, frame_shift_ms
def __repr__(self):
"""object repr"""
cls_name_str = self.__class__.__name__
members = vars(self)
dict_str = "\n".join([" {}: {},".format(k, v) for k, v in members.items()])
repr_str = "{}(\n{})\n".format(cls_name_str, dict_str)
return repr_str
def save_wav(self, path, wav):
"""save audio with scipy.io.wavfile in 16bit integers"""
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
scipy.io.wavfile.write(path, self.sample_rate, wav_norm.as_type(np.int16))
def load_wav(self, path, sr=None):
"""load wav -> trim_silence -> rescale"""
x, sr = librosa.load(path, sr=None)
assert self.sample_rate == sr, "audio sample rate: {}Hz != processor sample rate: {}Hz".format(sr, self.sample_rate)
if self.do_trim_silence:
try:
x = self.trim_silence(x)
except ValueError:
print(" [!] File cannot be trimmed for silence - {}".format(path))
if self.sound_norm:
x = x / x.max() * 0.9 # why 0.9 ?
return x
def trim_silence(self, wav):
"""Trim soilent parts with a threshold and 0.01s margin"""
margin = int(self.sample_rate * 0.01)
wav = wav[margin: -margin]
trimed_wav = librosa.effects.trim(wav, top_db=60, frame_length=self.win_length, hop_length=self.hop_length)[0]
return trimed_wav
def apply_preemphasis(self, x):
if self.preemphasis == 0.:
raise RuntimeError(" !! Preemphasis coefficient should be positive. ")
return scipy.signal.lfilter([1., -self.preemphasis], [1.], x)
def apply_inv_preemphasis(self, x):
if self.preemphasis == 0.:
raise RuntimeError(" !! Preemphasis coefficient should be positive. ")
return scipy.signal.lfilter([1.], [1., -self.preemphasis], x)
def _amplitude_to_db(self, x):
amplitude_min = np.exp(self.min_level_db / 20 * np.log(10))
return 20 * np.log10(np.maximum(amplitude_min, x))
@staticmethod
def _db_to_amplitude(x):
return np.power(10., 0.05 * x)
def _linear_to_mel(self, spectrogram):
_mel_basis = self._build_mel_basis()
return np.dot(_mel_basis, spectrogram)
def _mel_to_linear(self, mel_spectrogram):
inv_mel_basis = np.linalg.pinv(self._build_mel_basis())
return np.maximum(1e-10, np.dot(inv_mel_basis, mel_spectrogram))
def _build_mel_basis(self):
"""return mel basis for mel scale"""
if self.mel_fmax is not None:
assert self.mel_fmax <= self.sample_rate // 2
return librosa.filters.mel(
self.sample_rate,
self.n_fft,
n_mels=self.num_mels,
fmin=self.mel_fmin,
fmax=self.mel_fmax)
def _normalize(self, S):
"""put values in [0, self.max_norm] or [-self.max_norm, self,max_norm]"""
if self.signal_norm:
S_norm = (S - self.min_level_db) / (-self.min_level_db)
if self.symmetric_norm:
S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm
if self.clip_norm:
S_norm = np.clip(S_norm, -self.max_norm, self.max_norm)
return S_norm
else:
S_norm = self.max_norm * S_norm
if self.clip_norm:
S_norm = np.clip(S_norm, 0, self.max_norm)
return S_norm
else:
return S
def _denormalize(self, S):
"""denormalize values"""
S_denorm = S
if self.signal_norm:
if self.symmetric_norm:
if self.clip_norm:
S_denorm = np.clip(S_denorm, -self.max_norm, self.max_norm)
S_denorm = (S_denorm + self.max_norm) * (-self.min_level_db) / (2 * self.max_norm) + self.min_level_db
return S_denorm
else:
if self.clip_norm:
S_denorm = np.clip(S_denorm, 0, self.max_norm)
S_denorm = S_denorm * (-self.min_level_db)/ self.max_norm + self.min_level_db
return S_denorm
else:
return S
def _stft(self, y):
return librosa.stft(
y=y,
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length)
def _istft(self, S):
return librosa.istft(S, hop_length=self.hop_length, win_length=self.win_length)
def spectrogram(self, y):
"""compute linear spectrogram(amplitude)
preemphasis -> stft -> mag -> amplitude_to_db -> minus_ref_level_db -> normalize
"""
if self.preemphasis:
D = self._stft(self.apply_preemphasis(y))
else:
D = self._stft(y)
S = self._amplitude_to_db(np.abs(D)) - self.ref_level_db
return self._normalize(S)
def melspectrogram(self, y):
"""compute linear spectrogram(amplitude)
preemphasis -> stft -> mag -> mel_scale -> amplitude_to_db -> minus_ref_level_db -> normalize
"""
if self.preemphasis:
D = self._stft(self.apply_preemphasis(y))
else:
D = self._stft(y)
S = self._amplitude_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
return self._normalize(S)
def inv_spectrogram(self, spectrogram):
"""convert spectrogram back to waveform using griffin_lim in librosa"""
S = self._denormalize(spectrogram)
S = self._db_to_amplitude(S + self.ref_level_db)
if self.preemphasis:
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
return self._griffin_lim(S ** self.power)
def inv_melspectrogram(self, mel_spectrogram):
S = self._denormalize(mel_spectrogram)
S = self._db_to_amplitude(S + self.ref_level_db)
S = self._linear_to_mel(np.abs(S))
if self.preemphasis:
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
return self._griffin_lim(S ** self.power)
def out_linear_to_mel(self, linear_spec):
"""convert output linear spec to mel spec"""
S = self._denormalize(linear_spec)
S = self._db_to_amplitude(S + self.ref_level_db)
S = self._linear_to_mel(np.abs(S))
S = self._amplitude_to_db(S) - self.ref_level_db
mel = self._normalize(S)
return mel
def _griffin_lim(self, S):
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
S_complex = np.abs(S).astype(np.complex)
y = self._istft(S_complex * angles)
for _ in range(self.griffin_lim_iters):
angles = np.exp(1j * np.angle(self._stft(y)))
y = self._istft(S_complex * angles)
return y
@staticmethod
def mulaw_encode(wav, qc):
mu = 2 ** qc - 1
# wav_abs = np.minimum(np.abs(wav), 1.0)
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1. + mu)
# Quantize signal to the specified number of levels.
signal = (signal + 1) / 2 * mu + 0.5
return np.floor(signal,)
@staticmethod
def mulaw_decode(wav, qc):
"""Recovers waveform from quantized values."""
mu = 2 ** qc - 1
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
return x
@staticmethod
def encode_16bits(x):
return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16)
@staticmethod
def quantize(x, bits):
return (x + 1.) * (2**bits - 1) / 2
@staticmethod
def dequantize(x, bits):
return 2 * x / (2**bits - 1) - 1

View File

@ -2,7 +2,8 @@ from .sampler import SequentialSampler, RandomSampler, BatchSampler
class DataCargo(object): class DataCargo(object):
def __init__(self, dataset, batch_size=1, sampler=None, def __init__(self, dataset, batch_size=1, sampler=None,
shuffle=False, batch_sampler=None, drop_last=False): shuffle=False, batch_sampler=None, collate_fn=None,
drop_last=False):
self.dataset = dataset self.dataset = dataset
if batch_sampler is not None: if batch_sampler is not None:
@ -21,13 +22,20 @@ class DataCargo(object):
sampler = RandomSampler(dataset) sampler = RandomSampler(dataset)
else: else:
sampler = SequentialSampler(dataset) sampler = SequentialSampler(dataset)
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last) batch_sampler = BatchSampler(sampler, batch_size, drop_last)
else:
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.batch_sampler = batch_sampler
if collate_fn is None:
collate_fn = dataset._batch_examples
self.collate_fn = collate_fn
self.batch_size = batch_size self.batch_size = batch_size
self.drop_last = drop_last self.drop_last = drop_last
self.sampler = sampler self.sampler = sampler
self.batch_sampler = batch_sampler
def __iter__(self): def __iter__(self):
return DataIterator(self) return DataIterator(self)
@ -57,6 +65,7 @@ class DataIterator(object):
self._index_sampler = loader._index_sampler self._index_sampler = loader._index_sampler
self._sampler_iter = iter(self._index_sampler) self._sampler_iter = iter(self._index_sampler)
self.collate_fn = loader.collate_fn
def __iter__(self): def __iter__(self):
return self return self
@ -64,7 +73,7 @@ class DataIterator(object):
def __next__(self): def __next__(self):
index = self._next_index() # may raise StopIteration, TODO(chenfeiyu): use dynamic batch size index = self._next_index() # may raise StopIteration, TODO(chenfeiyu): use dynamic batch size
minibatch = [self._dataset[i] for i in index] # we can abstract it, too to use dynamic batch size minibatch = [self._dataset[i] for i in index] # we can abstract it, too to use dynamic batch size
minibatch = self._dataset._batch_examples(minibatch) # list[Example] -> Batch minibatch = self.collate_fn(minibatch)
return minibatch return minibatch
def _next_index(self): def _next_index(self):

View File

@ -20,7 +20,7 @@ epochs: 10000
lr: 0.001 lr: 0.001
save_step: 500 save_step: 500
use_gpu: True use_gpu: True
use_data_parallel: False use_data_parallel: True
data_path: ../../../dataset/LJSpeech-1.1 data_path: ../../../dataset/LJSpeech-1.1
save_path: ./checkpoint save_path: ./checkpoint

View File

@ -21,7 +21,7 @@ lr: 0.001
save_step: 500 save_step: 500
image_step: 2000 image_step: 2000
use_gpu: True use_gpu: True
use_data_parallel: False use_data_parallel: True
data_path: ../../../dataset/LJSpeech-1.1 data_path: ../../../dataset/LJSpeech-1.1
save_path: ./checkpoint save_path: ./checkpoint

View File

@ -0,0 +1,29 @@
from pathlib import Path
import numpy as np
from paddle import fluid
from parakeet.data.sampler import DistributedSampler
from parakeet.data.datacargo import DataCargo
from preprocess import batch_examples, LJSpeech, batch_examples_postnet
class LJSpeechLoader:
def __init__(self, config, nranks, rank, is_postnet=False):
place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace()
LJSPEECH_ROOT = Path(config.data_path)
dataset = LJSpeech(LJSPEECH_ROOT)
sampler = DistributedSampler(len(dataset), nranks, rank)
assert config.batch_size % nranks == 0
each_bs = config.batch_size // nranks
if is_postnet:
dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=True, collate_fn=batch_examples_postnet, drop_last=True)
else:
dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=True, collate_fn=batch_examples, drop_last=True)
self.reader = fluid.io.DataLoader.from_generator(
capacity=32,
iterable=True,
use_double_buffer=True,
return_list=True)
self.reader.set_batch_generator(dataloader, place)

View File

@ -130,7 +130,7 @@ class EncoderPrenet(dg.Layer):
self.projection = FC(self.full_name(), num_hidden, num_hidden) self.projection = FC(self.full_name(), num_hidden, num_hidden)
def forward(self, x): def forward(self, x):
x = self.embedding(fluid.layers.unsqueeze(x, axes=[-1])) #(batch_size, seq_len, embending_size) x = self.embedding(x) #(batch_size, seq_len, embending_size)
x = layers.transpose(x,[0,2,1]) x = layers.transpose(x,[0,2,1])
x = layers.dropout(layers.relu(self.batch_norm1(self.conv1(x))), 0.2) x = layers.dropout(layers.relu(self.batch_norm1(self.conv1(x))), 0.2)
x = layers.dropout(layers.relu(self.batch_norm2(self.conv2(x))), 0.2) x = layers.dropout(layers.relu(self.batch_norm2(self.conv2(x))), 0.2)
@ -211,8 +211,9 @@ class ScaledDotProductAttention(dg.Layer):
# Mask key to ignore padding # Mask key to ignore padding
if mask is not None: if mask is not None:
attention = attention * mask attention = attention * mask
mask = (mask == 0).astype(float) * (-2 ** 32 + 1) mask = (mask == 0).astype(np.float32) * (-2 ** 32 + 1)
attention = attention + mask attention = attention + mask
attention = layers.softmax(attention) attention = layers.softmax(attention)
# Mask query to ignore padding # Mask query to ignore padding

View File

@ -7,9 +7,9 @@ class Encoder(dg.Layer):
def __init__(self, name_scope, embedding_size, num_hidden, config): def __init__(self, name_scope, embedding_size, num_hidden, config):
super(Encoder, self).__init__(name_scope) super(Encoder, self).__init__(name_scope)
self.num_hidden = num_hidden self.num_hidden = num_hidden
param = fluid.ParamAttr(name='alpha') param = fluid.ParamAttr(name='alpha',
self.alpha = self.create_parameter(param, shape=(1, ), dtype='float32', initializer=fluid.initializer.Constant(value=1.0))
default_initializer = fluid.initializer.ConstantInitializer(value=1.0)) self.alpha = self.create_parameter(param, shape=(1, ), dtype='float32')
self.pos_inp = get_sinusoid_encoding_table(1024, self.num_hidden, padding_idx=0) self.pos_inp = get_sinusoid_encoding_table(1024, self.num_hidden, padding_idx=0)
self.pos_emb = dg.Embedding(name_scope=self.full_name(), self.pos_emb = dg.Embedding(name_scope=self.full_name(),
size=[1024, num_hidden], size=[1024, num_hidden],
@ -31,8 +31,8 @@ class Encoder(dg.Layer):
def forward(self, x, positional): def forward(self, x, positional):
if fluid.framework._dygraph_tracer()._train_mode: if fluid.framework._dygraph_tracer()._train_mode:
query_mask = (positional != 0).astype(float) query_mask = (positional != 0).astype(np.float32)
mask = (positional != 0).astype(float) mask = (positional != 0).astype(np.float32)
mask = fluid.layers.expand(fluid.layers.unsqueeze(mask,[1]), [1,x.shape[1], 1]) mask = fluid.layers.expand(fluid.layers.unsqueeze(mask,[1]), [1,x.shape[1], 1])
else: else:
query_mask, mask = None, None query_mask, mask = None, None
@ -42,7 +42,7 @@ class Encoder(dg.Layer):
# Get positional encoding # Get positional encoding
positional = self.pos_emb(fluid.layers.unsqueeze(positional, axes=[-1])) positional = self.pos_emb(positional)
x = positional * self.alpha + x #(N, T, C) x = positional * self.alpha + x #(N, T, C)
@ -102,14 +102,14 @@ class Decoder(dg.Layer):
if fluid.framework._dygraph_tracer()._train_mode: if fluid.framework._dygraph_tracer()._train_mode:
#zeros = np.zeros(positional.shape, dtype=np.float32) #zeros = np.zeros(positional.shape, dtype=np.float32)
m_mask = (positional != 0).astype(float) m_mask = (positional != 0).astype(np.float32)
mask = np.repeat(np.expand_dims(m_mask.numpy() == 0, axis=1), decoder_len, axis=1) mask = np.repeat(np.expand_dims(m_mask.numpy() == 0, axis=1), decoder_len, axis=1)
mask = mask + np.repeat(np.expand_dims(np.triu(np.ones([decoder_len, decoder_len]), 1), axis=0) ,batch_size, axis=0) mask = mask + np.repeat(np.expand_dims(np.triu(np.ones([decoder_len, decoder_len]), 1), axis=0) ,batch_size, axis=0)
mask = fluid.layers.cast(dg.to_variable(mask == 0), np.float32) mask = fluid.layers.cast(dg.to_variable(mask == 0), np.float32)
# (batch_size, decoder_len, decoder_len) # (batch_size, decoder_len, decoder_len)
zero_mask = fluid.layers.expand(fluid.layers.unsqueeze((c_mask != 0).astype(float), axes=2), [1,1,decoder_len]) zero_mask = fluid.layers.expand(fluid.layers.unsqueeze((c_mask != 0).astype(np.float32), axes=2), [1,1,decoder_len])
# (batch_size, decoder_len, seq_len) # (batch_size, decoder_len, seq_len)
zero_mask = fluid.layers.transpose(zero_mask, [0,2,1]) zero_mask = fluid.layers.transpose(zero_mask, [0,2,1])
@ -125,7 +125,7 @@ class Decoder(dg.Layer):
query = self.linear(query) query = self.linear(query)
# Get position embedding # Get position embedding
positional = self.pos_emb(fluid.layers.unsqueeze(positional, axes=[-1])) positional = self.pos_emb(positional)
query = positional * self.alpha + query query = positional * self.alpha + query
#positional dropout #positional dropout

View File

@ -1,13 +1,12 @@
from network import * from network import *
from preprocess import batch_examples_postnet, LJSpeech
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
import os import os
from tqdm import tqdm from tqdm import tqdm
from parakeet.data.datacargo import DataCargo
from pathlib import Path from pathlib import Path
import jsonargparse import jsonargparse
from parse import add_config_options_to_parser from parse import add_config_options_to_parser
from pprint import pprint from pprint import pprint
from data import LJSpeechLoader
class MyDataParallel(dg.parallel.DataParallel): class MyDataParallel(dg.parallel.DataParallel):
""" """
@ -27,21 +26,15 @@ class MyDataParallel(dg.parallel.DataParallel):
object.__getattribute__(self, "_sub_layers")["_layers"], key) object.__getattribute__(self, "_sub_layers")["_layers"], key)
def main(): def main(cfg):
parser = jsonargparse.ArgumentParser(description="Train postnet model", formatter_class='default_argparse')
add_config_options_to_parser(parser) local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0
cfg = parser.parse_args('-c ./config/train_postnet.yaml'.split()) nranks = dg.parallel.Env().nranks if cfg.use_data_parallel else 1
local_rank = dg.parallel.Env().local_rank
if local_rank == 0: if local_rank == 0:
# Print the whole config setting. # Print the whole config setting.
pprint(jsonargparse.namespace_to_dict(cfg)) pprint(jsonargparse.namespace_to_dict(cfg))
LJSPEECH_ROOT = Path(cfg.data_path)
dataset = LJSpeech(LJSPEECH_ROOT)
dataloader = DataCargo(dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=batch_examples_postnet, drop_last=True)
global_step = 0 global_step = 0
place = (fluid.CUDAPlace(dg.parallel.Env().dev_id) place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
if cfg.use_data_parallel else fluid.CUDAPlace(0) if cfg.use_data_parallel else fluid.CUDAPlace(0)
@ -50,35 +43,10 @@ def main():
if not os.path.exists(cfg.log_dir): if not os.path.exists(cfg.log_dir):
os.mkdir(cfg.log_dir) os.mkdir(cfg.log_dir)
path = os.path.join(cfg.log_dir,'postnet') path = os.path.join(cfg.log_dir,'postnet')
writer = SummaryWriter(path)
with dg.guard(place): writer = SummaryWriter(path) if local_rank == 0 else None
# dataloader
input_fields = {
'names': ['mel', 'mag'],
'shapes':
[[cfg.batch_size, None, 80], [cfg.batch_size, None, 257]],
'dtypes': ['float32', 'float32'],
'lod_levels': [0, 0]
}
inputs = [ with dg.guard(place):
fluid.data(
name=input_fields['names'][i],
shape=input_fields['shapes'][i],
dtype=input_fields['dtypes'][i],
lod_level=input_fields['lod_levels'][i])
for i in range(len(input_fields['names']))
]
reader = fluid.io.DataLoader.from_generator(
feed_list=inputs,
capacity=32,
iterable=True,
use_double_buffer=True,
return_list=True)
model = ModelPostNet('postnet', cfg) model = ModelPostNet('postnet', cfg)
model.train() model.train()
@ -94,9 +62,10 @@ def main():
strategy = dg.parallel.prepare_context() strategy = dg.parallel.prepare_context()
model = MyDataParallel(model, strategy) model = MyDataParallel(model, strategy)
reader = LJSpeechLoader(cfg, nranks, local_rank, is_postnet=True).reader()
for epoch in range(cfg.epochs): for epoch in range(cfg.epochs):
reader.set_batch_generator(dataloader, place) pbar = tqdm(reader)
pbar = tqdm(reader())
for i, data in enumerate(pbar): for i, data in enumerate(pbar):
pbar.set_description('Processing at epoch %d'%epoch) pbar.set_description('Processing at epoch %d'%epoch)
mel, mag = data mel, mag = data
@ -109,27 +78,30 @@ def main():
loss = layers.mean(layers.abs(layers.elementwise_sub(mag_pred, mag))) loss = layers.mean(layers.abs(layers.elementwise_sub(mag_pred, mag)))
if cfg.use_data_parallel: if cfg.use_data_parallel:
loss = model.scale_loss(loss) loss = model.scale_loss(loss)
loss.backward()
writer.add_scalars('training_loss',{
'loss':loss.numpy(),
}, global_step)
loss.backward()
if cfg.use_data_parallel:
model.apply_collective_grads() model.apply_collective_grads()
else:
loss.backward()
optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(1)) optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(1))
model.clear_gradients() model.clear_gradients()
if global_step % cfg.save_step == 0:
if not os.path.exists(cfg.save_path):
os.mkdir(cfg.save_path)
save_path = os.path.join(cfg.save_path,'postnet/%d' % global_step)
dg.save_dygraph(model.state_dict(), save_path)
dg.save_dygraph(optimizer.state_dict(), save_path)
if local_rank==0:
writer.add_scalars('training_loss',{
'loss':loss.numpy(),
}, global_step)
if global_step % cfg.save_step == 0:
if not os.path.exists(cfg.save_path):
os.mkdir(cfg.save_path)
save_path = os.path.join(cfg.save_path,'postnet/%d' % global_step)
dg.save_dygraph(model.state_dict(), save_path)
dg.save_dygraph(optimizer.state_dict(), save_path)
if local_rank==0:
writer.close()
if __name__ == '__main__': if __name__ == '__main__':
main() parser = jsonargparse.ArgumentParser(description="Train postnet model", formatter_class='default_argparse')
add_config_options_to_parser(parser)
cfg = parser.parse_args('-c ./config/train_postnet.yaml'.split())
main(cfg)

View File

@ -1,16 +1,15 @@
from preprocess import batch_examples, LJSpeech
import os import os
from tqdm import tqdm from tqdm import tqdm
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from network import * from network import *
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from parakeet.data.datacargo import DataCargo
from pathlib import Path from pathlib import Path
import jsonargparse import jsonargparse
from parse import add_config_options_to_parser from parse import add_config_options_to_parser
from pprint import pprint from pprint import pprint
from matplotlib import cm from matplotlib import cm
from data import LJSpeechLoader
class MyDataParallel(dg.parallel.DataParallel): class MyDataParallel(dg.parallel.DataParallel):
""" """
@ -30,21 +29,14 @@ class MyDataParallel(dg.parallel.DataParallel):
object.__getattribute__(self, "_sub_layers")["_layers"], key) object.__getattribute__(self, "_sub_layers")["_layers"], key)
def main(): def main(cfg):
parser = jsonargparse.ArgumentParser(description="Train TransformerTTS model", formatter_class='default_argparse') local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0
add_config_options_to_parser(parser) nranks = dg.parallel.Env().nranks if cfg.use_data_parallel else 1
cfg = parser.parse_args('-c ./config/train_transformer.yaml'.split())
local_rank = dg.parallel.Env().local_rank
if local_rank == 0: if local_rank == 0:
# Print the whole config setting. # Print the whole config setting.
pprint(jsonargparse.namespace_to_dict(cfg)) pprint(jsonargparse.namespace_to_dict(cfg))
LJSPEECH_ROOT = Path(cfg.data_path)
dataset = LJSpeech(LJSPEECH_ROOT)
dataloader = DataCargo(dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=batch_examples, drop_last=True)
global_step = 0 global_step = 0
place = (fluid.CUDAPlace(dg.parallel.Env().dev_id) place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
if cfg.use_data_parallel else fluid.CUDAPlace(0) if cfg.use_data_parallel else fluid.CUDAPlace(0)
@ -57,39 +49,13 @@ def main():
writer = SummaryWriter(path) if local_rank == 0 else None writer = SummaryWriter(path) if local_rank == 0 else None
with dg.guard(place): with dg.guard(place):
if cfg.use_data_parallel:
strategy = dg.parallel.prepare_context()
# dataloader
input_fields = {
'names': ['character', 'mel', 'mel_input', 'pos_text', 'pos_mel', 'text_len'],
'shapes':
[[cfg.batch_size, None], [cfg.batch_size, None, 80], [cfg.batch_size, None, 80], [cfg.batch_size, 1], [cfg.batch_size, 1], [cfg.batch_size, 1]],
'dtypes': ['float32', 'float32', 'float32', 'int64', 'int64', 'int64'],
'lod_levels': [0, 0, 0, 0, 0, 0]
}
inputs = [
fluid.data(
name=input_fields['names'][i],
shape=input_fields['shapes'][i],
dtype=input_fields['dtypes'][i],
lod_level=input_fields['lod_levels'][i])
for i in range(len(input_fields['names']))
]
reader = fluid.io.DataLoader.from_generator(
feed_list=inputs,
capacity=32,
iterable=True,
use_double_buffer=True,
return_list=True)
model = Model('transtts', cfg) model = Model('transtts', cfg)
model.train() model.train()
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(4000 *( cfg.lr ** 2)), 4000)) optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(4000 *( cfg.lr ** 2)), 4000))
reader = LJSpeechLoader(cfg, nranks, local_rank).reader()
if cfg.checkpoint_path is not None: if cfg.checkpoint_path is not None:
model_dict, opti_dict = fluid.dygraph.load_dygraph(cfg.checkpoint_path) model_dict, opti_dict = fluid.dygraph.load_dygraph(cfg.checkpoint_path)
model.set_dict(model_dict) model.set_dict(model_dict)
@ -97,11 +63,11 @@ def main():
print("load checkpoint!!!") print("load checkpoint!!!")
if cfg.use_data_parallel: if cfg.use_data_parallel:
strategy = dg.parallel.prepare_context()
model = MyDataParallel(model, strategy) model = MyDataParallel(model, strategy)
for epoch in range(cfg.epochs): for epoch in range(cfg.epochs):
reader.set_batch_generator(dataloader, place) pbar = tqdm(reader)
pbar = tqdm(reader())
for i, data in enumerate(pbar): for i, data in enumerate(pbar):
pbar.set_description('Processing at epoch %d'%epoch) pbar.set_description('Processing at epoch %d'%epoch)
character, mel, mel_input, pos_text, pos_mel, text_length = data character, mel, mel_input, pos_text, pos_mel, text_length = data
@ -114,40 +80,41 @@ def main():
post_mel_loss = layers.mean(layers.abs(layers.elementwise_sub(postnet_pred, mel))) post_mel_loss = layers.mean(layers.abs(layers.elementwise_sub(postnet_pred, mel)))
loss = mel_loss + post_mel_loss loss = mel_loss + post_mel_loss
if local_rank==0:
writer.add_scalars('training_loss', {
'mel_loss':mel_loss.numpy(),
'post_mel_loss':post_mel_loss.numpy(),
}, global_step)
writer.add_scalars('alphas', {
'encoder_alpha':model.encoder.alpha.numpy(),
'decoder_alpha':model.decoder.alpha.numpy(),
}, global_step)
writer.add_scalar('learning_rate', optimizer._learning_rate.step().numpy(), global_step)
if global_step % cfg.image_step == 1:
for i, prob in enumerate(attn_probs):
for j in range(4):
x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255)
writer.add_image('Attention_enc_%d_0'%global_step, x, i*4+j, dataformats="HWC")
for i, prob in enumerate(attn_enc):
for j in range(4):
x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255)
writer.add_image('Attention_enc_%d_0'%global_step, x, i*4+j, dataformats="HWC")
for i, prob in enumerate(attn_dec):
for j in range(4):
x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255)
writer.add_image('Attention_dec_%d_0'%global_step, x, i*4+j, dataformats="HWC")
if cfg.use_data_parallel: if cfg.use_data_parallel:
loss = model.scale_loss(loss) loss = model.scale_loss(loss)
loss.backward()
writer.add_scalars('training_loss', {
'mel_loss':mel_loss.numpy(),
'post_mel_loss':post_mel_loss.numpy(),
}, global_step)
writer.add_scalars('alphas', {
'encoder_alpha':model.encoder.alpha.numpy(),
'decoder_alpha':model.decoder.alpha.numpy(),
}, global_step)
writer.add_scalar('learning_rate', optimizer._learning_rate.step().numpy(), global_step)
if global_step % cfg.image_step == 1:
for i, prob in enumerate(attn_probs):
for j in range(4):
x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255)
writer.add_image('Attention_enc_%d_0'%global_step, x, i*4+j, dataformats="HWC")
for i, prob in enumerate(attn_enc):
for j in range(4):
x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255)
writer.add_image('Attention_enc_%d_0'%global_step, x, i*4+j, dataformats="HWC")
for i, prob in enumerate(attn_dec):
for j in range(4):
x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255)
writer.add_image('Attention_dec_%d_0'%global_step, x, i*4+j, dataformats="HWC")
loss.backward()
if cfg.use_data_parallel:
model.apply_collective_grads() model.apply_collective_grads()
else:
loss.backward()
optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(1)) optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(1))
model.clear_gradients() model.clear_gradients()
@ -163,4 +130,7 @@ def main():
if __name__ =='__main__': if __name__ =='__main__':
main() parser = jsonargparse.ArgumentParser(description="Train TransformerTTS model", formatter_class='default_argparse')
add_config_options_to_parser(parser)
cfg = parser.parse_args('-c ./config/train_transformer.yaml'.split())
main(cfg)