diff --git a/configs/rec/rec_mv3_tps_bilstm_att.yml b/configs/rec/rec_mv3_tps_bilstm_att.yml new file mode 100644 index 00000000..c64b2ccc --- /dev/null +++ b/configs/rec/rec_mv3_tps_bilstm_att.yml @@ -0,0 +1,102 @@ +Global: + use_gpu: true + epoch_num: 72 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/rec_mv3_tps_bilstm_att/ + save_epoch_step: 3 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [0, 2000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: + character_type: en + max_text_length: 25 + infer_mode: False + use_space_char: False + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + learning_rate: 0.0005 + regularizer: + name: 'L2' + factor: 0.00001 + +Architecture: + model_type: rec + algorithm: RARE + Transform: + name: TPS + num_fiducial: 20 + loc_lr: 0.1 + model_name: small + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: large + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 96 + Head: + name: AttentionHead + hidden_size: 96 + + +Loss: + name: AttentionLoss + +PostProcess: + name: AttnLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDateSet + data_dir: ../training/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - AttnLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 32, 100] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 256 + drop_last: True + num_workers: 8 + +Eval: + dataset: + name: LMDBDateSet + data_dir: ../validation/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - AttnLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 32, 100] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 1 diff --git a/configs/rec/rec_r34_vd_tps_bilstm_att.yml b/configs/rec/rec_r34_vd_tps_bilstm_att.yml new file mode 100644 index 00000000..f42bfdcc --- /dev/null +++ b/configs/rec/rec_r34_vd_tps_bilstm_att.yml @@ -0,0 +1,103 @@ +Global: + use_gpu: true + epoch_num: 400 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/b3_rare_r34_none_gru/ + save_epoch_step: 3 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [0, 2000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: + character_type: en + max_text_length: 25 + infer_mode: False + use_space_char: False + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + learning_rate: 0.0005 + regularizer: + name: 'L2' + factor: 0.00000 + +Architecture: + model_type: rec + algorithm: RARE + Transform: + name: TPS + num_fiducial: 20 + loc_lr: 0.1 + model_name: large + Backbone: + name: ResNet + layers: 34 + + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 256 #96 + Head: + name: AttentionHead # AttentionHead + hidden_size: 256 # + l2_decay: 0.00001 + + +Loss: + name: AttentionLoss + +PostProcess: + name: AttnLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDateSet + data_dir: ../training/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - AttnLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 32, 100] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 256 + drop_last: True + num_workers: 8 + +Eval: + dataset: + name: LMDBDateSet + data_dir: ../validation/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - AttnLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 32, 100] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 8 diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 14c1cc9c..6d9ea190 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -197,16 +197,30 @@ class AttnLabelEncode(BaseRecLabelEncode): super(AttnLabelEncode, self).__init__(max_text_length, character_dict_path, character_type, use_space_char) - self.beg_str = "sos" - self.end_str = "eos" def add_special_char(self, dict_character): - dict_character = [self.beg_str, self.end_str] + dict_character + self.beg_str = "sos" + self.end_str = "eos" + dict_character = [self.beg_str] + dict_character + [self.end_str] return dict_character - def __call__(self, text): + def __call__(self, data): + text = data['label'] text = self.encode(text) - return text + if text is None: + return None + if len(text) > self.max_text_len: + return None + data['length'] = np.array(len(text)) + text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len + - len(text) - 1) + data['label'] = np.array(text) + return data + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] def get_beg_end_flag_idx(self, beg_or_end): if beg_or_end == "beg": diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 4673d35c..94314235 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -23,11 +23,14 @@ def build_loss(config): # rec loss from .rec_ctc_loss import CTCLoss + from .rec_att_loss import AttentionLoss # cls loss from .cls_loss import ClsLoss - support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss'] + support_dict = [ + 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss' + ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/rec_att_loss.py b/ppocr/losses/rec_att_loss.py new file mode 100644 index 00000000..6e2f6748 --- /dev/null +++ b/ppocr/losses/rec_att_loss.py @@ -0,0 +1,39 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn + + +class AttentionLoss(nn.Layer): + def __init__(self, **kwargs): + super(AttentionLoss, self).__init__() + self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none') + + def forward(self, predicts, batch): + targets = batch[1].astype("int64") + label_lengths = batch[2].astype('int64') + batch_size, num_steps, num_classes = predicts.shape[0], predicts.shape[ + 1], predicts.shape[2] + assert len(targets.shape) == len(list(predicts.shape)) - 1, \ + "The target's shape and inputs's shape is [N, d] and [N, num_steps]" + + inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]]) + targets = paddle.reshape(targets, [-1]) + + return {'loss': paddle.sum(self.loss_func(inputs, targets))} diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 78074709..29d0ba80 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -23,10 +23,13 @@ def build_head(config): # rec head from .rec_ctc_head import CTCHead + from .rec_att_head import AttentionHead # cls head from .cls_head import ClsHead - support_dict = ['DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead'] + support_dict = [ + 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead' + ] module_name = config.pop('name') assert module_name in support_dict, Exception('head only support {}'.format( diff --git a/ppocr/modeling/heads/rec_att_head.py b/ppocr/modeling/heads/rec_att_head.py new file mode 100644 index 00000000..d01f0e6c --- /dev/null +++ b/ppocr/modeling/heads/rec_att_head.py @@ -0,0 +1,211 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np +from paddle.jit import to_static + + +class AttentionHead(nn.Layer): + def __init__(self, in_channels, out_channels, hidden_size, **kwargs): + super(AttentionHead, self).__init__() + self.input_size = in_channels + self.hidden_size = hidden_size + self.num_classes = out_channels + + self.attention_cell = AttentionGRUCell( + in_channels, hidden_size, out_channels, use_gru=False) + self.generator = nn.Linear(hidden_size, out_channels) + + def _char_to_onehot(self, input_char, onehot_dim): + input_ont_hot = F.one_hot(input_char, onehot_dim) + return input_ont_hot + + def forward(self, inputs, targets=None, batch_max_length=25): + batch_size = inputs.shape[0] + num_steps = batch_max_length + + hidden = paddle.zeros((batch_size, self.hidden_size)) + output_hiddens = [] + + if targets is not None: + for i in range(num_steps): + char_onehots = self._char_to_onehot( + targets[:, i], onehot_dim=self.num_classes) + (outputs, hidden), alpha = self.attention_cell(hidden, inputs, + char_onehots) + output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) + output = paddle.concat(output_hiddens, axis=1) + probs = self.generator(output) + + else: + targets = paddle.zeros(shape=[batch_size], dtype="int32") + probs = None + + for i in range(num_steps): + char_onehots = self._char_to_onehot( + targets, onehot_dim=self.num_classes) + (outputs, hidden), alpha = self.attention_cell(hidden, inputs, + char_onehots) + probs_step = self.generator(outputs) + probs = paddle.unsqueeze( + probs_step, axis=1) if probs is None else paddle.concat( + [probs, paddle.unsqueeze( + probs_step, axis=1)], axis=1) + next_input = probs_step.argmax(axis=1) + targets = next_input + + return probs + + +class AttentionGRUCell(nn.Layer): + def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): + super(AttentionGRUCell, self).__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False) + self.h2h = nn.Linear(hidden_size, hidden_size) + self.score = nn.Linear(hidden_size, 1, bias_attr=False) + + self.rnn = nn.GRUCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_onehots): + + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1) + + res = paddle.add(batch_H_proj, prev_hidden_proj) + res = paddle.tanh(res) + e = self.score(res) + + alpha = F.softmax(e, axis=1) + alpha = paddle.transpose(alpha, [0, 2, 1]) + context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) + concat_context = paddle.concat([context, char_onehots], 1) + + cur_hidden = self.rnn(concat_context, prev_hidden) + + return cur_hidden, alpha + + +class AttentionLSTM(nn.Layer): + def __init__(self, in_channels, out_channels, hidden_size, **kwargs): + super(AttentionLSTM, self).__init__() + self.input_size = in_channels + self.hidden_size = hidden_size + self.num_classes = out_channels + + self.attention_cell = AttentionLSTMCell( + in_channels, hidden_size, out_channels, use_gru=False) + self.generator = nn.Linear(hidden_size, out_channels) + + def _char_to_onehot(self, input_char, onehot_dim): + input_ont_hot = F.one_hot(input_char, onehot_dim) + return input_ont_hot + + def forward(self, inputs, targets=None, batch_max_length=25): + batch_size = inputs.shape[0] + num_steps = batch_max_length + + hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros( + (batch_size, self.hidden_size))) + output_hiddens = [] + + if targets is not None: + for i in range(num_steps): + # one-hot vectors for a i-th char + char_onehots = self._char_to_onehot( + targets[:, i], onehot_dim=self.num_classes) + hidden, alpha = self.attention_cell(hidden, inputs, + char_onehots) + + hidden = (hidden[1][0], hidden[1][1]) + output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1)) + output = paddle.concat(output_hiddens, axis=1) + probs = self.generator(output) + + else: + targets = paddle.zeros(shape=[batch_size], dtype="int32") + probs = None + + for i in range(num_steps): + char_onehots = self._char_to_onehot( + targets, onehot_dim=self.num_classes) + hidden, alpha = self.attention_cell(hidden, inputs, + char_onehots) + probs_step = self.generator(hidden[0]) + hidden = (hidden[1][0], hidden[1][1]) + probs = paddle.unsqueeze( + probs_step, axis=1) if probs is None else paddle.concat( + [probs, paddle.unsqueeze( + probs_step, axis=1)], axis=1) + + next_input = probs_step.argmax(axis=1) + + targets = next_input + + return probs + + +class AttentionLSTMCell(nn.Layer): + def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): + super(AttentionLSTMCell, self).__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False) + self.h2h = nn.Linear(hidden_size, hidden_size) + self.score = nn.Linear(hidden_size, 1, bias_attr=False) + if not use_gru: + self.rnn = nn.LSTMCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + else: + self.rnn = nn.GRUCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_onehots): + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1) + res = paddle.add(batch_H_proj, prev_hidden_proj) + res = paddle.tanh(res) + e = self.score(res) + + alpha = F.softmax(e, axis=1) + alpha = paddle.transpose(alpha, [0, 2, 1]) + context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) + concat_context = paddle.concat([context, char_onehots], 1) + cur_hidden = self.rnn(concat_context, prev_hidden) + + return cur_hidden, alpha + + +if __name__ == '__main__': + paddle.disable_static() + + model = Attention(100, 200, 10) + + x = np.random.uniform(-1, 1, [2, 10, 100]).astype(np.float32) + y = np.random.randint(0, 10, [2, 21]).astype(np.int32) + + xp = paddle.to_tensor(x) + yp = paddle.to_tensor(y) + + res = model(inputs=xp, targets=yp, is_train=True, batch_max_length=20) + print("res: ", res.shape) diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index c9b42e08..2b8d00a9 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -30,7 +30,8 @@ def build_post_process(config, global_config=None): from .cls_postprocess import ClsPostProcess support_dict = [ - 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess' + 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', + 'AttnLabelDecode', 'ClsPostProcess', 'AttnLabelDecode' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 65ed4671..1ac35246 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -133,16 +133,52 @@ class AttnLabelDecode(BaseRecLabelDecode): **kwargs): super(AttnLabelDecode, self).__init__(character_dict_path, character_type, use_space_char) - self.beg_str = "sos" - self.end_str = "eos" def add_special_char(self, dict_character): - dict_character = [self.beg_str, self.end_str] + dict_character + self.beg_str = "sos" + self.end_str = "eos" + dict_character = dict_character + dict_character = [self.beg_str] + dict_character + [self.end_str] return dict_character - def __call__(self, text): + def __call__(self, preds, label=None, *args, **kwargs): + """ text = self.decode(text) - return text + if label is None: + return text + else: + label = self.decode(label, is_remove_duplicate=False) + return text, label + """ + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) + if label is None: + return text + label = self.decode(label, is_remove_duplicate=True) + return text, label + + def encoder(self, labels, labels_length): + """ + used to encoder labels readed from LMDB dataset, forexample: + [35, 25, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] encode to + 'you': [0, 35,25,31, 37, 0, ...] 'sos'you'eos' + """ + if isinstance(labels, paddle.Tensor): + labels = labels.numpy() + batch_max_length = labels.shape[ + 1] + 2 # add start token 'sos' and end token 'eos' + new_labels = np.zeros( + [labels.shape[0], batch_max_length]).astype(np.int64) + for i in range(labels.shape[0]): + new_labels[i, 1:1 + labels_length[i]] = labels[i, :labels_length[ + i]] # new_labels[i, 0] = 'sos' token + new_labels[i, labels_length[i] + 1] = len( + self.character) - 1 # add end charactor 'eos' token + return new_labels def get_ignored_tokens(self): beg_idx = self.get_beg_end_flag_idx("beg")