add func
This commit is contained in:
parent
2e6b4be5ed
commit
c0f508b24b
|
@ -9,6 +9,7 @@ logger = logging.getLogger(__name__)
|
|||
__all__ = [
|
||||
'manual_seed',
|
||||
'seq_len_to_mask',
|
||||
'to_one_hot',
|
||||
]
|
||||
|
||||
|
||||
|
@ -49,3 +50,17 @@ def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None
|
|||
raise logger.error("Only support 1-d list or 1-d numpy.ndarray or 1-d torch.Tensor.")
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def to_one_hot(x: torch.Tensor, length: int) -> torch.Tensor:
|
||||
"""
|
||||
:param x: [B] 一般是 target 的值
|
||||
:param length: L 一般是关系种类树
|
||||
:return: [B, L] 每一行,只有对应位置为1,其余为0
|
||||
"""
|
||||
B = x.size(0)
|
||||
x_one_hot = torch.zeros(B, length)
|
||||
for i in range(B):
|
||||
x_one_hot[i, x[i]] = 1.0
|
||||
|
||||
return x_one_hot
|
Loading…
Reference in New Issue