add max_iteration into configuration, fix np.pad for lower versions of numpy

This commit is contained in:
chenfeiyu 2020-08-11 09:12:50 +00:00
parent 610181d4c0
commit 3717ac1342
4 changed files with 5 additions and 6 deletions

View File

@ -87,7 +87,7 @@ runs/Jul07_09-39-34_instance-mqcyj27y-4/
... ...
``` ```
Since e use waveflow to synthesize audio while training, so download the trained waveflow model and extract it in current directory before training. Since we use waveflow to synthesize audio while training, so download the trained waveflow model and extract it in current directory before training.
```bash ```bash
wget https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_res128_ljspeech_ckpt_1.0.zip wget https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_res128_ljspeech_ckpt_1.0.zip

View File

@ -39,6 +39,7 @@ clip_value: 5.0
clip_norm: 100.0 clip_norm: 100.0
# training: # training:
max_iteration: 1000000
batch_size: 16 batch_size: 16
report_interval: 10000 report_interval: 10000
save_interval: 10000 save_interval: 10000

View File

@ -62,10 +62,8 @@ class DataCollector(object):
for example in examples: for example in examples:
text, spec, mel, _ = example text, spec, mel, _ = example
text_seqs.append(en.text_to_sequence(text, self.p_pronunciation)) text_seqs.append(en.text_to_sequence(text, self.p_pronunciation))
# if max_frames - mel.shape[0] < 0: specs.append(np.pad(spec, [(0, max_frames - spec.shape[0]), (0, 0)], mode="constant"))
# import pdb; pdb.set_trace() mels.append(np.pad(mel, [(0, max_frames - mel.shape[0]), (0, 0)], mode="constant"))
specs.append(np.pad(spec, [(0, max_frames - spec.shape[0]), (0, 0)]))
mels.append(np.pad(mel, [(0, max_frames - mel.shape[0]), (0, 0)]))
specs = np.stack(specs) specs = np.stack(specs)
mels = np.stack(mels) mels = np.stack(mels)

View File

@ -81,7 +81,7 @@ def train(args, config):
optim = create_optimizer(model, config) optim = create_optimizer(model, config)
global global_step global global_step
max_iteration = 1000000 max_iteration = config["max_iteration"]
iterator = iter(tqdm.tqdm(train_loader)) iterator = iter(tqdm.tqdm(train_loader))
while global_step <= max_iteration: while global_step <= max_iteration: