update rec_sar_head
This commit is contained in:
parent
073fad37ba
commit
df4a2f6a7e
|
@ -9,7 +9,7 @@ from paddle import nn
|
||||||
class SARLoss(nn.Layer):
|
class SARLoss(nn.Layer):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super(SARLoss, self).__init__()
|
super(SARLoss, self).__init__()
|
||||||
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=92)
|
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=96)
|
||||||
|
|
||||||
def forward(self, predicts, batch):
|
def forward(self, predicts, batch):
|
||||||
predict = predicts[:, :-1, :] # ignore last index of outputs to be in same seq_len with targets
|
predict = predicts[:, :-1, :] # ignore last index of outputs to be in same seq_len with targets
|
||||||
|
|
|
@ -118,8 +118,7 @@ class BaseDecoder(nn.Layer):
|
||||||
class ParallelSARDecoder(BaseDecoder):
|
class ParallelSARDecoder(BaseDecoder):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
num_classes (int): Output class number.
|
out_channels (int): Output class number.
|
||||||
channels (list[int]): Network layer channels.
|
|
||||||
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
|
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
|
||||||
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
|
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
|
||||||
dec_drop_rnn (float): Dropout of RNN layer in decoder.
|
dec_drop_rnn (float): Dropout of RNN layer in decoder.
|
||||||
|
@ -137,7 +136,7 @@ class ParallelSARDecoder(BaseDecoder):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_classes=93, # 90 + unknown + start + padding
|
out_channels, # 90 + unknown + start + padding
|
||||||
enc_bi_rnn=False,
|
enc_bi_rnn=False,
|
||||||
dec_bi_rnn=False,
|
dec_bi_rnn=False,
|
||||||
dec_drop_rnn=0.0,
|
dec_drop_rnn=0.0,
|
||||||
|
@ -148,8 +147,6 @@ class ParallelSARDecoder(BaseDecoder):
|
||||||
pred_dropout=0.1,
|
pred_dropout=0.1,
|
||||||
max_text_length=30,
|
max_text_length=30,
|
||||||
mask=True,
|
mask=True,
|
||||||
start_idx=91,
|
|
||||||
padding_idx=92, # 92
|
|
||||||
pred_concat=True,
|
pred_concat=True,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -157,7 +154,8 @@ class ParallelSARDecoder(BaseDecoder):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.enc_bi_rnn = enc_bi_rnn
|
self.enc_bi_rnn = enc_bi_rnn
|
||||||
self.d_k = d_k
|
self.d_k = d_k
|
||||||
self.start_idx = start_idx
|
self.start_idx = out_channels - 2
|
||||||
|
self.padding_idx = out_channels - 1
|
||||||
self.max_seq_len = max_text_length
|
self.max_seq_len = max_text_length
|
||||||
self.mask = mask
|
self.mask = mask
|
||||||
self.pred_concat = pred_concat
|
self.pred_concat = pred_concat
|
||||||
|
@ -191,7 +189,7 @@ class ParallelSARDecoder(BaseDecoder):
|
||||||
|
|
||||||
# Decoder input embedding
|
# Decoder input embedding
|
||||||
self.embedding = nn.Embedding(
|
self.embedding = nn.Embedding(
|
||||||
self.num_classes, encoder_rnn_out_size, padding_idx=padding_idx)
|
self.num_classes, encoder_rnn_out_size, padding_idx=self.padding_idx)
|
||||||
|
|
||||||
# Prediction layer
|
# Prediction layer
|
||||||
self.pred_dropout = nn.Dropout(pred_dropout)
|
self.pred_dropout = nn.Dropout(pred_dropout)
|
||||||
|
@ -330,6 +328,7 @@ class ParallelSARDecoder(BaseDecoder):
|
||||||
|
|
||||||
class SARHead(nn.Layer):
|
class SARHead(nn.Layer):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
out_channels,
|
||||||
enc_bi_rnn=False,
|
enc_bi_rnn=False,
|
||||||
enc_drop_rnn=0.1,
|
enc_drop_rnn=0.1,
|
||||||
enc_gru=False,
|
enc_gru=False,
|
||||||
|
@ -351,6 +350,7 @@ class SARHead(nn.Layer):
|
||||||
|
|
||||||
# decoder module
|
# decoder module
|
||||||
self.decoder = ParallelSARDecoder(
|
self.decoder = ParallelSARDecoder(
|
||||||
|
out_channels=out_channels,
|
||||||
enc_bi_rnn=enc_bi_rnn,
|
enc_bi_rnn=enc_bi_rnn,
|
||||||
dec_bi_rnn=dec_bi_rnn,
|
dec_bi_rnn=dec_bi_rnn,
|
||||||
dec_drop_rnn=dec_drop_rnn,
|
dec_drop_rnn=dec_drop_rnn,
|
||||||
|
|
Loading…
Reference in New Issue