ParakeetEricRoss/parakeet/modules/nets_utils.py

158 lines
4.5 KiB
Python
Raw Normal View History

2021-07-12 14:01:43 +08:00
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
2021-07-13 15:55:56 +08:00
from paddle import nn
from typeguard import check_argument_types
2021-07-12 14:01:43 +08:00
def pad_list(xs, pad_value):
"""Perform padding for the list of tensors.
2021-07-13 15:55:56 +08:00
Parameters
----------
2021-08-05 20:29:20 +08:00
xs : List[Tensor]
List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value : float)
Value for padding.
2021-07-12 14:01:43 +08:00
2021-07-13 15:55:56 +08:00
Returns
----------
2021-08-05 20:29:20 +08:00
Tensor
Padded tensor (B, Tmax, `*`).
2021-07-12 14:01:43 +08:00
2021-07-13 15:55:56 +08:00
Examples
----------
2021-08-05 20:29:20 +08:00
>>> x = [paddle.ones([4]), paddle.ones([2]), paddle.ones([1])]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
2021-07-12 14:01:43 +08:00
"""
n_batch = len(xs)
max_len = max(x.shape[0] for x in xs)
pad = paddle.full([n_batch, max_len, *xs[0].shape[1:]], pad_value)
for i in range(n_batch):
pad[i, :xs[i].shape[0]] = xs[i]
return pad
def make_pad_mask(lengths, length_dim=-1):
"""Make mask tensor containing indices of padded part.
2021-07-13 15:55:56 +08:00
Parameters
----------
2021-08-05 20:29:20 +08:00
lengths : LongTensor or List
Batch of lengths (B,).
2021-07-12 14:01:43 +08:00
2021-07-13 15:55:56 +08:00
Returns
----------
2021-08-05 20:29:20 +08:00
Tensor(bool)
Mask tensor containing indices of padded part bool.
2021-07-12 14:01:43 +08:00
2021-07-13 15:55:56 +08:00
Examples
----------
2021-08-05 20:29:20 +08:00
With only lengths.
2021-07-12 14:01:43 +08:00
2021-08-05 20:29:20 +08:00
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
2021-07-12 14:01:43 +08:00
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if not isinstance(lengths, list):
lengths = lengths.tolist()
bs = int(len(lengths))
maxlen = int(max(lengths))
seq_range = paddle.arange(0, maxlen, dtype=paddle.int64)
seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen])
seq_length_expand = paddle.to_tensor(
lengths, dtype=seq_range_expand.dtype).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
def make_non_pad_mask(lengths, length_dim=-1):
"""Make mask tensor containing indices of non-padded part.
2021-07-13 15:55:56 +08:00
Parameters
----------
2021-08-05 20:29:20 +08:00
lengths : LongTensor or List
Batch of lengths (B,).
xs : Tensor, optional
The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim : int, optional
Dimension indicator of the above tensor.
See the example.
2021-07-12 14:01:43 +08:00
2021-07-13 15:55:56 +08:00
Returns
----------
2021-08-05 20:29:20 +08:00
Tensor(bool)
mask tensor containing indices of padded part bool.
2021-07-12 14:01:43 +08:00
2021-07-13 15:55:56 +08:00
Examples
----------
2021-08-05 20:29:20 +08:00
With only lengths.
2021-07-12 14:01:43 +08:00
2021-08-05 20:29:20 +08:00
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
2021-07-12 14:01:43 +08:00
"""
return paddle.logical_not(make_pad_mask(lengths, length_dim))
2021-07-13 15:55:56 +08:00
def initialize(model: nn.Layer, init: str):
"""Initialize weights of a neural network module.
Parameters are initialized using the given method or distribution.
Custom initialization routines can be implemented into submodules
Parameters
----------
2021-08-05 20:29:20 +08:00
model : paddle.nn.Layer
Target.
init : str
Method of initialization.
2021-07-13 15:55:56 +08:00
"""
assert check_argument_types()
if init == "xavier_uniform":
nn.initializer.set_global_initializer(nn.initializer.XavierUniform(),
nn.initializer.Constant())
elif init == "xavier_normal":
nn.initializer.set_global_initializer(nn.initializer.XavierNormal(),
nn.initializer.Constant())
elif init == "kaiming_uniform":
nn.initializer.set_global_initializer(nn.initializer.KaimingUniform(),
nn.initializer.Constant())
elif init == "kaiming_normal":
nn.initializer.set_global_initializer(nn.initializer.KaimingNormal(),
nn.initializer.Constant())
else:
raise ValueError("Unknown initialization: " + init)