add to_static export for speedyspeech and pwg, at the cost of making lots of comprimises
This commit is contained in:
parent
4ba8e7e342
commit
133294340c
|
@ -28,6 +28,8 @@ import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
from paddle import distributed as dist
|
from paddle import distributed as dist
|
||||||
|
from paddle import jit
|
||||||
|
from paddle.static import InputSpec
|
||||||
from yacs.config import CfgNode
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
from parakeet.datasets.data_table import DataTable
|
from parakeet.datasets.data_table import DataTable
|
||||||
|
@ -62,16 +64,35 @@ def evaluate(args, speedyspeech_config, pwg_config):
|
||||||
mu = paddle.to_tensor(mu)
|
mu = paddle.to_tensor(mu)
|
||||||
std = paddle.to_tensor(std)
|
std = paddle.to_tensor(std)
|
||||||
speedyspeech_normalizer = ZScore(mu, std)
|
speedyspeech_normalizer = ZScore(mu, std)
|
||||||
|
speedyspeech_normalizer.eval()
|
||||||
|
|
||||||
stat = np.load(args.pwg_stat)
|
stat = np.load(args.pwg_stat)
|
||||||
mu, std = stat
|
mu, std = stat
|
||||||
mu = paddle.to_tensor(mu)
|
mu = paddle.to_tensor(mu)
|
||||||
std = paddle.to_tensor(std)
|
std = paddle.to_tensor(std)
|
||||||
pwg_normalizer = ZScore(mu, std)
|
pwg_normalizer = ZScore(mu, std)
|
||||||
|
pwg_normalizer.eval()
|
||||||
|
|
||||||
speedyspeech_inferencce = SpeedySpeechInference(speedyspeech_normalizer,
|
speedyspeech_inference = SpeedySpeechInference(speedyspeech_normalizer,
|
||||||
model)
|
model)
|
||||||
|
speedyspeech_inference.eval()
|
||||||
|
speedyspeech_inference = jit.to_static(
|
||||||
|
speedyspeech_inference,
|
||||||
|
input_spec=[
|
||||||
|
InputSpec(
|
||||||
|
[-1], dtype=paddle.int64), InputSpec(
|
||||||
|
[-1], dtype=paddle.int64)
|
||||||
|
])
|
||||||
|
paddle.jit.save(speedyspeech_inference,
|
||||||
|
os.path.join(args.inference_dir, "speedyspeech"))
|
||||||
|
|
||||||
pwg_inference = PWGInference(pwg_normalizer, vocoder)
|
pwg_inference = PWGInference(pwg_normalizer, vocoder)
|
||||||
|
pwg_inference.eval()
|
||||||
|
pwg_inference = jit.to_static(
|
||||||
|
pwg_inference,
|
||||||
|
input_spec=[InputSpec(
|
||||||
|
[-1, 80], dtype=paddle.float32), ])
|
||||||
|
paddle.jit.save(pwg_inference, os.path.join(args.inference_dir, "pwg"))
|
||||||
|
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -82,7 +103,7 @@ def evaluate(args, speedyspeech_config, pwg_config):
|
||||||
tones = paddle.to_tensor(datum["tones"])
|
tones = paddle.to_tensor(datum["tones"])
|
||||||
|
|
||||||
with paddle.no_grad():
|
with paddle.no_grad():
|
||||||
wav = pwg_inference(speedyspeech_inferencce(phones, tones))
|
wav = pwg_inference(speedyspeech_inference(phones, tones))
|
||||||
sf.write(
|
sf.write(
|
||||||
output_dir / (utt_id + ".wav"),
|
output_dir / (utt_id + ".wav"),
|
||||||
wav.numpy(),
|
wav.numpy(),
|
||||||
|
@ -97,7 +118,7 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speedyspeech-config",
|
"--speedyspeech-config",
|
||||||
type=str,
|
type=str,
|
||||||
help="config file to overwrite default config")
|
help="config file for speedyspeech.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speedyspeech-checkpoint",
|
"--speedyspeech-checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -108,10 +129,7 @@ def main():
|
||||||
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
|
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pwg-config",
|
"--pwg-config", type=str, help="config file for parallelwavegan.")
|
||||||
type=str,
|
|
||||||
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pwg-params",
|
"--pwg-params",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -123,6 +141,8 @@ def main():
|
||||||
)
|
)
|
||||||
parser.add_argument("--test-metadata", type=str, help="test metadata")
|
parser.add_argument("--test-metadata", type=str, help="test metadata")
|
||||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||||
|
parser.add_argument(
|
||||||
|
"--inference-dir", type=str, help="dir to save inference models")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device", type=str, default="gpu", help="device type to use")
|
"--device", type=str, default="gpu", help="device type to use")
|
||||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||||
|
|
|
@ -7,4 +7,5 @@ python synthesize.py \
|
||||||
--pwg-stat=../../parallelwave_gan/baker/dump/train/stats.npy \
|
--pwg-stat=../../parallelwave_gan/baker/dump/train/stats.npy \
|
||||||
--test-metadata=dump/test/norm/metadata.jsonl \
|
--test-metadata=dump/test/norm/metadata.jsonl \
|
||||||
--output-dir=exp/debug/test \
|
--output-dir=exp/debug/test \
|
||||||
|
--inference-dir=exp/debug/inference \
|
||||||
--device="gpu"
|
--device="gpu"
|
||||||
|
|
|
@ -25,6 +25,8 @@ import paddle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import paddle
|
import paddle
|
||||||
|
from paddle import jit
|
||||||
|
from paddle.static import InputSpec
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
from paddle import distributed as dist
|
from paddle import distributed as dist
|
||||||
|
@ -72,9 +74,26 @@ def evaluate(args, speedyspeech_config, pwg_config):
|
||||||
std = paddle.to_tensor(std)
|
std = paddle.to_tensor(std)
|
||||||
pwg_normalizer = ZScore(mu, std)
|
pwg_normalizer = ZScore(mu, std)
|
||||||
|
|
||||||
speedyspeech_inferencce = SpeedySpeechInference(speedyspeech_normalizer,
|
speedyspeech_inference = SpeedySpeechInference(speedyspeech_normalizer,
|
||||||
model)
|
model)
|
||||||
|
speedyspeech_inference.eval()
|
||||||
|
speedyspeech_inference = jit.to_static(
|
||||||
|
speedyspeech_inference,
|
||||||
|
input_spec=[
|
||||||
|
InputSpec(
|
||||||
|
[-1], dtype=paddle.int64), InputSpec(
|
||||||
|
[-1], dtype=paddle.int64)
|
||||||
|
])
|
||||||
|
paddle.jit.save(speedyspeech_inference,
|
||||||
|
os.path.join(args.inference_dir, "speedyspeech"))
|
||||||
|
|
||||||
pwg_inference = PWGInference(pwg_normalizer, vocoder)
|
pwg_inference = PWGInference(pwg_normalizer, vocoder)
|
||||||
|
pwg_inference.eval()
|
||||||
|
pwg_inference = jit.to_static(
|
||||||
|
pwg_inference,
|
||||||
|
input_spec=[InputSpec(
|
||||||
|
[-1, 80], dtype=paddle.float32), ])
|
||||||
|
paddle.jit.save(pwg_inference, os.path.join(args.inference_dir, "pwg"))
|
||||||
|
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -83,7 +102,7 @@ def evaluate(args, speedyspeech_config, pwg_config):
|
||||||
phones, tones = text_analysis(sentence)
|
phones, tones = text_analysis(sentence)
|
||||||
|
|
||||||
with paddle.no_grad():
|
with paddle.no_grad():
|
||||||
wav = pwg_inference(speedyspeech_inferencce(phones, tones))
|
wav = pwg_inference(speedyspeech_inference(phones, tones))
|
||||||
sf.write(
|
sf.write(
|
||||||
output_dir / (utt_id + ".wav"),
|
output_dir / (utt_id + ".wav"),
|
||||||
wav.numpy(),
|
wav.numpy(),
|
||||||
|
@ -98,7 +117,7 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speedyspeech-config",
|
"--speedyspeech-config",
|
||||||
type=str,
|
type=str,
|
||||||
help="config file to overwrite default config")
|
help="config file for speedyspeech.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speedyspeech-checkpoint",
|
"--speedyspeech-checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -109,10 +128,7 @@ def main():
|
||||||
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
|
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pwg-config",
|
"--pwg-config", type=str, help="config file for parallelwavegan.")
|
||||||
type=str,
|
|
||||||
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pwg-params",
|
"--pwg-params",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -127,6 +143,8 @@ def main():
|
||||||
type=str,
|
type=str,
|
||||||
help="text to synthesize, a 'utt_id sentence' pair per line")
|
help="text to synthesize, a 'utt_id sentence' pair per line")
|
||||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||||
|
parser.add_argument(
|
||||||
|
"--inference-dir", type=str, help="dir to save inference models")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device", type=str, default="gpu", help="device type to use")
|
"--device", type=str, default="gpu", help="device type to use")
|
||||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||||
|
|
|
@ -6,5 +6,6 @@ python synthesize_e2e.py \
|
||||||
--pwg-params=../../parallelwave_gan/baker/converted.pdparams \
|
--pwg-params=../../parallelwave_gan/baker/converted.pdparams \
|
||||||
--pwg-stat=../../parallelwave_gan/baker/dump/train/stats.npy \
|
--pwg-stat=../../parallelwave_gan/baker/dump/train/stats.npy \
|
||||||
--text=sentences.txt \
|
--text=sentences.txt \
|
||||||
--output-dir=exp/e2e \
|
--output-dir=exp/debug/e2e \
|
||||||
|
--inference-dir=exp/debug/inference \
|
||||||
--device="gpu"
|
--device="gpu"
|
||||||
|
|
|
@ -44,7 +44,7 @@ class Stretch2D(nn.Layer):
|
||||||
self.h_scale = h_scale
|
self.h_scale = h_scale
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -115,7 +115,7 @@ class UpsampleNet(nn.Layer):
|
||||||
nn, nonlinear_activation)(**nonlinear_activation_params)
|
nn, nonlinear_activation)(**nonlinear_activation_params)
|
||||||
self.up_layers.append(nonlinear)
|
self.up_layers.append(nonlinear)
|
||||||
|
|
||||||
def forward(self, c: Tensor) -> Tensor:
|
def forward(self, c):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -197,7 +197,7 @@ class ConvInUpsampleNet(nn.Layer):
|
||||||
freq_axis_kernel_size=freq_axis_kernel_size,
|
freq_axis_kernel_size=freq_axis_kernel_size,
|
||||||
use_causal_conv=use_causal_conv)
|
use_causal_conv=use_causal_conv)
|
||||||
|
|
||||||
def forward(self, c: Tensor) -> Tensor:
|
def forward(self, c):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -283,7 +283,7 @@ class ResidualBlock(nn.Layer):
|
||||||
self.conv1x1_skip = nn.Conv1D(
|
self.conv1x1_skip = nn.Conv1D(
|
||||||
gate_out_channels, skip_channels, kernel_size=1, bias_attr=bias)
|
gate_out_channels, skip_channels, kernel_size=1, bias_attr=bias)
|
||||||
|
|
||||||
def forward(self, x: Tensor, c: Tensor) -> Tuple[Tensor, Tensor]:
|
def forward(self, x, c):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -439,7 +439,7 @@ class PWGGenerator(nn.Layer):
|
||||||
if use_weight_norm:
|
if use_weight_norm:
|
||||||
self.apply_weight_norm()
|
self.apply_weight_norm()
|
||||||
|
|
||||||
def forward(self, x: Tensor, c: Tensor) -> Tensor:
|
def forward(self, x, c):
|
||||||
"""Generate waveform.
|
"""Generate waveform.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -492,8 +492,7 @@ class PWGGenerator(nn.Layer):
|
||||||
|
|
||||||
self.apply(_remove_weight_norm)
|
self.apply(_remove_weight_norm)
|
||||||
|
|
||||||
def inference(self, c: Optional[Tensor]=None,
|
def inference(self, c=None):
|
||||||
x: Optional[Tensor]=None) -> Tensor:
|
|
||||||
"""Waveform generation. This function is used for single instance
|
"""Waveform generation. This function is used for single instance
|
||||||
inference.
|
inference.
|
||||||
|
|
||||||
|
@ -510,17 +509,11 @@ class PWGGenerator(nn.Layer):
|
||||||
Tensor
|
Tensor
|
||||||
Shape (T, C_out), the generated waveform
|
Shape (T, C_out), the generated waveform
|
||||||
"""
|
"""
|
||||||
if x is not None:
|
|
||||||
x = paddle.transpose(x, [1, 0]).unsqueeze(0) # pseudo batch
|
|
||||||
else:
|
|
||||||
assert c is not None
|
|
||||||
x = paddle.randn(
|
x = paddle.randn(
|
||||||
[1, self.in_channels, c.shape[0] * self.upsample_factor])
|
[1, self.in_channels, paddle.shape(c)[0] * self.upsample_factor])
|
||||||
|
|
||||||
if c is not None:
|
|
||||||
c = paddle.transpose(c, [1, 0]).unsqueeze(0) # pseudo batch
|
c = paddle.transpose(c, [1, 0]).unsqueeze(0) # pseudo batch
|
||||||
c = nn.Pad1D(self.aux_context_window, mode='replicate')(c)
|
c = nn.Pad1D(self.aux_context_window, mode='replicate')(c)
|
||||||
out = self.forward(x, c).squeeze(0).transpose([1, 0])
|
out = self(x, c).squeeze(0).transpose([1, 0])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -603,7 +596,7 @@ class PWGDiscriminator(nn.Layer):
|
||||||
if use_weight_norm:
|
if use_weight_norm:
|
||||||
self.apply_weight_norm()
|
self.apply_weight_norm()
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -730,7 +723,7 @@ class ResidualPWGDiscriminator(nn.Layer):
|
||||||
if use_weight_norm:
|
if use_weight_norm:
|
||||||
self.apply_weight_norm()
|
self.apply_weight_norm()
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
|
|
@ -205,7 +205,19 @@ class SpeedySpeech(nn.Layer):
|
||||||
pred_durations = self.duration_predictor(encodings) # (1, T)
|
pred_durations = self.duration_predictor(encodings) # (1, T)
|
||||||
durations_to_expand = paddle.round(pred_durations.exp())
|
durations_to_expand = paddle.round(pred_durations.exp())
|
||||||
durations_to_expand = (durations_to_expand).astype(paddle.int64)
|
durations_to_expand = (durations_to_expand).astype(paddle.int64)
|
||||||
encodings = expand(encodings, durations_to_expand)
|
|
||||||
|
slens = paddle.sum(durations_to_expand, -1) # [1]
|
||||||
|
t_dec = slens[0] # [1]
|
||||||
|
t_enc = paddle.shape(pred_durations)[-1]
|
||||||
|
M = paddle.zeros([1, t_dec, t_enc])
|
||||||
|
|
||||||
|
k = paddle.full([1], 0, dtype=paddle.int64)
|
||||||
|
for j in range(t_enc):
|
||||||
|
d = durations_to_expand[0, j]
|
||||||
|
M[0, k:k + d, j] = 1
|
||||||
|
k += d
|
||||||
|
|
||||||
|
encodings = paddle.matmul(M, encodings)
|
||||||
|
|
||||||
shape = paddle.shape(encodings)
|
shape = paddle.shape(encodings)
|
||||||
t_dec, feature_size = shape[1], shape[2]
|
t_dec, feature_size = shape[1], shape[2]
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,7 +24,11 @@ class ZScore(nn.Layer):
|
||||||
self.register_buffer("sigma", sigma)
|
self.register_buffer("sigma", sigma)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return (x - self.mu) / self.sigma
|
# NOTE: to be compatible with paddle's to_static, we must explicitly
|
||||||
|
# call multiply, or add, etc, instead of +-*/, etc.
|
||||||
|
return paddle.divide(paddle.subtract(x, self.mu), self.sigma)
|
||||||
|
|
||||||
def inverse(self, x):
|
def inverse(self, x):
|
||||||
return x * self.sigma + self.mu
|
# NOTE: to be compatible with paddle's to_static, we must explicitly
|
||||||
|
# call multiply, or add, etc, instead of +-*/, etc.
|
||||||
|
return paddle.add(paddle.multiply(x, self.sigma), self.mu)
|
||||||
|
|
|
@ -26,10 +26,12 @@ def sinusoid_position_encoding(num_positions: int,
|
||||||
feature_size: int,
|
feature_size: int,
|
||||||
omega: float=1.0,
|
omega: float=1.0,
|
||||||
start_pos: int=0,
|
start_pos: int=0,
|
||||||
dtype=None) -> Tensor:
|
dtype=None) -> paddle.Tensor:
|
||||||
# return tensor shape (num_positions, feature_size)
|
# return tensor shape (num_positions, feature_size)
|
||||||
if (feature_size % 2 != 0):
|
# NOTE: to be compatible with paddle's to_static, we cannnot raise
|
||||||
raise ValueError("size should be divisible by 2")
|
# an exception here, take care of it by yourself
|
||||||
|
# if (feature_size % 2 != 0):
|
||||||
|
# raise ValueError("size should be divisible by 2")
|
||||||
dtype = dtype or paddle.get_default_dtype()
|
dtype = dtype or paddle.get_default_dtype()
|
||||||
|
|
||||||
channel = paddle.arange(0, feature_size, 2, dtype=dtype)
|
channel = paddle.arange(0, feature_size, 2, dtype=dtype)
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import Tensor
|
||||||
|
from paddle.static import InputSpec
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def sinusoid_position_encoding(num_positions: int,
|
||||||
|
feature_size: int,
|
||||||
|
omega: float=1.0,
|
||||||
|
start_pos: int=0,
|
||||||
|
dtype=None) -> paddle.Tensor:
|
||||||
|
# return tensor shape (num_positions, feature_size)
|
||||||
|
|
||||||
|
if (feature_size % 2 != 0):
|
||||||
|
raise ValueError("size should be divisible by 2")
|
||||||
|
dtype = dtype or paddle.get_default_dtype()
|
||||||
|
|
||||||
|
channel = paddle.arange(0, feature_size, 2, dtype=dtype)
|
||||||
|
index = paddle.arange(start_pos, start_pos + num_positions, 1, dtype=dtype)
|
||||||
|
p = (paddle.unsqueeze(index, -1) *
|
||||||
|
omega) / (10000.0**(channel / float(feature_size)))
|
||||||
|
encodings = paddle.zeros([num_positions, feature_size], dtype=dtype)
|
||||||
|
encodings[:, 0::2] = paddle.sin(p)
|
||||||
|
encodings[:, 1::2] = paddle.cos(p)
|
||||||
|
return encodings
|
||||||
|
|
||||||
|
|
||||||
|
def call_it(x):
|
||||||
|
shape = paddle.shape(x)
|
||||||
|
a = shape[0]
|
||||||
|
b = shape[1]
|
||||||
|
c = sinusoid_position_encoding(a, b)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
call_it(paddle.randn([8, 32]))
|
||||||
|
m = paddle.jit.to_static(
|
||||||
|
call_it, input_spec=[InputSpec(
|
||||||
|
[-1, -1], dtype=paddle.int32)])
|
||||||
|
m(paddle.randn([8, 32]).astype(paddle.int32))
|
|
@ -15,7 +15,8 @@
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from paddle.jit import to_static
|
from paddle import nn
|
||||||
|
from paddle.jit import to_static, save
|
||||||
from paddle.static import InputSpec
|
from paddle.static import InputSpec
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,3 +33,33 @@ def test_applicative_evaluation():
|
||||||
|
|
||||||
print(x)
|
print(x)
|
||||||
print(y)
|
print(y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_sequential():
|
||||||
|
class Net(nn.Layer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
group1 = nn.Sequential(
|
||||||
|
nn.Linear(2, 3),
|
||||||
|
nn.Sigmoid(), )
|
||||||
|
group2 = nn.Sequential(
|
||||||
|
nn.Sequential(nn.Linear(3, 3)),
|
||||||
|
nn.Linear(3, 4),
|
||||||
|
nn.ReLU(), )
|
||||||
|
self.layers = nn.Sequential(group1, group2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
x = paddle.randn([4, 2])
|
||||||
|
y = net(x)
|
||||||
|
print(y)
|
||||||
|
|
||||||
|
subgraph = to_static(net, input_spec=[InputSpec([-1, 2])])
|
||||||
|
paddle.jit.save(subgraph, './temp_test_to_static')
|
||||||
|
|
||||||
|
fn = paddle.jit.load('./temp_test_to_static')
|
||||||
|
y = fn(x)
|
||||||
|
|
||||||
|
print(y)
|
||||||
|
|
Loading…
Reference in New Issue