This commit is contained in:
leo 2019-12-05 20:46:18 +08:00
parent 2e6b4be5ed
commit c0f508b24b
1 changed files with 15 additions and 0 deletions

View File

@ -9,6 +9,7 @@ logger = logging.getLogger(__name__)
__all__ = [ __all__ = [
'manual_seed', 'manual_seed',
'seq_len_to_mask', 'seq_len_to_mask',
'to_one_hot',
] ]
@ -49,3 +50,17 @@ def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None
raise logger.error("Only support 1-d list or 1-d numpy.ndarray or 1-d torch.Tensor.") raise logger.error("Only support 1-d list or 1-d numpy.ndarray or 1-d torch.Tensor.")
return mask return mask
def to_one_hot(x: torch.Tensor, length: int) -> torch.Tensor:
"""
:param x: [B] 一般是 target 的值
:param length: L 一般是关系种类树
:return: [B, L] 每一行只有对应位置为1其余为0
"""
B = x.size(0)
x_one_hot = torch.zeros(B, length)
for i in range(B):
x_one_hot[i, x[i]] = 1.0
return x_one_hot