fix missing imports, fix ljspeech.yaml config key: encoder_channels

This commit is contained in:
chenfeiyu 2020-02-13 03:42:20 +00:00 committed by liuyibing01
parent 5beef513af
commit 173693f469
5 changed files with 30 additions and 33 deletions

View File

@ -21,6 +21,7 @@ transform:
# db scale
min_level_db: -100
ref_level_db: 20
clip_norm: true
loss:
@ -48,20 +49,20 @@ model:
embedding_weight_std: 0.1
freeze_embedding: false
padding_idx: 0
encoder_channels: 256
encoder_channels: 512
# decoder
query_position_rate: 1.0
key_position_rate: 1.29
trainable_positional_encodings: false
kernel_size: 3
decoder_channels: 512
decoder_channels: 256
downsample_factor: 4
outputs_per_step: 1
# attention
key_position_rate: true
value_position_rate: true
key_projection: true
value_projection: true
force_monotonic_attention: true
window_backward: -1
window_ahead: 3
@ -88,16 +89,3 @@ train:
snap_interval: 1000
eval_interval: 10000
save_interval: 10000

View File

@ -1,6 +1,6 @@
import os
import argparse
import ruamel.yamls
import ruamel.yaml
import numpy as np
import soundfile as sf
@ -22,6 +22,11 @@ if __name__ == "__main__":
parser.add_argument("checkpoint", type=str, help="checkpoint to load.")
parser.add_argument("text", type=str, help="text file to synthesize")
parser.add_argument("output_path", type=str, help="path to save results")
parser.add_argument("-g",
"--device",
type=int,
default=-1,
help="device to use")
args = parser.parse_args()
with open(args.config, 'rt') as f:
@ -67,7 +72,7 @@ if __name__ == "__main__":
use_memory_mask = model_config["use_memory_mask"]
query_position_rate = model_config["query_position_rate"]
key_position_rate = model_config["key_position_rate"]
window_behind = model_config["window_behind"]
window_backward = model_config["window_backward"]
window_ahead = model_config["window_ahead"]
key_projection = model_config["key_projection"]
value_projection = model_config["value_projection"]
@ -76,11 +81,12 @@ if __name__ == "__main__":
freeze_embedding, filter_size, encoder_channels,
n_mels, decoder_channels, r,
trainable_positional_encodings, use_memory_mask,
query_position_rate, key_position_rate, window_behind,
window_ahead, key_projection, value_projection,
downsample_factor, linear_dim, use_decoder_states,
converter_channels, dropout)
query_position_rate, key_position_rate,
window_backward, window_ahead, key_projection,
value_projection, downsample_factor, linear_dim,
use_decoder_states, converter_channels, dropout)
summary(dv3)
state, _ = dg.load_dygraph(args.checkpoint)
dv3.set_dict(state)

View File

@ -1,6 +1,6 @@
import os
import argparse
import ruamel.yamls
import ruamel.yaml
import numpy as np
from matplotlib import cm
import matplotlib.pyplot as plt
@ -15,10 +15,9 @@ import paddle.fluid.layers as F
import paddle.fluid.dygraph as dg
from parakeet.g2p import en
from parakeet.models.deepvoice3.encoder import ConvSpec
from parakeet.data import FilterDataset, TransformDataset, FilterDataset
from parakeet.data import DataCargo, PartialyRandomizedSimilarTimeLengthSampler, SequentialSampler
from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3
from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3, ConvSpec
from parakeet.models.deepvoice3.loss import TTSLoss
from parakeet.utils.layer_tools import summary
@ -128,7 +127,7 @@ if __name__ == "__main__":
use_memory_mask = model_config["use_memory_mask"]
query_position_rate = model_config["query_position_rate"]
key_position_rate = model_config["key_position_rate"]
window_behind = model_config["window_behind"]
window_backward = model_config["window_backward"]
window_ahead = model_config["window_ahead"]
key_projection = model_config["key_projection"]
value_projection = model_config["value_projection"]
@ -137,10 +136,10 @@ if __name__ == "__main__":
freeze_embedding, filter_size, encoder_channels,
n_mels, decoder_channels, r,
trainable_positional_encodings, use_memory_mask,
query_position_rate, key_position_rate, window_behind,
window_ahead, key_projection, value_projection,
downsample_factor, linear_dim, use_decoder_states,
converter_channels, dropout)
query_position_rate, key_position_rate,
window_backward, window_ahead, key_projection,
value_projection, downsample_factor, linear_dim,
use_decoder_states, converter_channels, dropout)
# =========================loss=========================
loss_config = config["loss"]

View File

@ -0,0 +1,4 @@
from .dataset import *
from .datacargo import *
from .sampler import *
from .batch import *

View File

@ -1,4 +1,4 @@
from parakeet.models.deepvoice3.encoder import Encoder
from parakeet.models.deepvoice3.decoder import Decoder
from parakeet.models.deepvoice3.encoder import Encoder, ConvSpec
from parakeet.models.deepvoice3.decoder import Decoder, WindowRange
from parakeet.models.deepvoice3.converter import Converter
from parakeet.models.deepvoice3.model import DeepVoice3