ParakeetEricRoss/parakeet/modules/nets_utils.py

112 lines
3.3 KiB
Python
Raw Normal View History

2021-07-12 14:01:43 +08:00
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
# 按照这个 batch 里面最长的补零
def pad_list(xs, pad_value):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [paddle.ones([4]), paddle.ones([2]), paddle.ones([1])]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch = len(xs)
max_len = max(x.shape[0] for x in xs)
# pad = xs[0].new(n_batch, max_len, *xs[0].shape[1:]).fill_(pad_value)
pad = paddle.full([n_batch, max_len, *xs[0].shape[1:]], pad_value)
for i in range(n_batch):
pad[i, :xs[i].shape[0]] = xs[i]
return pad
def make_pad_mask(lengths, length_dim=-1):
"""Make mask tensor containing indices of padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
Returns:
Tensor: Mask tensor containing indices of padded part bool.
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if not isinstance(lengths, list):
lengths = lengths.tolist()
bs = int(len(lengths))
maxlen = int(max(lengths))
seq_range = paddle.arange(0, maxlen, dtype=paddle.int64)
seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen])
seq_length_expand = paddle.to_tensor(
lengths, dtype=seq_range_expand.dtype).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
def make_non_pad_mask(lengths, length_dim=-1):
"""Make mask tensor containing indices of non-padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
ByteTensor: mask tensor containing indices of padded part bool.
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
"""
return paddle.logical_not(make_pad_mask(lengths, length_dim))