add sar dict
This commit is contained in:
parent
89c9f363e6
commit
e760df4b1e
|
@ -9,11 +9,14 @@ from paddle import nn
|
|||
class SARLoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(SARLoss, self).__init__()
|
||||
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=96)
|
||||
self.loss_func = paddle.nn.loss.CrossEntropyLoss(
|
||||
reduction="mean", ignore_index=92)
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
predict = predicts[:, :-1, :] # ignore last index of outputs to be in same seq_len with targets
|
||||
label = batch[1].astype("int64")[:, 1:] # ignore first index of target in loss calculation
|
||||
predict = predicts[:, :
|
||||
-1, :] # ignore last index of outputs to be in same seq_len with targets
|
||||
label = batch[1].astype(
|
||||
"int64")[:, 1:] # ignore first index of target in loss calculation
|
||||
batch_size, num_steps, num_classes = predict.shape[0], predict.shape[
|
||||
1], predict.shape[2]
|
||||
assert len(label.shape) == len(list(predict.shape)) - 1, \
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
||||
A
|
||||
B
|
||||
C
|
||||
D
|
||||
E
|
||||
F
|
||||
G
|
||||
H
|
||||
I
|
||||
J
|
||||
K
|
||||
L
|
||||
M
|
||||
N
|
||||
O
|
||||
P
|
||||
Q
|
||||
R
|
||||
S
|
||||
T
|
||||
U
|
||||
V
|
||||
W
|
||||
X
|
||||
Y
|
||||
Z
|
||||
!
|
||||
"
|
||||
#
|
||||
$
|
||||
%
|
||||
&
|
||||
'
|
||||
(
|
||||
)
|
||||
*
|
||||
+
|
||||
,
|
||||
-
|
||||
.
|
||||
/
|
||||
:
|
||||
;
|
||||
<
|
||||
=
|
||||
>
|
||||
?
|
||||
@
|
||||
[
|
||||
\
|
||||
]
|
||||
_
|
||||
`
|
||||
~
|
Loading…
Reference in New Issue