clean code: remove deprecated modules
This commit is contained in:
parent
5270774bb0
commit
53d0382fc7
|
@ -13,5 +13,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
from .dataset import *
|
||||
from .datasets import *
|
||||
from .sampler import *
|
||||
from .batch import *
|
||||
|
|
|
@ -1,272 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle import fluid
|
||||
import paddle.fluid.layers as F
|
||||
import paddle.fluid.dygraph as dg
|
||||
|
||||
|
||||
class Pool1D(dg.Layer):
|
||||
"""
|
||||
A Pool 1D block implemented with Pool2D.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pool_size=-1,
|
||||
pool_type='max',
|
||||
pool_stride=1,
|
||||
pool_padding=0,
|
||||
global_pooling=False,
|
||||
use_cudnn=True,
|
||||
ceil_mode=False,
|
||||
exclusive=True,
|
||||
data_format='NCT'):
|
||||
super(Pool1D, self).__init__()
|
||||
self.pool_size = pool_size
|
||||
self.pool_type = pool_type
|
||||
self.pool_stride = pool_stride
|
||||
self.pool_padding = pool_padding
|
||||
self.global_pooling = global_pooling
|
||||
self.use_cudnn = use_cudnn
|
||||
self.ceil_mode = ceil_mode
|
||||
self.exclusive = exclusive
|
||||
self.data_format = data_format
|
||||
|
||||
self.pool2d = dg.Pool2D(
|
||||
[1, pool_size],
|
||||
pool_type=pool_type,
|
||||
pool_stride=[1, pool_stride],
|
||||
pool_padding=[0, pool_padding],
|
||||
global_pooling=global_pooling,
|
||||
use_cudnn=use_cudnn,
|
||||
ceil_mode=ceil_mode,
|
||||
exclusive=exclusive)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Variable): Shape(B, C_in, 1, T), the input, where C_in means
|
||||
input channels.
|
||||
Returns:
|
||||
x (Variable): Shape(B, C_out, 1, T), the outputs, where C_out means
|
||||
output channels (num_filters).
|
||||
"""
|
||||
if self.data_format == 'NTC':
|
||||
x = fluid.layers.transpose(x, [0, 2, 1])
|
||||
x = fluid.layers.unsqueeze(x, [2])
|
||||
x = self.pool2d(x)
|
||||
x = fluid.layers.squeeze(x, [2])
|
||||
if self.data_format == 'NTC':
|
||||
x = fluid.layers.transpose(x, [0, 2, 1])
|
||||
return x
|
||||
|
||||
|
||||
class Conv1D(dg.Conv2D):
|
||||
"""A standard Conv1D layer that use (B, C, T) data layout. It inherit Conv2D and
|
||||
use (B, C, 1, T) data layout to compute 1D convolution. Nothing more.
|
||||
NOTE: we inherit Conv2D instead of encapsulate a Conv2D layer to make it a simple
|
||||
layer, instead of a complex one. So we can easily apply weight norm to it.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
param_attr=None,
|
||||
bias_attr=None,
|
||||
use_cudnn=True,
|
||||
act=None,
|
||||
dtype='float32'):
|
||||
super(Conv1D, self).__init__(
|
||||
num_channels,
|
||||
num_filters, (1, filter_size),
|
||||
stride=(1, stride),
|
||||
padding=(0, padding),
|
||||
dilation=(1, dilation),
|
||||
groups=groups,
|
||||
param_attr=param_attr,
|
||||
bias_attr=bias_attr,
|
||||
use_cudnn=use_cudnn,
|
||||
act=act,
|
||||
dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute Conv1D by unsqueeze the input and squeeze the output.
|
||||
|
||||
Args:
|
||||
x (Variable): shape(B, C_in, T_in), dtype float32, input of Conv1D.
|
||||
|
||||
Returns:
|
||||
Variable: shape(B, C_out, T_out), dtype float32, output of Conv1D.
|
||||
"""
|
||||
x = F.unsqueeze(x, [2])
|
||||
x = super(Conv1D, self).forward(x) # maybe risky here
|
||||
x = F.squeeze(x, [2])
|
||||
return x
|
||||
|
||||
|
||||
class Conv1DTranspose(dg.Conv2DTranspose):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
padding=0,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
param_attr=None,
|
||||
bias_attr=None,
|
||||
use_cudnn=True,
|
||||
act=None,
|
||||
dtype='float32'):
|
||||
super(Conv1DTranspose, self).__init__(
|
||||
num_channels,
|
||||
num_filters, (1, filter_size),
|
||||
output_size=None,
|
||||
padding=(0, padding),
|
||||
stride=(1, stride),
|
||||
dilation=(1, dilation),
|
||||
groups=groups,
|
||||
param_attr=param_attr,
|
||||
bias_attr=bias_attr,
|
||||
use_cudnn=use_cudnn,
|
||||
act=act,
|
||||
dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute Conv1DTranspose by unsqueeze the input and squeeze the output.
|
||||
|
||||
Args:
|
||||
x (Variable): shape(B, C_in, T_in), dtype float32, input of Conv1DTranspose.
|
||||
|
||||
Returns:
|
||||
Variable: shape(B, C_out, T_out), dtype float32, output of Conv1DTranspose.
|
||||
"""
|
||||
x = F.unsqueeze(x, [2])
|
||||
x = super(Conv1DTranspose, self).forward(x) # maybe risky here
|
||||
x = F.squeeze(x, [2])
|
||||
return x
|
||||
|
||||
|
||||
class Conv1DCell(Conv1D):
|
||||
"""A causal convolve-1d cell. It uses causal padding, padding(receptive_field -1, 0).
|
||||
But Conv2D in dygraph does not support asymmetric padding yet, we just pad
|
||||
(receptive_field -1, receptive_field -1) and drop last receptive_field -1 steps in
|
||||
the output.
|
||||
|
||||
It is a cell that it acts like an RNN cell. It does not support stride > 1, and it
|
||||
ensures 1-to-1 mapping from input time steps to output timesteps.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
dilation=1,
|
||||
causal=False,
|
||||
groups=1,
|
||||
param_attr=None,
|
||||
bias_attr=None,
|
||||
use_cudnn=True,
|
||||
act=None,
|
||||
dtype='float32'):
|
||||
receptive_field = 1 + dilation * (filter_size - 1)
|
||||
padding = receptive_field - 1 if causal else receptive_field // 2
|
||||
self._receptive_field = receptive_field
|
||||
self.causal = causal
|
||||
super(Conv1DCell, self).__init__(
|
||||
num_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
param_attr=param_attr,
|
||||
bias_attr=bias_attr,
|
||||
use_cudnn=use_cudnn,
|
||||
act=act,
|
||||
dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute Conv1D by unsqueeze the input and squeeze the output.
|
||||
|
||||
Args:
|
||||
x (Variable): shape(B, C_in, T), dtype float32, input of Conv1D.
|
||||
|
||||
Returns:
|
||||
Variable: shape(B, C_out, T), dtype float32, output of Conv1D.
|
||||
"""
|
||||
# it ensures that ouput time steps == input time steps
|
||||
time_steps = x.shape[-1]
|
||||
x = super(Conv1DCell, self).forward(x)
|
||||
if x.shape[-1] != time_steps:
|
||||
x = x[:, :, :time_steps]
|
||||
return x
|
||||
|
||||
@property
|
||||
def receptive_field(self):
|
||||
return self._receptive_field
|
||||
|
||||
def start_sequence(self):
|
||||
"""Prepare the Conv1DCell to generate a new sequence, this method should be called before calling add_input multiple times.
|
||||
|
||||
WARNING:
|
||||
This method accesses `self.weight` directly. If a `Conv1DCell` object is wrapped in a `WeightNormWrapper`, make sure this method is called only after the `WeightNormWrapper`'s hook is called.
|
||||
`WeightNormWrapper` removes the wrapped layer's `weight`, add has a `weight_v` and `weight_g` to re-compute the wrapped layer's weight as $weight = weight_g * weight_v / ||weight_v||$. (Recomputing the `weight` is a hook before calling the wrapped layer's `forward` method.)
|
||||
Whenever a `WeightNormWrapper`'s `forward` method is called, the wrapped layer's weight is updated. But when loading from a checkpoint, `weight_v` and `weight_g` are updated but the wrapped layer's weight is not, since it is no longer a `Parameter`. You should manually call `remove_weight_norm` or `hook` to re-compute the wrapped layer's weight before calling this method if you don't call `forward` first.
|
||||
So when loading a model which uses `Conv1DCell` objects wrapped in `WeightNormWrapper`s, remember to call `remove_weight_norm` for all `WeightNormWrapper`s before synthesizing. Also, removing weight norm speeds up computation.
|
||||
"""
|
||||
if not self.causal:
|
||||
raise ValueError(
|
||||
"Only causal conv1d shell should use start sequence")
|
||||
if self.receptive_field == 1:
|
||||
raise ValueError(
|
||||
"Convolution block with receptive field = 1 does not need"
|
||||
" to be implemented as a Conv1DCell. Conv1D suffices")
|
||||
self._buffer = None
|
||||
self._reshaped_weight = F.reshape(self.weight, (self._num_filters, -1))
|
||||
|
||||
def add_input(self, x_t):
|
||||
"""This method works similarily with forward but in a `step-in-step-out` fashion.
|
||||
|
||||
Args:
|
||||
x (Variable): shape(B, C_in, T=1), dtype float32, input of Conv1D.
|
||||
|
||||
Returns:
|
||||
Variable: shape(B, C_out, T=1), dtype float32, output of Conv1D.
|
||||
"""
|
||||
batch_size, c_in, _ = x_t.shape
|
||||
if self._buffer is None:
|
||||
self._buffer = F.zeros(
|
||||
(batch_size, c_in, self.receptive_field), dtype=x_t.dtype)
|
||||
self._buffer = F.concat([self._buffer[:, :, 1:], x_t], -1)
|
||||
if self._dilation[1] > 1:
|
||||
input = F.strided_slice(
|
||||
self._buffer,
|
||||
axes=[2],
|
||||
starts=[0],
|
||||
ends=[self.receptive_field],
|
||||
strides=[self._dilation[1]])
|
||||
else:
|
||||
input = self._buffer
|
||||
input = F.reshape(input, (batch_size, -1))
|
||||
y_t = F.matmul(input, self._reshaped_weight, transpose_y=True)
|
||||
y_t = y_t + self.bias
|
||||
y_t = F.unsqueeze(y_t, [-1])
|
||||
return y_t
|
|
@ -1,64 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid.layers as layers
|
||||
|
||||
|
||||
class DynamicGRU(dg.Layer):
|
||||
def __init__(self,
|
||||
size,
|
||||
param_attr=None,
|
||||
bias_attr=None,
|
||||
is_reverse=False,
|
||||
gate_activation='sigmoid',
|
||||
candidate_activation='tanh',
|
||||
h_0=None,
|
||||
origin_mode=False,
|
||||
init_size=None):
|
||||
super(DynamicGRU, self).__init__()
|
||||
self.gru_unit = dg.GRUUnit(
|
||||
size * 3,
|
||||
param_attr=param_attr,
|
||||
bias_attr=bias_attr,
|
||||
activation=candidate_activation,
|
||||
gate_activation=gate_activation,
|
||||
origin_mode=origin_mode)
|
||||
self.size = size
|
||||
self.h_0 = h_0
|
||||
self.is_reverse = is_reverse
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
Dynamic GRU block.
|
||||
|
||||
Args:
|
||||
input (Variable): shape(B, T, C), dtype float32, the input value.
|
||||
|
||||
Returns:
|
||||
output (Variable): shape(B, T, C), the result compute by GRU.
|
||||
"""
|
||||
hidden = self.h_0
|
||||
res = []
|
||||
for i in range(inputs.shape[1]):
|
||||
if self.is_reverse:
|
||||
i = inputs.shape[1] - 1 - i
|
||||
input_ = inputs[:, i:i + 1, :]
|
||||
input_ = layers.reshape(input_, [-1, input_.shape[2]])
|
||||
hidden, reset, gate = self.gru_unit(input_, hidden)
|
||||
hidden_ = layers.reshape(hidden, [-1, 1, hidden.shape[1]])
|
||||
res.append(hidden_)
|
||||
if self.is_reverse:
|
||||
res = res[::-1]
|
||||
res = layers.concat(res, axis=1)
|
||||
return res
|
|
@ -1,93 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.fluid as fluid
|
||||
import math
|
||||
from parakeet.modules.customized import Conv1D
|
||||
|
||||
|
||||
class PositionwiseFeedForward(dg.Layer):
|
||||
def __init__(self,
|
||||
d_in,
|
||||
num_hidden,
|
||||
filter_size,
|
||||
padding=0,
|
||||
use_cudnn=True,
|
||||
dropout=0.1):
|
||||
"""A two-feed-forward-layer module.
|
||||
|
||||
Args:
|
||||
d_in (int): the size of input channel.
|
||||
num_hidden (int): the size of hidden layer in network.
|
||||
filter_size (int): the filter size of Conv
|
||||
padding (int, optional): the padding size of Conv. Defaults to 0.
|
||||
use_cudnn (bool, optional): use cudnn in Conv or not. Defaults to True.
|
||||
dropout (float, optional): dropout probability. Defaults to 0.1.
|
||||
"""
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.num_hidden = num_hidden
|
||||
self.use_cudnn = use_cudnn
|
||||
self.dropout = dropout
|
||||
|
||||
k = math.sqrt(1.0 / d_in)
|
||||
self.w_1 = Conv1D(
|
||||
num_channels=d_in,
|
||||
num_filters=num_hidden,
|
||||
filter_size=filter_size,
|
||||
padding=padding,
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.XavierInitializer()),
|
||||
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
|
||||
low=-k, high=k)),
|
||||
use_cudnn=use_cudnn)
|
||||
k = math.sqrt(1.0 / num_hidden)
|
||||
self.w_2 = Conv1D(
|
||||
num_channels=num_hidden,
|
||||
num_filters=d_in,
|
||||
filter_size=filter_size,
|
||||
padding=padding,
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.XavierInitializer()),
|
||||
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
|
||||
low=-k, high=k)),
|
||||
use_cudnn=use_cudnn)
|
||||
self.layer_norm = dg.LayerNorm(d_in)
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
Compute feed forward network result.
|
||||
|
||||
Args:
|
||||
input (Variable): shape(B, T, C), dtype float32, the input value.
|
||||
|
||||
Returns:
|
||||
output (Variable): shape(B, T, C), the result after FFN.
|
||||
"""
|
||||
x = layers.transpose(input, [0, 2, 1])
|
||||
#FFN Networt
|
||||
x = self.w_2(layers.relu(self.w_1(x)))
|
||||
|
||||
# dropout
|
||||
x = layers.dropout(
|
||||
x, self.dropout, dropout_implementation='upscale_in_train')
|
||||
|
||||
x = layers.transpose(x, [0, 2, 1])
|
||||
# residual connection
|
||||
x = x + input
|
||||
|
||||
#layer normalization
|
||||
output = self.layer_norm(x)
|
||||
|
||||
return output
|
|
@ -1,158 +0,0 @@
|
|||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from numba import jit
|
||||
|
||||
from paddle import fluid
|
||||
import paddle.fluid.dygraph as dg
|
||||
|
||||
|
||||
def masked_mean(inputs, mask):
|
||||
"""
|
||||
Args:
|
||||
inputs (Variable): Shape(B, C, 1, T), the input, where B means
|
||||
batch size, C means channels of input, T means timesteps of
|
||||
the input.
|
||||
mask (Variable): Shape(B, T), a mask.
|
||||
Returns:
|
||||
loss (Variable): Shape(1, ), masked mean.
|
||||
"""
|
||||
channels = inputs.shape[1]
|
||||
reshaped_mask = fluid.layers.reshape(
|
||||
mask, shape=[mask.shape[0], 1, 1, mask.shape[-1]])
|
||||
expanded_mask = fluid.layers.expand(
|
||||
reshaped_mask, expand_times=[1, channels, 1, 1])
|
||||
expanded_mask.stop_gradient = True
|
||||
|
||||
valid_cnt = fluid.layers.reduce_sum(expanded_mask)
|
||||
valid_cnt.stop_gradient = True
|
||||
|
||||
masked_inputs = inputs * expanded_mask
|
||||
loss = fluid.layers.reduce_sum(masked_inputs) / valid_cnt
|
||||
return loss
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def guided_attention(N, max_N, T, max_T, g):
|
||||
W = np.zeros((max_N, max_T), dtype=np.float32)
|
||||
for n in range(N):
|
||||
for t in range(T):
|
||||
W[n, t] = 1 - np.exp(-(n / N - t / T)**2 / (2 * g * g))
|
||||
return W
|
||||
|
||||
|
||||
def guided_attentions(input_lengths, target_lengths, max_target_len, g=0.2):
|
||||
B = len(input_lengths)
|
||||
max_input_len = input_lengths.max()
|
||||
W = np.zeros((B, max_target_len, max_input_len), dtype=np.float32)
|
||||
for b in range(B):
|
||||
W[b] = guided_attention(input_lengths[b], max_input_len,
|
||||
target_lengths[b], max_target_len, g).T
|
||||
return W
|
||||
|
||||
|
||||
class TTSLoss(object):
|
||||
def __init__(self,
|
||||
masked_weight=0.0,
|
||||
priority_weight=0.0,
|
||||
binary_divergence_weight=0.0,
|
||||
guided_attention_sigma=0.2):
|
||||
self.masked_weight = masked_weight
|
||||
self.priority_weight = priority_weight
|
||||
self.binary_divergence_weight = binary_divergence_weight
|
||||
self.guided_attention_sigma = guided_attention_sigma
|
||||
|
||||
def l1_loss(self, prediction, target, mask, priority_bin=None):
|
||||
abs_diff = fluid.layers.abs(prediction - target)
|
||||
|
||||
# basic mask-weighted l1 loss
|
||||
w = self.masked_weight
|
||||
if w > 0 and mask is not None:
|
||||
base_l1_loss = w * masked_mean(abs_diff, mask) + (
|
||||
1 - w) * fluid.layers.reduce_mean(abs_diff)
|
||||
else:
|
||||
base_l1_loss = fluid.layers.reduce_mean(abs_diff)
|
||||
|
||||
if self.priority_weight > 0 and priority_bin is not None:
|
||||
# mask-weighted priority channels' l1-loss
|
||||
priority_abs_diff = fluid.layers.slice(
|
||||
abs_diff, axes=[1], starts=[0], ends=[priority_bin])
|
||||
if w > 0 and mask is not None:
|
||||
priority_loss = w * masked_mean(priority_abs_diff, mask) + (
|
||||
1 - w) * fluid.layers.reduce_mean(priority_abs_diff)
|
||||
else:
|
||||
priority_loss = fluid.layers.reduce_mean(priority_abs_diff)
|
||||
|
||||
# priority weighted sum
|
||||
p = self.priority_weight
|
||||
loss = p * priority_loss + (1 - p) * base_l1_loss
|
||||
else:
|
||||
loss = base_l1_loss
|
||||
return loss
|
||||
|
||||
def binary_divergence(self, prediction, target, mask):
|
||||
flattened_prediction = fluid.layers.reshape(prediction, [-1, 1])
|
||||
flattened_target = fluid.layers.reshape(target, [-1, 1])
|
||||
flattened_loss = fluid.layers.log_loss(
|
||||
flattened_prediction, flattened_target, epsilon=1e-8)
|
||||
bin_div = fluid.layers.reshape(flattened_loss, prediction.shape)
|
||||
|
||||
w = self.masked_weight
|
||||
if w > 0 and mask is not None:
|
||||
loss = w * masked_mean(bin_div, mask) + (
|
||||
1 - w) * fluid.layers.reduce_mean(bin_div)
|
||||
else:
|
||||
loss = fluid.layers.reduce_mean(bin_div)
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def done_loss(done_hat, done):
|
||||
flat_done_hat = fluid.layers.reshape(done_hat, [-1, 1])
|
||||
flat_done = fluid.layers.reshape(done, [-1, 1])
|
||||
loss = fluid.layers.log_loss(flat_done_hat, flat_done, epsilon=1e-8)
|
||||
loss = fluid.layers.reduce_mean(loss)
|
||||
return loss
|
||||
|
||||
def attention_loss(self, predicted_attention, input_lengths,
|
||||
target_lengths):
|
||||
"""
|
||||
Given valid encoder_lengths and decoder_lengths, compute a diagonal
|
||||
guide, and compute loss from the predicted attention and the guide.
|
||||
|
||||
Args:
|
||||
predicted_attention (Variable): Shape(*, B, T_dec, T_enc), the
|
||||
alignment tensor, where B means batch size, T_dec means number
|
||||
of time steps of the decoder, T_enc means the number of time
|
||||
steps of the encoder, * means other possible dimensions.
|
||||
input_lengths (numpy.ndarray): Shape(B,), dtype:int64, valid lengths
|
||||
(time steps) of encoder outputs.
|
||||
target_lengths (numpy.ndarray): Shape(batch_size,), dtype:int64,
|
||||
valid lengths (time steps) of decoder outputs.
|
||||
|
||||
Returns:
|
||||
loss (Variable): Shape(1, ) attention loss.
|
||||
"""
|
||||
n_attention, batch_size, max_target_len, max_input_len = (
|
||||
predicted_attention.shape)
|
||||
soft_mask = guided_attentions(input_lengths, target_lengths,
|
||||
max_target_len,
|
||||
self.guided_attention_sigma)
|
||||
soft_mask_ = dg.to_variable(soft_mask)
|
||||
loss = fluid.layers.reduce_mean(predicted_attention * soft_mask_)
|
||||
return loss
|
|
@ -1,282 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle import fluid
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid.layers as F
|
||||
|
||||
from parakeet.modules import customized as L
|
||||
|
||||
|
||||
def norm(param, dim, power):
|
||||
powered = F.pow(param, power)
|
||||
in_dtype = powered.dtype
|
||||
if in_dtype == fluid.core.VarDesc.VarType.FP16:
|
||||
powered = F.cast(powered, "float32")
|
||||
powered_norm = F.reduce_sum(powered, dim=dim, keep_dim=False)
|
||||
norm_ = F.pow(powered_norm, 1. / power)
|
||||
if in_dtype == fluid.core.VarDesc.VarType.FP16:
|
||||
norm_ = F.cast(norm_, "float16")
|
||||
return norm_
|
||||
|
||||
|
||||
def norm_except(param, dim, power):
|
||||
"""Computes the norm over all dimensions except dim.
|
||||
It differs from pytorch implementation that it does not keep dim.
|
||||
This difference is related with the broadcast mechanism in paddle.
|
||||
Read elementeise_mul for more.
|
||||
"""
|
||||
shape = param.shape
|
||||
ndim = len(shape)
|
||||
|
||||
if dim is None:
|
||||
return norm(param, dim, power)
|
||||
elif dim == 0:
|
||||
param_matrix = F.reshape(param, (shape[0], -1))
|
||||
return norm(param_matrix, dim=1, power=power)
|
||||
elif dim == -1 or dim == ndim - 1:
|
||||
param_matrix = F.reshape(param, (-1, shape[-1]))
|
||||
return norm(param_matrix, dim=0, power=power)
|
||||
else:
|
||||
perm = list(range(ndim))
|
||||
perm[0] = dim
|
||||
perm[dim] = 0
|
||||
transposed_param = F.transpose(param, perm)
|
||||
return norm_except(transposed_param, dim=0, power=power)
|
||||
|
||||
|
||||
def compute_l2_normalized_weight(v, g, dim):
|
||||
shape = v.shape
|
||||
ndim = len(shape)
|
||||
|
||||
if dim is None:
|
||||
v_normalized = v / (F.sqrt(F.reduce_sum(F.square(v))) + 1e-12)
|
||||
elif dim == 0:
|
||||
param_matrix = F.reshape(v, (shape[0], -1))
|
||||
v_normalized = F.l2_normalize(param_matrix, axis=1)
|
||||
v_normalized = F.reshape(v_normalized, shape)
|
||||
elif dim == -1 or dim == ndim - 1:
|
||||
param_matrix = F.reshape(v, (-1, shape[-1]))
|
||||
v_normalized = F.l2_normalize(param_matrix, axis=0)
|
||||
v_normalized = F.reshape(v_normalized, shape)
|
||||
else:
|
||||
perm = list(range(ndim))
|
||||
perm[0] = dim
|
||||
perm[dim] = 0
|
||||
transposed_param = F.transpose(v, perm)
|
||||
transposed_shape = transposed_param.shape
|
||||
param_matrix = F.reshape(transposed_param,
|
||||
(transposed_param.shape[0], -1))
|
||||
v_normalized = F.l2_normalize(param_matrix, axis=1)
|
||||
v_normalized = F.reshape(v_normalized, transposed_shape)
|
||||
v_normalized = F.transpose(v_normalized, perm)
|
||||
weight = F.elementwise_mul(v_normalized, g, axis=dim)
|
||||
return weight
|
||||
|
||||
|
||||
def compute_weight(v, g, dim, power):
|
||||
assert len(g.shape) == 1, "magnitude should be a vector"
|
||||
if power == 2:
|
||||
in_dtype = v.dtype
|
||||
if in_dtype == fluid.core.VarDesc.VarType.FP16:
|
||||
v = F.cast(v, "float32")
|
||||
g = F.cast(g, "float32")
|
||||
weight = compute_l2_normalized_weight(v, g, dim)
|
||||
if in_dtype == fluid.core.VarDesc.VarType.FP16:
|
||||
weight = F.cast(weight, "float16")
|
||||
return weight
|
||||
else:
|
||||
v_normalized = F.elementwise_div(
|
||||
v, (norm_except(v, dim, power) + 1e-12), axis=dim)
|
||||
weight = F.elementwise_mul(v_normalized, g, axis=dim)
|
||||
return weight
|
||||
|
||||
|
||||
class WeightNormWrapper(dg.Layer):
|
||||
def __init__(self, layer, param_name="weight", dim=0, power=2):
|
||||
super(WeightNormWrapper, self).__init__()
|
||||
|
||||
self.param_name = param_name
|
||||
self.dim = dim
|
||||
self.power = power
|
||||
self.layer = layer
|
||||
|
||||
w_v = param_name + "_v"
|
||||
w_g = param_name + "_g"
|
||||
|
||||
# we could also use numpy to compute this, after all, it is run only once
|
||||
# at initialization.
|
||||
original_weight = getattr(layer, param_name)
|
||||
self.add_parameter(
|
||||
w_v,
|
||||
self.create_parameter(
|
||||
shape=original_weight.shape, dtype=original_weight.dtype))
|
||||
with dg.no_grad():
|
||||
F.assign(original_weight, getattr(self, w_v))
|
||||
delattr(layer, param_name)
|
||||
temp = norm_except(getattr(self, w_v), self.dim, self.power)
|
||||
self.add_parameter(
|
||||
w_g, self.create_parameter(
|
||||
shape=temp.shape, dtype=temp.dtype))
|
||||
with dg.no_grad():
|
||||
F.assign(temp, getattr(self, w_g))
|
||||
|
||||
# also set this when setting up
|
||||
setattr(self.layer, self.param_name,
|
||||
compute_weight(
|
||||
getattr(self, w_v),
|
||||
getattr(self, w_g), self.dim, self.power))
|
||||
|
||||
self.weigth_norm_applied = True
|
||||
|
||||
# hook to compute weight with v & g
|
||||
def hook(self):
|
||||
w_v = self.param_name + "_v"
|
||||
w_g = self.param_name + "_g"
|
||||
setattr(self.layer, self.param_name,
|
||||
compute_weight(
|
||||
getattr(self, w_v),
|
||||
getattr(self, w_g), self.dim, self.power))
|
||||
|
||||
def remove_weight_norm(self):
|
||||
self.hook()
|
||||
self.weigth_norm_applied = False
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.weigth_norm_applied == True:
|
||||
self.hook()
|
||||
return self.layer(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, key):
|
||||
"""
|
||||
this is used for attr forwarding.
|
||||
"""
|
||||
if key in self._parameters:
|
||||
return self._parameters[key]
|
||||
elif key in self._sub_layers:
|
||||
return self._sub_layers[key]
|
||||
elif key is "layer":
|
||||
return self._sub_layers["layer"]
|
||||
else:
|
||||
return getattr(
|
||||
object.__getattribute__(self, "_sub_layers")["layer"], key)
|
||||
|
||||
|
||||
def Linear(input_dim,
|
||||
output_dim,
|
||||
param_attr=None,
|
||||
bias_attr=None,
|
||||
act=None,
|
||||
dtype="float32"):
|
||||
# a weight norm applied linear layer.
|
||||
lin = dg.Linear(input_dim, output_dim, param_attr, bias_attr, act, dtype)
|
||||
lin = WeightNormWrapper(lin, dim=1)
|
||||
return lin
|
||||
|
||||
|
||||
def Conv1D(num_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
param_attr=None,
|
||||
bias_attr=None,
|
||||
use_cudnn=True,
|
||||
act=None,
|
||||
dtype='float32'):
|
||||
conv = L.Conv1D(num_channels, num_filters, filter_size, stride, padding,
|
||||
dilation, groups, param_attr, bias_attr, use_cudnn, act,
|
||||
dtype)
|
||||
conv = WeightNormWrapper(conv, dim=0)
|
||||
return conv
|
||||
|
||||
|
||||
def Conv1DTranspose(num_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
padding=0,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
param_attr=None,
|
||||
bias_attr=None,
|
||||
use_cudnn=True,
|
||||
act=None,
|
||||
dtype='float32'):
|
||||
conv = L.Conv1DTranspose(num_channels, num_filters, filter_size, padding,
|
||||
stride, dilation, groups, param_attr, bias_attr,
|
||||
use_cudnn, act, dtype)
|
||||
conv = WeightNormWrapper(conv, dim=0)
|
||||
return conv
|
||||
|
||||
|
||||
def Conv1DCell(num_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
dilation=1,
|
||||
causal=False,
|
||||
groups=1,
|
||||
param_attr=None,
|
||||
bias_attr=None,
|
||||
use_cudnn=True,
|
||||
act=None,
|
||||
dtype='float32'):
|
||||
conv = L.Conv1DCell(num_channels, num_filters, filter_size, dilation,
|
||||
causal, groups, param_attr, bias_attr, use_cudnn, act,
|
||||
dtype)
|
||||
conv = WeightNormWrapper(conv, dim=0)
|
||||
return conv
|
||||
|
||||
|
||||
def Conv2D(num_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
param_attr=None,
|
||||
bias_attr=None,
|
||||
use_cudnn=True,
|
||||
act=None,
|
||||
dtype='float32'):
|
||||
# a conv2d layer with weight norm wrapper
|
||||
conv = dg.Conv2D(num_channels, num_filters, filter_size, stride, padding,
|
||||
dilation, groups, param_attr, bias_attr, use_cudnn, act,
|
||||
dtype)
|
||||
conv = WeightNormWrapper(conv, dim=0)
|
||||
return conv
|
||||
|
||||
|
||||
def Conv2DTranspose(num_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
output_size=None,
|
||||
padding=0,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
param_attr=None,
|
||||
bias_attr=None,
|
||||
use_cudnn=True,
|
||||
act=None,
|
||||
dtype='float32'):
|
||||
# a conv2d transpose layer with weight norm wrapper.
|
||||
conv = dg.Conv2DTranspose(num_channels, num_filters, filter_size,
|
||||
output_size, padding, stride, dilation, groups,
|
||||
param_attr, bias_attr, use_cudnn, act, dtype)
|
||||
conv = WeightNormWrapper(conv, dim=0)
|
||||
return conv
|
Loading…
Reference in New Issue