Merge branch 'commit' into 'master'
modified data preprocessing and synthesis of transformer_tts and fastspeech See merge request !62
This commit is contained in:
commit
563d3bae74
|
@ -87,7 +87,7 @@ python train.py \
|
||||||
--use_gpu=1 \
|
--use_gpu=1 \
|
||||||
--data=${DATAPATH} \
|
--data=${DATAPATH} \
|
||||||
--alignments_path=${ALIGNMENTS_PATH} \
|
--alignments_path=${ALIGNMENTS_PATH} \
|
||||||
--output='./experiment' \
|
--output=${OUTPUTPATH} \
|
||||||
--config='configs/ljspeech.yaml' \
|
--config='configs/ljspeech.yaml' \
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -105,7 +105,7 @@ python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog tr
|
||||||
--use_gpu=1 \
|
--use_gpu=1 \
|
||||||
--data=${DATAPATH} \
|
--data=${DATAPATH} \
|
||||||
--alignments_path=${ALIGNMENTS_PATH} \
|
--alignments_path=${ALIGNMENTS_PATH} \
|
||||||
--output='./experiment' \
|
--output=${OUTPUTPATH} \
|
||||||
--config='configs/ljspeech.yaml' \
|
--config='configs/ljspeech.yaml' \
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -123,14 +123,13 @@ After training the FastSpeech, audio can be synthesized by running ``synthesis.p
|
||||||
python synthesis.py \
|
python synthesis.py \
|
||||||
--use_gpu=1 \
|
--use_gpu=1 \
|
||||||
--alpha=1.0 \
|
--alpha=1.0 \
|
||||||
--checkpoint='./checkpoint/fastspeech/step-120000' \
|
--checkpoint=${CHECKPOINTPATH} \
|
||||||
--config='configs/ljspeech.yaml' \
|
--config='configs/ljspeech.yaml' \
|
||||||
--config_clarine='../clarinet/configs/config.yaml' \
|
--output=${OUTPUTPATH} \
|
||||||
--checkpoint_clarinet='../clarinet/checkpoint/step-500000' \
|
--vocoder='griffin-lim' \
|
||||||
--output='./synthesis' \
|
|
||||||
```
|
```
|
||||||
|
|
||||||
We use Clarinet to synthesis wav, so it necessary for you to prepare a pre-trained [Clarinet checkpoint](https://paddlespeech.bj.bcebos.com/Parakeet/clarinet_ljspeech_ckpt_1.0.zip).
|
We currently support two vocoders, Griffin-Lim algorithm and WaveFlow. You can set ``--vocoder`` to use one of them. If you want to use WaveFlow as your vocoder, you need to set ``--config_vocoder`` and ``--checkpoint_vocoder`` which are the path of the config and checkpoint of vocoder. You can download the pre-trained model of WaveFlow from [here](https://github.com/PaddlePaddle/Parakeet#vocoders).
|
||||||
|
|
||||||
Or you can run the script file directly.
|
Or you can run the script file directly.
|
||||||
|
|
||||||
|
@ -141,3 +140,5 @@ sh synthesis.sh
|
||||||
For more help on arguments
|
For more help on arguments
|
||||||
|
|
||||||
``python synthesis.py --help``.
|
``python synthesis.py --help``.
|
||||||
|
|
||||||
|
Then you can find the synthesized audio files in ``${OUTPUTPATH}/samples``.
|
||||||
|
|
|
@ -27,7 +27,6 @@ from collections import OrderedDict
|
||||||
import paddle.fluid as fluid
|
import paddle.fluid as fluid
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
from parakeet.models.transformer_tts.utils import *
|
from parakeet.models.transformer_tts.utils import *
|
||||||
from parakeet import audio
|
|
||||||
from parakeet.models.transformer_tts import TransformerTTS
|
from parakeet.models.transformer_tts import TransformerTTS
|
||||||
from parakeet.models.fastspeech.utils import get_alignment
|
from parakeet.models.fastspeech.utils import get_alignment
|
||||||
from parakeet.utils import io
|
from parakeet.utils import io
|
||||||
|
@ -78,25 +77,6 @@ def alignments(args):
|
||||||
header=None,
|
header=None,
|
||||||
quoting=csv.QUOTE_NONE,
|
quoting=csv.QUOTE_NONE,
|
||||||
names=["fname", "raw_text", "normalized_text"])
|
names=["fname", "raw_text", "normalized_text"])
|
||||||
ljspeech_processor = audio.AudioProcessor(
|
|
||||||
sample_rate=cfg['audio']['sr'],
|
|
||||||
num_mels=cfg['audio']['num_mels'],
|
|
||||||
min_level_db=cfg['audio']['min_level_db'],
|
|
||||||
ref_level_db=cfg['audio']['ref_level_db'],
|
|
||||||
n_fft=cfg['audio']['n_fft'],
|
|
||||||
win_length=cfg['audio']['win_length'],
|
|
||||||
hop_length=cfg['audio']['hop_length'],
|
|
||||||
power=cfg['audio']['power'],
|
|
||||||
preemphasis=cfg['audio']['preemphasis'],
|
|
||||||
signal_norm=True,
|
|
||||||
symmetric_norm=False,
|
|
||||||
max_norm=1.,
|
|
||||||
mel_fmin=0,
|
|
||||||
mel_fmax=None,
|
|
||||||
clip_norm=True,
|
|
||||||
griffin_lim_iters=60,
|
|
||||||
do_trim_silence=False,
|
|
||||||
sound_norm=False)
|
|
||||||
|
|
||||||
pbar = tqdm(range(len(table)))
|
pbar = tqdm(range(len(table)))
|
||||||
alignments = OrderedDict()
|
alignments = OrderedDict()
|
||||||
|
@ -107,11 +87,26 @@ def alignments(args):
|
||||||
text = fluid.layers.unsqueeze(dg.to_variable(text), [0])
|
text = fluid.layers.unsqueeze(dg.to_variable(text), [0])
|
||||||
pos_text = np.arange(1, text.shape[1] + 1)
|
pos_text = np.arange(1, text.shape[1] + 1)
|
||||||
pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text), [0])
|
pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text), [0])
|
||||||
wav = ljspeech_processor.load_wav(
|
|
||||||
os.path.join(args.data, 'wavs', fname + ".wav"))
|
# load
|
||||||
mel_input = ljspeech_processor.melspectrogram(wav).astype(
|
wav, _ = librosa.load(
|
||||||
np.float32)
|
str(os.path.join(args.data, 'wavs', fname + ".wav")))
|
||||||
mel_input = np.transpose(mel_input, axes=(1, 0))
|
|
||||||
|
spec = librosa.stft(
|
||||||
|
y=wav,
|
||||||
|
n_fft=cfg['audio']['n_fft'],
|
||||||
|
win_length=cfg['audio']['win_length'],
|
||||||
|
hop_length=cfg['audio']['hop_length'])
|
||||||
|
mag = np.abs(spec)
|
||||||
|
mel = librosa.filters.mel(sr=cfg['audio']['sr'],
|
||||||
|
n_fft=cfg['audio']['n_fft'],
|
||||||
|
n_mels=cfg['audio']['num_mels'],
|
||||||
|
fmin=cfg['audio']['fmin'],
|
||||||
|
fmax=cfg['audio']['fmax'])
|
||||||
|
mel = np.matmul(mel, mag)
|
||||||
|
mel = np.log(np.maximum(mel, 1e-5))
|
||||||
|
|
||||||
|
mel_input = np.transpose(mel, axes=(1, 0))
|
||||||
mel_input = fluid.layers.unsqueeze(dg.to_variable(mel_input), [0])
|
mel_input = fluid.layers.unsqueeze(dg.to_variable(mel_input), [0])
|
||||||
mel_lens = mel_input.shape[1]
|
mel_lens = mel_input.shape[1]
|
||||||
|
|
||||||
|
@ -125,7 +120,7 @@ def alignments(args):
|
||||||
alignment, _ = get_alignment(attn_probs, mel_lens,
|
alignment, _ = get_alignment(attn_probs, mel_lens,
|
||||||
network_cfg['decoder_num_head'])
|
network_cfg['decoder_num_head'])
|
||||||
alignments[fname] = alignment
|
alignments[fname] = alignment
|
||||||
with open(args.output + '.txt', "wb") as f:
|
with open(args.output + '.pkl', "wb") as f:
|
||||||
pickle.dump(alignments, f)
|
pickle.dump(alignments, f)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
audio:
|
audio:
|
||||||
num_mels: 80 #the number of mel bands when calculating mel spectrograms.
|
num_mels: 80 #the number of mel bands when calculating mel spectrograms.
|
||||||
n_fft: 2048 #the number of fft components.
|
n_fft: 1024 #the number of fft components.
|
||||||
sr: 22050 #the sampling rate of audio data file.
|
sr: 22050 #the sampling rate of audio data file.
|
||||||
hop_length: 256 #the number of samples to advance between frames.
|
hop_length: 256 #the number of samples to advance between frames.
|
||||||
win_length: 1024 #the length (width) of the window function.
|
win_length: 1024 #the length (width) of the window function.
|
||||||
|
preemphasis: 0.97
|
||||||
power: 1.2 #the power to raise before griffin-lim.
|
power: 1.2 #the power to raise before griffin-lim.
|
||||||
|
fmin: 0
|
||||||
|
fmax: 8000
|
||||||
|
|
||||||
network:
|
network:
|
||||||
encoder_n_layer: 6 #the number of FFT Block in encoder.
|
encoder_n_layer: 6 #the number of FFT Block in encoder.
|
||||||
|
|
|
@ -42,12 +42,7 @@ class LJSpeechLoader:
|
||||||
|
|
||||||
LJSPEECH_ROOT = Path(data_path)
|
LJSPEECH_ROOT = Path(data_path)
|
||||||
metadata = LJSpeechMetaData(LJSPEECH_ROOT, alignments_path)
|
metadata = LJSpeechMetaData(LJSPEECH_ROOT, alignments_path)
|
||||||
transformer = LJSpeech(
|
transformer = LJSpeech(config)
|
||||||
sr=config['sr'],
|
|
||||||
n_fft=config['n_fft'],
|
|
||||||
num_mels=config['num_mels'],
|
|
||||||
win_length=config['win_length'],
|
|
||||||
hop_length=config['hop_length'])
|
|
||||||
dataset = TransformDataset(metadata, transformer)
|
dataset = TransformDataset(metadata, transformer)
|
||||||
dataset = CacheDataset(dataset)
|
dataset = CacheDataset(dataset)
|
||||||
|
|
||||||
|
@ -96,18 +91,16 @@ class LJSpeechMetaData(DatasetMixin):
|
||||||
|
|
||||||
|
|
||||||
class LJSpeech(object):
|
class LJSpeech(object):
|
||||||
def __init__(self,
|
def __init__(self, cfg):
|
||||||
sr=22050,
|
|
||||||
n_fft=2048,
|
|
||||||
num_mels=80,
|
|
||||||
win_length=1024,
|
|
||||||
hop_length=256):
|
|
||||||
super(LJSpeech, self).__init__()
|
super(LJSpeech, self).__init__()
|
||||||
self.sr = sr
|
self.sr = cfg['sr']
|
||||||
self.n_fft = n_fft
|
self.n_fft = cfg['n_fft']
|
||||||
self.num_mels = num_mels
|
self.num_mels = cfg['num_mels']
|
||||||
self.win_length = win_length
|
self.win_length = cfg['win_length']
|
||||||
self.hop_length = hop_length
|
self.hop_length = cfg['hop_length']
|
||||||
|
self.preemphasis = cfg['preemphasis']
|
||||||
|
self.fmin = cfg['fmin']
|
||||||
|
self.fmax = cfg['fmax']
|
||||||
|
|
||||||
def __call__(self, metadatum):
|
def __call__(self, metadatum):
|
||||||
"""All the code for generating an Example from a metadatum. If you want a
|
"""All the code for generating an Example from a metadatum. If you want a
|
||||||
|
@ -125,7 +118,11 @@ class LJSpeech(object):
|
||||||
win_length=self.win_length,
|
win_length=self.win_length,
|
||||||
hop_length=self.hop_length)
|
hop_length=self.hop_length)
|
||||||
mag = np.abs(spec)
|
mag = np.abs(spec)
|
||||||
mel = librosa.filters.mel(self.sr, self.n_fft, n_mels=self.num_mels)
|
mel = librosa.filters.mel(self.sr,
|
||||||
|
self.n_fft,
|
||||||
|
n_mels=self.num_mels,
|
||||||
|
fmin=self.fmin,
|
||||||
|
fmax=self.fmax)
|
||||||
mel = np.matmul(mel, mag)
|
mel = np.matmul(mel, mag)
|
||||||
mel = np.log(np.maximum(mel, 1e-5))
|
mel = np.log(np.maximum(mel, 1e-5))
|
||||||
phonemes = np.array(
|
phonemes = np.array(
|
||||||
|
|
|
@ -28,6 +28,8 @@ from parakeet.models.fastspeech.fastspeech import FastSpeech
|
||||||
from parakeet.models.transformer_tts.utils import *
|
from parakeet.models.transformer_tts.utils import *
|
||||||
from parakeet.models.wavenet import WaveNet, UpsampleNet
|
from parakeet.models.wavenet import WaveNet, UpsampleNet
|
||||||
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
|
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
|
||||||
|
from parakeet.modules import weight_norm
|
||||||
|
from parakeet.models.waveflow import WaveFlowModule
|
||||||
from parakeet.utils.layer_tools import freeze
|
from parakeet.utils.layer_tools import freeze
|
||||||
from parakeet.utils import io
|
from parakeet.utils import io
|
||||||
|
|
||||||
|
@ -35,7 +37,13 @@ from parakeet.utils import io
|
||||||
def add_config_options_to_parser(parser):
|
def add_config_options_to_parser(parser):
|
||||||
parser.add_argument("--config", type=str, help="path of the config file")
|
parser.add_argument("--config", type=str, help="path of the config file")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config_clarinet", type=str, help="path of the clarinet config file")
|
"--vocoder",
|
||||||
|
type=str,
|
||||||
|
default="griffin-lim",
|
||||||
|
choices=['griffin-lim', 'waveflow'],
|
||||||
|
help="vocoder method")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_vocoder", type=str, help="path of the vocoder config file")
|
||||||
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--alpha",
|
"--alpha",
|
||||||
|
@ -45,11 +53,11 @@ def add_config_options_to_parser(parser):
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--checkpoint", type=str, help="fastspeech checkpoint to synthesis")
|
"--checkpoint", type=str, help="fastspeech checkpoint for synthesis")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--checkpoint_clarinet",
|
"--checkpoint_vocoder",
|
||||||
type=str,
|
type=str,
|
||||||
help="clarinet checkpoint to synthesis")
|
help="vocoder checkpoint for synthesis")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output",
|
"--output",
|
||||||
|
@ -88,110 +96,68 @@ def synthesis(text_input, args):
|
||||||
|
|
||||||
_, mel_output_postnet = model(text, pos_text, alpha=args.alpha)
|
_, mel_output_postnet = model(text, pos_text, alpha=args.alpha)
|
||||||
|
|
||||||
result = np.exp(mel_output_postnet.numpy())
|
if args.vocoder == 'griffin-lim':
|
||||||
mel_output_postnet = fluid.layers.transpose(
|
#synthesis use griffin-lim
|
||||||
fluid.layers.squeeze(mel_output_postnet, [0]), [1, 0])
|
wav = synthesis_with_griffinlim(mel_output_postnet, cfg['audio'])
|
||||||
mel_output_postnet = np.exp(mel_output_postnet.numpy())
|
elif args.vocoder == 'waveflow':
|
||||||
basis = librosa.filters.mel(cfg['audio']['sr'], cfg['audio']['n_fft'],
|
wav = synthesis_with_waveflow(mel_output_postnet, args,
|
||||||
cfg['audio']['num_mels'])
|
args.checkpoint_vocoder, place)
|
||||||
inv_basis = np.linalg.pinv(basis)
|
else:
|
||||||
spec = np.maximum(1e-10, np.dot(inv_basis, mel_output_postnet))
|
print(
|
||||||
|
'vocoder error, we only support griffinlim and waveflow, but recevied %s.'
|
||||||
|
% args.vocoder)
|
||||||
|
|
||||||
# synthesis use clarinet
|
writer.add_audio(text_input + '(' + args.vocoder + ')', wav, 0,
|
||||||
wav_clarinet = synthesis_with_clarinet(
|
|
||||||
args.config_clarinet, args.checkpoint_clarinet, result, place)
|
|
||||||
writer.add_audio(text_input + '(clarinet)', wav_clarinet, 0,
|
|
||||||
cfg['audio']['sr'])
|
cfg['audio']['sr'])
|
||||||
if not os.path.exists(os.path.join(args.output, 'samples')):
|
if not os.path.exists(os.path.join(args.output, 'samples')):
|
||||||
os.mkdir(os.path.join(args.output, 'samples'))
|
os.mkdir(os.path.join(args.output, 'samples'))
|
||||||
write(
|
|
||||||
os.path.join(os.path.join(args.output, 'samples'), 'clarinet.wav'),
|
|
||||||
cfg['audio']['sr'], wav_clarinet)
|
|
||||||
|
|
||||||
#synthesis use griffin-lim
|
|
||||||
wav = librosa.core.griffinlim(
|
|
||||||
spec**cfg['audio']['power'],
|
|
||||||
hop_length=cfg['audio']['hop_length'],
|
|
||||||
win_length=cfg['audio']['win_length'])
|
|
||||||
writer.add_audio(text_input + '(griffin-lim)', wav, 0, cfg['audio']['sr'])
|
|
||||||
write(
|
write(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
os.path.join(args.output, 'samples'), 'grinffin-lim.wav'),
|
os.path.join(args.output, 'samples'), args.vocoder + '.wav'),
|
||||||
cfg['audio']['sr'], wav)
|
cfg['audio']['sr'], wav)
|
||||||
print("Synthesis completed !!!")
|
print("Synthesis completed !!!")
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
|
|
||||||
def synthesis_with_clarinet(config_path, checkpoint, mel_spectrogram, place):
|
def synthesis_with_griffinlim(mel_output, cfg):
|
||||||
with open(config_path, 'rt') as f:
|
mel_output = fluid.layers.transpose(
|
||||||
config = yaml.safe_load(f)
|
fluid.layers.squeeze(mel_output, [0]), [1, 0])
|
||||||
|
mel_output = np.exp(mel_output.numpy())
|
||||||
|
basis = librosa.filters.mel(cfg['sr'],
|
||||||
|
cfg['n_fft'],
|
||||||
|
cfg['num_mels'],
|
||||||
|
fmin=cfg['fmin'],
|
||||||
|
fmax=cfg['fmax'])
|
||||||
|
inv_basis = np.linalg.pinv(basis)
|
||||||
|
spec = np.maximum(1e-10, np.dot(inv_basis, mel_output))
|
||||||
|
|
||||||
data_config = config["data"]
|
wav = librosa.core.griffinlim(
|
||||||
n_mels = data_config["n_mels"]
|
spec**cfg['power'],
|
||||||
|
hop_length=cfg['hop_length'],
|
||||||
|
win_length=cfg['win_length'])
|
||||||
|
|
||||||
teacher_config = config["teacher"]
|
return wav
|
||||||
n_loop = teacher_config["n_loop"]
|
|
||||||
n_layer = teacher_config["n_layer"]
|
|
||||||
filter_size = teacher_config["filter_size"]
|
|
||||||
|
|
||||||
# only batch=1 for validation is enabled
|
|
||||||
|
|
||||||
with dg.guard(place):
|
def synthesis_with_waveflow(mel_output, args, checkpoint, place):
|
||||||
# conditioner(upsampling net)
|
|
||||||
conditioner_config = config["conditioner"]
|
|
||||||
upsampling_factors = conditioner_config["upsampling_factors"]
|
|
||||||
upsample_net = UpsampleNet(upscale_factors=upsampling_factors)
|
|
||||||
freeze(upsample_net)
|
|
||||||
|
|
||||||
residual_channels = teacher_config["residual_channels"]
|
fluid.enable_dygraph(place)
|
||||||
loss_type = teacher_config["loss_type"]
|
args.config = args.config_vocoder
|
||||||
output_dim = teacher_config["output_dim"]
|
args.use_fp16 = False
|
||||||
log_scale_min = teacher_config["log_scale_min"]
|
config = io.add_yaml_config_to_args(args)
|
||||||
assert loss_type == "mog" and output_dim == 3, \
|
|
||||||
"the teacher wavenet should be a wavenet with single gaussian output"
|
|
||||||
|
|
||||||
teacher = WaveNet(n_loop, n_layer, residual_channels, output_dim,
|
mel_spectrogram = fluid.layers.transpose(mel_output, [0, 2, 1])
|
||||||
n_mels, filter_size, loss_type, log_scale_min)
|
|
||||||
# load & freeze upsample_net & teacher
|
|
||||||
freeze(teacher)
|
|
||||||
|
|
||||||
student_config = config["student"]
|
# Build model.
|
||||||
n_loops = student_config["n_loops"]
|
waveflow = WaveFlowModule(config)
|
||||||
n_layers = student_config["n_layers"]
|
io.load_parameters(model=waveflow, checkpoint_path=checkpoint)
|
||||||
student_residual_channels = student_config["residual_channels"]
|
for layer in waveflow.sublayers():
|
||||||
student_filter_size = student_config["filter_size"]
|
if isinstance(layer, weight_norm.WeightNormWrapper):
|
||||||
student_log_scale_min = student_config["log_scale_min"]
|
layer.remove_weight_norm()
|
||||||
student = ParallelWaveNet(n_loops, n_layers, student_residual_channels,
|
|
||||||
n_mels, student_filter_size)
|
|
||||||
|
|
||||||
stft_config = config["stft"]
|
# Run model inference.
|
||||||
stft = STFT(
|
wav = waveflow.synthesize(mel_spectrogram, sigma=config.sigma)
|
||||||
n_fft=stft_config["n_fft"],
|
return wav.numpy()[0]
|
||||||
hop_length=stft_config["hop_length"],
|
|
||||||
win_length=stft_config["win_length"])
|
|
||||||
|
|
||||||
lmd = config["loss"]["lmd"]
|
|
||||||
model = Clarinet(upsample_net, teacher, student, stft,
|
|
||||||
student_log_scale_min, lmd)
|
|
||||||
io.load_parameters(model=model, checkpoint_path=checkpoint)
|
|
||||||
|
|
||||||
if not os.path.exists(args.output):
|
|
||||||
os.makedirs(args.output)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# Rescale mel_spectrogram.
|
|
||||||
min_level, ref_level = 1e-5, 20 # hard code it
|
|
||||||
mel_spectrogram = 20 * np.log10(np.maximum(min_level, mel_spectrogram))
|
|
||||||
mel_spectrogram = mel_spectrogram - ref_level
|
|
||||||
mel_spectrogram = np.clip((mel_spectrogram + 100) / 100, 0, 1)
|
|
||||||
|
|
||||||
mel_spectrogram = dg.to_variable(mel_spectrogram)
|
|
||||||
mel_spectrogram = fluid.layers.transpose(mel_spectrogram, [0, 2, 1])
|
|
||||||
|
|
||||||
wav_var = model.synthesis(mel_spectrogram)
|
|
||||||
wav_np = wav_var.numpy()[0]
|
|
||||||
|
|
||||||
return wav_np
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -199,5 +165,6 @@ if __name__ == '__main__':
|
||||||
add_config_options_to_parser(parser)
|
add_config_options_to_parser(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
pprint(vars(args))
|
pprint(vars(args))
|
||||||
synthesis("Simple as this proposition is, it is necessary to be stated,",
|
synthesis(
|
||||||
args)
|
"Don't argue with the people of strong determination, because they may change the fact!",
|
||||||
|
args)
|
||||||
|
|
|
@ -1,13 +1,17 @@
|
||||||
# train model
|
# train model
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
python -u synthesis.py \
|
python -u synthesis.py \
|
||||||
--use_gpu=1 \
|
--use_gpu=1 \
|
||||||
--alpha=1.0 \
|
--alpha=1.0 \
|
||||||
--checkpoint='./checkpoint/fastspeech/step-120000' \
|
--checkpoint='./fastspeech_ljspeech_ckpt_1.0/fastspeech/step-162000' \
|
||||||
--config='configs/ljspeech.yaml' \
|
--config='fastspeech_ljspeech_ckpt_1.0/ljspeech.yaml' \
|
||||||
--config_clarine='../clarinet/configs/config.yaml' \
|
|
||||||
--checkpoint_clarinet='../clarinet/checkpoint/step-500000' \
|
|
||||||
--output='./synthesis' \
|
--output='./synthesis' \
|
||||||
|
--vocoder='waveflow' \
|
||||||
|
--config_vocoder='./waveflow_res128_ljspeech_ckpt_1.0/waveflow_ljspeech.yaml' \
|
||||||
|
--checkpoint_vocoder='./waveflow_res128_ljspeech_ckpt_1.0/step-2000000' \
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
if [ $? -ne 0 ]; then
|
||||||
echo "Failed in synthesis!"
|
echo "Failed in synthesis!"
|
||||||
|
|
|
@ -3,7 +3,7 @@ export CUDA_VISIBLE_DEVICES=0
|
||||||
python -u train.py \
|
python -u train.py \
|
||||||
--use_gpu=1 \
|
--use_gpu=1 \
|
||||||
--data='../../dataset/LJSpeech-1.1' \
|
--data='../../dataset/LJSpeech-1.1' \
|
||||||
--alignments_path='./alignments/alignments.txt' \
|
--alignments_path='./alignments/alignments.pkl' \
|
||||||
--output='./experiment' \
|
--output='./experiment' \
|
||||||
--config='configs/ljspeech.yaml' \
|
--config='configs/ljspeech.yaml' \
|
||||||
#--checkpoint='./checkpoint/fastspeech/step-120000' \
|
#--checkpoint='./checkpoint/fastspeech/step-120000' \
|
||||||
|
|
|
@ -56,7 +56,7 @@ TransformerTTS model can be trained by running ``train_transformer.py``.
|
||||||
python train_transformer.py \
|
python train_transformer.py \
|
||||||
--use_gpu=1 \
|
--use_gpu=1 \
|
||||||
--data=${DATAPATH} \
|
--data=${DATAPATH} \
|
||||||
--output='./experiment' \
|
--output=${OUTPUTPATH} \
|
||||||
--config='configs/ljspeech.yaml' \
|
--config='configs/ljspeech.yaml' \
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog train_transformer.py \
|
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog train_transformer.py \
|
||||||
--use_gpu=1 \
|
--use_gpu=1 \
|
||||||
--data=${DATAPATH} \
|
--data=${DATAPATH} \
|
||||||
--output='./experiment' \
|
--output=${OUTPUTPATH} \
|
||||||
--config='configs/ljspeech.yaml' \
|
--config='configs/ljspeech.yaml' \
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -85,61 +85,28 @@ For more help on arguments
|
||||||
|
|
||||||
``python train_transformer.py --help``.
|
``python train_transformer.py --help``.
|
||||||
|
|
||||||
## Train Vocoder
|
|
||||||
|
|
||||||
Vocoder model can be trained by running ``train_vocoder.py``.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python train_vocoder.py \
|
|
||||||
--use_gpu=1 \
|
|
||||||
--data=${DATAPATH} \
|
|
||||||
--output='./vocoder' \
|
|
||||||
--config='configs/ljspeech.yaml' \
|
|
||||||
```
|
|
||||||
|
|
||||||
Or you can run the script file directly.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
sh train_vocoder.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
If you want to train on multiple GPUs, you must start training in the following way.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3
|
|
||||||
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog train_vocoder.py \
|
|
||||||
--use_gpu=1 \
|
|
||||||
--data=${DATAPATH} \
|
|
||||||
--output='./vocoder' \
|
|
||||||
--config='configs/ljspeech.yaml' \
|
|
||||||
```
|
|
||||||
|
|
||||||
If you wish to resume from an existing model, See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
|
|
||||||
|
|
||||||
For more help on arguments
|
|
||||||
|
|
||||||
``python train_vocoder.py --help``.
|
|
||||||
|
|
||||||
## Synthesis
|
## Synthesis
|
||||||
|
|
||||||
After training the TransformerTTS and vocoder model, audio can be synthesized by running ``synthesis.py``.
|
After training the TransformerTTS, audio can be synthesized by running ``synthesis.py``.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python synthesis.py \
|
python synthesis.py \
|
||||||
--max_len=300 \
|
--use_gpu=0 \
|
||||||
--use_gpu=1 \
|
--output=${OUTPUTPATH} \
|
||||||
--output='./synthesis' \
|
|
||||||
--config='configs/ljspeech.yaml' \
|
--config='configs/ljspeech.yaml' \
|
||||||
--checkpoint_transformer='./checkpoint/transformer/step-120000' \
|
--checkpoint_transformer=${CHECKPOINTPATH} \
|
||||||
--checkpoint_vocoder='./checkpoint/vocoder/step-100000' \
|
--vocoder='griffin-lim' \
|
||||||
```
|
```
|
||||||
|
|
||||||
|
We currently support two vocoders, Griffin-Lim algorithm and WaveFlow. You can set ``--vocoder`` to use one of them. If you want to use WaveFlow as your vocoder, you need to set ``--config_vocoder`` and ``--checkpoint_vocoder`` which are the path of the config and checkpoint of vocoder. You can download the pre-trained model of WaveFlow from [here](https://github.com/PaddlePaddle/Parakeet#vocoders).
|
||||||
|
|
||||||
Or you can run the script file directly.
|
Or you can run the script file directly.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sh synthesis.sh
|
sh synthesis.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
For more help on arguments
|
For more help on arguments
|
||||||
|
|
||||||
``python synthesis.py --help``.
|
``python synthesis.py --help``.
|
||||||
|
|
||||||
|
Then you can find the synthesized audio files in ``${OUTPUTPATH}/samples``.
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
audio:
|
audio:
|
||||||
num_mels: 80
|
num_mels: 80
|
||||||
n_fft: 2048
|
n_fft: 1024
|
||||||
sr: 22050
|
sr: 22050
|
||||||
preemphasis: 0.97
|
preemphasis: 0.97
|
||||||
hop_length: 256 #275
|
hop_length: 256
|
||||||
win_length: 1024 #1102
|
win_length: 1024
|
||||||
power: 1.2
|
power: 1.2
|
||||||
min_level_db: -100
|
fmin: 0
|
||||||
ref_level_db: 20
|
fmax: 8000
|
||||||
|
|
||||||
network:
|
network:
|
||||||
hidden_size: 256
|
hidden_size: 256
|
||||||
|
@ -17,7 +17,7 @@ network:
|
||||||
decoder_num_head: 4
|
decoder_num_head: 4
|
||||||
decoder_n_layers: 3
|
decoder_n_layers: 3
|
||||||
outputs_per_step: 1
|
outputs_per_step: 1
|
||||||
stop_token: False
|
stop_loss_weight: 8
|
||||||
|
|
||||||
vocoder:
|
vocoder:
|
||||||
hidden_size: 256
|
hidden_size: 256
|
||||||
|
|
|
@ -19,7 +19,6 @@ import csv
|
||||||
|
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
from parakeet import g2p
|
from parakeet import g2p
|
||||||
from parakeet import audio
|
|
||||||
from parakeet.data.sampler import *
|
from parakeet.data.sampler import *
|
||||||
from parakeet.data.datacargo import DataCargo
|
from parakeet.data.datacargo import DataCargo
|
||||||
from parakeet.data.batch import TextIDBatcher, SpecBatcher
|
from parakeet.data.batch import TextIDBatcher, SpecBatcher
|
||||||
|
@ -98,25 +97,14 @@ class LJSpeech(object):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(LJSpeech, self).__init__()
|
super(LJSpeech, self).__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self._ljspeech_processor = audio.AudioProcessor(
|
self.sr = config['sr']
|
||||||
sample_rate=config['sr'],
|
self.n_mels = config['num_mels']
|
||||||
num_mels=config['num_mels'],
|
self.preemphasis = config['preemphasis']
|
||||||
min_level_db=config['min_level_db'],
|
self.n_fft = config['n_fft']
|
||||||
ref_level_db=config['ref_level_db'],
|
self.win_length = config['win_length']
|
||||||
n_fft=config['n_fft'],
|
self.hop_length = config['hop_length']
|
||||||
win_length=config['win_length'],
|
self.fmin = config['fmin']
|
||||||
hop_length=config['hop_length'],
|
self.fmax = config['fmax']
|
||||||
power=config['power'],
|
|
||||||
preemphasis=config['preemphasis'],
|
|
||||||
signal_norm=True,
|
|
||||||
symmetric_norm=False,
|
|
||||||
max_norm=1.,
|
|
||||||
mel_fmin=0,
|
|
||||||
mel_fmax=None,
|
|
||||||
clip_norm=True,
|
|
||||||
griffin_lim_iters=60,
|
|
||||||
do_trim_silence=False,
|
|
||||||
sound_norm=False)
|
|
||||||
|
|
||||||
def __call__(self, metadatum):
|
def __call__(self, metadatum):
|
||||||
"""All the code for generating an Example from a metadatum. If you want a
|
"""All the code for generating an Example from a metadatum. If you want a
|
||||||
|
@ -127,14 +115,26 @@ class LJSpeech(object):
|
||||||
"""
|
"""
|
||||||
fname, raw_text, normalized_text = metadatum
|
fname, raw_text, normalized_text = metadatum
|
||||||
|
|
||||||
# load -> trim -> preemphasis -> stft -> magnitude -> mel_scale -> logscale -> normalize
|
# load
|
||||||
wav = self._ljspeech_processor.load_wav(str(fname))
|
wav, _ = librosa.load(str(fname))
|
||||||
mag = self._ljspeech_processor.spectrogram(wav).astype(np.float32)
|
|
||||||
mel = self._ljspeech_processor.melspectrogram(wav).astype(np.float32)
|
spec = librosa.stft(
|
||||||
phonemes = np.array(
|
y=wav,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
win_length=self.win_length,
|
||||||
|
hop_length=self.hop_length)
|
||||||
|
mag = np.abs(spec)
|
||||||
|
mel = librosa.filters.mel(sr=self.sr,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
fmin=self.fmin,
|
||||||
|
fmax=self.fmax)
|
||||||
|
mel = np.matmul(mel, mag)
|
||||||
|
mel = np.log(np.maximum(mel, 1e-5))
|
||||||
|
|
||||||
|
characters = np.array(
|
||||||
g2p.en.text_to_sequence(normalized_text), dtype=np.int64)
|
g2p.en.text_to_sequence(normalized_text), dtype=np.int64)
|
||||||
return (mag, mel, phonemes
|
return (mag, mel, characters)
|
||||||
) # maybe we need to implement it as a map in the future
|
|
||||||
|
|
||||||
|
|
||||||
def batch_examples(batch):
|
def batch_examples(batch):
|
||||||
|
@ -144,6 +144,7 @@ def batch_examples(batch):
|
||||||
text_lens = []
|
text_lens = []
|
||||||
pos_texts = []
|
pos_texts = []
|
||||||
pos_mels = []
|
pos_mels = []
|
||||||
|
stop_tokens = []
|
||||||
for data in batch:
|
for data in batch:
|
||||||
_, mel, text = data
|
_, mel, text = data
|
||||||
mel_inputs.append(
|
mel_inputs.append(
|
||||||
|
@ -155,6 +156,8 @@ def batch_examples(batch):
|
||||||
pos_mels.append(np.arange(1, mel.shape[1] + 1))
|
pos_mels.append(np.arange(1, mel.shape[1] + 1))
|
||||||
mels.append(mel)
|
mels.append(mel)
|
||||||
texts.append(text)
|
texts.append(text)
|
||||||
|
stop_token = np.append(np.zeros([mel.shape[1] - 1], np.float32), 1.0)
|
||||||
|
stop_tokens.append(stop_token)
|
||||||
|
|
||||||
# Sort by text_len in descending order
|
# Sort by text_len in descending order
|
||||||
texts = [
|
texts = [
|
||||||
|
@ -182,18 +185,24 @@ def batch_examples(batch):
|
||||||
for i, _ in sorted(
|
for i, _ in sorted(
|
||||||
zip(pos_mels, text_lens), key=lambda x: x[1], reverse=True)
|
zip(pos_mels, text_lens), key=lambda x: x[1], reverse=True)
|
||||||
]
|
]
|
||||||
|
stop_tokens = [
|
||||||
|
i
|
||||||
|
for i, _ in sorted(
|
||||||
|
zip(stop_tokens, text_lens), key=lambda x: x[1], reverse=True)
|
||||||
|
]
|
||||||
text_lens = sorted(text_lens, reverse=True)
|
text_lens = sorted(text_lens, reverse=True)
|
||||||
|
|
||||||
# Pad sequence with largest len of the batch
|
# Pad sequence with largest len of the batch
|
||||||
texts = TextIDBatcher(pad_id=0)(texts) #(B, T)
|
texts = TextIDBatcher(pad_id=0)(texts) #(B, T)
|
||||||
pos_texts = TextIDBatcher(pad_id=0)(pos_texts) #(B,T)
|
pos_texts = TextIDBatcher(pad_id=0)(pos_texts) #(B,T)
|
||||||
pos_mels = TextIDBatcher(pad_id=0)(pos_mels) #(B,T)
|
pos_mels = TextIDBatcher(pad_id=0)(pos_mels) #(B,T)
|
||||||
|
stop_tokens = TextIDBatcher(pad_id=1, dtype=np.float32)(pos_mels)
|
||||||
mels = np.transpose(
|
mels = np.transpose(
|
||||||
SpecBatcher(pad_value=0.)(mels), axes=(0, 2, 1)) #(B,T,num_mels)
|
SpecBatcher(pad_value=0.)(mels), axes=(0, 2, 1)) #(B,T,num_mels)
|
||||||
mel_inputs = np.transpose(
|
mel_inputs = np.transpose(
|
||||||
SpecBatcher(pad_value=0.)(mel_inputs), axes=(0, 2, 1)) #(B,T,num_mels)
|
SpecBatcher(pad_value=0.)(mel_inputs), axes=(0, 2, 1)) #(B,T,num_mels)
|
||||||
|
|
||||||
return (texts, mels, mel_inputs, pos_texts, pos_mels)
|
return (texts, mels, mel_inputs, pos_texts, pos_mels, stop_tokens)
|
||||||
|
|
||||||
|
|
||||||
def batch_examples_vocoder(batch):
|
def batch_examples_vocoder(batch):
|
||||||
|
|
|
@ -25,29 +25,43 @@ import paddle.fluid as fluid
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
from parakeet.g2p.en import text_to_sequence
|
from parakeet.g2p.en import text_to_sequence
|
||||||
from parakeet.models.transformer_tts.utils import *
|
from parakeet.models.transformer_tts.utils import *
|
||||||
from parakeet import audio
|
|
||||||
from parakeet.models.transformer_tts import Vocoder
|
|
||||||
from parakeet.models.transformer_tts import TransformerTTS
|
from parakeet.models.transformer_tts import TransformerTTS
|
||||||
|
from parakeet.models.waveflow import WaveFlowModule
|
||||||
|
from parakeet.modules.weight_norm import WeightNormWrapper
|
||||||
from parakeet.utils import io
|
from parakeet.utils import io
|
||||||
|
|
||||||
|
|
||||||
def add_config_options_to_parser(parser):
|
def add_config_options_to_parser(parser):
|
||||||
parser.add_argument("--config", type=str, help="path of the config file")
|
parser.add_argument("--config", type=str, help="path of the config file")
|
||||||
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
||||||
|
parser.add_argument(
|
||||||
|
"--stop_threshold",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="The threshold of stop token which indicates the time step should stop generate spectrum or not."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_len",
|
"--max_len",
|
||||||
type=int,
|
type=int,
|
||||||
default=200,
|
default=1000,
|
||||||
help="The max length of audio when synthsis.")
|
help="The max length of audio when synthsis.")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--checkpoint_transformer",
|
"--checkpoint_transformer",
|
||||||
type=str,
|
type=str,
|
||||||
help="transformer_tts checkpoint to synthesis")
|
help="transformer_tts checkpoint for synthesis")
|
||||||
|
parser.add_argument(
|
||||||
|
"--vocoder",
|
||||||
|
type=str,
|
||||||
|
default="griffin-lim",
|
||||||
|
choices=['griffin-lim', 'waveflow'],
|
||||||
|
help="vocoder method")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_vocoder", type=str, help="path of the vocoder config file")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--checkpoint_vocoder",
|
"--checkpoint_vocoder",
|
||||||
type=str,
|
type=str,
|
||||||
help="vocoder checkpoint to synthesis")
|
help="vocoder checkpoint for synthesis")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output",
|
"--output",
|
||||||
|
@ -82,14 +96,6 @@ def synthesis(text_input, args):
|
||||||
model=model, checkpoint_path=args.checkpoint_transformer)
|
model=model, checkpoint_path=args.checkpoint_transformer)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
with fluid.unique_name.guard():
|
|
||||||
model_vocoder = Vocoder(
|
|
||||||
cfg['train']['batch_size'], cfg['vocoder']['hidden_size'],
|
|
||||||
cfg['audio']['num_mels'], cfg['audio']['n_fft'])
|
|
||||||
# Load parameters.
|
|
||||||
global_step = io.load_parameters(
|
|
||||||
model=model_vocoder, checkpoint_path=args.checkpoint_vocoder)
|
|
||||||
model_vocoder.eval()
|
|
||||||
# init input
|
# init input
|
||||||
text = np.asarray(text_to_sequence(text_input))
|
text = np.asarray(text_to_sequence(text_input))
|
||||||
text = fluid.layers.unsqueeze(dg.to_variable(text).astype(np.int64), [0])
|
text = fluid.layers.unsqueeze(dg.to_variable(text).astype(np.int64), [0])
|
||||||
|
@ -98,42 +104,16 @@ def synthesis(text_input, args):
|
||||||
pos_text = fluid.layers.unsqueeze(
|
pos_text = fluid.layers.unsqueeze(
|
||||||
dg.to_variable(pos_text).astype(np.int64), [0])
|
dg.to_variable(pos_text).astype(np.int64), [0])
|
||||||
|
|
||||||
pbar = tqdm(range(args.max_len))
|
for i in range(args.max_len):
|
||||||
for i in pbar:
|
|
||||||
pos_mel = np.arange(1, mel_input.shape[1] + 1)
|
pos_mel = np.arange(1, mel_input.shape[1] + 1)
|
||||||
pos_mel = fluid.layers.unsqueeze(
|
pos_mel = fluid.layers.unsqueeze(
|
||||||
dg.to_variable(pos_mel).astype(np.int64), [0])
|
dg.to_variable(pos_mel).astype(np.int64), [0])
|
||||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
||||||
text, mel_input, pos_text, pos_mel)
|
text, mel_input, pos_text, pos_mel)
|
||||||
|
if stop_preds.numpy()[0, -1] > args.stop_threshold:
|
||||||
|
break
|
||||||
mel_input = fluid.layers.concat(
|
mel_input = fluid.layers.concat(
|
||||||
[mel_input, postnet_pred[:, -1:, :]], axis=1)
|
[mel_input, postnet_pred[:, -1:, :]], axis=1)
|
||||||
|
|
||||||
mag_pred = model_vocoder(postnet_pred)
|
|
||||||
|
|
||||||
_ljspeech_processor = audio.AudioProcessor(
|
|
||||||
sample_rate=cfg['audio']['sr'],
|
|
||||||
num_mels=cfg['audio']['num_mels'],
|
|
||||||
min_level_db=cfg['audio']['min_level_db'],
|
|
||||||
ref_level_db=cfg['audio']['ref_level_db'],
|
|
||||||
n_fft=cfg['audio']['n_fft'],
|
|
||||||
win_length=cfg['audio']['win_length'],
|
|
||||||
hop_length=cfg['audio']['hop_length'],
|
|
||||||
power=cfg['audio']['power'],
|
|
||||||
preemphasis=cfg['audio']['preemphasis'],
|
|
||||||
signal_norm=True,
|
|
||||||
symmetric_norm=False,
|
|
||||||
max_norm=1.,
|
|
||||||
mel_fmin=0,
|
|
||||||
mel_fmax=None,
|
|
||||||
clip_norm=True,
|
|
||||||
griffin_lim_iters=60,
|
|
||||||
do_trim_silence=False,
|
|
||||||
sound_norm=False)
|
|
||||||
|
|
||||||
# synthesis with cbhg
|
|
||||||
wav = _ljspeech_processor.inv_spectrogram(
|
|
||||||
fluid.layers.transpose(fluid.layers.squeeze(mag_pred, [0]), [1, 0])
|
|
||||||
.numpy())
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
for i, prob in enumerate(attn_probs):
|
for i, prob in enumerate(attn_probs):
|
||||||
for j in range(4):
|
for j in range(4):
|
||||||
|
@ -144,32 +124,79 @@ def synthesis(text_input, args):
|
||||||
i * 4 + j,
|
i * 4 + j,
|
||||||
dataformats="HWC")
|
dataformats="HWC")
|
||||||
|
|
||||||
writer.add_audio(text_input + '(cbhg)', wav, 0, cfg['audio']['sr'])
|
if args.vocoder == 'griffin-lim':
|
||||||
|
#synthesis use griffin-lim
|
||||||
|
wav = synthesis_with_griffinlim(postnet_pred, cfg['audio'])
|
||||||
|
elif args.vocoder == 'waveflow':
|
||||||
|
# synthesis use waveflow
|
||||||
|
wav = synthesis_with_waveflow(postnet_pred, args,
|
||||||
|
args.checkpoint_vocoder, place)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
'vocoder error, we only support griffinlim and waveflow, but recevied %s.'
|
||||||
|
% args.vocoder)
|
||||||
|
|
||||||
|
writer.add_audio(text_input + '(' + args.vocoder + ')', wav, 0,
|
||||||
|
cfg['audio']['sr'])
|
||||||
if not os.path.exists(os.path.join(args.output, 'samples')):
|
if not os.path.exists(os.path.join(args.output, 'samples')):
|
||||||
os.mkdir(os.path.join(args.output, 'samples'))
|
os.mkdir(os.path.join(args.output, 'samples'))
|
||||||
write(
|
write(
|
||||||
os.path.join(os.path.join(args.output, 'samples'), 'cbhg.wav'),
|
os.path.join(
|
||||||
cfg['audio']['sr'], wav)
|
os.path.join(args.output, 'samples'), args.vocoder + '.wav'),
|
||||||
|
|
||||||
# synthesis with griffin-lim
|
|
||||||
wav = _ljspeech_processor.inv_melspectrogram(
|
|
||||||
fluid.layers.transpose(
|
|
||||||
fluid.layers.squeeze(postnet_pred, [0]), [1, 0]).numpy())
|
|
||||||
writer.add_audio(text_input + '(griffin)', wav, 0, cfg['audio']['sr'])
|
|
||||||
|
|
||||||
write(
|
|
||||||
os.path.join(os.path.join(args.output, 'samples'), 'griffin.wav'),
|
|
||||||
cfg['audio']['sr'], wav)
|
cfg['audio']['sr'], wav)
|
||||||
print("Synthesis completed !!!")
|
print("Synthesis completed !!!")
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
|
|
||||||
|
def synthesis_with_griffinlim(mel_output, cfg):
|
||||||
|
# synthesis with griffin-lim
|
||||||
|
mel_output = fluid.layers.transpose(
|
||||||
|
fluid.layers.squeeze(mel_output, [0]), [1, 0])
|
||||||
|
mel_output = np.exp(mel_output.numpy())
|
||||||
|
basis = librosa.filters.mel(cfg['sr'],
|
||||||
|
cfg['n_fft'],
|
||||||
|
cfg['num_mels'],
|
||||||
|
fmin=cfg['fmin'],
|
||||||
|
fmax=cfg['fmax'])
|
||||||
|
inv_basis = np.linalg.pinv(basis)
|
||||||
|
spec = np.maximum(1e-10, np.dot(inv_basis, mel_output))
|
||||||
|
|
||||||
|
wav = librosa.core.griffinlim(
|
||||||
|
spec**cfg['power'],
|
||||||
|
hop_length=cfg['hop_length'],
|
||||||
|
win_length=cfg['win_length'])
|
||||||
|
|
||||||
|
return wav
|
||||||
|
|
||||||
|
|
||||||
|
def synthesis_with_waveflow(mel_output, args, checkpoint, place):
|
||||||
|
fluid.enable_dygraph(place)
|
||||||
|
args.config = args.config_vocoder
|
||||||
|
args.use_fp16 = False
|
||||||
|
config = io.add_yaml_config_to_args(args)
|
||||||
|
|
||||||
|
mel_spectrogram = fluid.layers.transpose(
|
||||||
|
fluid.layers.squeeze(mel_output, [0]), [1, 0])
|
||||||
|
mel_spectrogram = fluid.layers.unsqueeze(mel_spectrogram, [0])
|
||||||
|
|
||||||
|
# Build model.
|
||||||
|
waveflow = WaveFlowModule(config)
|
||||||
|
io.load_parameters(model=waveflow, checkpoint_path=checkpoint)
|
||||||
|
for layer in waveflow.sublayers():
|
||||||
|
if isinstance(layer, WeightNormWrapper):
|
||||||
|
layer.remove_weight_norm()
|
||||||
|
|
||||||
|
# Run model inference.
|
||||||
|
wav = waveflow.synthesize(mel_spectrogram, sigma=config.sigma)
|
||||||
|
return wav.numpy()[0]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description="Synthesis model")
|
parser = argparse.ArgumentParser(description="Synthesis model")
|
||||||
add_config_options_to_parser(parser)
|
add_config_options_to_parser(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# Print the whole config setting.
|
# Print the whole config setting.
|
||||||
pprint(vars(args))
|
pprint(vars(args))
|
||||||
synthesis("Parakeet stands for Paddle PARAllel text-to-speech toolkit.",
|
synthesis(
|
||||||
args)
|
"Life was like a box of chocolates, you never know what you're gonna get.",
|
||||||
|
args)
|
||||||
|
|
|
@ -2,12 +2,13 @@
|
||||||
# train model
|
# train model
|
||||||
CUDA_VISIBLE_DEVICES=0 \
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
python -u synthesis.py \
|
python -u synthesis.py \
|
||||||
--max_len=300 \
|
--use_gpu=0 \
|
||||||
--use_gpu=1 \
|
|
||||||
--output='./synthesis' \
|
--output='./synthesis' \
|
||||||
--config='configs/ljspeech.yaml' \
|
--config='transformer_tts_ljspeech_ckpt_1.0/ljspeech.yaml' \
|
||||||
--checkpoint_transformer='./checkpoint/transformer/step-120000' \
|
--checkpoint_transformer='./transformer_tts_ljspeech_ckpt_1.0/step-120000' \
|
||||||
--checkpoint_vocoder='./checkpoint/vocoder/step-100000' \
|
--vocoder='waveflow' \
|
||||||
|
--config_vocoder='./waveflow_res128_ljspeech_ckpt_1.0/waveflow_ljspeech.yaml' \
|
||||||
|
--checkpoint_vocoder='./waveflow_res128_ljspeech_ckpt_1.0/step-2000000' \
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
if [ $? -ne 0 ]; then
|
||||||
echo "Failed in training!"
|
echo "Failed in training!"
|
||||||
|
|
|
@ -115,7 +115,7 @@ def main(args):
|
||||||
iterator = iter(tqdm(reader))
|
iterator = iter(tqdm(reader))
|
||||||
batch = next(iterator)
|
batch = next(iterator)
|
||||||
|
|
||||||
character, mel, mel_input, pos_text, pos_mel = batch
|
character, mel, mel_input, pos_text, pos_mel, stop_tokens = batch
|
||||||
|
|
||||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
||||||
character, mel_input, pos_text, pos_mel)
|
character, mel_input, pos_text, pos_mel)
|
||||||
|
@ -126,11 +126,9 @@ def main(args):
|
||||||
layers.abs(layers.elementwise_sub(postnet_pred, mel)))
|
layers.abs(layers.elementwise_sub(postnet_pred, mel)))
|
||||||
loss = mel_loss + post_mel_loss
|
loss = mel_loss + post_mel_loss
|
||||||
|
|
||||||
# Note: When used stop token loss the learning did not work.
|
stop_loss = cross_entropy(
|
||||||
if cfg['network']['stop_token']:
|
stop_preds, stop_tokens, weight=cfg['network']['stop_loss_weight'])
|
||||||
label = (pos_mel == 0).astype(np.float32)
|
loss = loss + stop_loss
|
||||||
stop_loss = cross_entropy(stop_preds, label)
|
|
||||||
loss = loss + stop_loss
|
|
||||||
|
|
||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
writer.add_scalars('training_loss', {
|
writer.add_scalars('training_loss', {
|
||||||
|
@ -138,8 +136,7 @@ def main(args):
|
||||||
'post_mel_loss': post_mel_loss.numpy()
|
'post_mel_loss': post_mel_loss.numpy()
|
||||||
}, global_step)
|
}, global_step)
|
||||||
|
|
||||||
if cfg['network']['stop_token']:
|
writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
|
||||||
writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
|
|
||||||
|
|
||||||
if parallel:
|
if parallel:
|
||||||
writer.add_scalars('alphas', {
|
writer.add_scalars('alphas', {
|
||||||
|
|
|
@ -98,7 +98,7 @@ def main(args):
|
||||||
local_rank,
|
local_rank,
|
||||||
is_vocoder=True).reader()
|
is_vocoder=True).reader()
|
||||||
|
|
||||||
for epoch in range(cfg['train']['max_epochs']):
|
for epoch in range(cfg['train']['max_iteration']):
|
||||||
pbar = tqdm(reader)
|
pbar = tqdm(reader)
|
||||||
for i, data in enumerate(pbar):
|
for i, data in enumerate(pbar):
|
||||||
pbar.set_description('Processing at epoch %d' % epoch)
|
pbar.set_description('Processing at epoch %d' % epoch)
|
||||||
|
|
|
@ -37,13 +37,12 @@ class LengthRegulator(dg.Layer):
|
||||||
filter_size=filter_size,
|
filter_size=filter_size,
|
||||||
dropout=dropout)
|
dropout=dropout)
|
||||||
|
|
||||||
def LR(self, x, duration_predictor_output, alpha=1.0):
|
def LR(self, x, duration_predictor_output):
|
||||||
output = []
|
output = []
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
output.append(
|
output.append(
|
||||||
self.expand(x[i:i + 1], duration_predictor_output[i:i + 1],
|
self.expand(x[i:i + 1], duration_predictor_output[i:i + 1]))
|
||||||
alpha))
|
|
||||||
output = self.pad(output)
|
output = self.pad(output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -58,7 +57,7 @@ class LengthRegulator(dg.Layer):
|
||||||
out_padded = layers.stack(out_list)
|
out_padded = layers.stack(out_list)
|
||||||
return out_padded
|
return out_padded
|
||||||
|
|
||||||
def expand(self, batch, predicted, alpha):
|
def expand(self, batch, predicted):
|
||||||
out = []
|
out = []
|
||||||
time_steps = batch.shape[1]
|
time_steps = batch.shape[1]
|
||||||
fertilities = predicted.numpy()
|
fertilities = predicted.numpy()
|
||||||
|
@ -92,8 +91,9 @@ class LengthRegulator(dg.Layer):
|
||||||
output = self.LR(x, target)
|
output = self.LR(x, target)
|
||||||
return output, duration_predictor_output
|
return output, duration_predictor_output
|
||||||
else:
|
else:
|
||||||
duration_predictor_output = layers.round(duration_predictor_output)
|
duration_predictor_output = duration_predictor_output * alpha
|
||||||
output = self.LR(x, duration_predictor_output, alpha)
|
duration_predictor_output = layers.ceil(duration_predictor_output)
|
||||||
|
output = self.LR(x, duration_predictor_output)
|
||||||
mel_pos = dg.to_variable(np.arange(1, output.shape[1] + 1)).astype(
|
mel_pos = dg.to_variable(np.arange(1, output.shape[1] + 1)).astype(
|
||||||
np.int64)
|
np.int64)
|
||||||
mel_pos = layers.unsqueeze(mel_pos, [0])
|
mel_pos = layers.unsqueeze(mel_pos, [0])
|
||||||
|
|
|
@ -93,9 +93,9 @@ def guided_attention(N, T, g=0.2):
|
||||||
return W
|
return W
|
||||||
|
|
||||||
|
|
||||||
def cross_entropy(input, label, position_weight=1.0, epsilon=1e-30):
|
def cross_entropy(input, label, weight=1.0, epsilon=1e-30):
|
||||||
output = -1 * label * layers.log(input + epsilon) - (
|
output = -1 * label * layers.log(input + epsilon) - (
|
||||||
1 - label) * layers.log(1 - input + epsilon)
|
1 - label) * layers.log(1 - input + epsilon)
|
||||||
output = output * (label * (position_weight - 1) + 1)
|
output = output * (label * (weight - 1) + 1)
|
||||||
|
|
||||||
return layers.reduce_sum(output, dim=[0, 1])
|
return layers.reduce_mean(output, dim=[0, 1])
|
||||||
|
|
Loading…
Reference in New Issue