Merge pull request #970 from WenmuZhou/dygraph
解决crnn训练时对labels进行合并的bug
This commit is contained in:
commit
f1048e296e
|
@ -21,18 +21,6 @@ from paddle import nn
|
|||
from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr
|
||||
|
||||
|
||||
class EncoderWithReshape(nn.Layer):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x = x.reshape((B, C, -1))
|
||||
x = x.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
|
||||
return x
|
||||
|
||||
|
||||
class Im2Seq(nn.Layer):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super().__init__()
|
||||
|
@ -40,9 +28,8 @@ class Im2Seq(nn.Layer):
|
|||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert H == 1
|
||||
x = x.transpose((0, 2, 3, 1))
|
||||
x = x.reshape((-1, C))
|
||||
x = x.reshape((B, -1, W))
|
||||
x = x.transpose((0, 2, 1)) # (NTC)(batch, width, channels)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -50,49 +37,10 @@ class EncoderWithRNN(nn.Layer):
|
|||
def __init__(self, in_channels, hidden_size):
|
||||
super(EncoderWithRNN, self).__init__()
|
||||
self.out_channels = hidden_size * 2
|
||||
# self.lstm1_fw = nn.LSTMCell(
|
||||
# in_channels,
|
||||
# hidden_size,
|
||||
# weight_ih_attr=ParamAttr(name='lstm_st1_fc1_w'),
|
||||
# bias_ih_attr=ParamAttr(name='lstm_st1_fc1_b'),
|
||||
# weight_hh_attr=ParamAttr(name='lstm_st1_out1_w'),
|
||||
# bias_hh_attr=ParamAttr(name='lstm_st1_out1_b'),
|
||||
# )
|
||||
# self.lstm1_bw = nn.LSTMCell(
|
||||
# in_channels,
|
||||
# hidden_size,
|
||||
# weight_ih_attr=ParamAttr(name='lstm_st1_fc2_w'),
|
||||
# bias_ih_attr=ParamAttr(name='lstm_st1_fc2_b'),
|
||||
# weight_hh_attr=ParamAttr(name='lstm_st1_out2_w'),
|
||||
# bias_hh_attr=ParamAttr(name='lstm_st1_out2_b'),
|
||||
# )
|
||||
# self.lstm2_fw = nn.LSTMCell(
|
||||
# hidden_size,
|
||||
# hidden_size,
|
||||
# weight_ih_attr=ParamAttr(name='lstm_st2_fc1_w'),
|
||||
# bias_ih_attr=ParamAttr(name='lstm_st2_fc1_b'),
|
||||
# weight_hh_attr=ParamAttr(name='lstm_st2_out1_w'),
|
||||
# bias_hh_attr=ParamAttr(name='lstm_st2_out1_b'),
|
||||
# )
|
||||
# self.lstm2_bw = nn.LSTMCell(
|
||||
# hidden_size,
|
||||
# hidden_size,
|
||||
# weight_ih_attr=ParamAttr(name='lstm_st2_fc2_w'),
|
||||
# bias_ih_attr=ParamAttr(name='lstm_st2_fc2_b'),
|
||||
# weight_hh_attr=ParamAttr(name='lstm_st2_out2_w'),
|
||||
# bias_hh_attr=ParamAttr(name='lstm_st2_out2_b'),
|
||||
# )
|
||||
self.lstm = nn.LSTM(
|
||||
in_channels, hidden_size, direction='bidirectional', num_layers=2)
|
||||
|
||||
def forward(self, x):
|
||||
# fw_x, _ = self.lstm1_fw(x)
|
||||
# fw_x, _ = self.lstm2_fw(fw_x)
|
||||
#
|
||||
# # bw
|
||||
# bw_x, _ = self.lstm1_bw(x)
|
||||
# bw_x, _ = self.lstm2_bw(bw_x)
|
||||
# x = paddle.concat([fw_x, bw_x], axis=2)
|
||||
x, _ = self.lstm(x)
|
||||
return x
|
||||
|
||||
|
@ -118,13 +66,13 @@ class EncoderWithFC(nn.Layer):
|
|||
class SequenceEncoder(nn.Layer):
|
||||
def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
|
||||
super(SequenceEncoder, self).__init__()
|
||||
self.encoder_reshape = EncoderWithReshape(in_channels)
|
||||
self.encoder_reshape = Im2Seq(in_channels)
|
||||
self.out_channels = self.encoder_reshape.out_channels
|
||||
if encoder_type == 'reshape':
|
||||
self.only_reshape = True
|
||||
else:
|
||||
support_encoder_dict = {
|
||||
'reshape': EncoderWithReshape,
|
||||
'reshape': Im2Seq,
|
||||
'fc': EncoderWithFC,
|
||||
'rnn': EncoderWithRNN
|
||||
}
|
||||
|
|
|
@ -70,6 +70,7 @@ class BaseRecLabelDecode(object):
|
|||
if text_index[batch_idx][idx] in ignored_tokens:
|
||||
continue
|
||||
if is_remove_duplicate:
|
||||
# only for predict
|
||||
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||
batch_idx][idx]:
|
||||
continue
|
||||
|
@ -107,7 +108,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
|||
text = self.decode(preds_idx, preds_prob)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label)
|
||||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
|
|
Loading…
Reference in New Issue