add func
This commit is contained in:
parent
2e6b4be5ed
commit
c0f508b24b
|
@ -9,6 +9,7 @@ logger = logging.getLogger(__name__)
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'manual_seed',
|
'manual_seed',
|
||||||
'seq_len_to_mask',
|
'seq_len_to_mask',
|
||||||
|
'to_one_hot',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,3 +50,17 @@ def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None
|
||||||
raise logger.error("Only support 1-d list or 1-d numpy.ndarray or 1-d torch.Tensor.")
|
raise logger.error("Only support 1-d list or 1-d numpy.ndarray or 1-d torch.Tensor.")
|
||||||
|
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def to_one_hot(x: torch.Tensor, length: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
:param x: [B] 一般是 target 的值
|
||||||
|
:param length: L 一般是关系种类树
|
||||||
|
:return: [B, L] 每一行,只有对应位置为1,其余为0
|
||||||
|
"""
|
||||||
|
B = x.size(0)
|
||||||
|
x_one_hot = torch.zeros(B, length)
|
||||||
|
for i in range(B):
|
||||||
|
x_one_hot[i, x[i]] = 1.0
|
||||||
|
|
||||||
|
return x_one_hot
|
Loading…
Reference in New Issue