add to_static export for speedyspeech and pwg, at the cost of making lots of comprimises

This commit is contained in:
chenfeiyu 2021-07-21 16:57:35 +08:00
parent 4ba8e7e342
commit 133294340c
10 changed files with 184 additions and 44 deletions

View File

@ -28,6 +28,8 @@ import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle import distributed as dist
from paddle import jit
from paddle.static import InputSpec
from yacs.config import CfgNode
from parakeet.datasets.data_table import DataTable
@ -62,16 +64,35 @@ def evaluate(args, speedyspeech_config, pwg_config):
mu = paddle.to_tensor(mu)
std = paddle.to_tensor(std)
speedyspeech_normalizer = ZScore(mu, std)
speedyspeech_normalizer.eval()
stat = np.load(args.pwg_stat)
mu, std = stat
mu = paddle.to_tensor(mu)
std = paddle.to_tensor(std)
pwg_normalizer = ZScore(mu, std)
pwg_normalizer.eval()
speedyspeech_inference = SpeedySpeechInference(speedyspeech_normalizer,
model)
speedyspeech_inference.eval()
speedyspeech_inference = jit.to_static(
speedyspeech_inference,
input_spec=[
InputSpec(
[-1], dtype=paddle.int64), InputSpec(
[-1], dtype=paddle.int64)
])
paddle.jit.save(speedyspeech_inference,
os.path.join(args.inference_dir, "speedyspeech"))
speedyspeech_inferencce = SpeedySpeechInference(speedyspeech_normalizer,
model)
pwg_inference = PWGInference(pwg_normalizer, vocoder)
pwg_inference.eval()
pwg_inference = jit.to_static(
pwg_inference,
input_spec=[InputSpec(
[-1, 80], dtype=paddle.float32), ])
paddle.jit.save(pwg_inference, os.path.join(args.inference_dir, "pwg"))
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
@ -82,7 +103,7 @@ def evaluate(args, speedyspeech_config, pwg_config):
tones = paddle.to_tensor(datum["tones"])
with paddle.no_grad():
wav = pwg_inference(speedyspeech_inferencce(phones, tones))
wav = pwg_inference(speedyspeech_inference(phones, tones))
sf.write(
output_dir / (utt_id + ".wav"),
wav.numpy(),
@ -97,7 +118,7 @@ def main():
parser.add_argument(
"--speedyspeech-config",
type=str,
help="config file to overwrite default config")
help="config file for speedyspeech.")
parser.add_argument(
"--speedyspeech-checkpoint",
type=str,
@ -108,10 +129,7 @@ def main():
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
)
parser.add_argument(
"--pwg-config",
type=str,
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
)
"--pwg-config", type=str, help="config file for parallelwavegan.")
parser.add_argument(
"--pwg-params",
type=str,
@ -123,6 +141,8 @@ def main():
)
parser.add_argument("--test-metadata", type=str, help="test metadata")
parser.add_argument("--output-dir", type=str, help="output dir")
parser.add_argument(
"--inference-dir", type=str, help="dir to save inference models")
parser.add_argument(
"--device", type=str, default="gpu", help="device type to use")
parser.add_argument("--verbose", type=int, default=1, help="verbose")

View File

@ -7,4 +7,5 @@ python synthesize.py \
--pwg-stat=../../parallelwave_gan/baker/dump/train/stats.npy \
--test-metadata=dump/test/norm/metadata.jsonl \
--output-dir=exp/debug/test \
--inference-dir=exp/debug/inference \
--device="gpu"

View File

@ -25,6 +25,8 @@ import paddle
import numpy as np
import soundfile as sf
import paddle
from paddle import jit
from paddle.static import InputSpec
from paddle import nn
from paddle.nn import functional as F
from paddle import distributed as dist
@ -72,9 +74,26 @@ def evaluate(args, speedyspeech_config, pwg_config):
std = paddle.to_tensor(std)
pwg_normalizer = ZScore(mu, std)
speedyspeech_inferencce = SpeedySpeechInference(speedyspeech_normalizer,
model)
speedyspeech_inference = SpeedySpeechInference(speedyspeech_normalizer,
model)
speedyspeech_inference.eval()
speedyspeech_inference = jit.to_static(
speedyspeech_inference,
input_spec=[
InputSpec(
[-1], dtype=paddle.int64), InputSpec(
[-1], dtype=paddle.int64)
])
paddle.jit.save(speedyspeech_inference,
os.path.join(args.inference_dir, "speedyspeech"))
pwg_inference = PWGInference(pwg_normalizer, vocoder)
pwg_inference.eval()
pwg_inference = jit.to_static(
pwg_inference,
input_spec=[InputSpec(
[-1, 80], dtype=paddle.float32), ])
paddle.jit.save(pwg_inference, os.path.join(args.inference_dir, "pwg"))
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
@ -83,7 +102,7 @@ def evaluate(args, speedyspeech_config, pwg_config):
phones, tones = text_analysis(sentence)
with paddle.no_grad():
wav = pwg_inference(speedyspeech_inferencce(phones, tones))
wav = pwg_inference(speedyspeech_inference(phones, tones))
sf.write(
output_dir / (utt_id + ".wav"),
wav.numpy(),
@ -98,7 +117,7 @@ def main():
parser.add_argument(
"--speedyspeech-config",
type=str,
help="config file to overwrite default config")
help="config file for speedyspeech.")
parser.add_argument(
"--speedyspeech-checkpoint",
type=str,
@ -109,10 +128,7 @@ def main():
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
)
parser.add_argument(
"--pwg-config",
type=str,
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
)
"--pwg-config", type=str, help="config file for parallelwavegan.")
parser.add_argument(
"--pwg-params",
type=str,
@ -127,6 +143,8 @@ def main():
type=str,
help="text to synthesize, a 'utt_id sentence' pair per line")
parser.add_argument("--output-dir", type=str, help="output dir")
parser.add_argument(
"--inference-dir", type=str, help="dir to save inference models")
parser.add_argument(
"--device", type=str, default="gpu", help="device type to use")
parser.add_argument("--verbose", type=int, default=1, help="verbose")

View File

@ -6,5 +6,6 @@ python synthesize_e2e.py \
--pwg-params=../../parallelwave_gan/baker/converted.pdparams \
--pwg-stat=../../parallelwave_gan/baker/dump/train/stats.npy \
--text=sentences.txt \
--output-dir=exp/e2e \
--output-dir=exp/debug/e2e \
--inference-dir=exp/debug/inference \
--device="gpu"

View File

@ -44,7 +44,7 @@ class Stretch2D(nn.Layer):
self.h_scale = h_scale
self.mode = mode
def forward(self, x: Tensor) -> Tensor:
def forward(self, x):
"""
Parameters
----------
@ -115,7 +115,7 @@ class UpsampleNet(nn.Layer):
nn, nonlinear_activation)(**nonlinear_activation_params)
self.up_layers.append(nonlinear)
def forward(self, c: Tensor) -> Tensor:
def forward(self, c):
"""
Parameters
----------
@ -197,7 +197,7 @@ class ConvInUpsampleNet(nn.Layer):
freq_axis_kernel_size=freq_axis_kernel_size,
use_causal_conv=use_causal_conv)
def forward(self, c: Tensor) -> Tensor:
def forward(self, c):
"""
Parameters
----------
@ -283,7 +283,7 @@ class ResidualBlock(nn.Layer):
self.conv1x1_skip = nn.Conv1D(
gate_out_channels, skip_channels, kernel_size=1, bias_attr=bias)
def forward(self, x: Tensor, c: Tensor) -> Tuple[Tensor, Tensor]:
def forward(self, x, c):
"""
Parameters
----------
@ -439,7 +439,7 @@ class PWGGenerator(nn.Layer):
if use_weight_norm:
self.apply_weight_norm()
def forward(self, x: Tensor, c: Tensor) -> Tensor:
def forward(self, x, c):
"""Generate waveform.
Parameters
@ -492,8 +492,7 @@ class PWGGenerator(nn.Layer):
self.apply(_remove_weight_norm)
def inference(self, c: Optional[Tensor]=None,
x: Optional[Tensor]=None) -> Tensor:
def inference(self, c=None):
"""Waveform generation. This function is used for single instance
inference.
@ -510,17 +509,11 @@ class PWGGenerator(nn.Layer):
Tensor
Shape (T, C_out), the generated waveform
"""
if x is not None:
x = paddle.transpose(x, [1, 0]).unsqueeze(0) # pseudo batch
else:
assert c is not None
x = paddle.randn(
[1, self.in_channels, c.shape[0] * self.upsample_factor])
if c is not None:
c = paddle.transpose(c, [1, 0]).unsqueeze(0) # pseudo batch
c = nn.Pad1D(self.aux_context_window, mode='replicate')(c)
out = self.forward(x, c).squeeze(0).transpose([1, 0])
x = paddle.randn(
[1, self.in_channels, paddle.shape(c)[0] * self.upsample_factor])
c = paddle.transpose(c, [1, 0]).unsqueeze(0) # pseudo batch
c = nn.Pad1D(self.aux_context_window, mode='replicate')(c)
out = self(x, c).squeeze(0).transpose([1, 0])
return out
@ -603,7 +596,7 @@ class PWGDiscriminator(nn.Layer):
if use_weight_norm:
self.apply_weight_norm()
def forward(self, x: Tensor) -> Tensor:
def forward(self, x):
"""
Parameters
----------
@ -730,7 +723,7 @@ class ResidualPWGDiscriminator(nn.Layer):
if use_weight_norm:
self.apply_weight_norm()
def forward(self, x: Tensor) -> Tensor:
def forward(self, x):
"""
Parameters
----------

View File

@ -205,7 +205,19 @@ class SpeedySpeech(nn.Layer):
pred_durations = self.duration_predictor(encodings) # (1, T)
durations_to_expand = paddle.round(pred_durations.exp())
durations_to_expand = (durations_to_expand).astype(paddle.int64)
encodings = expand(encodings, durations_to_expand)
slens = paddle.sum(durations_to_expand, -1) # [1]
t_dec = slens[0] # [1]
t_enc = paddle.shape(pred_durations)[-1]
M = paddle.zeros([1, t_dec, t_enc])
k = paddle.full([1], 0, dtype=paddle.int64)
for j in range(t_enc):
d = durations_to_expand[0, j]
M[0, k:k + d, j] = 1
k += d
encodings = paddle.matmul(M, encodings)
shape = paddle.shape(encodings)
t_dec, feature_size = shape[1], shape[2]

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
@ -23,7 +24,11 @@ class ZScore(nn.Layer):
self.register_buffer("sigma", sigma)
def forward(self, x):
return (x - self.mu) / self.sigma
# NOTE: to be compatible with paddle's to_static, we must explicitly
# call multiply, or add, etc, instead of +-*/, etc.
return paddle.divide(paddle.subtract(x, self.mu), self.sigma)
def inverse(self, x):
return x * self.sigma + self.mu
# NOTE: to be compatible with paddle's to_static, we must explicitly
# call multiply, or add, etc, instead of +-*/, etc.
return paddle.add(paddle.multiply(x, self.sigma), self.mu)

View File

@ -26,10 +26,12 @@ def sinusoid_position_encoding(num_positions: int,
feature_size: int,
omega: float=1.0,
start_pos: int=0,
dtype=None) -> Tensor:
dtype=None) -> paddle.Tensor:
# return tensor shape (num_positions, feature_size)
if (feature_size % 2 != 0):
raise ValueError("size should be divisible by 2")
# NOTE: to be compatible with paddle's to_static, we cannnot raise
# an exception here, take care of it by yourself
# if (feature_size % 2 != 0):
# raise ValueError("size should be divisible by 2")
dtype = dtype or paddle.get_default_dtype()
channel = paddle.arange(0, feature_size, 2, dtype=dtype)

57
tests/test_raise.py Normal file
View File

@ -0,0 +1,57 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle
from paddle import Tensor
from paddle.static import InputSpec
from paddle.nn import functional as F
def sinusoid_position_encoding(num_positions: int,
feature_size: int,
omega: float=1.0,
start_pos: int=0,
dtype=None) -> paddle.Tensor:
# return tensor shape (num_positions, feature_size)
if (feature_size % 2 != 0):
raise ValueError("size should be divisible by 2")
dtype = dtype or paddle.get_default_dtype()
channel = paddle.arange(0, feature_size, 2, dtype=dtype)
index = paddle.arange(start_pos, start_pos + num_positions, 1, dtype=dtype)
p = (paddle.unsqueeze(index, -1) *
omega) / (10000.0**(channel / float(feature_size)))
encodings = paddle.zeros([num_positions, feature_size], dtype=dtype)
encodings[:, 0::2] = paddle.sin(p)
encodings[:, 1::2] = paddle.cos(p)
return encodings
def call_it(x):
shape = paddle.shape(x)
a = shape[0]
b = shape[1]
c = sinusoid_position_encoding(a, b)
return c
call_it(paddle.randn([8, 32]))
m = paddle.jit.to_static(
call_it, input_spec=[InputSpec(
[-1, -1], dtype=paddle.int32)])
m(paddle.randn([8, 32]).astype(paddle.int32))

View File

@ -15,7 +15,8 @@
import math
import paddle
from paddle.jit import to_static
from paddle import nn
from paddle.jit import to_static, save
from paddle.static import InputSpec
@ -32,3 +33,33 @@ def test_applicative_evaluation():
print(x)
print(y)
def test_nested_sequential():
class Net(nn.Layer):
def __init__(self):
super().__init__()
group1 = nn.Sequential(
nn.Linear(2, 3),
nn.Sigmoid(), )
group2 = nn.Sequential(
nn.Sequential(nn.Linear(3, 3)),
nn.Linear(3, 4),
nn.ReLU(), )
self.layers = nn.Sequential(group1, group2)
def forward(self, x):
return self.layers(x)
net = Net()
x = paddle.randn([4, 2])
y = net(x)
print(y)
subgraph = to_static(net, input_spec=[InputSpec([-1, 2])])
paddle.jit.save(subgraph, './temp_test_to_static')
fn = paddle.jit.load('./temp_test_to_static')
y = fn(x)
print(y)