Merge pull request #61 from iclementine/reborn

add examples: transformer_tts, waveflow, wavenet
This commit is contained in:
Feiyu Chan 2020-12-18 19:53:23 +08:00 committed by GitHub
commit 949dfa2f3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1481 additions and 0 deletions

View File

@ -0,0 +1,55 @@
from yacs.config import CfgNode as CN
_C = CN()
_C.data = CN(
dict(
batch_size=16, # batch size
valid_size=64, # the first N examples are reserved for validation
sample_rate=22050, # Hz, sample rate
n_fft=1024, # fft frame size
win_length=1024, # window size
hop_length=256, # hop size between ajacent frame
f_max=8000, # Hz, max frequency when converting to mel
d_mel=80, # mel bands
padding_idx=0, # text embedding's padding index
mel_start_value=0.5, # value for starting frame
mel_end_value=-0.5, # # value for ending frame
)
)
_C.model = CN(
dict(
d_encoder=512, # embedding & encoder's internal size
d_decoder=256, # decoder's internal size
n_heads=4, # actually it can differ at each layer
d_ffn=1024, # encoder_d_ffn & decoder_d_ffn
encoder_layers=4, # number of transformer encoder layer
decoder_layers=4, # number of transformer decoder layer
d_prenet=256, # decprenet's hidden size (d_mel=>d_prenet=>d_decoder)
d_postnet=256, # decoder postnet(cnn)'s internal channel
postnet_layers=5, # decoder postnet(cnn)'s layer
postnet_kernel_size=5, # decoder postnet(cnn)'s kernel size
max_reduction_factor=10, # max_reduction factor
dropout=0.1, # global droput probability
stop_loss_scale=8.0, # scaler for stop _loss
decoder_prenet_dropout=0.5, # decoder prenet dropout probability
)
)
_C.training = CN(
dict(
lr=1e-4, # learning rate
drop_n_heads=[[0, 0], [15000, 1]],
reduction_factor=[[0, 10], [80000, 4], [200000, 2]],
plot_interval=1000, # plot attention and spectrogram
valid_interval=1000, # validation
save_interval=10000, # checkpoint
max_iteration=900000, # max iteration to train
)
)
def get_cfg_defaults():
"""Get a yacs CfgNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
return _C.clone()

View File

@ -0,0 +1,88 @@
import os
from pathlib import Path
import pickle
import numpy as np
from paddle.io import Dataset, DataLoader
from parakeet.data.batch import batch_spec, batch_text_id
from parakeet.data import dataset
class LJSpeech(Dataset):
"""A simple dataset adaptor for the processed ljspeech dataset."""
def __init__(self, root):
self.root = Path(root).expanduser()
records = []
with open(self.root / "metadata.pkl", 'rb') as f:
metadata = pickle.load(f)
for mel_name, text, phonemes, ids in metadata:
mel_name = self.root / "mel" / (mel_name + ".npy")
records.append((mel_name, text, phonemes, ids))
self.records = records
def __getitem__(self, i):
mel_name, _, _, ids = self.records[i]
mel = np.load(mel_name)
return ids, mel
def __len__(self):
return len(self.records)
# decorate mel & create stop probability
class Transform(object):
def __init__(self, start_value, end_value):
self.start_value = start_value
self.end_value = end_value
def __call__(self, example):
ids, mel = example # ids already have <s> and </s>
ids = np.array(ids, dtype=np.int64)
# add start and end frame
mel = np.pad(mel,
[(0, 0), (1, 1)],
mode='constant',
constant_values=[(0, 0), (self.start_value, self.end_value)])
stop_labels = np.ones([mel.shape[1]], dtype=np.int64)
stop_labels[-1] = 2
# actually this thing can also be done within the model
return ids, mel, stop_labels
class LJSpeechCollector(object):
"""A simple callable to batch LJSpeech examples."""
def __init__(self, padding_idx=0, padding_value=0.):
self.padding_idx = padding_idx
self.padding_value = padding_value
def __call__(self, examples):
ids = [example[0] for example in examples]
mels = [example[1] for example in examples]
stop_probs = [example[2] for example in examples]
ids = batch_text_id(ids, pad_id=self.padding_idx)
mels = batch_spec(mels, pad_value=self.padding_value)
stop_probs = batch_text_id(stop_probs, pad_id=self.padding_idx)
return ids, np.transpose(mels, [0, 2, 1]), stop_probs
def create_dataloader(config, source_path):
lj = LJSpeech(source_path)
transform = Transform(config.data.mel_start_value, config.data.mel_end_value)
lj = dataset.TransformDataset(lj, transform)
valid_set, train_set = dataset.split(lj, config.data.valid_size)
data_collator = LJSpeechCollector(padding_idx=config.data.padding_idx)
train_loader = DataLoader(
train_set,
batch_size=config.data.batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator)
valid_loader = DataLoader(
valid_set,
batch_size=config.data.batch_size,
shuffle=False,
drop_last=False,
collate_fn=data_collator)
return train_loader, valid_loader

View File

@ -0,0 +1,82 @@
import os
import tqdm
import pickle
import argparse
import numpy as np
from pathlib import Path
from parakeet.datasets import LJSpeechMetaData
from parakeet.audio import AudioProcessor, LogMagnitude
from parakeet.frontend import English
from config import get_cfg_defaults
def create_dataset(config, source_path, target_path, verbose=False):
# create output dir
target_path = Path(target_path).expanduser()
mel_path = target_path / "mel"
os.makedirs(mel_path, exist_ok=True)
meta_data = LJSpeechMetaData(source_path)
frontend = English()
processor = AudioProcessor(
sample_rate=config.data.sample_rate,
n_fft=config.data.n_fft,
n_mels=config.data.d_mel,
win_length=config.data.win_length,
hop_length=config.data.hop_length,
f_max=config.data.f_max)
normalizer = LogMagnitude()
records = []
for (fname, text, _) in tqdm.tqdm(meta_data):
wav = processor.read_wav(fname)
mel = processor.mel_spectrogram(wav)
mel = normalizer.transform(mel)
phonemes = frontend.phoneticize(text)
ids = frontend.numericalize(phonemes)
mel_name = os.path.splitext(os.path.basename(fname))[0]
# save mel spectrogram
records.append((mel_name, text, phonemes, ids))
np.save(mel_path / mel_name, mel)
if verbose:
print("save mel spectrograms into {}".format(mel_path))
# save meta data as pickle archive
with open(target_path / "metadata.pkl", 'wb') as f:
pickle.dump(records, f)
if verbose:
print("saved metadata into {}".format(target_path / "metadata.pkl"))
# also save meta data into text format for inspection
with open(target_path / "metadata.txt", 'wt') as f:
for mel_name, text, phonemes, _ in records:
phoneme_str = "|".join(phonemes)
f.write("{}\t{}\t{}\n".format(mel_name, text, phoneme_str))
if verbose:
print("saved metadata into {}".format(target_path / "metadata.txt"))
print("Done.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="create dataset")
parser.add_argument("--config", type=str, metavar="FILE", help="extra config to overwrite the default config")
parser.add_argument("--input", type=str, help="path of the ljspeech dataset")
parser.add_argument("--output", type=str, help="path to save output dataset")
parser.add_argument("--opts", nargs=argparse.REMAINDER,
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs"
)
parser.add_argument("-v", "--verbose", action="store_true", help="print msg")
config = get_cfg_defaults()
args = parser.parse_args()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config.data)
create_dataset(config, args.input, args.output, args.verbose)

View File

@ -0,0 +1,64 @@
import argparse
import time
from pathlib import Path
import numpy as np
import paddle
import parakeet
from parakeet.frontend import English
from parakeet.models.transformer_tts import TransformerTTS
from parakeet.utils import scheduler
from parakeet.training.cli import default_argument_parser
from parakeet.utils.display import add_attention_plots
from config import get_cfg_defaults
@paddle.fluid.dygraph.no_grad
def main(config, args):
paddle.set_device(args.device)
# model
frontend = English()
model = TransformerTTS.from_pretrained(
frontend, config, args.checkpoint_path)
model.eval()
# inputs
input_path = Path(args.input).expanduser()
with open(input_path, "rt") as f:
sentences = f.readlines()
output_dir = Path(args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
for i, sentence in enumerate(sentences):
outputs = model.predict(sentence, verbose=args.verbose)
mel_output = outputs["mel_output"]
# cross_attention_weights = outputs["cross_attention_weights"]
mel_output = mel_output.T #(C, T)
np.save(str(output_dir / f"sentence_{i}"), mel_output)
if args.verbose:
print("spectrogram saved at {}".format(output_dir / f"sentence_{i}.npy"))
if __name__ == "__main__":
config = get_cfg_defaults()
parser = argparse.ArgumentParser(description="generate mel spectrogram with TransformerTTS.")
parser.add_argument("--config", type=str, metavar="FILE", help="extra config to overwrite the default config")
parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load.")
parser.add_argument("--input", type=str, help="path of the text sentences")
parser.add_argument("--output", type=str, help="path to save outputs")
parser.add_argument("--device", type=str, default="cpu", help="device type to use.")
parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
parser.add_argument("-v", "--verbose", action="store_true", help="print msg")
args = parser.parse_args()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
print(args)
main(config, args)

View File

@ -0,0 +1,202 @@
import time
import logging
from pathlib import Path
import numpy as np
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader, DistributedBatchSampler
from tensorboardX import SummaryWriter
from collections import defaultdict
import parakeet
from parakeet.data import dataset
from parakeet.frontend import English
from parakeet.models.transformer_tts import TransformerTTS, TransformerTTSLoss
from parakeet.utils import scheduler, checkpoint, mp_tools, display
from parakeet.training.cli import default_argument_parser
from parakeet.training.experiment import ExperimentBase
from config import get_cfg_defaults
from ljspeech import LJSpeech, LJSpeechCollector, Transform
class Experiment(ExperimentBase):
def setup_model(self):
config = self.config
frontend = English()
model = TransformerTTS(
frontend,
d_encoder=config.model.d_encoder,
d_decoder=config.model.d_decoder,
d_mel=config.data.d_mel,
n_heads=config.model.n_heads,
d_ffn=config.model.d_ffn,
encoder_layers=config.model.encoder_layers,
decoder_layers=config.model.decoder_layers,
d_prenet=config.model.d_prenet,
d_postnet=config.model.d_postnet,
postnet_layers=config.model.postnet_layers,
postnet_kernel_size=config.model.postnet_kernel_size,
max_reduction_factor=config.model.max_reduction_factor,
decoder_prenet_dropout=config.model.decoder_prenet_dropout,
dropout=config.model.dropout)
if self.parallel:
model = paddle.DataParallel(model)
optimizer = paddle.optimizer.Adam(
learning_rate=config.training.lr,
beta1=0.9,
beta2=0.98,
epsilon=1e-9,
parameters=model.parameters()
)
criterion = TransformerTTSLoss(config.model.stop_loss_scale)
drop_n_heads = scheduler.StepWise(config.training.drop_n_heads)
reduction_factor = scheduler.StepWise(config.training.reduction_factor)
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.drop_n_heads = drop_n_heads
self.reduction_factor = reduction_factor
def setup_dataloader(self):
args = self.args
config = self.config
ljspeech_dataset = LJSpeech(args.data)
transform = Transform(config.data.mel_start_value, config.data.mel_end_value)
ljspeech_dataset = dataset.TransformDataset(ljspeech_dataset, transform)
valid_set, train_set = dataset.split(ljspeech_dataset, config.data.valid_size)
batch_fn = LJSpeechCollector(padding_idx=config.data.padding_idx)
if not self.parallel:
train_loader = DataLoader(
train_set,
batch_size=config.data.batch_size,
shuffle=True,
drop_last=True,
collate_fn=batch_fn)
else:
sampler = DistributedBatchSampler(
train_set,
batch_size=config.data.batch_size,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=True,
drop_last=True)
train_loader = DataLoader(
train_set, batch_sampler=sampler, collate_fn=batch_fn)
valid_loader = DataLoader(
valid_set, batch_size=config.data.batch_size, collate_fn=batch_fn)
self.train_loader = train_loader
self.valid_loader = valid_loader
def compute_outputs(self, text, mel, stop_label):
model_core = self.model._layers if self.parallel else self.model
model_core.set_constants(
self.reduction_factor(self.iteration),
self.drop_n_heads(self.iteration))
# TODO(chenfeiyu): we can combine these 2 slices
mel_input = mel[:,:-1, :]
reduced_mel_input = mel_input[:, ::model_core.r, :]
outputs = self.model(text, reduced_mel_input)
return outputs
def compute_losses(self, inputs, outputs):
_, mel, stop_label = inputs
mel_target = mel[:, 1:, :]
stop_label_target = stop_label[:, 1:]
mel_output = outputs["mel_output"]
mel_intermediate = outputs["mel_intermediate"]
stop_logits = outputs["stop_logits"]
time_steps = mel_target.shape[1]
losses = self.criterion(
mel_output[:,:time_steps, :],
mel_intermediate[:,:time_steps, :],
mel_target,
stop_logits[:,:time_steps, :],
stop_label_target)
return losses
def train_batch(self):
start = time.time()
batch = self.read_batch()
data_loader_time = time.time() - start
self.optimizer.clear_grad()
self.model.train()
text, mel, stop_label = batch
outputs = self.compute_outputs(text, mel, stop_label)
losses = self.compute_losses(batch, outputs)
loss = losses["loss"]
loss.backward()
self.optimizer.step()
iteration_time = time.time() - start
losses_np = {k: float(v) for k, v in losses.items()}
# logging
msg = "Rank: {}, ".format(dist.get_rank())
msg += "step: {}, ".format(self.iteration)
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time, iteration_time)
msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_np.items())
self.logger.info(msg)
if dist.get_rank() == 0:
for k, v in losses_np.items():
self.visualizer.add_scalar(f"train_loss/{k}", v, self.iteration)
@mp_tools.rank_zero_only
@paddle.no_grad()
def valid(self):
valid_losses = defaultdict(list)
for i, batch in enumerate(self.valid_loader):
text, mel, stop_label = batch
outputs = self.compute_outputs(text, mel, stop_label)
losses = self.compute_losses(batch, outputs)
for k, v in losses.items():
valid_losses[k].append(float(v))
if i < 2:
attention_weights = outputs["cross_attention_weights"]
display.add_multi_attention_plots(
self.visualizer,
f"valid_sentence_{i}_cross_attention_weights",
attention_weights,
self.iteration)
# write visual log
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
for k, v in valid_losses.items():
self.visualizer.add_scalar(f"valid/{k}", v, self.iteration)
def main_sp(config, args):
exp = Experiment(config, args)
exp.setup()
exp.run()
def main(config, args):
if args.nprocs > 1 and args.device == "gpu":
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
else:
main_sp(config, args)
if __name__ == "__main__":
config = get_cfg_defaults()
parser = default_argument_parser()
args = parser.parse_args()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
print(args)
main(config, args)

View File

@ -0,0 +1,43 @@
from yacs.config import CfgNode as CN
_C = CN()
_C.data = CN(
dict(
batch_size=8, # batch size
valid_size=16, # the first N examples are reserved for validation
sample_rate=22050, # Hz, sample rate
n_fft=1024, # fft frame size
win_length=1024, # window size
hop_length=256, # hop size between ajacent frame
f_max=8000, # Hz, max frequency when converting to mel
n_mels=80, # mel bands
clip_frames=65, # mel clip frames
)
)
_C.model = CN(
dict(
upsample_factors=[16, 16],
n_flows=8, # number of flows in WaveFlow
n_layers=8, # number of conv block in each flow
n_group=16, # folding factor of audio and spectrogram
channels=128, # resiaudal channel in each flow
kernel_size=[3, 3], # kernel size in each conv block
sigma=1.0, # stddev of the random noise
)
)
_C.training = CN(
dict(
lr=2e-4, # learning rates
valid_interval=1000, # validation
save_interval=10000, # checkpoint
max_iteration=3000000, # max iteration to train
)
)
def get_cfg_defaults():
"""Get a yacs CfgNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
return _C.clone()

View File

@ -0,0 +1,78 @@
import os
from pathlib import Path
import pickle
import numpy as np
import pandas
from paddle.io import Dataset, DataLoader
from parakeet.data.batch import batch_spec, batch_wav
from parakeet.data import dataset
from parakeet.audio import AudioProcessor
class LJSpeech(Dataset):
"""A simple dataset adaptor for the processed ljspeech dataset."""
def __init__(self, root):
self.root = Path(root).expanduser()
meta_data = pandas.read_csv(
str(self.root / "metadata.csv"),
sep="\t",
header=None,
names=["fname", "frames", "samples"]
)
records = []
for row in meta_data.itertuples() :
mel_path = str(self.root / "mel" / (row.fname + ".npy"))
wav_path = str(self.root / "wav" / (row.fname + ".npy"))
records.append((mel_path, wav_path))
self.records = records
def __getitem__(self, i):
mel_name, wav_name = self.records[i]
mel = np.load(mel_name)
wav = np.load(wav_name)
return mel, wav
def __len__(self):
return len(self.records)
class LJSpeechCollector(object):
"""A simple callable to batch LJSpeech examples."""
def __init__(self, padding_value=0.):
self.padding_value = padding_value
def __call__(self, examples):
mels = [example[0] for example in examples]
wavs = [example[1] for example in examples]
mels = batch_spec(mels, pad_value=self.padding_value)
wavs = batch_wav(wavs, pad_value=self.padding_value)
return mels, wavs
class LJSpeechClipCollector(object):
def __init__(self, clip_frames=65, hop_length=256):
self.clip_frames = clip_frames
self.hop_length = hop_length
def __call__(self, examples):
mels = []
wavs = []
for example in examples:
mel_clip, wav_clip = self.clip(example)
mels.append(mel_clip)
wavs.append(wav_clip)
mels = np.stack(mels)
wavs = np.stack(wavs)
return mels, wavs
def clip(self, example):
mel, wav = example
frames = mel.shape[-1]
start = np.random.randint(0, frames - self.clip_frames)
mel_clip = mel[:, start: start + self.clip_frames]
wav_clip = wav[start * self.hop_length: (start + self.clip_frames) * self.hop_length]
return mel_clip, wav_clip

View File

@ -0,0 +1,138 @@
import os
import tqdm
import csv
import argparse
import numpy as np
import librosa
from pathlib import Path
import pandas as pd
from paddle.io import Dataset
from parakeet.data import batch_spec, batch_wav
from parakeet.datasets import LJSpeechMetaData
from parakeet.audio import AudioProcessor, LogMagnitude
from config import get_cfg_defaults
class Transform(object):
def __init__(self, sample_rate, n_fft, win_length, hop_length, n_mels):
self.sample_rate = sample_rate
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.n_mels = n_mels
self.spec_normalizer = LogMagnitude(min=1e-5)
def __call__(self, example):
wav_path, _, _ = example
sr = self.sample_rate
n_fft = self.n_fft
win_length = self.win_length
hop_length = self.hop_length
n_mels = self.n_mels
wav, loaded_sr = librosa.load(wav_path, sr=None)
assert loaded_sr == sr, "sample rate does not match, resampling applied"
# Pad audio to the right size.
frames = int(np.ceil(float(wav.size) / hop_length))
fft_padding = (n_fft - hop_length) // 2 # sound
desired_length = frames * hop_length + fft_padding * 2
pad_amount = (desired_length - wav.size) // 2
if wav.size % 2 == 0:
wav = np.pad(wav, (pad_amount, pad_amount), mode='reflect')
else:
wav = np.pad(wav, (pad_amount, pad_amount + 1), mode='reflect')
# Normalize audio.
wav = wav / np.abs(wav).max() * 0.999
# Compute mel-spectrogram.
# Turn center to False to prevent internal padding.
spectrogram = librosa.core.stft(
wav,
hop_length=hop_length,
win_length=win_length,
n_fft=n_fft,
center=False)
spectrogram_magnitude = np.abs(spectrogram)
# Compute mel-spectrograms.
mel_filter_bank = librosa.filters.mel(sr=sr,
n_fft=n_fft,
n_mels=n_mels)
mel_spectrogram = np.dot(mel_filter_bank, spectrogram_magnitude)
mel_spectrogram = mel_spectrogram
# log scale mel_spectrogram.
mel_spectrogram = self.spec_normalizer.transform(mel_spectrogram)
# Extract the center of audio that corresponds to mel spectrograms.
audio = wav[fft_padding:-fft_padding]
assert mel_spectrogram.shape[1] * hop_length == audio.size
# there is no clipping here
return audio, mel_spectrogram
def create_dataset(config, input_dir, output_dir, verbose=True):
input_dir = Path(input_dir).expanduser()
dataset = LJSpeechMetaData(input_dir)
output_dir = Path(output_dir).expanduser()
output_dir.mkdir(exist_ok=True)
transform = Transform(
config.sample_rate,
config.n_fft,
config.win_length,
config.hop_length,
config.n_mels)
file_names = []
for example in tqdm.tqdm(dataset):
fname, _, _ = example
base_name = os.path.splitext(os.path.basename(fname))[0]
wav_dir = output_dir / "wav"
mel_dir = output_dir / "mel"
wav_dir.mkdir(exist_ok=True)
mel_dir.mkdir(exist_ok=True)
audio, mel = transform(example)
np.save(str(wav_dir / base_name), audio)
np.save(str(mel_dir / base_name), mel)
file_names.append((base_name, mel.shape[-1], audio.shape[-1]))
meta_data = pd.DataFrame.from_records(file_names)
meta_data.to_csv(str(output_dir / "metadata.csv"), sep="\t", index=None, header=None)
print("saved meta data in to {}".format(os.path.join(output_dir, "metadata.csv")))
print("Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="create dataset")
parser.add_argument("--config", type=str, metavar="FILE", help="extra config to overwrite the default config")
parser.add_argument("--input", type=str, help="path of the ljspeech dataset")
parser.add_argument("--output", type=str, help="path to save output dataset")
parser.add_argument("--opts", nargs=argparse.REMAINDER,
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs"
)
parser.add_argument("-v", "--verbose", action="store_true", help="print msg")
config = get_cfg_defaults()
args = parser.parse_args()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
if args.verbose:
print(config.data)
print(args)
create_dataset(config.data, args.input, args.output, args.verbose)

View File

@ -0,0 +1,52 @@
import argparse
import numpy as np
import soundfile as sf
import os
from pathlib import Path
import paddle
import parakeet
from parakeet.models.waveflow import UpsampleNet, WaveFlow, ConditionalWaveFlow
from parakeet.utils import layer_tools, checkpoint
from config import get_cfg_defaults
def main(config, args):
paddle.set_device(args.device)
model = ConditionalWaveFlow.from_pretrained(config, args.checkpoint_path)
layer_tools.recursively_remove_weight_norm(model)
model.eval()
mel_dir = Path(args.input).expanduser()
output_dir = Path(args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
for file_path in mel_dir.iterdir():
mel = np.load(str(file_path))
audio = model.predict(mel)
audio_path = output_dir / (os.path.splitext(file_path.name)[0] + ".wav")
sf.write(audio_path, audio, config.data.sample_rate)
print("[synthesize] {} -> {}".format(file_path, audio_path))
if __name__ == "__main__":
config = get_cfg_defaults()
parser = argparse.ArgumentParser(description="generate mel spectrogram with TransformerTTS.")
parser.add_argument("--config", type=str, metavar="FILE", help="extra config to overwrite the default config")
parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load.")
parser.add_argument("--input", type=str, help="path of directory containing mel spectrogram (in .npy format)")
parser.add_argument("--output", type=str, help="path to save outputs")
parser.add_argument("--device", type=str, default="cpu", help="device type to use.")
parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
parser.add_argument("-v", "--verbose", action="store_true", help="print msg")
args = parser.parse_args()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
print(args)
main(config, args)

147
examples/waveflow/train.py Normal file
View File

@ -0,0 +1,147 @@
import time
from pathlib import Path
import numpy as np
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader, DistributedBatchSampler
from tensorboardX import SummaryWriter
from collections import defaultdict
import parakeet
from parakeet.data import dataset
from parakeet.models.waveflow import UpsampleNet, WaveFlow, ConditionalWaveFlow, WaveFlowLoss
from parakeet.audio import AudioProcessor
from parakeet.utils import scheduler, mp_tools
from parakeet.training.cli import default_argument_parser
from parakeet.training.experiment import ExperimentBase
from parakeet.utils.mp_tools import rank_zero_only
from config import get_cfg_defaults
from ljspeech import LJSpeech, LJSpeechClipCollector, LJSpeechCollector
class Experiment(ExperimentBase):
def setup_model(self):
config = self.config
model = ConditionalWaveFlow(
upsample_factors=config.model.upsample_factors,
n_flows=config.model.n_flows,
n_layers=config.model.n_layers,
n_group=config.model.n_group,
channels=config.model.channels,
n_mels=config.data.n_mels,
kernel_size=config.model.kernel_size)
if self.parallel > 1:
model = paddle.DataParallel(model)
optimizer = paddle.optimizer.Adam(config.training.lr, parameters=model.parameters())
criterion = WaveFlowLoss(sigma=config.model.sigma)
self.model = model
self.optimizer = optimizer
self.criterion = criterion
def setup_dataloader(self):
config = self.config
args = self.args
ljspeech_dataset = LJSpeech(args.data)
valid_set, train_set = dataset.split(ljspeech_dataset, config.data.valid_size)
batch_fn = LJSpeechClipCollector(config.data.clip_frames, config.data.hop_length)
if not self.parallel:
train_loader = DataLoader(
train_set,
batch_size=config.data.batch_size,
shuffle=True,
drop_last=True,
collate_fn=batch_fn)
else:
sampler = DistributedBatchSampler(
train_set,
batch_size=config.data.batch_size,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=True,
drop_last=True)
train_loader = DataLoader(
train_set, batch_sampler=sampler, collate_fn=batch_fn)
valid_batch_fn = LJSpeechCollector()
valid_loader = DataLoader(
valid_set, batch_size=1, collate_fn=valid_batch_fn)
self.train_loader = train_loader
self.valid_loader = valid_loader
def compute_outputs(self, mel, wav):
# model_core = model._layers if isinstance(model, paddle.DataParallel) else model
z, log_det_jocobian = self.model(wav, mel)
return z, log_det_jocobian
def compute_losses(self, outputs):
loss = self.criterion(outputs)
return loss
def train_batch(self):
start = time.time()
batch = self.read_batch()
data_loader_time = time.time() - start
self.model.train()
self.optimizer.clear_grad()
mel, wav = batch
outputs = self.compute_outputs(mel, wav)
loss = self.compute_losses(outputs)
loss.backward()
self.optimizer.step()
iteration_time = time.time() - start
loss_value = float(loss)
msg = "Rank: {}, ".format(dist.get_rank())
msg += "step: {}, ".format(self.iteration)
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time, iteration_time)
msg += "loss: {:>.6f}".format(loss_value)
self.logger.info(msg)
self.visualizer.add_scalar("train/loss", loss_value, global_step=self.iteration)
@mp_tools.rank_zero_only
@paddle.no_grad()
def valid(self):
valid_iterator = iter(self.valid_loader)
valid_losses = []
mel, wav = next(valid_iterator)
outputs = self.compute_outputs(mel, wav)
loss = self.compute_losses(outputs)
valid_losses.append(float(loss))
valid_loss = np.mean(valid_losses)
self.visualizer.add_scalar("valid/loss", valid_loss, global_step=self.iteration)
def main_sp(config, args):
exp = Experiment(config, args)
exp.setup()
exp.run()
def main(config, args):
if args.nprocs > 1 and args.device == "gpu":
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
else:
main_sp(config, args)
if __name__ == "__main__":
config = get_cfg_defaults()
parser = default_argument_parser()
args = parser.parse_args()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
print(args)
main(config, args)

View File

@ -0,0 +1,47 @@
from yacs.config import CfgNode as CN
_C = CN()
_C.data = CN(
dict(
batch_size=8, # batch size
valid_size=16, # the first N examples are reserved for validation
sample_rate=22050, # Hz, sample rate
n_fft=2048, # fft frame size
win_length=1024, # window size
hop_length=256, # hop size between ajacent frame
# f_max=8000, # Hz, max frequency when converting to mel
n_mels=80, # mel bands
train_clip_seconds=0.5, # audio clip length(in seconds)
)
)
_C.model = CN(
dict(
upsample_factors=[16, 16],
n_stack=3,
n_loop=10,
filter_size=2,
residual_channels=128, # resiaudal channel in each flow
loss_type="mog",
output_dim=3, # single gaussian
log_scale_min=-9.0,
)
)
_C.training = CN(
dict(
lr=1e-3, # learning rates
anneal_rate=0.5, # learning rate decay rate
anneal_interval=200000, # decrese lr by annel_rate every anneal_interval steps
valid_interval=1000, # validation
save_interval=10000, # checkpoint
max_iteration=3000000, # max iteration to train
gradient_max_norm=100.0 # global norm of gradients
)
)
def get_cfg_defaults():
"""Get a yacs CfgNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
return _C.clone()

View File

@ -0,0 +1,138 @@
import os
from pathlib import Path
import pickle
import numpy as np
import pandas
from paddle.io import Dataset, DataLoader
from parakeet.data.batch import batch_spec, batch_wav
from parakeet.data import dataset
from parakeet.audio import AudioProcessor
class LJSpeech(Dataset):
"""A simple dataset adaptor for the processed ljspeech dataset."""
def __init__(self, root):
self.root = Path(root).expanduser()
meta_data = pandas.read_csv(
str(self.root / "metadata.csv"),
sep="\t",
header=None,
names=["fname", "frames", "samples"]
)
records = []
for row in meta_data.itertuples() :
mel_path = str(self.root / "mel" / (row.fname + ".npy"))
wav_path = str(self.root / "wav" / (row.fname + ".npy"))
records.append((mel_path, wav_path))
self.records = records
def __getitem__(self, i):
mel_name, wav_name = self.records[i]
mel = np.load(mel_name)
wav = np.load(wav_name)
return mel, wav
def __len__(self):
return len(self.records)
class LJSpeechCollector(object):
"""A simple callable to batch LJSpeech examples."""
def __init__(self, padding_value=0.):
self.padding_value = padding_value
def __call__(self, examples):
batch_size = len(examples)
mels = [example[0] for example in examples]
wavs = [example[1] for example in examples]
mels = batch_spec(mels, pad_value=self.padding_value)
wavs = batch_wav(wavs, pad_value=self.padding_value)
audio_starts = np.zeros((batch_size,), dtype=np.int64)
return mels, wavs, audio_starts
class LJSpeechClipCollector(object):
def __init__(self, clip_frames=65, hop_length=256):
self.clip_frames = clip_frames
self.hop_length = hop_length
def __call__(self, examples):
mels = []
wavs = []
starts = []
for example in examples:
mel, wav_clip, start = self.clip(example)
mels.append(mel)
wavs.append(wav_clip)
starts.append(start)
mels = batch_spec(mels)
wavs = np.stack(wavs)
starts = np.array(starts, dtype=np.int64)
return mels, wavs, starts
def clip(self, example):
mel, wav = example
frames = mel.shape[-1]
start = np.random.randint(0, frames - self.clip_frames)
wav_clip = wav[start * self.hop_length: (start + self.clip_frames) * self.hop_length]
return mel, wav_clip, start
class DataCollector(object):
def __init__(self,
context_size,
sample_rate,
hop_length,
train_clip_seconds,
valid=False):
frames_per_second = sample_rate // hop_length
train_clip_frames = int(
np.ceil(train_clip_seconds * frames_per_second))
context_frames = context_size // hop_length
self.num_frames = train_clip_frames + context_frames
self.sample_rate = sample_rate
self.hop_length = hop_length
self.valid = valid
def random_crop(self, sample):
audio, mel_spectrogram = sample
audio_frames = int(audio.size) // self.hop_length
max_start_frame = audio_frames - self.num_frames
assert max_start_frame >= 0, "audio is too short to be cropped"
frame_start = np.random.randint(0, max_start_frame)
# frame_start = 0 # norandom
frame_end = frame_start + self.num_frames
audio_start = frame_start * self.hop_length
audio_end = frame_end * self.hop_length
audio = audio[audio_start:audio_end]
return audio, mel_spectrogram, audio_start
def __call__(self, samples):
# transform them first
if self.valid:
samples = [(audio, mel_spectrogram, 0)
for audio, mel_spectrogram in samples]
else:
samples = [self.random_crop(sample) for sample in samples]
# batch them
audios = [sample[0] for sample in samples]
audio_starts = [sample[2] for sample in samples]
mels = [sample[1] for sample in samples]
mels = batch_spec(mels)
if self.valid:
audios = batch_wav(audios, dtype=np.float32)
else:
audios = np.array(audios, dtype=np.float32)
audio_starts = np.array(audio_starts, dtype=np.int64)
return audios, mels, audio_starts

View File

@ -0,0 +1,139 @@
import os
import tqdm
import csv
import argparse
import numpy as np
import librosa
from pathlib import Path
import pandas as pd
from paddle.io import Dataset
from parakeet.data import batch_spec, batch_wav
from parakeet.datasets import LJSpeechMetaData
from parakeet.audio import AudioProcessor
from parakeet.audio.spec_normalizer import UnitMagnitude
from config import get_cfg_defaults
class Transform(object):
def __init__(self, sample_rate, n_fft, win_length, hop_length, n_mels):
self.sample_rate = sample_rate
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.n_mels = n_mels
self.spec_normalizer = UnitMagnitude(min=1e-5)
def __call__(self, example):
wav_path, _, _ = example
sr = self.sample_rate
n_fft = self.n_fft
win_length = self.win_length
hop_length = self.hop_length
n_mels = self.n_mels
wav, loaded_sr = librosa.load(wav_path, sr=None)
assert loaded_sr == sr, "sample rate does not match, resampling applied"
# Pad audio to the right size.
frames = int(np.ceil(float(wav.size) / hop_length))
fft_padding = (n_fft - hop_length) // 2 # sound
desired_length = frames * hop_length + fft_padding * 2
pad_amount = (desired_length - wav.size) // 2
if wav.size % 2 == 0:
wav = np.pad(wav, (pad_amount, pad_amount), mode='reflect')
else:
wav = np.pad(wav, (pad_amount, pad_amount + 1), mode='reflect')
# Normalize audio.
wav = wav / np.abs(wav).max() * 0.999
# Compute mel-spectrogram.
# Turn center to False to prevent internal padding.
spectrogram = librosa.core.stft(
wav,
hop_length=hop_length,
win_length=win_length,
n_fft=n_fft,
center=False)
spectrogram_magnitude = np.abs(spectrogram)
# Compute mel-spectrograms.
mel_filter_bank = librosa.filters.mel(sr=sr,
n_fft=n_fft,
n_mels=n_mels)
mel_spectrogram = np.dot(mel_filter_bank, spectrogram_magnitude)
mel_spectrogram = mel_spectrogram
# log scale mel_spectrogram.
mel_spectrogram = self.spec_normalizer.transform(mel_spectrogram)
# Extract the center of audio that corresponds to mel spectrograms.
audio = wav[fft_padding:-fft_padding]
assert mel_spectrogram.shape[1] * hop_length == audio.size
# there is no clipping here
return audio, mel_spectrogram
def create_dataset(config, input_dir, output_dir, verbose=True):
input_dir = Path(input_dir).expanduser()
dataset = LJSpeechMetaData(input_dir)
output_dir = Path(output_dir).expanduser()
output_dir.mkdir(exist_ok=True)
transform = Transform(
config.sample_rate,
config.n_fft,
config.win_length,
config.hop_length,
config.n_mels)
file_names = []
for example in tqdm.tqdm(dataset):
fname, _, _ = example
base_name = os.path.splitext(os.path.basename(fname))[0]
wav_dir = output_dir / "wav"
mel_dir = output_dir / "mel"
wav_dir.mkdir(exist_ok=True)
mel_dir.mkdir(exist_ok=True)
audio, mel = transform(example)
np.save(str(wav_dir / base_name), audio)
np.save(str(mel_dir / base_name), mel)
file_names.append((base_name, mel.shape[-1], audio.shape[-1]))
meta_data = pd.DataFrame.from_records(file_names)
meta_data.to_csv(str(output_dir / "metadata.csv"), sep="\t", index=None, header=None)
print("saved meta data in to {}".format(os.path.join(output_dir, "metadata.csv")))
print("Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="create dataset")
parser.add_argument("--config", type=str, metavar="FILE", help="extra config to overwrite the default config")
parser.add_argument("--input", type=str, help="path of the ljspeech dataset")
parser.add_argument("--output", type=str, help="path to save output dataset")
parser.add_argument("--opts", nargs=argparse.REMAINDER,
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs"
)
parser.add_argument("-v", "--verbose", action="store_true", help="print msg")
config = get_cfg_defaults()
args = parser.parse_args()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
if args.verbose:
print(config.data)
print(args)
create_dataset(config.data, args.input, args.output, args.verbose)

View File

@ -0,0 +1,51 @@
import argparse
import numpy as np
import soundfile as sf
import os
from pathlib import Path
import paddle
import parakeet
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWaveNet
from parakeet.utils import layer_tools, checkpoint
from config import get_cfg_defaults
def main(config, args):
paddle.set_device(args.device)
model = ConditionalWaveNet.from_pretrained(config, args.checkpoint_path)
layer_tools.recursively_remove_weight_norm(model)
model.eval()
mel_dir = Path(args.input).expanduser()
output_dir = Path(args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
for file_path in mel_dir.iterdir():
mel = np.load(str(file_path))
audio = model.predict(mel)
audio_path = output_dir / (os.path.splitext(file_path.name)[0] + ".wav")
sf.write(audio_path, audio, config.data.sample_rate)
print("[synthesize] {} -> {}".format(file_path, audio_path))
if __name__ == "__main__":
config = get_cfg_defaults()
parser = argparse.ArgumentParser(description="generate mel spectrogram with TransformerTTS.")
parser.add_argument("--config", type=str, metavar="FILE", help="extra config to overwrite the default config")
parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load.")
parser.add_argument("--input", type=str, help="path of directory containing mel spectrogram (in .npy format)")
parser.add_argument("--output", type=str, help="path to save outputs")
parser.add_argument("--device", type=str, default="cpu", help="device type to use.")
parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
parser.add_argument("-v", "--verbose", action="store_true", help="print msg")
args = parser.parse_args()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
print(args)
main(config, args)

157
examples/wavenet/train.py Normal file
View File

@ -0,0 +1,157 @@
import time
from pathlib import Path
import math
import numpy as np
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader, DistributedBatchSampler
from tensorboardX import SummaryWriter
from collections import defaultdict
import parakeet
from parakeet.data import dataset
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWaveNet
from parakeet.audio import AudioProcessor
from parakeet.utils import scheduler, mp_tools
from parakeet.training.cli import default_argument_parser
from parakeet.training.experiment import ExperimentBase
from parakeet.utils.mp_tools import rank_zero_only
from config import get_cfg_defaults
from ljspeech import LJSpeech, LJSpeechClipCollector, LJSpeechCollector
class Experiment(ExperimentBase):
def setup_model(self):
config = self.config
model = ConditionalWaveNet(
upsample_factors=config.model.upsample_factors,
n_stack=config.model.n_stack,
n_loop=config.model.n_loop,
residual_channels=config.model.residual_channels,
output_dim=config.model.output_dim,
n_mels=config.data.n_mels,
filter_size=config.model.filter_size,
loss_type=config.model.loss_type,
log_scale_min=config.model.log_scale_min)
if self.parallel > 1:
model = paddle.DataParallel(model)
lr_scheduler = paddle.optimizer.lr.StepDecay(
config.training.lr,
config.training.anneal_interval,
config.training.anneal_rate)
optimizer = paddle.optimizer.Adam(
lr_scheduler,
parameters=model.parameters(),
grad_clip=paddle.nn.ClipGradByGlobalNorm(config.training.gradient_max_norm))
self.model = model
self.model_core = model._layer if self.parallel else model
self.optimizer = optimizer
def setup_dataloader(self):
config = self.config
args = self.args
ljspeech_dataset = LJSpeech(args.data)
valid_set, train_set = dataset.split(ljspeech_dataset, config.data.valid_size)
# convolutional net's causal padding size
context_size = config.model.n_stack \
* sum([(config.model.filter_size - 1) * 2**i for i in range(config.model.n_loop)]) \
+ 1
context_frames = context_size // config.data.hop_length
# frames used to compute loss
frames_per_second = config.data.sample_rate // config.data.hop_length
train_clip_frames = math.ceil(config.data.train_clip_seconds * frames_per_second)
num_frames = train_clip_frames + context_frames
batch_fn = LJSpeechClipCollector(num_frames, config.data.hop_length)
if not self.parallel:
train_loader = DataLoader(
train_set,
batch_size=config.data.batch_size,
shuffle=True,
drop_last=True,
collate_fn=batch_fn)
else:
sampler = DistributedBatchSampler(
train_set,
batch_size=config.data.batch_size,
shuffle=True,
drop_last=True)
train_loader = DataLoader(
train_set, batch_sampler=sampler, collate_fn=batch_fn)
valid_batch_fn = LJSpeechCollector()
valid_loader = DataLoader(
valid_set, batch_size=1, collate_fn=valid_batch_fn)
self.train_loader = train_loader
self.valid_loader = valid_loader
def train_batch(self):
start = time.time()
batch = self.read_batch()
data_loader_time = time.time() - start
self.model.train()
self.optimizer.clear_grad()
mel, wav, audio_starts = batch
y = self.model(wav, mel, audio_starts)
loss = self.model.loss(y, wav)
loss.backward()
self.optimizer.step()
iteration_time = time.time() - start
loss_value = float(loss)
msg = "Rank: {}, ".format(dist.get_rank())
msg += "step: {}, ".format(self.iteration)
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time, iteration_time)
msg += "loss: {:>.6f}".format(loss_value)
self.logger.info(msg)
self.visualizer.add_scalar("train/loss", loss_value, global_step=self.iteration)
@mp_tools.rank_zero_only
@paddle.no_grad()
def valid(self):
valid_iterator = iter(self.valid_loader)
valid_losses = []
mel, wav, audio_starts = next(valid_iterator)
y = self.model(wav, mel, audio_starts)
loss = self.model.loss(y, wav)
valid_losses.append(float(loss))
valid_loss = np.mean(valid_losses)
self.visualizer.add_scalar("valid/loss", valid_loss, global_step=self.iteration)
def main_sp(config, args):
exp = Experiment(config, args)
exp.setup()
exp.run()
def main(config, args):
if args.nprocs > 1 and args.device == "gpu":
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
else:
main_sp(config, args)
if __name__ == "__main__":
config = get_cfg_defaults()
parser = default_argument_parser()
args = parser.parse_args()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
print(args)
main(config, args)