112 lines
3.3 KiB
Python
112 lines
3.3 KiB
Python
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import paddle
|
||
|
|
||
|
|
||
|
# 按照这个 batch 里面最长的补零
|
||
|
def pad_list(xs, pad_value):
|
||
|
"""Perform padding for the list of tensors.
|
||
|
|
||
|
Args:
|
||
|
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
||
|
pad_value (float): Value for padding.
|
||
|
|
||
|
Returns:
|
||
|
Tensor: Padded tensor (B, Tmax, `*`).
|
||
|
|
||
|
Examples:
|
||
|
>>> x = [paddle.ones([4]), paddle.ones([2]), paddle.ones([1])]
|
||
|
>>> x
|
||
|
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
||
|
>>> pad_list(x, 0)
|
||
|
tensor([[1., 1., 1., 1.],
|
||
|
[1., 1., 0., 0.],
|
||
|
[1., 0., 0., 0.]])
|
||
|
|
||
|
"""
|
||
|
n_batch = len(xs)
|
||
|
max_len = max(x.shape[0] for x in xs)
|
||
|
# pad = xs[0].new(n_batch, max_len, *xs[0].shape[1:]).fill_(pad_value)
|
||
|
pad = paddle.full([n_batch, max_len, *xs[0].shape[1:]], pad_value)
|
||
|
|
||
|
for i in range(n_batch):
|
||
|
pad[i, :xs[i].shape[0]] = xs[i]
|
||
|
|
||
|
return pad
|
||
|
|
||
|
|
||
|
def make_pad_mask(lengths, length_dim=-1):
|
||
|
"""Make mask tensor containing indices of padded part.
|
||
|
|
||
|
Args:
|
||
|
lengths (LongTensor or List): Batch of lengths (B,).
|
||
|
|
||
|
Returns:
|
||
|
Tensor: Mask tensor containing indices of padded part bool.
|
||
|
|
||
|
Examples:
|
||
|
With only lengths.
|
||
|
|
||
|
>>> lengths = [5, 3, 2]
|
||
|
>>> make_non_pad_mask(lengths)
|
||
|
masks = [[0, 0, 0, 0 ,0],
|
||
|
[0, 0, 0, 1, 1],
|
||
|
[0, 0, 1, 1, 1]]
|
||
|
|
||
|
"""
|
||
|
if length_dim == 0:
|
||
|
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
|
||
|
|
||
|
if not isinstance(lengths, list):
|
||
|
lengths = lengths.tolist()
|
||
|
bs = int(len(lengths))
|
||
|
|
||
|
maxlen = int(max(lengths))
|
||
|
|
||
|
seq_range = paddle.arange(0, maxlen, dtype=paddle.int64)
|
||
|
seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen])
|
||
|
|
||
|
seq_length_expand = paddle.to_tensor(
|
||
|
lengths, dtype=seq_range_expand.dtype).unsqueeze(-1)
|
||
|
mask = seq_range_expand >= seq_length_expand
|
||
|
|
||
|
return mask
|
||
|
|
||
|
|
||
|
def make_non_pad_mask(lengths, length_dim=-1):
|
||
|
"""Make mask tensor containing indices of non-padded part.
|
||
|
|
||
|
Args:
|
||
|
lengths (LongTensor or List): Batch of lengths (B,).
|
||
|
xs (Tensor, optional): The reference tensor.
|
||
|
If set, masks will be the same shape as this tensor.
|
||
|
length_dim (int, optional): Dimension indicator of the above tensor.
|
||
|
See the example.
|
||
|
|
||
|
Returns:
|
||
|
ByteTensor: mask tensor containing indices of padded part bool.
|
||
|
|
||
|
Examples:
|
||
|
With only lengths.
|
||
|
|
||
|
>>> lengths = [5, 3, 2]
|
||
|
>>> make_non_pad_mask(lengths)
|
||
|
masks = [[1, 1, 1, 1 ,1],
|
||
|
[1, 1, 1, 0, 0],
|
||
|
[1, 1, 0, 0, 0]]
|
||
|
|
||
|
"""
|
||
|
return paddle.logical_not(make_pad_mask(lengths, length_dim))
|