update rec_sar_head

This commit is contained in:
andyjpaddle 2021-09-07 03:33:02 +00:00
parent 073fad37ba
commit df4a2f6a7e
2 changed files with 10 additions and 10 deletions

View File

@ -9,7 +9,7 @@ from paddle import nn
class SARLoss(nn.Layer): class SARLoss(nn.Layer):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(SARLoss, self).__init__() super(SARLoss, self).__init__()
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=92) self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=96)
def forward(self, predicts, batch): def forward(self, predicts, batch):
predict = predicts[:, :-1, :] # ignore last index of outputs to be in same seq_len with targets predict = predicts[:, :-1, :] # ignore last index of outputs to be in same seq_len with targets

View File

@ -118,8 +118,7 @@ class BaseDecoder(nn.Layer):
class ParallelSARDecoder(BaseDecoder): class ParallelSARDecoder(BaseDecoder):
""" """
Args: Args:
num_classes (int): Output class number. out_channels (int): Output class number.
channels (list[int]): Network layer channels.
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder. dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
dec_drop_rnn (float): Dropout of RNN layer in decoder. dec_drop_rnn (float): Dropout of RNN layer in decoder.
@ -137,7 +136,7 @@ class ParallelSARDecoder(BaseDecoder):
""" """
def __init__(self, def __init__(self,
num_classes=93, # 90 + unknown + start + padding out_channels, # 90 + unknown + start + padding
enc_bi_rnn=False, enc_bi_rnn=False,
dec_bi_rnn=False, dec_bi_rnn=False,
dec_drop_rnn=0.0, dec_drop_rnn=0.0,
@ -148,8 +147,6 @@ class ParallelSARDecoder(BaseDecoder):
pred_dropout=0.1, pred_dropout=0.1,
max_text_length=30, max_text_length=30,
mask=True, mask=True,
start_idx=91,
padding_idx=92, # 92
pred_concat=True, pred_concat=True,
**kwargs): **kwargs):
super().__init__() super().__init__()
@ -157,7 +154,8 @@ class ParallelSARDecoder(BaseDecoder):
self.num_classes = num_classes self.num_classes = num_classes
self.enc_bi_rnn = enc_bi_rnn self.enc_bi_rnn = enc_bi_rnn
self.d_k = d_k self.d_k = d_k
self.start_idx = start_idx self.start_idx = out_channels - 2
self.padding_idx = out_channels - 1
self.max_seq_len = max_text_length self.max_seq_len = max_text_length
self.mask = mask self.mask = mask
self.pred_concat = pred_concat self.pred_concat = pred_concat
@ -191,7 +189,7 @@ class ParallelSARDecoder(BaseDecoder):
# Decoder input embedding # Decoder input embedding
self.embedding = nn.Embedding( self.embedding = nn.Embedding(
self.num_classes, encoder_rnn_out_size, padding_idx=padding_idx) self.num_classes, encoder_rnn_out_size, padding_idx=self.padding_idx)
# Prediction layer # Prediction layer
self.pred_dropout = nn.Dropout(pred_dropout) self.pred_dropout = nn.Dropout(pred_dropout)
@ -330,6 +328,7 @@ class ParallelSARDecoder(BaseDecoder):
class SARHead(nn.Layer): class SARHead(nn.Layer):
def __init__(self, def __init__(self,
out_channels,
enc_bi_rnn=False, enc_bi_rnn=False,
enc_drop_rnn=0.1, enc_drop_rnn=0.1,
enc_gru=False, enc_gru=False,
@ -351,7 +350,8 @@ class SARHead(nn.Layer):
# decoder module # decoder module
self.decoder = ParallelSARDecoder( self.decoder = ParallelSARDecoder(
enc_bi_rnn=enc_bi_rnn, out_channels=out_channels,
enc_bi_rnn=enc_bi_rnn,
dec_bi_rnn=dec_bi_rnn, dec_bi_rnn=dec_bi_rnn,
dec_drop_rnn=dec_drop_rnn, dec_drop_rnn=dec_drop_rnn,
dec_gru=dec_gru, dec_gru=dec_gru,