From 14235cd114c73c75fac12b2f10d827a741112e3b Mon Sep 17 00:00:00 2001 From: lifuchen Date: Fri, 19 Jun 2020 03:46:10 +0000 Subject: [PATCH 1/4] modified synthesis of transformer_tts & fastspeech --- examples/fastspeech/synthesis.py | 198 ++++++++++++++-------- examples/fastspeech/synthesis.sh | 13 +- examples/transformer_tts/synthesis.py | 181 +++++++++++++++----- examples/transformer_tts/synthesis.sh | 14 +- examples/transformer_tts/train_vocoder.py | 2 +- 5 files changed, 285 insertions(+), 123 deletions(-) diff --git a/examples/fastspeech/synthesis.py b/examples/fastspeech/synthesis.py index de726bd..c12742a 100644 --- a/examples/fastspeech/synthesis.py +++ b/examples/fastspeech/synthesis.py @@ -28,6 +28,8 @@ from parakeet.models.fastspeech.fastspeech import FastSpeech from parakeet.models.transformer_tts.utils import * from parakeet.models.wavenet import WaveNet, UpsampleNet from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet +from parakeet.modules import weight_norm +from parakeet.models.waveflow import WaveFlowModule from parakeet.utils.layer_tools import freeze from parakeet.utils import io @@ -35,7 +37,13 @@ from parakeet.utils import io def add_config_options_to_parser(parser): parser.add_argument("--config", type=str, help="path of the config file") parser.add_argument( - "--config_clarinet", type=str, help="path of the clarinet config file") + "--vocoder", + type=str, + default="griffinlim", + choices=['griffinlim', 'clarinet', 'waveflow'], + help="vocoder method") + parser.add_argument( + "--config_vocoder", type=str, help="path of the vocoder config file") parser.add_argument("--use_gpu", type=int, default=0, help="device to use") parser.add_argument( "--alpha", @@ -47,9 +55,9 @@ def add_config_options_to_parser(parser): parser.add_argument( "--checkpoint", type=str, help="fastspeech checkpoint to synthesis") parser.add_argument( - "--checkpoint_clarinet", + "--checkpoint_vocoder", type=str, - help="clarinet checkpoint to synthesis") + help="vocoder checkpoint to synthesis") parser.add_argument( "--output", @@ -83,46 +91,62 @@ def synthesis(text_input, args): pos_text = np.arange(1, text.shape[1] + 1) pos_text = np.expand_dims(pos_text, axis=0) - text = dg.to_variable(text) - pos_text = dg.to_variable(pos_text) + text = dg.to_variable(text).astype(np.int64) + pos_text = dg.to_variable(pos_text).astype(np.int64) _, mel_output_postnet = model(text, pos_text, alpha=args.alpha) - result = np.exp(mel_output_postnet.numpy()) - mel_output_postnet = fluid.layers.transpose( - fluid.layers.squeeze(mel_output_postnet, [0]), [1, 0]) - mel_output_postnet = np.exp(mel_output_postnet.numpy()) - basis = librosa.filters.mel(cfg['audio']['sr'], cfg['audio']['n_fft'], - cfg['audio']['num_mels']) - inv_basis = np.linalg.pinv(basis) - spec = np.maximum(1e-10, np.dot(inv_basis, mel_output_postnet)) + if args.vocoder == 'griffinlim': + #synthesis use griffin-lim + wav = synthesis_with_griffinlim( + mel_output_postnet, + sr=cfg['audio']['sr'], + n_fft=cfg['audio']['n_fft'], + num_mels=cfg['audio']['num_mels'], + power=cfg['audio']['power'], + hop_length=cfg['audio']['hop_length'], + win_length=cfg['audio']['win_length']) + elif args.vocoder == 'clarinet': + # synthesis use clarinet + wav = synthesis_with_clarinet(mel_output_postnet, args.config_vocoder, + args.checkpoint_vocoder, place) + elif args.vocoder == 'waveflow': + wav = synthesis_with_waveflow(mel_output_postnet, args, + args.checkpoint_vocoder, place) + else: + print( + 'vocoder error, we only support griffinlim, clarinet and waveflow, but recevied %s.' + % args.vocoder) - # synthesis use clarinet - wav_clarinet = synthesis_with_clarinet( - args.config_clarinet, args.checkpoint_clarinet, result, place) - writer.add_audio(text_input + '(clarinet)', wav_clarinet, 0, + writer.add_audio(text_input + '(' + args.vocoder + ')', wav, 0, cfg['audio']['sr']) if not os.path.exists(os.path.join(args.output, 'samples')): os.mkdir(os.path.join(args.output, 'samples')) - write( - os.path.join(os.path.join(args.output, 'samples'), 'clarinet.wav'), - cfg['audio']['sr'], wav_clarinet) - - #synthesis use griffin-lim - wav = librosa.core.griffinlim( - spec**cfg['audio']['power'], - hop_length=cfg['audio']['hop_length'], - win_length=cfg['audio']['win_length']) - writer.add_audio(text_input + '(griffin-lim)', wav, 0, cfg['audio']['sr']) write( os.path.join( - os.path.join(args.output, 'samples'), 'grinffin-lim.wav'), + os.path.join(args.output, 'samples'), args.vocoder + '.wav'), cfg['audio']['sr'], wav) print("Synthesis completed !!!") writer.close() -def synthesis_with_clarinet(config_path, checkpoint, mel_spectrogram, place): +def synthesis_with_griffinlim(mel_output, sr, n_fft, num_mels, power, + hop_length, win_length): + mel_output = fluid.layers.transpose( + fluid.layers.squeeze(mel_output, [0]), [1, 0]) + mel_output = np.exp(mel_output.numpy()) + basis = librosa.filters.mel(sr, n_fft, num_mels) + inv_basis = np.linalg.pinv(basis) + spec = np.maximum(1e-10, np.dot(inv_basis, mel_output)) + + wav = librosa.core.griffinlim( + spec**power, hop_length=hop_length, win_length=win_length) + + return wav + + +def synthesis_with_clarinet(mel_output, config_path, checkpoint, place): + mel_spectrogram = np.exp(mel_output.numpy()) with open(config_path, 'rt') as f: config = yaml.safe_load(f) @@ -136,62 +160,86 @@ def synthesis_with_clarinet(config_path, checkpoint, mel_spectrogram, place): # only batch=1 for validation is enabled - with dg.guard(place): - # conditioner(upsampling net) - conditioner_config = config["conditioner"] - upsampling_factors = conditioner_config["upsampling_factors"] - upsample_net = UpsampleNet(upscale_factors=upsampling_factors) - freeze(upsample_net) + fluid.enable_dygraph(place) + # conditioner(upsampling net) + conditioner_config = config["conditioner"] + upsampling_factors = conditioner_config["upsampling_factors"] + upsample_net = UpsampleNet(upscale_factors=upsampling_factors) + freeze(upsample_net) - residual_channels = teacher_config["residual_channels"] - loss_type = teacher_config["loss_type"] - output_dim = teacher_config["output_dim"] - log_scale_min = teacher_config["log_scale_min"] - assert loss_type == "mog" and output_dim == 3, \ - "the teacher wavenet should be a wavenet with single gaussian output" + residual_channels = teacher_config["residual_channels"] + loss_type = teacher_config["loss_type"] + output_dim = teacher_config["output_dim"] + log_scale_min = teacher_config["log_scale_min"] + assert loss_type == "mog" and output_dim == 3, \ + "the teacher wavenet should be a wavenet with single gaussian output" - teacher = WaveNet(n_loop, n_layer, residual_channels, output_dim, - n_mels, filter_size, loss_type, log_scale_min) - # load & freeze upsample_net & teacher - freeze(teacher) + teacher = WaveNet(n_loop, n_layer, residual_channels, output_dim, n_mels, + filter_size, loss_type, log_scale_min) + # load & freeze upsample_net & teacher + freeze(teacher) - student_config = config["student"] - n_loops = student_config["n_loops"] - n_layers = student_config["n_layers"] - student_residual_channels = student_config["residual_channels"] - student_filter_size = student_config["filter_size"] - student_log_scale_min = student_config["log_scale_min"] - student = ParallelWaveNet(n_loops, n_layers, student_residual_channels, - n_mels, student_filter_size) + student_config = config["student"] + n_loops = student_config["n_loops"] + n_layers = student_config["n_layers"] + student_residual_channels = student_config["residual_channels"] + student_filter_size = student_config["filter_size"] + student_log_scale_min = student_config["log_scale_min"] + student = ParallelWaveNet(n_loops, n_layers, student_residual_channels, + n_mels, student_filter_size) - stft_config = config["stft"] - stft = STFT( - n_fft=stft_config["n_fft"], - hop_length=stft_config["hop_length"], - win_length=stft_config["win_length"]) + stft_config = config["stft"] + stft = STFT( + n_fft=stft_config["n_fft"], + hop_length=stft_config["hop_length"], + win_length=stft_config["win_length"]) - lmd = config["loss"]["lmd"] - model = Clarinet(upsample_net, teacher, student, stft, - student_log_scale_min, lmd) - io.load_parameters(model=model, checkpoint_path=checkpoint) + lmd = config["loss"]["lmd"] + model = Clarinet(upsample_net, teacher, student, stft, + student_log_scale_min, lmd) + io.load_parameters(model=model, checkpoint_path=checkpoint) - if not os.path.exists(args.output): - os.makedirs(args.output) - model.eval() + if not os.path.exists(args.output): + os.makedirs(args.output) + model.eval() - # Rescale mel_spectrogram. - min_level, ref_level = 1e-5, 20 # hard code it - mel_spectrogram = 20 * np.log10(np.maximum(min_level, mel_spectrogram)) - mel_spectrogram = mel_spectrogram - ref_level - mel_spectrogram = np.clip((mel_spectrogram + 100) / 100, 0, 1) + # Rescale mel_spectrogram. + min_level, ref_level = 1e-5, 20 # hard code it + mel_spectrogram = 20 * np.log10(np.maximum(min_level, mel_spectrogram)) + mel_spectrogram = mel_spectrogram - ref_level + mel_spectrogram = np.clip((mel_spectrogram + 100) / 100, 0, 1) - mel_spectrogram = dg.to_variable(mel_spectrogram) - mel_spectrogram = fluid.layers.transpose(mel_spectrogram, [0, 2, 1]) + mel_spectrogram = dg.to_variable(mel_spectrogram) + mel_spectrogram = fluid.layers.transpose(mel_spectrogram, [0, 2, 1]) - wav_var = model.synthesis(mel_spectrogram) - wav_np = wav_var.numpy()[0] + wav_var = model.synthesis(mel_spectrogram) + wav_np = wav_var.numpy()[0] - return wav_np + return wav_np + + +def synthesis_with_waveflow(mel_output, args, checkpoint, place): + #mel_output = np.exp(mel_output.numpy()) + mel_output = mel_output.numpy() + + fluid.enable_dygraph(place) + args.config = args.config_vocoder + args.use_fp16 = False + config = io.add_yaml_config_to_args(args) + + mel_spectrogram = dg.to_variable(mel_output) + mel_spectrogram = fluid.layers.transpose(mel_spectrogram, [0, 2, 1]) + + # Build model. + waveflow = WaveFlowModule(config) + io.load_parameters(model=waveflow, checkpoint_path=checkpoint) + for layer in waveflow.sublayers(): + if isinstance(layer, weight_norm.WeightNormWrapper): + layer.remove_weight_norm() + + # Run model inference. + wav = waveflow.synthesize(mel_spectrogram, sigma=config.sigma) + return wav.numpy()[0] if __name__ == '__main__': diff --git a/examples/fastspeech/synthesis.sh b/examples/fastspeech/synthesis.sh index a6a0347..79a62d0 100644 --- a/examples/fastspeech/synthesis.sh +++ b/examples/fastspeech/synthesis.sh @@ -1,13 +1,20 @@ # train model +CUDA_VISIBLE_DEVICES=0 \ python -u synthesis.py \ --use_gpu=1 \ --alpha=1.0 \ ---checkpoint='./checkpoint/fastspeech/step-120000' \ +--checkpoint='./checkpoint/fastspeech1024/step-160000' \ --config='configs/ljspeech.yaml' \ ---config_clarine='../clarinet/configs/config.yaml' \ ---checkpoint_clarinet='../clarinet/checkpoint/step-500000' \ --output='./synthesis' \ +--vocoder='waveflow' \ +--config_vocoder='../waveflow/checkpoint/waveflow_res64_ljspeech_ckpt_1.0/waveflow_ljspeech.yaml' \ +--checkpoint_vocoder='../waveflow/checkpoint/waveflow_res64_ljspeech_ckpt_1.0/step-3020000' \ +#--vocoder='clarinet' \ +#--config_vocoder='../clarinet/configs/clarinet_ljspeech.yaml' \ +#--checkpoint_vocoder='../clarinet/checkpoint/step-500000' \ + + if [ $? -ne 0 ]; then echo "Failed in synthesis!" diff --git a/examples/transformer_tts/synthesis.py b/examples/transformer_tts/synthesis.py index 7d7f965..d4e17bf 100644 --- a/examples/transformer_tts/synthesis.py +++ b/examples/transformer_tts/synthesis.py @@ -28,6 +28,10 @@ from parakeet.models.transformer_tts.utils import * from parakeet import audio from parakeet.models.transformer_tts import Vocoder from parakeet.models.transformer_tts import TransformerTTS +from parakeet.modules import weight_norm +from parakeet.models.waveflow import WaveFlowModule +from parakeet.modules.weight_norm import WeightNormWrapper +from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet from parakeet.utils import io @@ -44,6 +48,14 @@ def add_config_options_to_parser(parser): "--checkpoint_transformer", type=str, help="transformer_tts checkpoint to synthesis") + parser.add_argument( + "--vocoder", + type=str, + default="griffinlim", + choices=['griffinlim', 'wavenet', 'waveflow'], + help="vocoder method") + parser.add_argument( + "--config_vocoder", type=str, help="path of the vocoder config file") parser.add_argument( "--checkpoint_vocoder", type=str, @@ -82,31 +94,32 @@ def synthesis(text_input, args): model=model, checkpoint_path=args.checkpoint_transformer) model.eval() - with fluid.unique_name.guard(): - model_vocoder = Vocoder( - cfg['train']['batch_size'], cfg['vocoder']['hidden_size'], - cfg['audio']['num_mels'], cfg['audio']['n_fft']) - # Load parameters. - global_step = io.load_parameters( - model=model_vocoder, checkpoint_path=args.checkpoint_vocoder) - model_vocoder.eval() # init input text = np.asarray(text_to_sequence(text_input)) - text = fluid.layers.unsqueeze(dg.to_variable(text), [0]) + text = fluid.layers.unsqueeze(dg.to_variable(text).astype(np.int64), [0]) mel_input = dg.to_variable(np.zeros([1, 1, 80])).astype(np.float32) pos_text = np.arange(1, text.shape[1] + 1) - pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text), [0]) + pos_text = fluid.layers.unsqueeze( + dg.to_variable(pos_text).astype(np.int64), [0]) pbar = tqdm(range(args.max_len)) for i in pbar: pos_mel = np.arange(1, mel_input.shape[1] + 1) - pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0]) + pos_mel = fluid.layers.unsqueeze( + dg.to_variable(pos_mel).astype(np.int64), [0]) mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model( text, mel_input, pos_text, pos_mel) mel_input = fluid.layers.concat( [mel_input, postnet_pred[:, -1:, :]], axis=1) - - mag_pred = model_vocoder(postnet_pred) + global_step = 0 + for i, prob in enumerate(attn_probs): + for j in range(4): + x = np.uint8(cm.viridis(prob.numpy()[j]) * 255) + writer.add_image( + 'Attention_%d_0' % global_step, + x, + i * 4 + j, + dataformats="HWC") _ljspeech_processor = audio.AudioProcessor( sample_rate=cfg['audio']['sr'], @@ -122,45 +135,130 @@ def synthesis(text_input, args): symmetric_norm=False, max_norm=1., mel_fmin=0, - mel_fmax=None, + mel_fmax=8000, clip_norm=True, griffin_lim_iters=60, do_trim_silence=False, sound_norm=False) + if args.vocoder == 'griffinlim': + #synthesis use griffin-lim + wav = synthesis_with_griffinlim(postnet_pred, _ljspeech_processor) + elif args.vocoder == 'wavenet': + # synthesis use wavenet + wav = synthesis_with_wavenet(postnet_pred, args) + elif args.vocoder == 'waveflow': + # synthesis use waveflow + wav = synthesis_with_waveflow(postnet_pred, args, + args.checkpoint_vocoder, + _ljspeech_processor, place) + else: + print( + 'vocoder error, we only support griffinlim, cbhg and waveflow, but recevied %s.' + % args.vocoder) + + writer.add_audio(text_input + '(' + args.vocoder + ')', wav, 0, + cfg['audio']['sr']) + if not os.path.exists(os.path.join(args.output, 'samples')): + os.mkdir(os.path.join(args.output, 'samples')) + write( + os.path.join( + os.path.join(args.output, 'samples'), args.vocoder + '.wav'), + cfg['audio']['sr'], wav) + print("Synthesis completed !!!") + writer.close() + + +def synthesis_with_griffinlim(mel_output, _ljspeech_processor): + # synthesis with griffin-lim + mel_output = fluid.layers.transpose( + fluid.layers.squeeze(mel_output, [0]), [1, 0]) + mel_output = np.exp(mel_output.numpy()) + basis = librosa.filters.mel(22050, 1024, 80, fmin=0, fmax=8000) + inv_basis = np.linalg.pinv(basis) + spec = np.maximum(1e-10, np.dot(inv_basis, mel_output)) + + wav = librosa.core.griffinlim(spec**1.2, hop_length=256, win_length=1024) + + return wav + + +def synthesis_with_wavenet(mel_output, args): + with open(args.config_vocoder, 'rt') as f: + config = yaml.safe_load(f) + n_mels = config["data"]["n_mels"] + model_config = config["model"] + filter_size = model_config["filter_size"] + upsampling_factors = model_config["upsampling_factors"] + encoder = UpsampleNet(upsampling_factors) + + n_loop = model_config["n_loop"] + n_layer = model_config["n_layer"] + residual_channels = model_config["residual_channels"] + output_dim = model_config["output_dim"] + loss_type = model_config["loss_type"] + log_scale_min = model_config["log_scale_min"] + decoder = WaveNet(n_loop, n_layer, residual_channels, output_dim, n_mels, + filter_size, loss_type, log_scale_min) + + model = ConditionalWavenet(encoder, decoder) + + # load model parameters + iteration = io.load_parameters( + model, checkpoint_path=args.checkpoint_vocoder) + + for layer in model.sublayers(): + if isinstance(layer, WeightNormWrapper): + layer.remove_weight_norm() + mel_output = fluid.layers.transpose(mel_output, [0, 2, 1]) + wav = model.synthesis(mel_output) + return wav.numpy()[0] + + +def synthesis_with_cbhg(mel_output, _ljspeech_processor, cfg): + with fluid.unique_name.guard(): + model_vocoder = Vocoder( + cfg['train']['batch_size'], cfg['vocoder']['hidden_size'], + cfg['audio']['num_mels'], cfg['audio']['n_fft']) + # Load parameters. + global_step = io.load_parameters( + model=model_vocoder, checkpoint_path=args.checkpoint_vocoder) + model_vocoder.eval() + mag_pred = model_vocoder(mel_output) # synthesis with cbhg wav = _ljspeech_processor.inv_spectrogram( fluid.layers.transpose(fluid.layers.squeeze(mag_pred, [0]), [1, 0]) .numpy()) - global_step = 0 - for i, prob in enumerate(attn_probs): - for j in range(4): - x = np.uint8(cm.viridis(prob.numpy()[j]) * 255) - writer.add_image( - 'Attention_%d_0' % global_step, - x, - i * 4 + j, - dataformats="HWC") + return wav - writer.add_audio(text_input + '(cbhg)', wav, 0, cfg['audio']['sr']) - if not os.path.exists(os.path.join(args.output, 'samples')): - os.mkdir(os.path.join(args.output, 'samples')) - write( - os.path.join(os.path.join(args.output, 'samples'), 'cbhg.wav'), - cfg['audio']['sr'], wav) +def synthesis_with_waveflow(mel_output, args, checkpoint, _ljspeech_processor, + place): + mel_output = fluid.layers.transpose( + fluid.layers.squeeze(mel_output, [0]), [1, 0]) + mel_output = mel_output.numpy() + #mel_output = (mel_output - mel_output.min())/(mel_output.max() - mel_output.min()) + #mel_output = 5 * mel_output - 4 + #mel_output = np.log(10) * mel_output - # synthesis with griffin-lim - wav = _ljspeech_processor.inv_melspectrogram( - fluid.layers.transpose( - fluid.layers.squeeze(postnet_pred, [0]), [1, 0]).numpy()) - writer.add_audio(text_input + '(griffin)', wav, 0, cfg['audio']['sr']) + fluid.enable_dygraph(place) + args.config = args.config_vocoder + args.use_fp16 = False + config = io.add_yaml_config_to_args(args) - write( - os.path.join(os.path.join(args.output, 'samples'), 'griffin.wav'), - cfg['audio']['sr'], wav) - print("Synthesis completed !!!") - writer.close() + mel_spectrogram = dg.to_variable(mel_output) + mel_spectrogram = fluid.layers.unsqueeze(mel_spectrogram, [0]) + + # Build model. + waveflow = WaveFlowModule(config) + io.load_parameters(model=waveflow, checkpoint_path=checkpoint) + for layer in waveflow.sublayers(): + if isinstance(layer, weight_norm.WeightNormWrapper): + layer.remove_weight_norm() + + # Run model inference. + wav = waveflow.synthesize(mel_spectrogram, sigma=config.sigma) + return wav.numpy()[0] if __name__ == '__main__': @@ -169,5 +267,6 @@ if __name__ == '__main__': args = parser.parse_args() # Print the whole config setting. pprint(vars(args)) - synthesis("Parakeet stands for Paddle PARAllel text-to-speech toolkit.", - args) + synthesis( + "Life was like a box of chocolates, you never know what you're gonna get.", + args) diff --git a/examples/transformer_tts/synthesis.sh b/examples/transformer_tts/synthesis.sh index 39312f8..a282d70 100644 --- a/examples/transformer_tts/synthesis.sh +++ b/examples/transformer_tts/synthesis.sh @@ -2,12 +2,20 @@ # train model CUDA_VISIBLE_DEVICES=0 \ python -u synthesis.py \ ---max_len=300 \ ---use_gpu=1 \ +--max_len=400 \ +--use_gpu=0 \ --output='./synthesis' \ --config='configs/ljspeech.yaml' \ --checkpoint_transformer='./checkpoint/transformer/step-120000' \ ---checkpoint_vocoder='./checkpoint/vocoder/step-100000' \ +--vocoder='wavenet' \ +--config_vocoder='../wavenet/config.yaml' \ +--checkpoint_vocoder='../wavenet/step-2450000' \ +#--vocoder='waveflow' \ +#--config_vocoder='../waveflow/checkpoint/waveflow_res64_ljspeech_ckpt_1.0/waveflow_ljspeech.yaml' \ +#--checkpoint_vocoder='../waveflow/checkpoint/waveflow_res64_ljspeech_ckpt_1.0/step-3020000' \ +#--vocoder='cbhg' \ +#--config_vocoder='configs/ljspeech.yaml' \ +#--checkpoint_vocoder='checkpoint/cbhg/step-100000' \ if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/transformer_tts/train_vocoder.py b/examples/transformer_tts/train_vocoder.py index 30f31aa..37e9398 100644 --- a/examples/transformer_tts/train_vocoder.py +++ b/examples/transformer_tts/train_vocoder.py @@ -98,7 +98,7 @@ def main(args): local_rank, is_vocoder=True).reader() - for epoch in range(cfg['train']['max_epochs']): + for epoch in range(cfg['train']['max_iteration']): pbar = tqdm(reader) for i, data in enumerate(pbar): pbar.set_description('Processing at epoch %d' % epoch) From aaae1008547af6cab62af8a7721d79d69e5956f5 Mon Sep 17 00:00:00 2001 From: lifuchen Date: Tue, 23 Jun 2020 12:52:58 +0000 Subject: [PATCH 2/4] modified data preprocessing and synthesis of transformer_tts and fastspeech --- examples/fastspeech/README.md | 15 +- .../fastspeech/alignments/get_alignments.py | 47 +++---- examples/fastspeech/configs/ljspeech.yaml | 5 +- examples/fastspeech/data.py | 33 ++--- examples/fastspeech/synthesis.py | 113 +++------------- examples/fastspeech/synthesis.sh | 9 +- examples/fastspeech/train.sh | 2 +- examples/transformer_tts/README.md | 55 ++------ .../transformer_tts/configs/ljspeech.yaml | 12 +- examples/transformer_tts/data.py | 65 +++++---- examples/transformer_tts/synthesis.py | 128 ++++-------------- examples/transformer_tts/synthesis.sh | 13 +- examples/transformer_tts/train_transformer.py | 13 +- .../models/fastspeech/length_regulator.py | 12 +- parakeet/models/transformer_tts/utils.py | 6 +- 15 files changed, 168 insertions(+), 360 deletions(-) diff --git a/examples/fastspeech/README.md b/examples/fastspeech/README.md index a50c39b..865f68a 100644 --- a/examples/fastspeech/README.md +++ b/examples/fastspeech/README.md @@ -87,7 +87,7 @@ python train.py \ --use_gpu=1 \ --data=${DATAPATH} \ --alignments_path=${ALIGNMENTS_PATH} \ ---output='./experiment' \ +--output=${OUTPUTPATH} \ --config='configs/ljspeech.yaml' \ ``` @@ -105,7 +105,7 @@ python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog tr --use_gpu=1 \ --data=${DATAPATH} \ --alignments_path=${ALIGNMENTS_PATH} \ ---output='./experiment' \ +--output=${OUTPUTPATH} \ --config='configs/ljspeech.yaml' \ ``` @@ -123,14 +123,13 @@ After training the FastSpeech, audio can be synthesized by running ``synthesis.p python synthesis.py \ --use_gpu=1 \ --alpha=1.0 \ ---checkpoint='./checkpoint/fastspeech/step-120000' \ +--checkpoint=${CHECKPOINTPATH} \ --config='configs/ljspeech.yaml' \ ---config_clarine='../clarinet/configs/config.yaml' \ ---checkpoint_clarinet='../clarinet/checkpoint/step-500000' \ ---output='./synthesis' \ +--output=${OUTPUTPATH} \ +--vocoder='griffinlim' \ ``` -We use Clarinet to synthesis wav, so it necessary for you to prepare a pre-trained [Clarinet checkpoint](https://paddlespeech.bj.bcebos.com/Parakeet/clarinet_ljspeech_ckpt_1.0.zip). +We currently support two vocoders, ``griffinlim`` and ``waveflow``. You can set ``--vocoder`` to use one of them. If you want to use ``waveflow`` as your vocoder, you need to set ``--config_vocoder`` and ``--checkpoint_vocoder`` which are the path of the config and checkpoint of vocoder. You can download the pretrain model of ``waveflow`` from [here](https://github.com/PaddlePaddle/Parakeet#vocoders). Or you can run the script file directly. @@ -141,3 +140,5 @@ sh synthesis.sh For more help on arguments ``python synthesis.py --help``. + +Then you can find the synthesized audio files in ``${OUTPUTPATH}/samples``. diff --git a/examples/fastspeech/alignments/get_alignments.py b/examples/fastspeech/alignments/get_alignments.py index d31bafc..8a46ff2 100644 --- a/examples/fastspeech/alignments/get_alignments.py +++ b/examples/fastspeech/alignments/get_alignments.py @@ -27,7 +27,6 @@ from collections import OrderedDict import paddle.fluid as fluid import paddle.fluid.dygraph as dg from parakeet.models.transformer_tts.utils import * -from parakeet import audio from parakeet.models.transformer_tts import TransformerTTS from parakeet.models.fastspeech.utils import get_alignment from parakeet.utils import io @@ -78,25 +77,6 @@ def alignments(args): header=None, quoting=csv.QUOTE_NONE, names=["fname", "raw_text", "normalized_text"]) - ljspeech_processor = audio.AudioProcessor( - sample_rate=cfg['audio']['sr'], - num_mels=cfg['audio']['num_mels'], - min_level_db=cfg['audio']['min_level_db'], - ref_level_db=cfg['audio']['ref_level_db'], - n_fft=cfg['audio']['n_fft'], - win_length=cfg['audio']['win_length'], - hop_length=cfg['audio']['hop_length'], - power=cfg['audio']['power'], - preemphasis=cfg['audio']['preemphasis'], - signal_norm=True, - symmetric_norm=False, - max_norm=1., - mel_fmin=0, - mel_fmax=None, - clip_norm=True, - griffin_lim_iters=60, - do_trim_silence=False, - sound_norm=False) pbar = tqdm(range(len(table))) alignments = OrderedDict() @@ -107,11 +87,26 @@ def alignments(args): text = fluid.layers.unsqueeze(dg.to_variable(text), [0]) pos_text = np.arange(1, text.shape[1] + 1) pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text), [0]) - wav = ljspeech_processor.load_wav( - os.path.join(args.data, 'wavs', fname + ".wav")) - mel_input = ljspeech_processor.melspectrogram(wav).astype( - np.float32) - mel_input = np.transpose(mel_input, axes=(1, 0)) + + # load + wav, _ = librosa.load( + str(os.path.join(args.data, 'wavs', fname + ".wav"))) + + spec = librosa.stft( + y=wav, + n_fft=cfg['audio']['n_fft'], + win_length=cfg['audio']['win_length'], + hop_length=cfg['audio']['hop_length']) + mag = np.abs(spec) + mel = librosa.filters.mel(sr=cfg['audio']['sr'], + n_fft=cfg['audio']['n_fft'], + n_mels=cfg['audio']['num_mels'], + fmin=cfg['audio']['fmin'], + fmax=cfg['audio']['fmax']) + mel = np.matmul(mel, mag) + mel = np.log(np.maximum(mel, 1e-5)) + + mel_input = np.transpose(mel, axes=(1, 0)) mel_input = fluid.layers.unsqueeze(dg.to_variable(mel_input), [0]) mel_lens = mel_input.shape[1] @@ -125,7 +120,7 @@ def alignments(args): alignment, _ = get_alignment(attn_probs, mel_lens, network_cfg['decoder_num_head']) alignments[fname] = alignment - with open(args.output + '.txt', "wb") as f: + with open(args.output + '.pkl', "wb") as f: pickle.dump(alignments, f) diff --git a/examples/fastspeech/configs/ljspeech.yaml b/examples/fastspeech/configs/ljspeech.yaml index 96b0d54..32bdd42 100644 --- a/examples/fastspeech/configs/ljspeech.yaml +++ b/examples/fastspeech/configs/ljspeech.yaml @@ -1,10 +1,13 @@ audio: num_mels: 80 #the number of mel bands when calculating mel spectrograms. - n_fft: 2048 #the number of fft components. + n_fft: 1024 #the number of fft components. sr: 22050 #the sampling rate of audio data file. hop_length: 256 #the number of samples to advance between frames. win_length: 1024 #the length (width) of the window function. + preemphasis: 0.97 power: 1.2 #the power to raise before griffin-lim. + fmin: 0 + fmax: 8000 network: encoder_n_layer: 6 #the number of FFT Block in encoder. diff --git a/examples/fastspeech/data.py b/examples/fastspeech/data.py index da1ffec..b920035 100644 --- a/examples/fastspeech/data.py +++ b/examples/fastspeech/data.py @@ -42,12 +42,7 @@ class LJSpeechLoader: LJSPEECH_ROOT = Path(data_path) metadata = LJSpeechMetaData(LJSPEECH_ROOT, alignments_path) - transformer = LJSpeech( - sr=config['sr'], - n_fft=config['n_fft'], - num_mels=config['num_mels'], - win_length=config['win_length'], - hop_length=config['hop_length']) + transformer = LJSpeech(config) dataset = TransformDataset(metadata, transformer) dataset = CacheDataset(dataset) @@ -96,18 +91,16 @@ class LJSpeechMetaData(DatasetMixin): class LJSpeech(object): - def __init__(self, - sr=22050, - n_fft=2048, - num_mels=80, - win_length=1024, - hop_length=256): + def __init__(self, cfg): super(LJSpeech, self).__init__() - self.sr = sr - self.n_fft = n_fft - self.num_mels = num_mels - self.win_length = win_length - self.hop_length = hop_length + self.sr = cfg['sr'] + self.n_fft = cfg['n_fft'] + self.num_mels = cfg['num_mels'] + self.win_length = cfg['win_length'] + self.hop_length = cfg['hop_length'] + self.preemphasis = cfg['preemphasis'] + self.fmin = cfg['fmin'] + self.fmax = cfg['fmax'] def __call__(self, metadatum): """All the code for generating an Example from a metadatum. If you want a @@ -125,7 +118,11 @@ class LJSpeech(object): win_length=self.win_length, hop_length=self.hop_length) mag = np.abs(spec) - mel = librosa.filters.mel(self.sr, self.n_fft, n_mels=self.num_mels) + mel = librosa.filters.mel(self.sr, + self.n_fft, + n_mels=self.num_mels, + fmin=self.fmin, + fmax=self.fmax) mel = np.matmul(mel, mag) mel = np.log(np.maximum(mel, 1e-5)) phonemes = np.array( diff --git a/examples/fastspeech/synthesis.py b/examples/fastspeech/synthesis.py index c12742a..96eceb5 100644 --- a/examples/fastspeech/synthesis.py +++ b/examples/fastspeech/synthesis.py @@ -40,7 +40,7 @@ def add_config_options_to_parser(parser): "--vocoder", type=str, default="griffinlim", - choices=['griffinlim', 'clarinet', 'waveflow'], + choices=['griffinlim', 'waveflow'], help="vocoder method") parser.add_argument( "--config_vocoder", type=str, help="path of the vocoder config file") @@ -98,24 +98,13 @@ def synthesis(text_input, args): if args.vocoder == 'griffinlim': #synthesis use griffin-lim - wav = synthesis_with_griffinlim( - mel_output_postnet, - sr=cfg['audio']['sr'], - n_fft=cfg['audio']['n_fft'], - num_mels=cfg['audio']['num_mels'], - power=cfg['audio']['power'], - hop_length=cfg['audio']['hop_length'], - win_length=cfg['audio']['win_length']) - elif args.vocoder == 'clarinet': - # synthesis use clarinet - wav = synthesis_with_clarinet(mel_output_postnet, args.config_vocoder, - args.checkpoint_vocoder, place) + wav = synthesis_with_griffinlim(mel_output_postnet, cfg['audio']) elif args.vocoder == 'waveflow': wav = synthesis_with_waveflow(mel_output_postnet, args, args.checkpoint_vocoder, place) else: print( - 'vocoder error, we only support griffinlim, clarinet and waveflow, but recevied %s.' + 'vocoder error, we only support griffinlim and waveflow, but recevied %s.' % args.vocoder) writer.add_audio(text_input + '(' + args.vocoder + ')', wav, 0, @@ -130,105 +119,34 @@ def synthesis(text_input, args): writer.close() -def synthesis_with_griffinlim(mel_output, sr, n_fft, num_mels, power, - hop_length, win_length): +def synthesis_with_griffinlim(mel_output, cfg): mel_output = fluid.layers.transpose( fluid.layers.squeeze(mel_output, [0]), [1, 0]) mel_output = np.exp(mel_output.numpy()) - basis = librosa.filters.mel(sr, n_fft, num_mels) + basis = librosa.filters.mel(cfg['sr'], + cfg['n_fft'], + cfg['num_mels'], + fmin=cfg['fmin'], + fmax=cfg['fmax']) inv_basis = np.linalg.pinv(basis) spec = np.maximum(1e-10, np.dot(inv_basis, mel_output)) wav = librosa.core.griffinlim( - spec**power, hop_length=hop_length, win_length=win_length) + spec**cfg['power'], + hop_length=cfg['hop_length'], + win_length=cfg['win_length']) return wav -def synthesis_with_clarinet(mel_output, config_path, checkpoint, place): - mel_spectrogram = np.exp(mel_output.numpy()) - with open(config_path, 'rt') as f: - config = yaml.safe_load(f) - - data_config = config["data"] - n_mels = data_config["n_mels"] - - teacher_config = config["teacher"] - n_loop = teacher_config["n_loop"] - n_layer = teacher_config["n_layer"] - filter_size = teacher_config["filter_size"] - - # only batch=1 for validation is enabled - - fluid.enable_dygraph(place) - # conditioner(upsampling net) - conditioner_config = config["conditioner"] - upsampling_factors = conditioner_config["upsampling_factors"] - upsample_net = UpsampleNet(upscale_factors=upsampling_factors) - freeze(upsample_net) - - residual_channels = teacher_config["residual_channels"] - loss_type = teacher_config["loss_type"] - output_dim = teacher_config["output_dim"] - log_scale_min = teacher_config["log_scale_min"] - assert loss_type == "mog" and output_dim == 3, \ - "the teacher wavenet should be a wavenet with single gaussian output" - - teacher = WaveNet(n_loop, n_layer, residual_channels, output_dim, n_mels, - filter_size, loss_type, log_scale_min) - # load & freeze upsample_net & teacher - freeze(teacher) - - student_config = config["student"] - n_loops = student_config["n_loops"] - n_layers = student_config["n_layers"] - student_residual_channels = student_config["residual_channels"] - student_filter_size = student_config["filter_size"] - student_log_scale_min = student_config["log_scale_min"] - student = ParallelWaveNet(n_loops, n_layers, student_residual_channels, - n_mels, student_filter_size) - - stft_config = config["stft"] - stft = STFT( - n_fft=stft_config["n_fft"], - hop_length=stft_config["hop_length"], - win_length=stft_config["win_length"]) - - lmd = config["loss"]["lmd"] - model = Clarinet(upsample_net, teacher, student, stft, - student_log_scale_min, lmd) - io.load_parameters(model=model, checkpoint_path=checkpoint) - - if not os.path.exists(args.output): - os.makedirs(args.output) - model.eval() - - # Rescale mel_spectrogram. - min_level, ref_level = 1e-5, 20 # hard code it - mel_spectrogram = 20 * np.log10(np.maximum(min_level, mel_spectrogram)) - mel_spectrogram = mel_spectrogram - ref_level - mel_spectrogram = np.clip((mel_spectrogram + 100) / 100, 0, 1) - - mel_spectrogram = dg.to_variable(mel_spectrogram) - mel_spectrogram = fluid.layers.transpose(mel_spectrogram, [0, 2, 1]) - - wav_var = model.synthesis(mel_spectrogram) - wav_np = wav_var.numpy()[0] - - return wav_np - - def synthesis_with_waveflow(mel_output, args, checkpoint, place): - #mel_output = np.exp(mel_output.numpy()) - mel_output = mel_output.numpy() fluid.enable_dygraph(place) args.config = args.config_vocoder args.use_fp16 = False config = io.add_yaml_config_to_args(args) - mel_spectrogram = dg.to_variable(mel_output) - mel_spectrogram = fluid.layers.transpose(mel_spectrogram, [0, 2, 1]) + mel_spectrogram = fluid.layers.transpose(mel_output, [0, 2, 1]) # Build model. waveflow = WaveFlowModule(config) @@ -247,5 +165,6 @@ if __name__ == '__main__': add_config_options_to_parser(parser) args = parser.parse_args() pprint(vars(args)) - synthesis("Simple as this proposition is, it is necessary to be stated,", - args) + synthesis( + "Don't argue with the people of strong determination, because they may change the fact!", + args) diff --git a/examples/fastspeech/synthesis.sh b/examples/fastspeech/synthesis.sh index 79a62d0..a94376f 100644 --- a/examples/fastspeech/synthesis.sh +++ b/examples/fastspeech/synthesis.sh @@ -4,15 +4,12 @@ CUDA_VISIBLE_DEVICES=0 \ python -u synthesis.py \ --use_gpu=1 \ --alpha=1.0 \ ---checkpoint='./checkpoint/fastspeech1024/step-160000' \ +--checkpoint='./checkpoint/fastspeech/step-162000' \ --config='configs/ljspeech.yaml' \ --output='./synthesis' \ --vocoder='waveflow' \ ---config_vocoder='../waveflow/checkpoint/waveflow_res64_ljspeech_ckpt_1.0/waveflow_ljspeech.yaml' \ ---checkpoint_vocoder='../waveflow/checkpoint/waveflow_res64_ljspeech_ckpt_1.0/step-3020000' \ -#--vocoder='clarinet' \ -#--config_vocoder='../clarinet/configs/clarinet_ljspeech.yaml' \ -#--checkpoint_vocoder='../clarinet/checkpoint/step-500000' \ +--config_vocoder='../waveflow/checkpoint/waveflow_res128_ljspeech_ckpt_1.0/waveflow_ljspeech.yaml' \ +--checkpoint_vocoder='../waveflow/checkpoint/waveflow_res128_ljspeech_ckpt_1.0/step-2000000' \ diff --git a/examples/fastspeech/train.sh b/examples/fastspeech/train.sh index 96cbb2e..97d5516 100644 --- a/examples/fastspeech/train.sh +++ b/examples/fastspeech/train.sh @@ -3,7 +3,7 @@ export CUDA_VISIBLE_DEVICES=0 python -u train.py \ --use_gpu=1 \ --data='../../dataset/LJSpeech-1.1' \ ---alignments_path='./alignments/alignments.txt' \ +--alignments_path='./alignments/alignments.pkl' \ --output='./experiment' \ --config='configs/ljspeech.yaml' \ #--checkpoint='./checkpoint/fastspeech/step-120000' \ diff --git a/examples/transformer_tts/README.md b/examples/transformer_tts/README.md index 0be870c..b449c6a 100644 --- a/examples/transformer_tts/README.md +++ b/examples/transformer_tts/README.md @@ -56,7 +56,7 @@ TransformerTTS model can be trained by running ``train_transformer.py``. python train_transformer.py \ --use_gpu=1 \ --data=${DATAPATH} \ ---output='./experiment' \ +--output=${OUTPUTPATH} \ --config='configs/ljspeech.yaml' \ ``` @@ -73,7 +73,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog train_transformer.py \ --use_gpu=1 \ --data=${DATAPATH} \ ---output='./experiment' \ +--output=${OUTPUTPATH} \ --config='configs/ljspeech.yaml' \ ``` @@ -85,61 +85,28 @@ For more help on arguments ``python train_transformer.py --help``. -## Train Vocoder - -Vocoder model can be trained by running ``train_vocoder.py``. - -```bash -python train_vocoder.py \ ---use_gpu=1 \ ---data=${DATAPATH} \ ---output='./vocoder' \ ---config='configs/ljspeech.yaml' \ -``` - -Or you can run the script file directly. - -```bash -sh train_vocoder.sh -``` - -If you want to train on multiple GPUs, you must start training in the following way. - -```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 -python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog train_vocoder.py \ ---use_gpu=1 \ ---data=${DATAPATH} \ ---output='./vocoder' \ ---config='configs/ljspeech.yaml' \ -``` - -If you wish to resume from an existing model, See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading. - -For more help on arguments - -``python train_vocoder.py --help``. - ## Synthesis -After training the TransformerTTS and vocoder model, audio can be synthesized by running ``synthesis.py``. +After training the TransformerTTS, audio can be synthesized by running ``synthesis.py``. ```bash python synthesis.py \ ---max_len=300 \ ---use_gpu=1 \ ---output='./synthesis' \ +--use_gpu=0 \ +--output=${OUTPUTPATH} \ --config='configs/ljspeech.yaml' \ ---checkpoint_transformer='./checkpoint/transformer/step-120000' \ ---checkpoint_vocoder='./checkpoint/vocoder/step-100000' \ +--checkpoint_transformer=${CHECKPOINTPATH} \ +--vocoder='griffinlim' \ ``` +We currently support two vocoders, ``griffinlim`` and ``waveflow``. You can set ``--vocoder`` to use one of them. If you want to use ``waveflow`` as your vocoder, you need to set ``--config_vocoder`` and ``--checkpoint_vocoder`` which are the path of the config and checkpoint of vocoder. You can download the pretrain model of ``waveflow`` from [here](https://github.com/PaddlePaddle/Parakeet#vocoders). + Or you can run the script file directly. ```bash sh synthesis.sh ``` - For more help on arguments ``python synthesis.py --help``. + +Then you can find the synthesized audio files in ``${OUTPUTPATH}/samples``. diff --git a/examples/transformer_tts/configs/ljspeech.yaml b/examples/transformer_tts/configs/ljspeech.yaml index f5aabf9..963a230 100644 --- a/examples/transformer_tts/configs/ljspeech.yaml +++ b/examples/transformer_tts/configs/ljspeech.yaml @@ -1,13 +1,13 @@ audio: num_mels: 80 - n_fft: 2048 + n_fft: 1024 sr: 22050 preemphasis: 0.97 - hop_length: 256 #275 - win_length: 1024 #1102 + hop_length: 256 + win_length: 1024 power: 1.2 - min_level_db: -100 - ref_level_db: 20 + fmin: 0 + fmax: 8000 network: hidden_size: 256 @@ -17,7 +17,7 @@ network: decoder_num_head: 4 decoder_n_layers: 3 outputs_per_step: 1 - stop_token: False + stop_loss_weight: 8 vocoder: hidden_size: 256 diff --git a/examples/transformer_tts/data.py b/examples/transformer_tts/data.py index 42be552..acaad60 100644 --- a/examples/transformer_tts/data.py +++ b/examples/transformer_tts/data.py @@ -19,7 +19,6 @@ import csv from paddle import fluid from parakeet import g2p -from parakeet import audio from parakeet.data.sampler import * from parakeet.data.datacargo import DataCargo from parakeet.data.batch import TextIDBatcher, SpecBatcher @@ -98,25 +97,14 @@ class LJSpeech(object): def __init__(self, config): super(LJSpeech, self).__init__() self.config = config - self._ljspeech_processor = audio.AudioProcessor( - sample_rate=config['sr'], - num_mels=config['num_mels'], - min_level_db=config['min_level_db'], - ref_level_db=config['ref_level_db'], - n_fft=config['n_fft'], - win_length=config['win_length'], - hop_length=config['hop_length'], - power=config['power'], - preemphasis=config['preemphasis'], - signal_norm=True, - symmetric_norm=False, - max_norm=1., - mel_fmin=0, - mel_fmax=None, - clip_norm=True, - griffin_lim_iters=60, - do_trim_silence=False, - sound_norm=False) + self.sr = config['sr'] + self.n_mels = config['num_mels'] + self.preemphasis = config['preemphasis'] + self.n_fft = config['n_fft'] + self.win_length = config['win_length'] + self.hop_length = config['hop_length'] + self.fmin = config['fmin'] + self.fmax = config['fmax'] def __call__(self, metadatum): """All the code for generating an Example from a metadatum. If you want a @@ -127,14 +115,26 @@ class LJSpeech(object): """ fname, raw_text, normalized_text = metadatum - # load -> trim -> preemphasis -> stft -> magnitude -> mel_scale -> logscale -> normalize - wav = self._ljspeech_processor.load_wav(str(fname)) - mag = self._ljspeech_processor.spectrogram(wav).astype(np.float32) - mel = self._ljspeech_processor.melspectrogram(wav).astype(np.float32) - phonemes = np.array( + # load + wav, _ = librosa.load(str(fname)) + + spec = librosa.stft( + y=wav, + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length) + mag = np.abs(spec) + mel = librosa.filters.mel(sr=self.sr, + n_fft=self.n_fft, + n_mels=self.n_mels, + fmin=self.fmin, + fmax=self.fmax) + mel = np.matmul(mel, mag) + mel = np.log(np.maximum(mel, 1e-5)) + + characters = np.array( g2p.en.text_to_sequence(normalized_text), dtype=np.int64) - return (mag, mel, phonemes - ) # maybe we need to implement it as a map in the future + return (mag, mel, characters) def batch_examples(batch): @@ -144,6 +144,7 @@ def batch_examples(batch): text_lens = [] pos_texts = [] pos_mels = [] + stop_tokens = [] for data in batch: _, mel, text = data mel_inputs.append( @@ -155,6 +156,8 @@ def batch_examples(batch): pos_mels.append(np.arange(1, mel.shape[1] + 1)) mels.append(mel) texts.append(text) + stop_token = np.append(np.zeros([mel.shape[1] - 1], np.float32), 1.0) + stop_tokens.append(stop_token) # Sort by text_len in descending order texts = [ @@ -182,18 +185,24 @@ def batch_examples(batch): for i, _ in sorted( zip(pos_mels, text_lens), key=lambda x: x[1], reverse=True) ] + stop_tokens = [ + i + for i, _ in sorted( + zip(stop_tokens, text_lens), key=lambda x: x[1], reverse=True) + ] text_lens = sorted(text_lens, reverse=True) # Pad sequence with largest len of the batch texts = TextIDBatcher(pad_id=0)(texts) #(B, T) pos_texts = TextIDBatcher(pad_id=0)(pos_texts) #(B,T) pos_mels = TextIDBatcher(pad_id=0)(pos_mels) #(B,T) + stop_tokens = TextIDBatcher(pad_id=1, dtype=np.float32)(pos_mels) mels = np.transpose( SpecBatcher(pad_value=0.)(mels), axes=(0, 2, 1)) #(B,T,num_mels) mel_inputs = np.transpose( SpecBatcher(pad_value=0.)(mel_inputs), axes=(0, 2, 1)) #(B,T,num_mels) - return (texts, mels, mel_inputs, pos_texts, pos_mels) + return (texts, mels, mel_inputs, pos_texts, pos_mels, stop_tokens) def batch_examples_vocoder(batch): diff --git a/examples/transformer_tts/synthesis.py b/examples/transformer_tts/synthesis.py index d4e17bf..9464638 100644 --- a/examples/transformer_tts/synthesis.py +++ b/examples/transformer_tts/synthesis.py @@ -25,23 +25,25 @@ import paddle.fluid as fluid import paddle.fluid.dygraph as dg from parakeet.g2p.en import text_to_sequence from parakeet.models.transformer_tts.utils import * -from parakeet import audio -from parakeet.models.transformer_tts import Vocoder from parakeet.models.transformer_tts import TransformerTTS -from parakeet.modules import weight_norm from parakeet.models.waveflow import WaveFlowModule from parakeet.modules.weight_norm import WeightNormWrapper -from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet from parakeet.utils import io def add_config_options_to_parser(parser): parser.add_argument("--config", type=str, help="path of the config file") parser.add_argument("--use_gpu", type=int, default=0, help="device to use") + parser.add_argument( + "--stop_threshold", + type=float, + default=0.5, + help="The threshold of stop token which indicates the time step should stop generate spectrum or not." + ) parser.add_argument( "--max_len", type=int, - default=200, + default=1000, help="The max length of audio when synthsis.") parser.add_argument( @@ -52,7 +54,7 @@ def add_config_options_to_parser(parser): "--vocoder", type=str, default="griffinlim", - choices=['griffinlim', 'wavenet', 'waveflow'], + choices=['griffinlim', 'waveflow'], help="vocoder method") parser.add_argument( "--config_vocoder", type=str, help="path of the vocoder config file") @@ -102,13 +104,14 @@ def synthesis(text_input, args): pos_text = fluid.layers.unsqueeze( dg.to_variable(pos_text).astype(np.int64), [0]) - pbar = tqdm(range(args.max_len)) - for i in pbar: + for i in range(args.max_len): pos_mel = np.arange(1, mel_input.shape[1] + 1) pos_mel = fluid.layers.unsqueeze( dg.to_variable(pos_mel).astype(np.int64), [0]) mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model( text, mel_input, pos_text, pos_mel) + if stop_preds.numpy()[0, -1] > args.stop_threshold: + break mel_input = fluid.layers.concat( [mel_input, postnet_pred[:, -1:, :]], axis=1) global_step = 0 @@ -121,40 +124,16 @@ def synthesis(text_input, args): i * 4 + j, dataformats="HWC") - _ljspeech_processor = audio.AudioProcessor( - sample_rate=cfg['audio']['sr'], - num_mels=cfg['audio']['num_mels'], - min_level_db=cfg['audio']['min_level_db'], - ref_level_db=cfg['audio']['ref_level_db'], - n_fft=cfg['audio']['n_fft'], - win_length=cfg['audio']['win_length'], - hop_length=cfg['audio']['hop_length'], - power=cfg['audio']['power'], - preemphasis=cfg['audio']['preemphasis'], - signal_norm=True, - symmetric_norm=False, - max_norm=1., - mel_fmin=0, - mel_fmax=8000, - clip_norm=True, - griffin_lim_iters=60, - do_trim_silence=False, - sound_norm=False) - if args.vocoder == 'griffinlim': #synthesis use griffin-lim - wav = synthesis_with_griffinlim(postnet_pred, _ljspeech_processor) - elif args.vocoder == 'wavenet': - # synthesis use wavenet - wav = synthesis_with_wavenet(postnet_pred, args) + wav = synthesis_with_griffinlim(postnet_pred, cfg['audio']) elif args.vocoder == 'waveflow': # synthesis use waveflow wav = synthesis_with_waveflow(postnet_pred, args, - args.checkpoint_vocoder, - _ljspeech_processor, place) + args.checkpoint_vocoder, place) else: print( - 'vocoder error, we only support griffinlim, cbhg and waveflow, but recevied %s.' + 'vocoder error, we only support griffinlim and waveflow, but recevied %s.' % args.vocoder) writer.add_audio(text_input + '(' + args.vocoder + ')', wav, 0, @@ -169,91 +148,42 @@ def synthesis(text_input, args): writer.close() -def synthesis_with_griffinlim(mel_output, _ljspeech_processor): +def synthesis_with_griffinlim(mel_output, cfg): # synthesis with griffin-lim mel_output = fluid.layers.transpose( fluid.layers.squeeze(mel_output, [0]), [1, 0]) mel_output = np.exp(mel_output.numpy()) - basis = librosa.filters.mel(22050, 1024, 80, fmin=0, fmax=8000) + basis = librosa.filters.mel(cfg['sr'], + cfg['n_fft'], + cfg['num_mels'], + fmin=cfg['fmin'], + fmax=cfg['fmax']) inv_basis = np.linalg.pinv(basis) spec = np.maximum(1e-10, np.dot(inv_basis, mel_output)) - wav = librosa.core.griffinlim(spec**1.2, hop_length=256, win_length=1024) + wav = librosa.core.griffinlim( + spec**cfg['power'], + hop_length=cfg['hop_length'], + win_length=cfg['win_length']) return wav -def synthesis_with_wavenet(mel_output, args): - with open(args.config_vocoder, 'rt') as f: - config = yaml.safe_load(f) - n_mels = config["data"]["n_mels"] - model_config = config["model"] - filter_size = model_config["filter_size"] - upsampling_factors = model_config["upsampling_factors"] - encoder = UpsampleNet(upsampling_factors) - - n_loop = model_config["n_loop"] - n_layer = model_config["n_layer"] - residual_channels = model_config["residual_channels"] - output_dim = model_config["output_dim"] - loss_type = model_config["loss_type"] - log_scale_min = model_config["log_scale_min"] - decoder = WaveNet(n_loop, n_layer, residual_channels, output_dim, n_mels, - filter_size, loss_type, log_scale_min) - - model = ConditionalWavenet(encoder, decoder) - - # load model parameters - iteration = io.load_parameters( - model, checkpoint_path=args.checkpoint_vocoder) - - for layer in model.sublayers(): - if isinstance(layer, WeightNormWrapper): - layer.remove_weight_norm() - mel_output = fluid.layers.transpose(mel_output, [0, 2, 1]) - wav = model.synthesis(mel_output) - return wav.numpy()[0] - - -def synthesis_with_cbhg(mel_output, _ljspeech_processor, cfg): - with fluid.unique_name.guard(): - model_vocoder = Vocoder( - cfg['train']['batch_size'], cfg['vocoder']['hidden_size'], - cfg['audio']['num_mels'], cfg['audio']['n_fft']) - # Load parameters. - global_step = io.load_parameters( - model=model_vocoder, checkpoint_path=args.checkpoint_vocoder) - model_vocoder.eval() - mag_pred = model_vocoder(mel_output) - # synthesis with cbhg - wav = _ljspeech_processor.inv_spectrogram( - fluid.layers.transpose(fluid.layers.squeeze(mag_pred, [0]), [1, 0]) - .numpy()) - return wav - - -def synthesis_with_waveflow(mel_output, args, checkpoint, _ljspeech_processor, - place): - mel_output = fluid.layers.transpose( - fluid.layers.squeeze(mel_output, [0]), [1, 0]) - mel_output = mel_output.numpy() - #mel_output = (mel_output - mel_output.min())/(mel_output.max() - mel_output.min()) - #mel_output = 5 * mel_output - 4 - #mel_output = np.log(10) * mel_output - +def synthesis_with_waveflow(mel_output, args, checkpoint, place): fluid.enable_dygraph(place) args.config = args.config_vocoder args.use_fp16 = False config = io.add_yaml_config_to_args(args) - mel_spectrogram = dg.to_variable(mel_output) + mel_spectrogram = fluid.layers.transpose( + fluid.layers.squeeze(mel_output, [0]), [1, 0]) mel_spectrogram = fluid.layers.unsqueeze(mel_spectrogram, [0]) # Build model. waveflow = WaveFlowModule(config) io.load_parameters(model=waveflow, checkpoint_path=checkpoint) for layer in waveflow.sublayers(): - if isinstance(layer, weight_norm.WeightNormWrapper): + if isinstance(layer, WeightNormWrapper): layer.remove_weight_norm() # Run model inference. @@ -268,5 +198,5 @@ if __name__ == '__main__': # Print the whole config setting. pprint(vars(args)) synthesis( - "Life was like a box of chocolates, you never know what you're gonna get.", + "Life was like a box of chocolates, you never know what you're gonna get.", args) diff --git a/examples/transformer_tts/synthesis.sh b/examples/transformer_tts/synthesis.sh index a282d70..1ceee83 100644 --- a/examples/transformer_tts/synthesis.sh +++ b/examples/transformer_tts/synthesis.sh @@ -2,20 +2,13 @@ # train model CUDA_VISIBLE_DEVICES=0 \ python -u synthesis.py \ ---max_len=400 \ --use_gpu=0 \ --output='./synthesis' \ --config='configs/ljspeech.yaml' \ --checkpoint_transformer='./checkpoint/transformer/step-120000' \ ---vocoder='wavenet' \ ---config_vocoder='../wavenet/config.yaml' \ ---checkpoint_vocoder='../wavenet/step-2450000' \ -#--vocoder='waveflow' \ -#--config_vocoder='../waveflow/checkpoint/waveflow_res64_ljspeech_ckpt_1.0/waveflow_ljspeech.yaml' \ -#--checkpoint_vocoder='../waveflow/checkpoint/waveflow_res64_ljspeech_ckpt_1.0/step-3020000' \ -#--vocoder='cbhg' \ -#--config_vocoder='configs/ljspeech.yaml' \ -#--checkpoint_vocoder='checkpoint/cbhg/step-100000' \ +--vocoder='waveflow' \ +--config_vocoder='../waveflow/checkpoint/waveflow_res128_ljspeech_ckpt_1.0/waveflow_ljspeech.yaml' \ +--checkpoint_vocoder='../waveflow/checkpoint/waveflow_res128_ljspeech_ckpt_1.0/step-2000000' \ if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/transformer_tts/train_transformer.py b/examples/transformer_tts/train_transformer.py index 646176f..299676c 100644 --- a/examples/transformer_tts/train_transformer.py +++ b/examples/transformer_tts/train_transformer.py @@ -115,7 +115,7 @@ def main(args): iterator = iter(tqdm(reader)) batch = next(iterator) - character, mel, mel_input, pos_text, pos_mel = batch + character, mel, mel_input, pos_text, pos_mel, stop_tokens = batch mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model( character, mel_input, pos_text, pos_mel) @@ -126,11 +126,9 @@ def main(args): layers.abs(layers.elementwise_sub(postnet_pred, mel))) loss = mel_loss + post_mel_loss - # Note: When used stop token loss the learning did not work. - if cfg['network']['stop_token']: - label = (pos_mel == 0).astype(np.float32) - stop_loss = cross_entropy(stop_preds, label) - loss = loss + stop_loss + stop_loss = cross_entropy( + stop_preds, stop_tokens, weight=cfg['network']['stop_loss_weight']) + loss = loss + stop_loss if local_rank == 0: writer.add_scalars('training_loss', { @@ -138,8 +136,7 @@ def main(args): 'post_mel_loss': post_mel_loss.numpy() }, global_step) - if cfg['network']['stop_token']: - writer.add_scalar('stop_loss', stop_loss.numpy(), global_step) + writer.add_scalar('stop_loss', stop_loss.numpy(), global_step) if parallel: writer.add_scalars('alphas', { diff --git a/parakeet/models/fastspeech/length_regulator.py b/parakeet/models/fastspeech/length_regulator.py index ecf0327..4c539ff 100644 --- a/parakeet/models/fastspeech/length_regulator.py +++ b/parakeet/models/fastspeech/length_regulator.py @@ -37,13 +37,12 @@ class LengthRegulator(dg.Layer): filter_size=filter_size, dropout=dropout) - def LR(self, x, duration_predictor_output, alpha=1.0): + def LR(self, x, duration_predictor_output): output = [] batch_size = x.shape[0] for i in range(batch_size): output.append( - self.expand(x[i:i + 1], duration_predictor_output[i:i + 1], - alpha)) + self.expand(x[i:i + 1], duration_predictor_output[i:i + 1])) output = self.pad(output) return output @@ -58,7 +57,7 @@ class LengthRegulator(dg.Layer): out_padded = layers.stack(out_list) return out_padded - def expand(self, batch, predicted, alpha): + def expand(self, batch, predicted): out = [] time_steps = batch.shape[1] fertilities = predicted.numpy() @@ -92,8 +91,9 @@ class LengthRegulator(dg.Layer): output = self.LR(x, target) return output, duration_predictor_output else: - duration_predictor_output = layers.round(duration_predictor_output) - output = self.LR(x, duration_predictor_output, alpha) + duration_predictor_output = duration_predictor_output * alpha + duration_predictor_output = layers.ceil(duration_predictor_output) + output = self.LR(x, duration_predictor_output) mel_pos = dg.to_variable(np.arange(1, output.shape[1] + 1)).astype( np.int64) mel_pos = layers.unsqueeze(mel_pos, [0]) diff --git a/parakeet/models/transformer_tts/utils.py b/parakeet/models/transformer_tts/utils.py index 3fa4c63..9482c23 100644 --- a/parakeet/models/transformer_tts/utils.py +++ b/parakeet/models/transformer_tts/utils.py @@ -93,9 +93,9 @@ def guided_attention(N, T, g=0.2): return W -def cross_entropy(input, label, position_weight=1.0, epsilon=1e-30): +def cross_entropy(input, label, weight=1.0, epsilon=1e-30): output = -1 * label * layers.log(input + epsilon) - ( 1 - label) * layers.log(1 - input + epsilon) - output = output * (label * (position_weight - 1) + 1) + output = output * (label * (weight - 1) + 1) - return layers.reduce_sum(output, dim=[0, 1]) + return layers.reduce_mean(output, dim=[0, 1]) From 5b804b70e6956948a8003122e12c989b1a9a3875 Mon Sep 17 00:00:00 2001 From: lifuchen Date: Wed, 24 Jun 2020 02:49:07 +0000 Subject: [PATCH 3/4] modified some config name and default path. --- examples/fastspeech/README.md | 4 ++-- examples/fastspeech/synthesis.py | 10 +++++----- examples/fastspeech/synthesis.sh | 8 ++++---- examples/transformer_tts/README.md | 4 ++-- examples/transformer_tts/synthesis.py | 10 +++++----- examples/transformer_tts/synthesis.sh | 8 ++++---- 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/examples/fastspeech/README.md b/examples/fastspeech/README.md index 865f68a..0c40488 100644 --- a/examples/fastspeech/README.md +++ b/examples/fastspeech/README.md @@ -126,10 +126,10 @@ python synthesis.py \ --checkpoint=${CHECKPOINTPATH} \ --config='configs/ljspeech.yaml' \ --output=${OUTPUTPATH} \ ---vocoder='griffinlim' \ +--vocoder='griffin-lim' \ ``` -We currently support two vocoders, ``griffinlim`` and ``waveflow``. You can set ``--vocoder`` to use one of them. If you want to use ``waveflow`` as your vocoder, you need to set ``--config_vocoder`` and ``--checkpoint_vocoder`` which are the path of the config and checkpoint of vocoder. You can download the pretrain model of ``waveflow`` from [here](https://github.com/PaddlePaddle/Parakeet#vocoders). +We currently support two vocoders, ``Griffin-Lim`` algorithm and ``WaveFlow``. You can set ``--vocoder`` to use one of them. If you want to use ``waveflow`` as your vocoder, you need to set ``--config_vocoder`` and ``--checkpoint_vocoder`` which are the path of the config and checkpoint of vocoder. You can download the pre-trained model of ``waveflow`` from [here](https://github.com/PaddlePaddle/Parakeet#vocoders). Or you can run the script file directly. diff --git a/examples/fastspeech/synthesis.py b/examples/fastspeech/synthesis.py index 96eceb5..dde776f 100644 --- a/examples/fastspeech/synthesis.py +++ b/examples/fastspeech/synthesis.py @@ -39,8 +39,8 @@ def add_config_options_to_parser(parser): parser.add_argument( "--vocoder", type=str, - default="griffinlim", - choices=['griffinlim', 'waveflow'], + default="griffin-lim", + choices=['griffin-lim', 'waveflow'], help="vocoder method") parser.add_argument( "--config_vocoder", type=str, help="path of the vocoder config file") @@ -53,11 +53,11 @@ def add_config_options_to_parser(parser): ) parser.add_argument( - "--checkpoint", type=str, help="fastspeech checkpoint to synthesis") + "--checkpoint", type=str, help="fastspeech checkpoint for synthesis") parser.add_argument( "--checkpoint_vocoder", type=str, - help="vocoder checkpoint to synthesis") + help="vocoder checkpoint for synthesis") parser.add_argument( "--output", @@ -96,7 +96,7 @@ def synthesis(text_input, args): _, mel_output_postnet = model(text, pos_text, alpha=args.alpha) - if args.vocoder == 'griffinlim': + if args.vocoder == 'griffin-lim': #synthesis use griffin-lim wav = synthesis_with_griffinlim(mel_output_postnet, cfg['audio']) elif args.vocoder == 'waveflow': diff --git a/examples/fastspeech/synthesis.sh b/examples/fastspeech/synthesis.sh index a94376f..1ebed1b 100644 --- a/examples/fastspeech/synthesis.sh +++ b/examples/fastspeech/synthesis.sh @@ -4,12 +4,12 @@ CUDA_VISIBLE_DEVICES=0 \ python -u synthesis.py \ --use_gpu=1 \ --alpha=1.0 \ ---checkpoint='./checkpoint/fastspeech/step-162000' \ ---config='configs/ljspeech.yaml' \ +--checkpoint='./fastspeech_ljspeech_ckpt_1.0/fastspeech/step-162000' \ +--config='fastspeech_ljspeech_ckpt_1.0/ljspeech.yaml' \ --output='./synthesis' \ --vocoder='waveflow' \ ---config_vocoder='../waveflow/checkpoint/waveflow_res128_ljspeech_ckpt_1.0/waveflow_ljspeech.yaml' \ ---checkpoint_vocoder='../waveflow/checkpoint/waveflow_res128_ljspeech_ckpt_1.0/step-2000000' \ +--config_vocoder='./waveflow_res128_ljspeech_ckpt_1.0/waveflow_ljspeech.yaml' \ +--checkpoint_vocoder='./waveflow_res128_ljspeech_ckpt_1.0/step-2000000' \ diff --git a/examples/transformer_tts/README.md b/examples/transformer_tts/README.md index b449c6a..e8e0131 100644 --- a/examples/transformer_tts/README.md +++ b/examples/transformer_tts/README.md @@ -95,10 +95,10 @@ python synthesis.py \ --output=${OUTPUTPATH} \ --config='configs/ljspeech.yaml' \ --checkpoint_transformer=${CHECKPOINTPATH} \ ---vocoder='griffinlim' \ +--vocoder='griffin-lim' \ ``` -We currently support two vocoders, ``griffinlim`` and ``waveflow``. You can set ``--vocoder`` to use one of them. If you want to use ``waveflow`` as your vocoder, you need to set ``--config_vocoder`` and ``--checkpoint_vocoder`` which are the path of the config and checkpoint of vocoder. You can download the pretrain model of ``waveflow`` from [here](https://github.com/PaddlePaddle/Parakeet#vocoders). +We currently support two vocoders, ``Griffin-Lim`` algorithm and ``WaveFlow``. You can set ``--vocoder`` to use one of them. If you want to use ``waveflow`` as your vocoder, you need to set ``--config_vocoder`` and ``--checkpoint_vocoder`` which are the path of the config and checkpoint of vocoder. You can download the pre-trained model of ``waveflow`` from [here](https://github.com/PaddlePaddle/Parakeet#vocoders). Or you can run the script file directly. diff --git a/examples/transformer_tts/synthesis.py b/examples/transformer_tts/synthesis.py index 9464638..effbffd 100644 --- a/examples/transformer_tts/synthesis.py +++ b/examples/transformer_tts/synthesis.py @@ -49,19 +49,19 @@ def add_config_options_to_parser(parser): parser.add_argument( "--checkpoint_transformer", type=str, - help="transformer_tts checkpoint to synthesis") + help="transformer_tts checkpoint for synthesis") parser.add_argument( "--vocoder", type=str, - default="griffinlim", - choices=['griffinlim', 'waveflow'], + default="griffin-lim", + choices=['griffin-lim', 'waveflow'], help="vocoder method") parser.add_argument( "--config_vocoder", type=str, help="path of the vocoder config file") parser.add_argument( "--checkpoint_vocoder", type=str, - help="vocoder checkpoint to synthesis") + help="vocoder checkpoint for synthesis") parser.add_argument( "--output", @@ -124,7 +124,7 @@ def synthesis(text_input, args): i * 4 + j, dataformats="HWC") - if args.vocoder == 'griffinlim': + if args.vocoder == 'griffin-lim': #synthesis use griffin-lim wav = synthesis_with_griffinlim(postnet_pred, cfg['audio']) elif args.vocoder == 'waveflow': diff --git a/examples/transformer_tts/synthesis.sh b/examples/transformer_tts/synthesis.sh index 1ceee83..be91cd4 100644 --- a/examples/transformer_tts/synthesis.sh +++ b/examples/transformer_tts/synthesis.sh @@ -4,11 +4,11 @@ CUDA_VISIBLE_DEVICES=0 \ python -u synthesis.py \ --use_gpu=0 \ --output='./synthesis' \ ---config='configs/ljspeech.yaml' \ ---checkpoint_transformer='./checkpoint/transformer/step-120000' \ +--config='transformer_tts_ljspeech_ckpt_1.0/ljspeech.yaml' \ +--checkpoint_transformer='./transformer_tts_ljspeech_ckpt_1.0/step-120000' \ --vocoder='waveflow' \ ---config_vocoder='../waveflow/checkpoint/waveflow_res128_ljspeech_ckpt_1.0/waveflow_ljspeech.yaml' \ ---checkpoint_vocoder='../waveflow/checkpoint/waveflow_res128_ljspeech_ckpt_1.0/step-2000000' \ +--config_vocoder='./waveflow_res128_ljspeech_ckpt_1.0/waveflow_ljspeech.yaml' \ +--checkpoint_vocoder='./waveflow_res128_ljspeech_ckpt_1.0/step-2000000' \ if [ $? -ne 0 ]; then echo "Failed in training!" From a333e64f7963bafcbd0a8dd9ba4d9b5c6ffcab3f Mon Sep 17 00:00:00 2001 From: lifuchen Date: Wed, 24 Jun 2020 03:16:05 +0000 Subject: [PATCH 4/4] modified README of transformer_tts and fastspeech --- examples/fastspeech/README.md | 2 +- examples/transformer_tts/README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fastspeech/README.md b/examples/fastspeech/README.md index 0c40488..08c3cfd 100644 --- a/examples/fastspeech/README.md +++ b/examples/fastspeech/README.md @@ -129,7 +129,7 @@ python synthesis.py \ --vocoder='griffin-lim' \ ``` -We currently support two vocoders, ``Griffin-Lim`` algorithm and ``WaveFlow``. You can set ``--vocoder`` to use one of them. If you want to use ``waveflow`` as your vocoder, you need to set ``--config_vocoder`` and ``--checkpoint_vocoder`` which are the path of the config and checkpoint of vocoder. You can download the pre-trained model of ``waveflow`` from [here](https://github.com/PaddlePaddle/Parakeet#vocoders). +We currently support two vocoders, Griffin-Lim algorithm and WaveFlow. You can set ``--vocoder`` to use one of them. If you want to use WaveFlow as your vocoder, you need to set ``--config_vocoder`` and ``--checkpoint_vocoder`` which are the path of the config and checkpoint of vocoder. You can download the pre-trained model of WaveFlow from [here](https://github.com/PaddlePaddle/Parakeet#vocoders). Or you can run the script file directly. diff --git a/examples/transformer_tts/README.md b/examples/transformer_tts/README.md index e8e0131..f1e73fe 100644 --- a/examples/transformer_tts/README.md +++ b/examples/transformer_tts/README.md @@ -98,7 +98,7 @@ python synthesis.py \ --vocoder='griffin-lim' \ ``` -We currently support two vocoders, ``Griffin-Lim`` algorithm and ``WaveFlow``. You can set ``--vocoder`` to use one of them. If you want to use ``waveflow`` as your vocoder, you need to set ``--config_vocoder`` and ``--checkpoint_vocoder`` which are the path of the config and checkpoint of vocoder. You can download the pre-trained model of ``waveflow`` from [here](https://github.com/PaddlePaddle/Parakeet#vocoders). +We currently support two vocoders, Griffin-Lim algorithm and WaveFlow. You can set ``--vocoder`` to use one of them. If you want to use WaveFlow as your vocoder, you need to set ``--config_vocoder`` and ``--checkpoint_vocoder`` which are the path of the config and checkpoint of vocoder. You can download the pre-trained model of WaveFlow from [here](https://github.com/PaddlePaddle/Parakeet#vocoders). Or you can run the script file directly.