update rec_sar_head
This commit is contained in:
parent
073fad37ba
commit
df4a2f6a7e
|
@ -9,7 +9,7 @@ from paddle import nn
|
|||
class SARLoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(SARLoss, self).__init__()
|
||||
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=92)
|
||||
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=96)
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
predict = predicts[:, :-1, :] # ignore last index of outputs to be in same seq_len with targets
|
||||
|
|
|
@ -118,8 +118,7 @@ class BaseDecoder(nn.Layer):
|
|||
class ParallelSARDecoder(BaseDecoder):
|
||||
"""
|
||||
Args:
|
||||
num_classes (int): Output class number.
|
||||
channels (list[int]): Network layer channels.
|
||||
out_channels (int): Output class number.
|
||||
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
|
||||
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
|
||||
dec_drop_rnn (float): Dropout of RNN layer in decoder.
|
||||
|
@ -137,7 +136,7 @@ class ParallelSARDecoder(BaseDecoder):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes=93, # 90 + unknown + start + padding
|
||||
out_channels, # 90 + unknown + start + padding
|
||||
enc_bi_rnn=False,
|
||||
dec_bi_rnn=False,
|
||||
dec_drop_rnn=0.0,
|
||||
|
@ -148,8 +147,6 @@ class ParallelSARDecoder(BaseDecoder):
|
|||
pred_dropout=0.1,
|
||||
max_text_length=30,
|
||||
mask=True,
|
||||
start_idx=91,
|
||||
padding_idx=92, # 92
|
||||
pred_concat=True,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
@ -157,7 +154,8 @@ class ParallelSARDecoder(BaseDecoder):
|
|||
self.num_classes = num_classes
|
||||
self.enc_bi_rnn = enc_bi_rnn
|
||||
self.d_k = d_k
|
||||
self.start_idx = start_idx
|
||||
self.start_idx = out_channels - 2
|
||||
self.padding_idx = out_channels - 1
|
||||
self.max_seq_len = max_text_length
|
||||
self.mask = mask
|
||||
self.pred_concat = pred_concat
|
||||
|
@ -191,7 +189,7 @@ class ParallelSARDecoder(BaseDecoder):
|
|||
|
||||
# Decoder input embedding
|
||||
self.embedding = nn.Embedding(
|
||||
self.num_classes, encoder_rnn_out_size, padding_idx=padding_idx)
|
||||
self.num_classes, encoder_rnn_out_size, padding_idx=self.padding_idx)
|
||||
|
||||
# Prediction layer
|
||||
self.pred_dropout = nn.Dropout(pred_dropout)
|
||||
|
@ -330,6 +328,7 @@ class ParallelSARDecoder(BaseDecoder):
|
|||
|
||||
class SARHead(nn.Layer):
|
||||
def __init__(self,
|
||||
out_channels,
|
||||
enc_bi_rnn=False,
|
||||
enc_drop_rnn=0.1,
|
||||
enc_gru=False,
|
||||
|
@ -351,7 +350,8 @@ class SARHead(nn.Layer):
|
|||
|
||||
# decoder module
|
||||
self.decoder = ParallelSARDecoder(
|
||||
enc_bi_rnn=enc_bi_rnn,
|
||||
out_channels=out_channels,
|
||||
enc_bi_rnn=enc_bi_rnn,
|
||||
dec_bi_rnn=dec_bi_rnn,
|
||||
dec_drop_rnn=dec_drop_rnn,
|
||||
dec_gru=dec_gru,
|
||||
|
|
Loading…
Reference in New Issue