From b6f0a90366e1e4eb78317523a81626e5d40beff3 Mon Sep 17 00:00:00 2001 From: Topdu <784990967@qq.com> Date: Mon, 16 Aug 2021 11:33:15 +0000 Subject: [PATCH 01/18] add rec_nrtr --- configs/rec/rec_mtb_nrtr.yml | 100 +++ ppocr/data/imaug/label_ops.py | 28 + ppocr/modeling/heads/__init__.py | 5 +- ppocr/modeling/heads/multiheadAttention.py | 365 +++++++++ ppocr/modeling/heads/rec_nrtr_optim_head.py | 779 ++++++++++++++++++++ ppocr/postprocess/rec_postprocess.py | 63 ++ tools/eval.py | 2 - 7 files changed, 1338 insertions(+), 4 deletions(-) create mode 100644 configs/rec/rec_mtb_nrtr.yml create mode 100755 ppocr/modeling/heads/multiheadAttention.py create mode 100644 ppocr/modeling/heads/rec_nrtr_optim_head.py diff --git a/configs/rec/rec_mtb_nrtr.yml b/configs/rec/rec_mtb_nrtr.yml new file mode 100644 index 00000000..d5d36cfa --- /dev/null +++ b/configs/rec/rec_mtb_nrtr.yml @@ -0,0 +1,100 @@ +Global: + use_gpu: True + epoch_num: 21 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/nrtr_final/ + save_epoch_step: 1 + # evaluation is run every 2000 iterations + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words_en/word_10.png + # for data or label process + character_dict_path: + character_type: EN_symbol + max_text_length: 25 + infer_mode: False + use_space_char: True + save_res_path: ./output/rec/predicts_nrtr.txt + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.99 + clip_norm: 5.0 + lr: + name: Cosine + learning_rate: 0.0005 + warmup_epoch: 2 + regularizer: + name: 'L2' + factor: 0. + +Architecture: + model_type: rec + algorithm: NRTR + in_channels: 1 + Transform: + Backbone: + name: MTB + cnn_num: 2 + Head: + name: TransformerOptim + d_model: 512 + num_encoder_layers: 6 + beam_size: -1 # When Beam size is greater than 0, it means to use beam search when evaluation. + + +Loss: + name: NRTRLoss + smoothing: True + +PostProcess: + name: NRTRLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: /paddle/data/ocr_data/training/ + transforms: + - NRTRDecodeImage: # load image + img_mode: BGR + channel_first: False + - NRTRLabelEncode: # Class handling label + - PILResize: + image_shape: [100, 32] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 512 + drop_last: True + num_workers: 8 + +Eval: + dataset: + name: LMDBDataSet + data_dir: /paddle/data/ocr_data/evaluation/ + transforms: + - NRTRDecodeImage: # load image + img_mode: BGR + channel_first: False + - NRTRLabelEncode: # Class handling label + - PILResize: + image_shape: [100, 32] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 1 + use_shared_memory: False diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index bba3209f..39ff8930 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -159,6 +159,34 @@ class BaseRecLabelEncode(object): return text_list +class NRTRLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + character_type='EN_symbol', + use_space_char=False, + **kwargs): + + super(NRTRLabelEncode, + self).__init__(max_text_length, character_dict_path, + character_type, use_space_char) + def __call__(self, data): + text = data['label'] + text = self.encode(text) + if text is None: + return None + data['length'] = np.array(len(text)) + text.insert(0, 2) + text.append(3) + text = text + [0] * (self.max_text_len - len(text)) + data['label'] = np.array(text) + return data + def add_special_char(self, dict_character): + dict_character = ['blank','','',''] + dict_character + return dict_character + class CTCLabelEncode(BaseRecLabelEncode): """ Convert between text-label and text-index """ diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 4852c7f2..10acd0fa 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -26,12 +26,13 @@ def build_head(config): from .rec_ctc_head import CTCHead from .rec_att_head import AttentionHead from .rec_srn_head import SRNHead - + from .rec_nrtr_optim_head import TransformerOptim + # cls head from .cls_head import ClsHead support_dict = [ 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', - 'SRNHead', 'PGHead'] + 'SRNHead', 'PGHead', 'TransformerOptim'] module_name = config.pop('name') diff --git a/ppocr/modeling/heads/multiheadAttention.py b/ppocr/modeling/heads/multiheadAttention.py new file mode 100755 index 00000000..f18e9957 --- /dev/null +++ b/ppocr/modeling/heads/multiheadAttention.py @@ -0,0 +1,365 @@ +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle.nn import Linear +from paddle.nn.initializer import XavierUniform as xavier_uniform_ +from paddle.nn.initializer import Constant as constant_ +from paddle.nn.initializer import XavierNormal as xavier_normal_ + +zeros_ = constant_(value=0.) +ones_ = constant_(value=1.) + +class MultiheadAttention(nn.Layer): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model + num_heads: parallel attention layers, or heads + + Examples:: + + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + """ + + def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False): + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias) + + if add_bias_kv: + self.bias_k = self.create_parameter( + shape=(1, 1, embed_dim), default_initializer=zeros_) + self.add_parameter("bias_k", self.bias_k) + self.bias_v = self.create_parameter( + shape=(1, 1, embed_dim), default_initializer=zeros_) + self.add_parameter("bias_v", self.bias_v) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + self.conv1 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + self.conv2 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim * 2, kernel_size=(1, 1)) + self.conv3 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim * 3, kernel_size=(1, 1)) + + def _reset_parameters(self): + + + xavier_uniform_(self.out_proj.weight) + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, + need_weights=True, static_kv=False, attn_mask=None, qkv_ = [False,False,False]): + """ + Inputs of forward function + query: [target length, batch size, embed dim] + key: [sequence length, batch size, embed dim] + value: [sequence length, batch size, embed dim] + key_padding_mask: if True, mask padding based on batch size + incremental_state: if provided, previous time steps are cashed + need_weights: output attn_output_weights + static_kv: key and value are static + + Outputs of forward function + attn_output: [target length, batch size, embed dim] + attn_output_weights: [batch size, target length, sequence length] + """ + qkv_same = qkv_[0] + kv_same = qkv_[1] + + tgt_len, bsz, embed_dim = query.shape + assert embed_dim == self.embed_dim + assert list(query.shape) == [tgt_len, bsz, embed_dim] + assert key.shape == value.shape + + if qkv_same: + # self-attention + q, k, v = self._in_proj_qkv(query) + elif kv_same: + # encoder-decoder attention + q = self._in_proj_q(query) + if key is None: + assert value is None + k = v = None + else: + k, v = self._in_proj_kv(key) + else: + q = self._in_proj_q(query) + k = self._in_proj_k(key) + v = self._in_proj_v(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + self.bias_k = paddle.concat([self.bias_k for i in range(bsz)],axis=1) + self.bias_v = paddle.concat([self.bias_v for i in range(bsz)],axis=1) + k = paddle.concat([k, self.bias_k]) + v = paddle.concat([v, self.bias_v]) + if attn_mask is not None: + attn_mask = paddle.concat([attn_mask, paddle.zeros([attn_mask.shape[0], 1],dtype=attn_mask.dtype)], axis=1) + if key_padding_mask is not None: + key_padding_mask = paddle.concat( + [key_padding_mask,paddle.zeros([key_padding_mask.shape[0], 1],dtype=key_padding_mask.dtype)], axis=1) + + q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + if k is not None: + k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + if v is not None: + v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + + + + src_len = k.shape[1] + + if key_padding_mask is not None: + assert key_padding_mask.shape[0] == bsz + assert key_padding_mask.shape[1] == src_len + + if self.add_zero_attn: + src_len += 1 + k = paddle.concat([k, paddle.zeros((k.shape[0], 1) + k.shape[2:],dtype=k.dtype)], axis=1) + v = paddle.concat([v, paddle.zeros((v.shape[0], 1) + v.shape[2:],dtype=v.dtype)], axis=1) + if attn_mask is not None: + attn_mask = paddle.concat([attn_mask, paddle.zeros([attn_mask.shape[0], 1],dtype=attn_mask.dtype)], axis=1) + if key_padding_mask is not None: + key_padding_mask = paddle.concat( + [key_padding_mask, paddle.zeros([key_padding_mask.shape[0], 1],dtype=key_padding_mask.dtype)], axis=1) + attn_output_weights = paddle.bmm(q, k.transpose([0,2,1])) + assert list(attn_output_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_output_weights += attn_mask + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) + key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32') + y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf') + y = paddle.where(key==0.,key, y) + attn_output_weights += y + attn_output_weights = attn_output_weights.reshape([bsz*self.num_heads, tgt_len, src_len]) + + attn_output_weights = F.softmax( + attn_output_weights.astype('float32'), axis=-1, + dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype) + attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training) + + attn_output = paddle.bmm(attn_output_weights, v) + assert list(attn_output.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn_output = attn_output.transpose([1, 0,2]).reshape([tgt_len, bsz, embed_dim]) + attn_output = self.out_proj(attn_output) + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) + attn_output_weights = attn_output_weights.sum(axis=1) / self.num_heads + else: + attn_output_weights = None + + return attn_output, attn_output_weights + + def _in_proj_qkv(self, query): + query = query.transpose([1, 2, 0]) + query = paddle.unsqueeze(query, axis=2) + res = self.conv3(query) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res.chunk(3, axis=-1) + + def _in_proj_kv(self, key): + key = key.transpose([1, 2, 0]) + key = paddle.unsqueeze(key, axis=2) + res = self.conv2(key) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res.chunk(2, axis=-1) + + def _in_proj_q(self, query): + query = query.transpose([1, 2, 0]) + query = paddle.unsqueeze(query, axis=2) + res = self.conv1(query) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + def _in_proj_k(self, key): + + key = key.transpose([1, 2, 0]) + key = paddle.unsqueeze(key, axis=2) + res = self.conv1(key) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + def _in_proj_v(self, value): + + value = value.transpose([1,2,0])#(1, 2, 0) + value = paddle.unsqueeze(value, axis=2) + res = self.conv1(value) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + + +class MultiheadAttentionOptim(nn.Layer): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model + num_heads: parallel attention layers, or heads + + Examples:: + + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + """ + + def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False): + super(MultiheadAttentionOptim, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias) + + self._reset_parameters() + + self.conv1 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + self.conv2 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + self.conv3 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + + def _reset_parameters(self): + + + xavier_uniform_(self.out_proj.weight) + + + def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, + need_weights=True, static_kv=False, attn_mask=None): + """ + Inputs of forward function + query: [target length, batch size, embed dim] + key: [sequence length, batch size, embed dim] + value: [sequence length, batch size, embed dim] + key_padding_mask: if True, mask padding based on batch size + incremental_state: if provided, previous time steps are cashed + need_weights: output attn_output_weights + static_kv: key and value are static + + Outputs of forward function + attn_output: [target length, batch size, embed dim] + attn_output_weights: [batch size, target length, sequence length] + """ + + + tgt_len, bsz, embed_dim = query.shape + assert embed_dim == self.embed_dim + assert list(query.shape) == [tgt_len, bsz, embed_dim] + assert key.shape == value.shape + + q = self._in_proj_q(query) + k = self._in_proj_k(key) + v = self._in_proj_v(value) + q *= self.scaling + + + q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + + + src_len = k.shape[1] + + if key_padding_mask is not None: + assert key_padding_mask.shape[0] == bsz + assert key_padding_mask.shape[1] == src_len + + + attn_output_weights = paddle.bmm(q, k.transpose([0,2,1])) + assert list(attn_output_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_output_weights += attn_mask + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) + key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32') + + y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf') + + y = paddle.where(key==0.,key, y) + + attn_output_weights += y + attn_output_weights = attn_output_weights.reshape([bsz*self.num_heads, tgt_len, src_len]) + + attn_output_weights = F.softmax( + attn_output_weights.astype('float32'), axis=-1, + dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype) + attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training) + + attn_output = paddle.bmm(attn_output_weights, v) + assert list(attn_output.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn_output = attn_output.transpose([1, 0,2]).reshape([tgt_len, bsz, embed_dim]) + attn_output = self.out_proj(attn_output) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) + attn_output_weights = attn_output_weights.sum(axis=1) / self.num_heads + else: + attn_output_weights = None + + return attn_output, attn_output_weights + + + def _in_proj_q(self, query): + query = query.transpose([1, 2, 0]) + query = paddle.unsqueeze(query, axis=2) + res = self.conv1(query) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + def _in_proj_k(self, key): + + key = key.transpose([1, 2, 0]) + key = paddle.unsqueeze(key, axis=2) + res = self.conv2(key) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + def _in_proj_v(self, value): + + value = value.transpose([1,2,0])#(1, 2, 0) + value = paddle.unsqueeze(value, axis=2) + res = self.conv3(value) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res \ No newline at end of file diff --git a/ppocr/modeling/heads/rec_nrtr_optim_head.py b/ppocr/modeling/heads/rec_nrtr_optim_head.py new file mode 100644 index 00000000..b9a5100a --- /dev/null +++ b/ppocr/modeling/heads/rec_nrtr_optim_head.py @@ -0,0 +1,779 @@ +import math +import paddle +import copy +from paddle import nn +import paddle.nn.functional as F +from paddle.nn import LayerList +from paddle.nn.initializer import XavierNormal as xavier_uniform_ +from paddle.nn import Dropout, Linear, LayerNorm, Conv2D +import numpy as np +from ppocr.modeling.heads.multiheadAttention import MultiheadAttentionOptim +from paddle.nn.initializer import Constant as constant_ +from paddle.nn.initializer import XavierNormal as xavier_normal_ + +zeros_ = constant_(value=0.) +ones_ = constant_(value=1.) + +class TransformerOptim(nn.Layer): + r"""A transformer model. User is able to modify the attributes as needed. The architechture + is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, + Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and + Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information + Processing Systems, pages 6000-6010. + + Args: + d_model: the number of expected features in the encoder/decoder inputs (default=512). + nhead: the number of heads in the multiheadattention models (default=8). + num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). + num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + custom_encoder: custom encoder (default=None). + custom_decoder: custom decoder (default=None). + + Examples:: + >>> transformer_model = nn.Transformer(src_vocab, tgt_vocab) + >>> transformer_model = nn.Transformer(src_vocab, tgt_vocab, nhead=16, num_encoder_layers=12) + """ + + def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, beam_size=0, + num_decoder_layers=6, dim_feedforward=1024, attention_dropout_rate=0.0, residual_dropout_rate=0.1, + custom_encoder=None, custom_decoder=None,in_channels=0,out_channels=0,dst_vocab_size=99,scale_embedding=True): + super(TransformerOptim, self).__init__() + self.embedding = Embeddings( + d_model=d_model, + vocab=dst_vocab_size, + padding_idx=0, + scale_embedding=scale_embedding + ) + self.positional_encoding = PositionalEncoding( + dropout=residual_dropout_rate, + dim=d_model, + ) + if custom_encoder is not None: + self.encoder = custom_encoder + else: + if num_encoder_layers > 0 : + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, attention_dropout_rate, residual_dropout_rate) + + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers) + else: + self.encoder = None + + if custom_decoder is not None: + self.decoder = custom_decoder + else: + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, attention_dropout_rate, residual_dropout_rate) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers) + + self._reset_parameters() + self.beam_size = beam_size + self.d_model = d_model + self.nhead = nhead + self.tgt_word_prj = nn.Linear(d_model, dst_vocab_size, bias_attr=False) + w0 = np.random.normal(0.0, d_model**-0.5,(d_model, dst_vocab_size)).astype(np.float32) + self.tgt_word_prj.weight.set_value(w0) + self.apply(self._init_weights) + + + def _init_weights(self, m): + + if isinstance(m, nn.Conv2D): + xavier_normal_(m.weight) + if m.bias is not None: + zeros_(m.bias) + + def forward_train(self,src,tgt): + tgt = tgt[:, :-1] + + + + tgt_key_padding_mask = self.generate_padding_mask(tgt) + tgt = self.embedding(tgt).transpose([1, 0, 2]) + tgt = self.positional_encoding(tgt) + tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0]) + + if self.encoder is not None : + src = self.positional_encoding(src.transpose([1, 0, 2])) + memory = self.encoder(src) + else: + memory = src.squeeze(2).transpose([2, 0, 1]) + output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=None, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=None) + output = output.transpose([1, 0, 2]) + logit = self.tgt_word_prj(output) + return logit + + def forward(self, src, tgt=None): + r"""Take in and process masked source/target sequences. + + Args: + src: the sequence to the encoder (required). + tgt: the sequence to the decoder (required). + src_mask: the additive mask for the src sequence (optional). + tgt_mask: the additive mask for the tgt sequence (optional). + memory_mask: the additive mask for the encoder output (optional). + src_key_padding_mask: the ByteTensor mask for src keys per batch (optional). + tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional). + memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional). + + Shape: + - src: :math:`(S, N, E)`. + - tgt: :math:`(T, N, E)`. + - src_mask: :math:`(S, S)`. + - tgt_mask: :math:`(T, T)`. + - memory_mask: :math:`(T, S)`. + - src_key_padding_mask: :math:`(N, S)`. + - tgt_key_padding_mask: :math:`(N, T)`. + - memory_key_padding_mask: :math:`(N, S)`. + + Note: [src/tgt/memory]_mask should be filled with + float('-inf') for the masked positions and float(0.0) else. These masks + ensure that predictions for position i depend only on the unmasked positions + j and are applied identically for each sequence in a batch. + [src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions + that should be masked with float('-inf') and False values will be unchanged. + This mask ensures that no information will be taken from position i if + it is masked, and has a separate mask for each sequence in a batch. + + - output: :math:`(T, N, E)`. + + Note: Due to the multi-head attention architecture in the transformer model, + the output sequence length of a transformer is same as the input sequence + (i.e. target) length of the decode. + + where S is the source sequence length, T is the target sequence length, N is the + batch size, E is the feature number + + Examples: + >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) + """ + if tgt is not None: + return self.forward_train(src, tgt) + else: + if self.beam_size > 0 : + return self.forward_beam(src) + else: + return self.forward_test(src) + + def forward_test(self, src): + bs = src.shape[0] + if self.encoder is not None : + src = self.positional_encoding(src.transpose([1, 0, 2])) + memory = self.encoder(src) + else: + memory = src.squeeze(2).transpose([2, 0, 1]) + dec_seq = paddle.full((bs,1), 2, dtype=paddle.int64) + for len_dec_seq in range(1, 25): + src_enc = memory.clone() + tgt_key_padding_mask = self.generate_padding_mask(dec_seq) + dec_seq_embed = self.embedding(dec_seq).transpose([1, 0, 2]) + dec_seq_embed = self.positional_encoding(dec_seq_embed) + tgt_mask = self.generate_square_subsequent_mask(dec_seq_embed.shape[0]) + output = self.decoder(dec_seq_embed, src_enc, tgt_mask=tgt_mask, memory_mask=None, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=None) + dec_output = output.transpose([1, 0, 2]) + + dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h + word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1) + word_prob = word_prob.reshape([1, bs, -1]) + preds_idx = word_prob.argmax(axis=2) + + if paddle.equal_all(preds_idx[-1],paddle.full(preds_idx[-1].shape,3,dtype='int64')): + break + + preds_prob = word_prob.max(axis=2) + dec_seq = paddle.concat([dec_seq,preds_idx.reshape([-1,1])],axis=1) + + return dec_seq + + def forward_beam(self,images): + + ''' Translation work in one batch ''' + + def get_inst_idx_to_tensor_position_map(inst_idx_list): + ''' Indicate the position of an instance in a tensor. ''' + return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)} + + def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): + ''' Collect tensor parts associated to active instances. ''' + + _, *d_hs = beamed_tensor.shape + n_curr_active_inst = len(curr_active_inst_idx) + new_shape = (n_curr_active_inst * n_bm, *d_hs) + + beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])#contiguous() + beamed_tensor = beamed_tensor.index_select(paddle.to_tensor(curr_active_inst_idx),axis=0) + beamed_tensor = beamed_tensor.reshape([*new_shape]) + + return beamed_tensor + + + def collate_active_info( + src_enc, inst_idx_to_position_map, active_inst_idx_list): + # Sentences which are still active are collected, + # so the decoder will not run on completed sentences. + + n_prev_active_inst = len(inst_idx_to_position_map) + active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list] + active_inst_idx = paddle.to_tensor(active_inst_idx, dtype='int64') + active_src_enc = collect_active_part(src_enc.transpose([1, 0, 2]), active_inst_idx, n_prev_active_inst, n_bm).transpose([1, 0, 2]) + active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) + return active_src_enc, active_inst_idx_to_position_map + + def beam_decode_step( + inst_dec_beams, len_dec_seq, enc_output, inst_idx_to_position_map, n_bm, memory_key_padding_mask): + ''' Decode and update beam status, and then return active beam idx ''' + + def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): + dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done] + dec_partial_seq = paddle.stack(dec_partial_seq) + + dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq]) + return dec_partial_seq + + def prepare_beam_memory_key_padding_mask(inst_dec_beams, memory_key_padding_mask, n_bm): + keep = [] + for idx in (memory_key_padding_mask): + if not inst_dec_beams[idx].done: + keep.append(idx) + memory_key_padding_mask = memory_key_padding_mask[paddle.to_tensor(keep)] + len_s = memory_key_padding_mask.shape[-1] + n_inst = memory_key_padding_mask.shape[0] + memory_key_padding_mask = paddle.concat([memory_key_padding_mask for i in range(n_bm)],axis=1) + memory_key_padding_mask = memory_key_padding_mask.reshape([n_inst * n_bm, len_s])#repeat(1, n_bm) + return memory_key_padding_mask + + def predict_word(dec_seq, enc_output, n_active_inst, n_bm, memory_key_padding_mask): + tgt_key_padding_mask = self.generate_padding_mask(dec_seq) + dec_seq = self.embedding(dec_seq).transpose([1, 0, 2]) + dec_seq = self.positional_encoding(dec_seq) + tgt_mask = self.generate_square_subsequent_mask(dec_seq.shape[0]) + dec_output = self.decoder( + dec_seq, enc_output, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ).transpose([1, 0, 2]) + dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h + word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1) + word_prob = word_prob.reshape([n_active_inst, n_bm, -1]) + return word_prob + + def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): + active_inst_idx_list = [] + for inst_idx, inst_position in inst_idx_to_position_map.items(): + is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position]) + if not is_inst_complete: + active_inst_idx_list += [inst_idx] + + return active_inst_idx_list + + n_active_inst = len(inst_idx_to_position_map) + dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) + memory_key_padding_mask = None + word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm, memory_key_padding_mask) + # Update the beam with predicted word prob information and collect incomplete instances + active_inst_idx_list = collect_active_inst_idx_list( + inst_dec_beams, word_prob, inst_idx_to_position_map) + return active_inst_idx_list + + def collect_hypothesis_and_scores(inst_dec_beams, n_best): + all_hyp, all_scores = [], [] + for inst_idx in range(len(inst_dec_beams)): + scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() + all_scores += [scores[:n_best]] + hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]] + all_hyp += [hyps] + return all_hyp, all_scores + + with paddle.no_grad(): + #-- Encode + + if self.encoder is not None : + src = self.positional_encoding(images.transpose([1, 0, 2])) + src_enc = self.encoder(src).transpose([1, 0, 2]) + else: + src_enc = images.squeeze(2).transpose([0, 2, 1]) + + #-- Repeat data for beam search + n_bm = self.beam_size + n_inst, len_s, d_h = src_enc.shape + src_enc = paddle.concat([src_enc for i in range(n_bm)],axis=1) + src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose([1, 0, 2])#repeat(1, n_bm, 1) + #-- Prepare beams + inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)] + + #-- Bookkeeping for active or not + active_inst_idx_list = list(range(n_inst)) + inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) + #-- Decode + for len_dec_seq in range(1, 25): + src_enc_copy = src_enc.clone() + active_inst_idx_list = beam_decode_step( + inst_dec_beams, len_dec_seq, src_enc_copy, inst_idx_to_position_map, n_bm, None) + if not active_inst_idx_list: + break # all instances have finished their path to + src_enc, inst_idx_to_position_map = collate_active_info( + src_enc_copy, inst_idx_to_position_map, active_inst_idx_list) + batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, 1) + result_hyp = [] + for bs_hyp in batch_hyp: + bs_hyp_pad =bs_hyp[0]+[3]*(25-len(bs_hyp[0])) + result_hyp.append(bs_hyp_pad) + return paddle.to_tensor(np.array(result_hyp),dtype=paddle.int64) + + def generate_square_subsequent_mask(self, sz): + r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). + Unmasked positions are filled with float(0.0). + """ + mask = paddle.zeros([sz, sz],dtype='float32') + mask_inf = paddle.triu(paddle.full(shape=[sz,sz], dtype='float32', fill_value='-inf'),diagonal=1) + mask = mask+mask_inf + return mask + + def generate_padding_mask(self, x): + padding_mask = x.equal(paddle.to_tensor(0,dtype=x.dtype)) + return padding_mask + + def _reset_parameters(self): + r"""Initiate parameters in the transformer model.""" + + for p in self.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + + +class TransformerEncoder(nn.Layer): + r"""TransformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model, nhead) + >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers) + """ + + def __init__(self, encoder_layer, num_layers): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + + def forward(self, src): + r"""Pass the input through the endocder layers in turn. + + Args: + src: the sequnce to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + output = src + + for i in range(self.num_layers): + output = self.layers[i](output, src_mask=None, + src_key_padding_mask=None) + + return output + + +class TransformerDecoder(nn.Layer): + r"""TransformerDecoder is a stack of N decoder layers + + Args: + decoder_layer: an instance of the TransformerDecoderLayer() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model, nhead) + >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers) + """ + + def __init__(self, decoder_layer, num_layers): + super(TransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + + + def forward(self, tgt, memory, tgt_mask=None, + memory_mask=None, tgt_key_padding_mask=None, + memory_key_padding_mask=None): + r"""Pass the inputs (and mask) through the decoder layer in turn. + + Args: + tgt: the sequence to the decoder (required). + memory: the sequnce from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + output = tgt + for i in range(self.num_layers): + output = self.layers[i](output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) + + return output + +class TransformerEncoderLayer(nn.Layer): + r"""TransformerEncoderLayer is made up of self-attn and feedforward network. + This standard encoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model, nhead) + """ + + def __init__(self, d_model, nhead, dim_feedforward=2048, attention_dropout_rate=0.0, residual_dropout_rate=0.1): + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiheadAttentionOptim(d_model, nhead, dropout=attention_dropout_rate) + + self.conv1 = Conv2D(in_channels=d_model, out_channels=dim_feedforward, kernel_size=(1, 1)) + self.conv2 = Conv2D(in_channels=dim_feedforward, out_channels=d_model, kernel_size=(1, 1)) + + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.dropout1 = Dropout(residual_dropout_rate) + self.dropout2 = Dropout(residual_dropout_rate) + + def forward(self, src, src_mask=None, src_key_padding_mask=None): + r"""Pass the input through the endocder layer. + + Args: + src: the sequnce to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + src2 = self.self_attn(src, src, src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + + src = src.transpose([1, 2, 0]) + src = paddle.unsqueeze(src, 2) + src2 = self.conv2(F.relu(self.conv1(src))) + src2 = paddle.squeeze(src2, 2) + src2 = src2.transpose([2, 0, 1]) + src = paddle.squeeze(src, 2) + src = src.transpose([2, 0, 1]) + + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + +class TransformerDecoderLayer(nn.Layer): + r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. + This standard decoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model, nhead) + """ + + def __init__(self, d_model, nhead, dim_feedforward=2048, attention_dropout_rate=0.0, residual_dropout_rate=0.1): + super(TransformerDecoderLayer, self).__init__() + self.self_attn = MultiheadAttentionOptim(d_model, nhead, dropout=attention_dropout_rate) + self.multihead_attn = MultiheadAttentionOptim(d_model, nhead, dropout=attention_dropout_rate) + + self.conv1 = Conv2D(in_channels=d_model, out_channels=dim_feedforward, kernel_size=(1, 1)) + self.conv2 = Conv2D(in_channels=dim_feedforward, out_channels=d_model, kernel_size=(1, 1)) + + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.norm3 = LayerNorm(d_model) + self.dropout1 = Dropout(residual_dropout_rate) + self.dropout2 = Dropout(residual_dropout_rate) + self.dropout3 = Dropout(residual_dropout_rate) + + def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, + tgt_key_padding_mask=None, memory_key_padding_mask=None): + r"""Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequnce from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # default + tgt = tgt.transpose([1, 2, 0]) + tgt = paddle.unsqueeze(tgt, 2) + tgt2 = self.conv2(F.relu(self.conv1(tgt))) + tgt2 = paddle.squeeze(tgt2, 2) + tgt2 = tgt2.transpose([2, 0, 1]) + tgt = paddle.squeeze(tgt, 2) + tgt = tgt.transpose([2, 0, 1]) + + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + +def _get_clones(module, N): + return LayerList([copy.deepcopy(module) for i in range(N)]) + + + +class PositionalEncoding(nn.Layer): + r"""Inject some information about the relative or absolute position of the tokens + in the sequence. The positional encodings have the same dimension as + the embeddings, so that the two can be summed. Here, we use sine and cosine + functions of different frequencies. + .. math:: + \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) + \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) + \text{where pos is the word position and i is the embed idx) + Args: + d_model: the embed dim (required). + dropout: the dropout value (default=0.1). + max_len: the max. length of the incoming sequence (default=5000). + Examples: + >>> pos_encoder = PositionalEncoding(d_model) + """ + + def __init__(self, dropout, dim, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = paddle.zeros([max_len, dim]) + position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1) + div_term = paddle.exp(paddle.arange(0, dim, 2).astype('float32') * (-math.log(10000.0) / dim)) + pe[:, 0::2] = paddle.sin(position * div_term) + pe[:, 1::2] = paddle.cos(position * div_term) + pe = pe.unsqueeze(0) + pe = pe.transpose([1, 0, 2]) + self.register_buffer('pe', pe) + + def forward(self, x): + r"""Inputs of forward function + Args: + x: the sequence fed to the positional encoder model (required). + Shape: + x: [sequence length, batch size, embed dim] + output: [sequence length, batch size, embed dim] + Examples: + >>> output = pos_encoder(x) + """ + x = x + self.pe[:x.shape[0], :] + return self.dropout(x) + + +class PositionalEncoding_2d(nn.Layer): + r"""Inject some information about the relative or absolute position of the tokens + in the sequence. The positional encodings have the same dimension as + the embeddings, so that the two can be summed. Here, we use sine and cosine + functions of different frequencies. + .. math:: + \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) + \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) + \text{where pos is the word position and i is the embed idx) + Args: + d_model: the embed dim (required). + dropout: the dropout value (default=0.1). + max_len: the max. length of the incoming sequence (default=5000). + Examples: + >>> pos_encoder = PositionalEncoding(d_model) + """ + + def __init__(self, dropout, dim, max_len=5000): + super(PositionalEncoding_2d, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = paddle.zeros([max_len, dim]) + position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1) + div_term = paddle.exp(paddle.arange(0, dim, 2).astype('float32') * (-math.log(10000.0) / dim)) + pe[:, 0::2] = paddle.sin(position * div_term) + pe[:, 1::2] = paddle.cos(position * div_term) + pe = pe.unsqueeze(0).transpose([1, 0, 2]) + self.register_buffer('pe', pe) + + self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1)) + self.linear1 = nn.Linear(dim, dim) + self.linear1.weight.data.fill_(1.) + self.avg_pool_2 = nn.AdaptiveAvgPool2D((1, 1)) + self.linear2 = nn.Linear(dim, dim) + self.linear2.weight.data.fill_(1.) + + def forward(self, x): + r"""Inputs of forward function + Args: + x: the sequence fed to the positional encoder model (required). + Shape: + x: [sequence length, batch size, embed dim] + output: [sequence length, batch size, embed dim] + Examples: + >>> output = pos_encoder(x) + """ + w_pe = self.pe[:x.shape[-1], :] + w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0) + w_pe = w_pe * w1 + w_pe = w_pe.transpose([1, 2, 0]) + w_pe = w_pe.unsqueeze(2) + + h_pe = self.pe[:x.shape[-2], :] + w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0) + h_pe = h_pe * w2 + h_pe = h_pe.transpose([1, 2, 0]) + h_pe = h_pe.unsqueeze(3) + + x = x + w_pe + h_pe + x = x.reshape([x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).transpose([2,0,1]) + + return self.dropout(x) + + +class Embeddings(nn.Layer): + def __init__(self, d_model, vocab, padding_idx, scale_embedding): + super(Embeddings, self).__init__() + self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx) + w0 = np.random.normal(0.0, d_model**-0.5,(vocab, d_model)).astype(np.float32) + self.embedding.weight.set_value(w0) + self.d_model = d_model + self.scale_embedding = scale_embedding + + def forward(self, x): + if self.scale_embedding: + x = self.embedding(x) + return x * math.sqrt(self.d_model) + return self.embedding(x) + + + + + +class Beam(): + ''' Beam search ''' + + def __init__(self, size, device=False): + + self.size = size + self._done = False + # The score for each translation on the beam. + self.scores = paddle.zeros((size,), dtype=paddle.float32) + self.all_scores = [] + # The backpointers at each time-step. + self.prev_ks = [] + # The outputs at each time-step. + self.next_ys = [paddle.full((size,), 0, dtype=paddle.int64)] + self.next_ys[0][0] = 2 + + def get_current_state(self): + "Get the outputs for the current timestep." + return self.get_tentative_hypothesis() + + def get_current_origin(self): + "Get the backpointers for the current timestep." + return self.prev_ks[-1] + + @property + def done(self): + return self._done + + def advance(self, word_prob): + "Update beam status and check if finished or not." + num_words = word_prob.shape[1] + + # Sum the previous scores. + if len(self.prev_ks) > 0: + beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob) + else: + beam_lk = word_prob[0] + + flat_beam_lk = beam_lk.reshape([-1]) + best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort + self.all_scores.append(self.scores) + self.scores = best_scores + + # bestScoresId is flattened as a (beam x word) array, + # so we need to calculate which word and beam each score came from + prev_k = best_scores_id // num_words + self.prev_ks.append(prev_k) + + self.next_ys.append(best_scores_id - prev_k * num_words) + + # End condition is when top-of-beam is EOS. + if self.next_ys[-1][0] == 3 : + self._done = True + self.all_scores.append(self.scores) + + + return self._done + + def sort_scores(self): + "Sort the scores." + return self.scores, paddle.to_tensor([i for i in range(self.scores.shape[0])],dtype='int32') + + def get_the_best_score_and_idx(self): + "Get the score of the best in the beam." + scores, ids = self.sort_scores() + return scores[1], ids[1] + + def get_tentative_hypothesis(self): + "Get the decoded sequence for the current timestep." + + if len(self.next_ys) == 1: + dec_seq = self.next_ys[0].unsqueeze(1) + else: + _, keys = self.sort_scores() + hyps = [self.get_hypothesis(k) for k in keys] + hyps = [[2] + h for h in hyps] + dec_seq = paddle.to_tensor(hyps, dtype='int64') + + return dec_seq + + def get_hypothesis(self, k): + """ Walk back to construct the full hypothesis. """ + hyp = [] + for j in range(len(self.prev_ks) - 1, -1, -1): + hyp.append(self.next_ys[j+1][k]) + k = self.prev_ks[j][k] + return list(map(lambda x: x.item(), hyp[::-1])) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index ae5470a5..7350b4ec 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -156,6 +156,69 @@ class DistillationCTCLabelDecode(CTCLabelDecode): return output +class NRTRLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, + character_dict_path=None, + character_type='EN_symbol', + use_space_char=True, + **kwargs): + super(NRTRLabelDecode, self).__init__(character_dict_path, + character_type, use_space_char) + + def __call__(self, preds, label=None, *args, **kwargs): + if preds.dtype == paddle.int64: + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + if preds[0][0]==2: + preds_idx = preds[:,1:] + else: + preds_idx = preds + + text = self.decode(preds_idx) + if label is None: + return text + label = self.decode(label[:,1:]) + else: + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label[:,1:]) + return text, label + + def add_special_char(self, dict_character): + dict_character = ['blank','','',''] + dict_character + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] == 3: # end + break + try: + char_list.append(self.character[int(text_index[batch_idx][idx])]) + except: + continue + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text.lower(), np.mean(conf_list))) + return result_list + + + class AttnLabelDecode(BaseRecLabelDecode): """ Convert between text-label and text-index """ diff --git a/tools/eval.py b/tools/eval.py index 66eb315f..d26f2a04 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -22,7 +22,6 @@ import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) - from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process @@ -31,7 +30,6 @@ from ppocr.utils.save_load import init_model from ppocr.utils.utility import print_dict import tools.program as program - def main(): global_config = config['Global'] # build dataloader From 190dff57e131b322a2f2342516821f28c8b40ec3 Mon Sep 17 00:00:00 2001 From: topduke <784990967@qq.com> Date: Tue, 17 Aug 2021 20:51:01 +0800 Subject: [PATCH 02/18] Update eval.py --- tools/eval.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tools/eval.py b/tools/eval.py index d26f2a04..66eb315f 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -22,6 +22,7 @@ import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process @@ -30,6 +31,7 @@ from ppocr.utils.save_load import init_model from ppocr.utils.utility import print_dict import tools.program as program + def main(): global_config = config['Global'] # build dataloader From 1623c17cdccd6108ebe68ee7cef2ffbae1a1cbf3 Mon Sep 17 00:00:00 2001 From: Topdu <784990967@qq.com> Date: Wed, 11 Aug 2021 09:50:51 +0000 Subject: [PATCH 03/18] add rec_nrtr --- configs/rec/rec_mtb_nrtr.yml | 16 + doc/doc_ch/algorithm_overview.md | 2 + doc/doc_ch/recognition.md | 1 + doc/doc_en/algorithm_overview_en.md | 1 + doc/doc_en/recognition_en.md | 2 +- ppocr/data/imaug/__init__.py | 2 +- ppocr/data/imaug/label_ops.py | 2 +- ppocr/data/imaug/operators.py | 32 ++ ppocr/data/imaug/rec_img_aug.py | 30 +- ppocr/losses/__init__.py | 6 +- ppocr/losses/rec_nrtr_loss.py | 38 ++ ppocr/metrics/rec_metric.py | 7 +- ppocr/modeling/architectures/base_model.py | 2 +- ppocr/modeling/backbones/__init__.py | 5 +- .../modeling/backbones/multiheadAttention.py | 365 ++++++++++++++++++ ppocr/modeling/backbones/rec_nrtr_mtb.py | 28 ++ ppocr/modeling/heads/rec_nrtr_optim_head.py | 4 + ppocr/modeling/necks/__init__.py | 2 +- ppocr/postprocess/__init__.py | 5 +- ppocr/postprocess/rec_postprocess.py | 5 +- ppocr/utils/dict_99.txt | 95 +++++ tools/eval.py | 2 + tools/program.py | 9 +- 23 files changed, 638 insertions(+), 23 deletions(-) create mode 100644 ppocr/losses/rec_nrtr_loss.py create mode 100755 ppocr/modeling/backbones/multiheadAttention.py create mode 100644 ppocr/modeling/backbones/rec_nrtr_mtb.py create mode 100644 ppocr/utils/dict_99.txt diff --git a/configs/rec/rec_mtb_nrtr.yml b/configs/rec/rec_mtb_nrtr.yml index d5d36cfa..86a833c5 100644 --- a/configs/rec/rec_mtb_nrtr.yml +++ b/configs/rec/rec_mtb_nrtr.yml @@ -3,22 +3,38 @@ Global: epoch_num: 21 log_smooth_window: 20 print_batch_step: 10 +<<<<<<< HEAD save_model_dir: ./output/rec/nrtr_final/ save_epoch_step: 1 # evaluation is run every 2000 iterations eval_batch_step: [0, 2000] cal_metric_during_train: True +======= + save_model_dir: ./output/rec/piloptimnrtr/ + save_epoch_step: 1 + # evaluation is run every 2000 iterations + eval_batch_step: [0, 2000] + cal_metric_during_train: False +>>>>>>> 9c67a7f... add rec_nrtr pretrained_model: checkpoints: save_inference_dir: use_visualdl: False infer_img: doc/imgs_words_en/word_10.png # for data or label process +<<<<<<< HEAD character_dict_path: character_type: EN_symbol max_text_length: 25 infer_mode: False use_space_char: True +======= + character_dict_path: ppocr/utils/dict_99.txt + character_type: dict_99 + max_text_length: 25 + infer_mode: False + use_space_char: False +>>>>>>> 9c67a7f... add rec_nrtr save_res_path: ./output/rec/predicts_nrtr.txt Optimizer: diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index 19d7a69c..9c352549 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -44,6 +44,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] - [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] +- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2) 参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -58,6 +59,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: |RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)| |RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) | +|NRTR|NRTR_MTB| 84.1% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) | PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。 diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md index 0f860065..6ce3003c 100644 --- a/doc/doc_ch/recognition.md +++ b/doc/doc_ch/recognition.md @@ -215,6 +215,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t | rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att | | rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att | | rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | +| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder | 训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件: diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index d70f99bb..fed9cf44 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -60,5 +60,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)| |RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)| +|NRTR|NRTR_MTB| 84.1% | rec_mtb_nrtr | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) | Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md) diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md index e23166e0..7a5e827d 100644 --- a/doc/doc_en/recognition_en.md +++ b/doc/doc_en/recognition_en.md @@ -207,7 +207,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend | rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att | | rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att | | rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | - +| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder | For training Chinese data, it is recommended to use [rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file: diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index a808fd58..9f175382 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap from .make_shrink_map import MakeShrinkMap from .random_crop_data import EastRandomCropData, PSERandomCrop -from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg +from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, PILResize, CVResize from .randaugment import RandAugment from .operators import * from .label_ops import * diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 39ff8930..a233738c 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -96,7 +96,7 @@ class BaseRecLabelEncode(object): 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean', 'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', - 'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari' + 'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari','dict_99' ] assert character_type in support_character_type, "Only {} are supported now but get {}".format( support_character_type, character_type) diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index 9c48b096..950c9377 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -57,6 +57,38 @@ class DecodeImage(object): return data +class NRTRDecodeImage(object): + """ decode image """ + + def __init__(self, img_mode='RGB', channel_first=False, **kwargs): + self.img_mode = img_mode + self.channel_first = channel_first + + def __call__(self, data): + img = data['image'] + if six.PY2: + assert type(img) is str and len( + img) > 0, "invalid input 'img' in DecodeImage" + else: + assert type(img) is bytes and len( + img) > 0, "invalid input 'img' in DecodeImage" + img = np.frombuffer(img, dtype='uint8') + + img = cv2.imdecode(img, 1) + + if img is None: + return None + if self.img_mode == 'GRAY': + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif self.img_mode == 'RGB': + assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape) + img = img[:, :, ::-1] + img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) + if self.channel_first: + img = img.transpose((2, 0, 1)) + data['image'] = img + return data + class NormalizeImage(object): """ normalize image such as substract mean, divide std """ diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 28e6bd0b..13a5c71d 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -16,7 +16,7 @@ import math import cv2 import numpy as np import random - +from PIL import Image from .text_image_aug import tia_perspective, tia_stretch, tia_distort @@ -42,6 +42,34 @@ class ClsResizeImg(object): data['image'] = norm_img return data +class PILResize(object): + def __init__(self, image_shape, **kwargs): + self.image_shape = image_shape + + def __call__(self, data): + img = data['image'] + image_pil = Image.fromarray(np.uint8(img)) + norm_img = image_pil.resize(self.image_shape, Image.ANTIALIAS) + norm_img = np.array(norm_img) + norm_img = np.expand_dims(norm_img, -1) + norm_img = norm_img.transpose((2, 0, 1)) + data['image'] = norm_img.astype(np.float32) / 128. - 1. + return data + + +class CVResize(object): + def __init__(self, image_shape, **kwargs): + self.image_shape = image_shape + + def __call__(self, data): + img = data['image'] + #print(img) + norm_img = cv2.resize(img,self.image_shape) + norm_img = np.expand_dims(norm_img, -1) + norm_img = norm_img.transpose((2, 0, 1)) + data['image'] = norm_img.astype(np.float32) / 128. - 1. + return data + class RecResizeImg(object): def __init__(self, diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index bf10d298..e1c3ed95 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -25,7 +25,7 @@ from .det_sast_loss import SASTLoss from .rec_ctc_loss import CTCLoss from .rec_att_loss import AttentionLoss from .rec_srn_loss import SRNLoss - +from .rec_nrtr_loss import NRTRLoss # cls loss from .cls_loss import ClsLoss @@ -42,8 +42,8 @@ from .combined_loss import CombinedLoss def build_loss(config): support_dict = [ 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', - 'SRNLoss', 'PGLoss', 'CombinedLoss' - ] + 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss'] + config = copy.deepcopy(config) module_name = config.pop('name') assert module_name in support_dict, Exception('loss only support {}'.format( diff --git a/ppocr/losses/rec_nrtr_loss.py b/ppocr/losses/rec_nrtr_loss.py new file mode 100644 index 00000000..915f506d --- /dev/null +++ b/ppocr/losses/rec_nrtr_loss.py @@ -0,0 +1,38 @@ +import paddle +from paddle import nn +import paddle.nn.functional as F + + +def cal_performance(pred, tgt): + + pred = pred.max(1)[1] + tgt = tgt.contiguous().view(-1) + non_pad_mask = tgt.ne(0) + n_correct = pred.eq(tgt) + n_correct = n_correct.masked_select(non_pad_mask).sum().item() + return n_correct + + +class NRTRLoss(nn.Layer): + def __init__(self,smoothing=True, **kwargs): + super(NRTRLoss, self).__init__() + self.loss_func = nn.CrossEntropyLoss(reduction='mean',ignore_index=0) + self.smoothing = smoothing + + def forward(self, pred, batch): + pred = pred.reshape([-1, pred.shape[2]]) + max_len = batch[2].max() + tgt = batch[1][:,1:2+max_len] + tgt = tgt.reshape([-1] ) + if self.smoothing: + eps = 0.1 + n_class = pred.shape[1] + one_hot = F.one_hot(tgt, pred.shape[1]) + one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) + log_prb = F.log_softmax(pred, axis=1) + non_pad_mask = paddle.not_equal(tgt, paddle.zeros(tgt.shape,dtype='int64')) + loss = -(one_hot * log_prb).sum(axis=1) + loss = loss.masked_select(non_pad_mask).mean() + else: + loss = self.loss_func(pred, tgt) + return {'loss': loss} diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index 66c084d7..3712e6e9 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -30,7 +30,7 @@ class RecMetric(object): target = target.replace(" ", "") norm_edit_dis += Levenshtein.distance(pred, target) / max( len(pred), len(target), 1) - if pred == target: + if pred.lower() == target.lower(): correct_num += 1 all_num += 1 self.correct_num += correct_num @@ -48,8 +48,8 @@ class RecMetric(object): 'norm_edit_dis': 0, } """ - acc = 1.0 * self.correct_num / self.all_num - norm_edit_dis = 1 - self.norm_edit_dis / self.all_num + acc = 1.0 * self.correct_num / (self.all_num) + norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num) self.reset() return {'acc': acc, 'norm_edit_dis': norm_edit_dis} @@ -57,3 +57,4 @@ class RecMetric(object): self.correct_num = 0 self.all_num = 0 self.norm_edit_dis = 0 + \ No newline at end of file diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index 4c941fcf..66da4b33 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - +import paddle from paddle import nn from ppocr.modeling.transforms import build_transform from ppocr.modeling.backbones import build_backbone diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index fe2c9bc3..73afbe11 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -25,7 +25,10 @@ def build_backbone(config, model_type): from .rec_mobilenet_v3 import MobileNetV3 from .rec_resnet_vd import ResNet from .rec_resnet_fpn import ResNetFPN - support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN'] + from .rec_nrtr_mtb import MTB + from .rec_swin import SwinTransformer + support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN','MTB','SwinTransformer'] + elif model_type == 'e2e': from .e2e_resnet_vd_pg import ResNet support_dict = ['ResNet'] diff --git a/ppocr/modeling/backbones/multiheadAttention.py b/ppocr/modeling/backbones/multiheadAttention.py new file mode 100755 index 00000000..f18e9957 --- /dev/null +++ b/ppocr/modeling/backbones/multiheadAttention.py @@ -0,0 +1,365 @@ +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle.nn import Linear +from paddle.nn.initializer import XavierUniform as xavier_uniform_ +from paddle.nn.initializer import Constant as constant_ +from paddle.nn.initializer import XavierNormal as xavier_normal_ + +zeros_ = constant_(value=0.) +ones_ = constant_(value=1.) + +class MultiheadAttention(nn.Layer): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model + num_heads: parallel attention layers, or heads + + Examples:: + + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + """ + + def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False): + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias) + + if add_bias_kv: + self.bias_k = self.create_parameter( + shape=(1, 1, embed_dim), default_initializer=zeros_) + self.add_parameter("bias_k", self.bias_k) + self.bias_v = self.create_parameter( + shape=(1, 1, embed_dim), default_initializer=zeros_) + self.add_parameter("bias_v", self.bias_v) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + self.conv1 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + self.conv2 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim * 2, kernel_size=(1, 1)) + self.conv3 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim * 3, kernel_size=(1, 1)) + + def _reset_parameters(self): + + + xavier_uniform_(self.out_proj.weight) + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, + need_weights=True, static_kv=False, attn_mask=None, qkv_ = [False,False,False]): + """ + Inputs of forward function + query: [target length, batch size, embed dim] + key: [sequence length, batch size, embed dim] + value: [sequence length, batch size, embed dim] + key_padding_mask: if True, mask padding based on batch size + incremental_state: if provided, previous time steps are cashed + need_weights: output attn_output_weights + static_kv: key and value are static + + Outputs of forward function + attn_output: [target length, batch size, embed dim] + attn_output_weights: [batch size, target length, sequence length] + """ + qkv_same = qkv_[0] + kv_same = qkv_[1] + + tgt_len, bsz, embed_dim = query.shape + assert embed_dim == self.embed_dim + assert list(query.shape) == [tgt_len, bsz, embed_dim] + assert key.shape == value.shape + + if qkv_same: + # self-attention + q, k, v = self._in_proj_qkv(query) + elif kv_same: + # encoder-decoder attention + q = self._in_proj_q(query) + if key is None: + assert value is None + k = v = None + else: + k, v = self._in_proj_kv(key) + else: + q = self._in_proj_q(query) + k = self._in_proj_k(key) + v = self._in_proj_v(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + self.bias_k = paddle.concat([self.bias_k for i in range(bsz)],axis=1) + self.bias_v = paddle.concat([self.bias_v for i in range(bsz)],axis=1) + k = paddle.concat([k, self.bias_k]) + v = paddle.concat([v, self.bias_v]) + if attn_mask is not None: + attn_mask = paddle.concat([attn_mask, paddle.zeros([attn_mask.shape[0], 1],dtype=attn_mask.dtype)], axis=1) + if key_padding_mask is not None: + key_padding_mask = paddle.concat( + [key_padding_mask,paddle.zeros([key_padding_mask.shape[0], 1],dtype=key_padding_mask.dtype)], axis=1) + + q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + if k is not None: + k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + if v is not None: + v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + + + + src_len = k.shape[1] + + if key_padding_mask is not None: + assert key_padding_mask.shape[0] == bsz + assert key_padding_mask.shape[1] == src_len + + if self.add_zero_attn: + src_len += 1 + k = paddle.concat([k, paddle.zeros((k.shape[0], 1) + k.shape[2:],dtype=k.dtype)], axis=1) + v = paddle.concat([v, paddle.zeros((v.shape[0], 1) + v.shape[2:],dtype=v.dtype)], axis=1) + if attn_mask is not None: + attn_mask = paddle.concat([attn_mask, paddle.zeros([attn_mask.shape[0], 1],dtype=attn_mask.dtype)], axis=1) + if key_padding_mask is not None: + key_padding_mask = paddle.concat( + [key_padding_mask, paddle.zeros([key_padding_mask.shape[0], 1],dtype=key_padding_mask.dtype)], axis=1) + attn_output_weights = paddle.bmm(q, k.transpose([0,2,1])) + assert list(attn_output_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_output_weights += attn_mask + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) + key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32') + y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf') + y = paddle.where(key==0.,key, y) + attn_output_weights += y + attn_output_weights = attn_output_weights.reshape([bsz*self.num_heads, tgt_len, src_len]) + + attn_output_weights = F.softmax( + attn_output_weights.astype('float32'), axis=-1, + dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype) + attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training) + + attn_output = paddle.bmm(attn_output_weights, v) + assert list(attn_output.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn_output = attn_output.transpose([1, 0,2]).reshape([tgt_len, bsz, embed_dim]) + attn_output = self.out_proj(attn_output) + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) + attn_output_weights = attn_output_weights.sum(axis=1) / self.num_heads + else: + attn_output_weights = None + + return attn_output, attn_output_weights + + def _in_proj_qkv(self, query): + query = query.transpose([1, 2, 0]) + query = paddle.unsqueeze(query, axis=2) + res = self.conv3(query) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res.chunk(3, axis=-1) + + def _in_proj_kv(self, key): + key = key.transpose([1, 2, 0]) + key = paddle.unsqueeze(key, axis=2) + res = self.conv2(key) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res.chunk(2, axis=-1) + + def _in_proj_q(self, query): + query = query.transpose([1, 2, 0]) + query = paddle.unsqueeze(query, axis=2) + res = self.conv1(query) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + def _in_proj_k(self, key): + + key = key.transpose([1, 2, 0]) + key = paddle.unsqueeze(key, axis=2) + res = self.conv1(key) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + def _in_proj_v(self, value): + + value = value.transpose([1,2,0])#(1, 2, 0) + value = paddle.unsqueeze(value, axis=2) + res = self.conv1(value) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + + +class MultiheadAttentionOptim(nn.Layer): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model + num_heads: parallel attention layers, or heads + + Examples:: + + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + """ + + def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False): + super(MultiheadAttentionOptim, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias) + + self._reset_parameters() + + self.conv1 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + self.conv2 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + self.conv3 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + + def _reset_parameters(self): + + + xavier_uniform_(self.out_proj.weight) + + + def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, + need_weights=True, static_kv=False, attn_mask=None): + """ + Inputs of forward function + query: [target length, batch size, embed dim] + key: [sequence length, batch size, embed dim] + value: [sequence length, batch size, embed dim] + key_padding_mask: if True, mask padding based on batch size + incremental_state: if provided, previous time steps are cashed + need_weights: output attn_output_weights + static_kv: key and value are static + + Outputs of forward function + attn_output: [target length, batch size, embed dim] + attn_output_weights: [batch size, target length, sequence length] + """ + + + tgt_len, bsz, embed_dim = query.shape + assert embed_dim == self.embed_dim + assert list(query.shape) == [tgt_len, bsz, embed_dim] + assert key.shape == value.shape + + q = self._in_proj_q(query) + k = self._in_proj_k(key) + v = self._in_proj_v(value) + q *= self.scaling + + + q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + + + src_len = k.shape[1] + + if key_padding_mask is not None: + assert key_padding_mask.shape[0] == bsz + assert key_padding_mask.shape[1] == src_len + + + attn_output_weights = paddle.bmm(q, k.transpose([0,2,1])) + assert list(attn_output_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_output_weights += attn_mask + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) + key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32') + + y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf') + + y = paddle.where(key==0.,key, y) + + attn_output_weights += y + attn_output_weights = attn_output_weights.reshape([bsz*self.num_heads, tgt_len, src_len]) + + attn_output_weights = F.softmax( + attn_output_weights.astype('float32'), axis=-1, + dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype) + attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training) + + attn_output = paddle.bmm(attn_output_weights, v) + assert list(attn_output.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn_output = attn_output.transpose([1, 0,2]).reshape([tgt_len, bsz, embed_dim]) + attn_output = self.out_proj(attn_output) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) + attn_output_weights = attn_output_weights.sum(axis=1) / self.num_heads + else: + attn_output_weights = None + + return attn_output, attn_output_weights + + + def _in_proj_q(self, query): + query = query.transpose([1, 2, 0]) + query = paddle.unsqueeze(query, axis=2) + res = self.conv1(query) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + def _in_proj_k(self, key): + + key = key.transpose([1, 2, 0]) + key = paddle.unsqueeze(key, axis=2) + res = self.conv2(key) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + def _in_proj_v(self, value): + + value = value.transpose([1,2,0])#(1, 2, 0) + value = paddle.unsqueeze(value, axis=2) + res = self.conv3(value) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res \ No newline at end of file diff --git a/ppocr/modeling/backbones/rec_nrtr_mtb.py b/ppocr/modeling/backbones/rec_nrtr_mtb.py new file mode 100644 index 00000000..26a0dc7f --- /dev/null +++ b/ppocr/modeling/backbones/rec_nrtr_mtb.py @@ -0,0 +1,28 @@ +from paddle import nn + +class MTB(nn.Layer): + def __init__(self, cnn_num, in_channels): + super(MTB, self).__init__() + self.block = nn.Sequential() + self.out_channels = in_channels + self.cnn_num = cnn_num + if self.cnn_num == 2: + for i in range(self.cnn_num): + self.block.add_sublayer('conv_{}'.format(i), nn.Conv2D( + in_channels = in_channels if i == 0 else 32*(2**(i-1)), + out_channels = 32*(2**i), + kernel_size = 3, + stride = 2, + padding=1)) + self.block.add_sublayer('relu_{}'.format(i), nn.ReLU()) + self.block.add_sublayer('bn_{}'.format(i), nn.BatchNorm2D(32*(2**i))) + + def forward(self, images): + + x = self.block(images) + if self.cnn_num == 2: + # (b, w, h, c) + x = x.transpose([0, 3, 2, 1]) + x_shape = x.shape + x = x.reshape([x_shape[0], x_shape[1], x_shape[2] * x_shape[3]]) + return x diff --git a/ppocr/modeling/heads/rec_nrtr_optim_head.py b/ppocr/modeling/heads/rec_nrtr_optim_head.py index b9a5100a..1537b0ca 100644 --- a/ppocr/modeling/heads/rec_nrtr_optim_head.py +++ b/ppocr/modeling/heads/rec_nrtr_optim_head.py @@ -7,7 +7,11 @@ from paddle.nn import LayerList from paddle.nn.initializer import XavierNormal as xavier_uniform_ from paddle.nn import Dropout, Linear, LayerNorm, Conv2D import numpy as np +<<<<<<< HEAD from ppocr.modeling.heads.multiheadAttention import MultiheadAttentionOptim +======= +from ppocr.modeling.backbones.multiheadAttention import MultiheadAttentionOptim +>>>>>>> 9c67a7f... add rec_nrtr from paddle.nn.initializer import Constant as constant_ from paddle.nn.initializer import XavierNormal as xavier_normal_ diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py index 37a5cf78..1be38e93 100644 --- a/ppocr/modeling/necks/__init__.py +++ b/ppocr/modeling/necks/__init__.py @@ -21,7 +21,7 @@ def build_neck(config): from .sast_fpn import SASTFPN from .rnn import SequenceEncoder from .pg_fpn import PGFPN - support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN'] + support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN','TFEncoder'] module_name = config.pop('name') assert module_name in support_dict, Exception('neck only support {}'.format( diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index cd2b7ea7..f7f1bcd6 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -24,16 +24,15 @@ __all__ = ['build_post_process'] from .db_postprocess import DBPostProcess from .east_postprocess import EASTPostProcess from .sast_postprocess import SASTPostProcess -from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode +from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess - def build_post_process(config, global_config=None): support_dict = [ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', - 'DistillationCTCLabelDecode' + 'DistillationCTCLabelDecode', 'NRTRLabelDecode' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 7350b4ec..e0f3b740 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -28,7 +28,7 @@ class BaseRecLabelDecode(object): 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr', - 'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari' + 'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari','dict_99' ] assert character_type in support_character_type, "Only {} are supported now but get {}".format( support_character_type, character_type) @@ -256,8 +256,7 @@ class AttnLabelDecode(BaseRecLabelDecode): if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ batch_idx][idx]: continue - char_list.append(self.character[int(text_index[batch_idx][ - idx])]) + char_list.append(self.character[int(text_index[batch_idx][idx])]) if text_prob is not None: conf_list.append(text_prob[batch_idx][idx]) else: diff --git a/ppocr/utils/dict_99.txt b/ppocr/utils/dict_99.txt new file mode 100644 index 00000000..e00863bf --- /dev/null +++ b/ppocr/utils/dict_99.txt @@ -0,0 +1,95 @@ +! +" +# +$ +% +& +' +( +) +* ++ +, +- +. +/ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +: +; +< += +> +? +@ +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +[ +\ +] +^ +_ +` +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +{ +| +} +~ + \ No newline at end of file diff --git a/tools/eval.py b/tools/eval.py index d26f2a04..66eb315f 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -22,6 +22,7 @@ import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process @@ -30,6 +31,7 @@ from ppocr.utils.save_load import init_model from ppocr.utils.utility import print_dict import tools.program as program + def main(): global_config = config['Global'] # build dataloader diff --git a/tools/program.py b/tools/program.py index 7641bed7..4b6dc9e4 100755 --- a/tools/program.py +++ b/tools/program.py @@ -186,7 +186,7 @@ def train(config, model.train() use_srn = config['Architecture']['algorithm'] == "SRN" - + use_nrtr = config['Architecture']['algorithm'] == "NRTR" if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] else: @@ -211,6 +211,9 @@ def train(config, others = batch[-4:] preds = model(images, others) model_average = True + elif use_nrtr: + max_len = batch[2].max() + preds = model(images, batch[1][:,:2+max_len]) else: preds = model(images) loss = loss_class(preds, batch) @@ -350,13 +353,11 @@ def eval(model, valid_dataloader, post_process_class, eval_class, break images = batch[0] start = time.time() - if use_srn: others = batch[-4:] preds = model(images, others) else: preds = model(images) - batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods post_result = post_process_class(preds, batch[1]) @@ -386,7 +387,7 @@ def preprocess(is_train=False): alg = config['Architecture']['algorithm'] assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', - 'CLS', 'PGNet', 'Distillation' + 'CLS', 'PGNet', 'Distillation','NRTR' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' From 4cb824537a4c2924fe133d2d750131389d32360d Mon Sep 17 00:00:00 2001 From: Topdu <784990967@qq.com> Date: Tue, 17 Aug 2021 13:37:32 +0000 Subject: [PATCH 04/18] add rec_nrtr --- configs/rec/rec_mtb_nrtr.yml | 18 +--- ppocr/data/imaug/label_ops.py | 2 +- ppocr/metrics/rec_metric.py | 4 +- ppocr/modeling/architectures/base_model.py | 1 - ppocr/modeling/backbones/__init__.py | 2 +- ppocr/modeling/heads/rec_nrtr_optim_head.py | 4 - ppocr/postprocess/rec_postprocess.py | 2 +- ppocr/utils/dict_99.txt | 95 --------------------- 8 files changed, 6 insertions(+), 122 deletions(-) delete mode 100644 ppocr/utils/dict_99.txt diff --git a/configs/rec/rec_mtb_nrtr.yml b/configs/rec/rec_mtb_nrtr.yml index 86a833c5..d16657d8 100644 --- a/configs/rec/rec_mtb_nrtr.yml +++ b/configs/rec/rec_mtb_nrtr.yml @@ -3,38 +3,22 @@ Global: epoch_num: 21 log_smooth_window: 20 print_batch_step: 10 -<<<<<<< HEAD - save_model_dir: ./output/rec/nrtr_final/ + save_model_dir: ./output/rec/nrtr/ save_epoch_step: 1 # evaluation is run every 2000 iterations eval_batch_step: [0, 2000] cal_metric_during_train: True -======= - save_model_dir: ./output/rec/piloptimnrtr/ - save_epoch_step: 1 - # evaluation is run every 2000 iterations - eval_batch_step: [0, 2000] - cal_metric_during_train: False ->>>>>>> 9c67a7f... add rec_nrtr pretrained_model: checkpoints: save_inference_dir: use_visualdl: False infer_img: doc/imgs_words_en/word_10.png # for data or label process -<<<<<<< HEAD character_dict_path: character_type: EN_symbol max_text_length: 25 infer_mode: False use_space_char: True -======= - character_dict_path: ppocr/utils/dict_99.txt - character_type: dict_99 - max_text_length: 25 - infer_mode: False - use_space_char: False ->>>>>>> 9c67a7f... add rec_nrtr save_res_path: ./output/rec/predicts_nrtr.txt Optimizer: diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index a233738c..39ff8930 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -96,7 +96,7 @@ class BaseRecLabelEncode(object): 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean', 'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', - 'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari','dict_99' + 'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari' ] assert character_type in support_character_type, "Only {} are supported now but get {}".format( support_character_type, character_type) diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index 3712e6e9..e4b65a50 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -30,7 +30,7 @@ class RecMetric(object): target = target.replace(" ", "") norm_edit_dis += Levenshtein.distance(pred, target) / max( len(pred), len(target), 1) - if pred.lower() == target.lower(): + if pred == target: correct_num += 1 all_num += 1 self.correct_num += correct_num @@ -57,4 +57,4 @@ class RecMetric(object): self.correct_num = 0 self.all_num = 0 self.norm_edit_dis = 0 - \ No newline at end of file + diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index 66da4b33..52ad1593 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -14,7 +14,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import paddle from paddle import nn from ppocr.modeling.transforms import build_transform from ppocr.modeling.backbones import build_backbone diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 73afbe11..49c34864 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -27,7 +27,7 @@ def build_backbone(config, model_type): from .rec_resnet_fpn import ResNetFPN from .rec_nrtr_mtb import MTB from .rec_swin import SwinTransformer - support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN','MTB','SwinTransformer'] + support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'SwinTransformer'] elif model_type == 'e2e': from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/heads/rec_nrtr_optim_head.py b/ppocr/modeling/heads/rec_nrtr_optim_head.py index 1537b0ca..b9a5100a 100644 --- a/ppocr/modeling/heads/rec_nrtr_optim_head.py +++ b/ppocr/modeling/heads/rec_nrtr_optim_head.py @@ -7,11 +7,7 @@ from paddle.nn import LayerList from paddle.nn.initializer import XavierNormal as xavier_uniform_ from paddle.nn import Dropout, Linear, LayerNorm, Conv2D import numpy as np -<<<<<<< HEAD from ppocr.modeling.heads.multiheadAttention import MultiheadAttentionOptim -======= -from ppocr.modeling.backbones.multiheadAttention import MultiheadAttentionOptim ->>>>>>> 9c67a7f... add rec_nrtr from paddle.nn.initializer import Constant as constant_ from paddle.nn.initializer import XavierNormal as xavier_normal_ diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index e0f3b740..371e2386 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -28,7 +28,7 @@ class BaseRecLabelDecode(object): 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr', - 'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari','dict_99' + 'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari' ] assert character_type in support_character_type, "Only {} are supported now but get {}".format( support_character_type, character_type) diff --git a/ppocr/utils/dict_99.txt b/ppocr/utils/dict_99.txt deleted file mode 100644 index e00863bf..00000000 --- a/ppocr/utils/dict_99.txt +++ /dev/null @@ -1,95 +0,0 @@ -! -" -# -$ -% -& -' -( -) -* -+ -, -- -. -/ -0 -1 -2 -3 -4 -5 -6 -7 -8 -9 -: -; -< -= -> -? -@ -A -B -C -D -E -F -G -H -I -J -K -L -M -N -O -P -Q -R -S -T -U -V -W -X -Y -Z -[ -\ -] -^ -_ -` -a -b -c -d -e -f -g -h -i -j -k -l -m -n -o -p -q -r -s -t -u -v -w -x -y -z -{ -| -} -~ - \ No newline at end of file From 2e76e46b3cca3589b55470b86811e58fd7da974d Mon Sep 17 00:00:00 2001 From: Topdu <784990967@qq.com> Date: Tue, 17 Aug 2021 13:44:44 +0000 Subject: [PATCH 05/18] update nrtr metric --- doc/doc_ch/algorithm_overview.md | 4 ++-- doc/doc_en/algorithm_overview_en.md | 3 ++- ppocr/modeling/necks/__init__.py | 2 +- tools/program.py | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index 9c352549..e8f23b54 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -44,7 +44,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] - [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] -- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2) +- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2)) 参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -59,7 +59,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: |RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)| |RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) | -|NRTR|NRTR_MTB| 84.1% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) | +|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) | PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。 diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index fed9cf44..8e8f0d3f 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -46,6 +46,7 @@ PaddleOCR open-source text recognition algorithms list: - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] - [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] +- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2)) Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: @@ -60,6 +61,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)| |RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)| -|NRTR|NRTR_MTB| 84.1% | rec_mtb_nrtr | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) | +|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) | Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md) diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py index 1be38e93..37a5cf78 100644 --- a/ppocr/modeling/necks/__init__.py +++ b/ppocr/modeling/necks/__init__.py @@ -21,7 +21,7 @@ def build_neck(config): from .sast_fpn import SASTFPN from .rnn import SequenceEncoder from .pg_fpn import PGFPN - support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN','TFEncoder'] + support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN'] module_name = config.pop('name') assert module_name in support_dict, Exception('neck only support {}'.format( diff --git a/tools/program.py b/tools/program.py index 4b6dc9e4..71076a19 100755 --- a/tools/program.py +++ b/tools/program.py @@ -387,7 +387,7 @@ def preprocess(is_train=False): alg = config['Architecture']['algorithm'] assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', - 'CLS', 'PGNet', 'Distillation','NRTR' + 'CLS', 'PGNet', 'Distillation', 'NRTR' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' From 685394cbbedee09bd339119b75d878884ecf51dc Mon Sep 17 00:00:00 2001 From: topduke <784990967@qq.com> Date: Tue, 17 Aug 2021 21:46:50 +0800 Subject: [PATCH 06/18] Update rec_mtb_nrtr.yml --- configs/rec/rec_mtb_nrtr.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/rec/rec_mtb_nrtr.yml b/configs/rec/rec_mtb_nrtr.yml index d16657d8..c832b917 100644 --- a/configs/rec/rec_mtb_nrtr.yml +++ b/configs/rec/rec_mtb_nrtr.yml @@ -63,7 +63,7 @@ Metric: Train: dataset: name: LMDBDataSet - data_dir: /paddle/data/ocr_data/training/ + data_dir: ./train_data/data_lmdb_release/training/ transforms: - NRTRDecodeImage: # load image img_mode: BGR @@ -82,7 +82,7 @@ Train: Eval: dataset: name: LMDBDataSet - data_dir: /paddle/data/ocr_data/evaluation/ + data_dir: ./train_data/data_lmdb_release/evaluation/ transforms: - NRTRDecodeImage: # load image img_mode: BGR From 45313ff37a4a9122f07cf0ff5b86a3b1a7119be1 Mon Sep 17 00:00:00 2001 From: topduke <784990967@qq.com> Date: Tue, 17 Aug 2021 21:50:41 +0800 Subject: [PATCH 07/18] Update multiheadAttention.py --- .../modeling/backbones/multiheadAttention.py | 210 +----------------- 1 file changed, 1 insertion(+), 209 deletions(-) diff --git a/ppocr/modeling/backbones/multiheadAttention.py b/ppocr/modeling/backbones/multiheadAttention.py index f18e9957..6aba81de 100755 --- a/ppocr/modeling/backbones/multiheadAttention.py +++ b/ppocr/modeling/backbones/multiheadAttention.py @@ -9,214 +9,6 @@ from paddle.nn.initializer import XavierNormal as xavier_normal_ zeros_ = constant_(value=0.) ones_ = constant_(value=1.) -class MultiheadAttention(nn.Layer): - r"""Allows the model to jointly attend to information - from different representation subspaces. - See reference: Attention Is All You Need - - .. math:: - \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O - \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) - - Args: - embed_dim: total dimension of the model - num_heads: parallel attention layers, or heads - - Examples:: - - >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value) - """ - - def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False): - super(MultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" - self.scaling = self.head_dim ** -0.5 - - self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias) - - if add_bias_kv: - self.bias_k = self.create_parameter( - shape=(1, 1, embed_dim), default_initializer=zeros_) - self.add_parameter("bias_k", self.bias_k) - self.bias_v = self.create_parameter( - shape=(1, 1, embed_dim), default_initializer=zeros_) - self.add_parameter("bias_v", self.bias_v) - else: - self.bias_k = self.bias_v = None - - self.add_zero_attn = add_zero_attn - - self._reset_parameters() - - self.conv1 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) - self.conv2 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim * 2, kernel_size=(1, 1)) - self.conv3 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim * 3, kernel_size=(1, 1)) - - def _reset_parameters(self): - - - xavier_uniform_(self.out_proj.weight) - if self.bias_k is not None: - xavier_normal_(self.bias_k) - if self.bias_v is not None: - xavier_normal_(self.bias_v) - - def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, - need_weights=True, static_kv=False, attn_mask=None, qkv_ = [False,False,False]): - """ - Inputs of forward function - query: [target length, batch size, embed dim] - key: [sequence length, batch size, embed dim] - value: [sequence length, batch size, embed dim] - key_padding_mask: if True, mask padding based on batch size - incremental_state: if provided, previous time steps are cashed - need_weights: output attn_output_weights - static_kv: key and value are static - - Outputs of forward function - attn_output: [target length, batch size, embed dim] - attn_output_weights: [batch size, target length, sequence length] - """ - qkv_same = qkv_[0] - kv_same = qkv_[1] - - tgt_len, bsz, embed_dim = query.shape - assert embed_dim == self.embed_dim - assert list(query.shape) == [tgt_len, bsz, embed_dim] - assert key.shape == value.shape - - if qkv_same: - # self-attention - q, k, v = self._in_proj_qkv(query) - elif kv_same: - # encoder-decoder attention - q = self._in_proj_q(query) - if key is None: - assert value is None - k = v = None - else: - k, v = self._in_proj_kv(key) - else: - q = self._in_proj_q(query) - k = self._in_proj_k(key) - v = self._in_proj_v(value) - q *= self.scaling - - if self.bias_k is not None: - assert self.bias_v is not None - self.bias_k = paddle.concat([self.bias_k for i in range(bsz)],axis=1) - self.bias_v = paddle.concat([self.bias_v for i in range(bsz)],axis=1) - k = paddle.concat([k, self.bias_k]) - v = paddle.concat([v, self.bias_v]) - if attn_mask is not None: - attn_mask = paddle.concat([attn_mask, paddle.zeros([attn_mask.shape[0], 1],dtype=attn_mask.dtype)], axis=1) - if key_padding_mask is not None: - key_padding_mask = paddle.concat( - [key_padding_mask,paddle.zeros([key_padding_mask.shape[0], 1],dtype=key_padding_mask.dtype)], axis=1) - - q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) - if k is not None: - k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) - if v is not None: - v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) - - - - src_len = k.shape[1] - - if key_padding_mask is not None: - assert key_padding_mask.shape[0] == bsz - assert key_padding_mask.shape[1] == src_len - - if self.add_zero_attn: - src_len += 1 - k = paddle.concat([k, paddle.zeros((k.shape[0], 1) + k.shape[2:],dtype=k.dtype)], axis=1) - v = paddle.concat([v, paddle.zeros((v.shape[0], 1) + v.shape[2:],dtype=v.dtype)], axis=1) - if attn_mask is not None: - attn_mask = paddle.concat([attn_mask, paddle.zeros([attn_mask.shape[0], 1],dtype=attn_mask.dtype)], axis=1) - if key_padding_mask is not None: - key_padding_mask = paddle.concat( - [key_padding_mask, paddle.zeros([key_padding_mask.shape[0], 1],dtype=key_padding_mask.dtype)], axis=1) - attn_output_weights = paddle.bmm(q, k.transpose([0,2,1])) - assert list(attn_output_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] - - if attn_mask is not None: - attn_mask = attn_mask.unsqueeze(0) - attn_output_weights += attn_mask - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) - key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32') - y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf') - y = paddle.where(key==0.,key, y) - attn_output_weights += y - attn_output_weights = attn_output_weights.reshape([bsz*self.num_heads, tgt_len, src_len]) - - attn_output_weights = F.softmax( - attn_output_weights.astype('float32'), axis=-1, - dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype) - attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training) - - attn_output = paddle.bmm(attn_output_weights, v) - assert list(attn_output.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] - attn_output = attn_output.transpose([1, 0,2]).reshape([tgt_len, bsz, embed_dim]) - attn_output = self.out_proj(attn_output) - if need_weights: - # average attention weights over heads - attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) - attn_output_weights = attn_output_weights.sum(axis=1) / self.num_heads - else: - attn_output_weights = None - - return attn_output, attn_output_weights - - def _in_proj_qkv(self, query): - query = query.transpose([1, 2, 0]) - query = paddle.unsqueeze(query, axis=2) - res = self.conv3(query) - res = paddle.squeeze(res, axis=2) - res = res.transpose([2, 0, 1]) - return res.chunk(3, axis=-1) - - def _in_proj_kv(self, key): - key = key.transpose([1, 2, 0]) - key = paddle.unsqueeze(key, axis=2) - res = self.conv2(key) - res = paddle.squeeze(res, axis=2) - res = res.transpose([2, 0, 1]) - return res.chunk(2, axis=-1) - - def _in_proj_q(self, query): - query = query.transpose([1, 2, 0]) - query = paddle.unsqueeze(query, axis=2) - res = self.conv1(query) - res = paddle.squeeze(res, axis=2) - res = res.transpose([2, 0, 1]) - return res - - def _in_proj_k(self, key): - - key = key.transpose([1, 2, 0]) - key = paddle.unsqueeze(key, axis=2) - res = self.conv1(key) - res = paddle.squeeze(res, axis=2) - res = res.transpose([2, 0, 1]) - return res - - def _in_proj_v(self, value): - - value = value.transpose([1,2,0])#(1, 2, 0) - value = paddle.unsqueeze(value, axis=2) - res = self.conv1(value) - res = paddle.squeeze(res, axis=2) - res = res.transpose([2, 0, 1]) - return res - - class MultiheadAttentionOptim(nn.Layer): r"""Allows the model to jointly attend to information @@ -362,4 +154,4 @@ class MultiheadAttentionOptim(nn.Layer): res = self.conv3(value) res = paddle.squeeze(res, axis=2) res = res.transpose([2, 0, 1]) - return res \ No newline at end of file + return res From 18349d177d21df56348923c5c39ca41e08eb2b59 Mon Sep 17 00:00:00 2001 From: topduke <784990967@qq.com> Date: Tue, 17 Aug 2021 21:53:10 +0800 Subject: [PATCH 08/18] Update multiheadAttention.py --- ppocr/modeling/heads/multiheadAttention.py | 210 +-------------------- 1 file changed, 1 insertion(+), 209 deletions(-) diff --git a/ppocr/modeling/heads/multiheadAttention.py b/ppocr/modeling/heads/multiheadAttention.py index f18e9957..6aba81de 100755 --- a/ppocr/modeling/heads/multiheadAttention.py +++ b/ppocr/modeling/heads/multiheadAttention.py @@ -9,214 +9,6 @@ from paddle.nn.initializer import XavierNormal as xavier_normal_ zeros_ = constant_(value=0.) ones_ = constant_(value=1.) -class MultiheadAttention(nn.Layer): - r"""Allows the model to jointly attend to information - from different representation subspaces. - See reference: Attention Is All You Need - - .. math:: - \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O - \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) - - Args: - embed_dim: total dimension of the model - num_heads: parallel attention layers, or heads - - Examples:: - - >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value) - """ - - def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False): - super(MultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" - self.scaling = self.head_dim ** -0.5 - - self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias) - - if add_bias_kv: - self.bias_k = self.create_parameter( - shape=(1, 1, embed_dim), default_initializer=zeros_) - self.add_parameter("bias_k", self.bias_k) - self.bias_v = self.create_parameter( - shape=(1, 1, embed_dim), default_initializer=zeros_) - self.add_parameter("bias_v", self.bias_v) - else: - self.bias_k = self.bias_v = None - - self.add_zero_attn = add_zero_attn - - self._reset_parameters() - - self.conv1 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) - self.conv2 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim * 2, kernel_size=(1, 1)) - self.conv3 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim * 3, kernel_size=(1, 1)) - - def _reset_parameters(self): - - - xavier_uniform_(self.out_proj.weight) - if self.bias_k is not None: - xavier_normal_(self.bias_k) - if self.bias_v is not None: - xavier_normal_(self.bias_v) - - def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, - need_weights=True, static_kv=False, attn_mask=None, qkv_ = [False,False,False]): - """ - Inputs of forward function - query: [target length, batch size, embed dim] - key: [sequence length, batch size, embed dim] - value: [sequence length, batch size, embed dim] - key_padding_mask: if True, mask padding based on batch size - incremental_state: if provided, previous time steps are cashed - need_weights: output attn_output_weights - static_kv: key and value are static - - Outputs of forward function - attn_output: [target length, batch size, embed dim] - attn_output_weights: [batch size, target length, sequence length] - """ - qkv_same = qkv_[0] - kv_same = qkv_[1] - - tgt_len, bsz, embed_dim = query.shape - assert embed_dim == self.embed_dim - assert list(query.shape) == [tgt_len, bsz, embed_dim] - assert key.shape == value.shape - - if qkv_same: - # self-attention - q, k, v = self._in_proj_qkv(query) - elif kv_same: - # encoder-decoder attention - q = self._in_proj_q(query) - if key is None: - assert value is None - k = v = None - else: - k, v = self._in_proj_kv(key) - else: - q = self._in_proj_q(query) - k = self._in_proj_k(key) - v = self._in_proj_v(value) - q *= self.scaling - - if self.bias_k is not None: - assert self.bias_v is not None - self.bias_k = paddle.concat([self.bias_k for i in range(bsz)],axis=1) - self.bias_v = paddle.concat([self.bias_v for i in range(bsz)],axis=1) - k = paddle.concat([k, self.bias_k]) - v = paddle.concat([v, self.bias_v]) - if attn_mask is not None: - attn_mask = paddle.concat([attn_mask, paddle.zeros([attn_mask.shape[0], 1],dtype=attn_mask.dtype)], axis=1) - if key_padding_mask is not None: - key_padding_mask = paddle.concat( - [key_padding_mask,paddle.zeros([key_padding_mask.shape[0], 1],dtype=key_padding_mask.dtype)], axis=1) - - q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) - if k is not None: - k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) - if v is not None: - v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) - - - - src_len = k.shape[1] - - if key_padding_mask is not None: - assert key_padding_mask.shape[0] == bsz - assert key_padding_mask.shape[1] == src_len - - if self.add_zero_attn: - src_len += 1 - k = paddle.concat([k, paddle.zeros((k.shape[0], 1) + k.shape[2:],dtype=k.dtype)], axis=1) - v = paddle.concat([v, paddle.zeros((v.shape[0], 1) + v.shape[2:],dtype=v.dtype)], axis=1) - if attn_mask is not None: - attn_mask = paddle.concat([attn_mask, paddle.zeros([attn_mask.shape[0], 1],dtype=attn_mask.dtype)], axis=1) - if key_padding_mask is not None: - key_padding_mask = paddle.concat( - [key_padding_mask, paddle.zeros([key_padding_mask.shape[0], 1],dtype=key_padding_mask.dtype)], axis=1) - attn_output_weights = paddle.bmm(q, k.transpose([0,2,1])) - assert list(attn_output_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] - - if attn_mask is not None: - attn_mask = attn_mask.unsqueeze(0) - attn_output_weights += attn_mask - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) - key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32') - y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf') - y = paddle.where(key==0.,key, y) - attn_output_weights += y - attn_output_weights = attn_output_weights.reshape([bsz*self.num_heads, tgt_len, src_len]) - - attn_output_weights = F.softmax( - attn_output_weights.astype('float32'), axis=-1, - dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype) - attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training) - - attn_output = paddle.bmm(attn_output_weights, v) - assert list(attn_output.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] - attn_output = attn_output.transpose([1, 0,2]).reshape([tgt_len, bsz, embed_dim]) - attn_output = self.out_proj(attn_output) - if need_weights: - # average attention weights over heads - attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) - attn_output_weights = attn_output_weights.sum(axis=1) / self.num_heads - else: - attn_output_weights = None - - return attn_output, attn_output_weights - - def _in_proj_qkv(self, query): - query = query.transpose([1, 2, 0]) - query = paddle.unsqueeze(query, axis=2) - res = self.conv3(query) - res = paddle.squeeze(res, axis=2) - res = res.transpose([2, 0, 1]) - return res.chunk(3, axis=-1) - - def _in_proj_kv(self, key): - key = key.transpose([1, 2, 0]) - key = paddle.unsqueeze(key, axis=2) - res = self.conv2(key) - res = paddle.squeeze(res, axis=2) - res = res.transpose([2, 0, 1]) - return res.chunk(2, axis=-1) - - def _in_proj_q(self, query): - query = query.transpose([1, 2, 0]) - query = paddle.unsqueeze(query, axis=2) - res = self.conv1(query) - res = paddle.squeeze(res, axis=2) - res = res.transpose([2, 0, 1]) - return res - - def _in_proj_k(self, key): - - key = key.transpose([1, 2, 0]) - key = paddle.unsqueeze(key, axis=2) - res = self.conv1(key) - res = paddle.squeeze(res, axis=2) - res = res.transpose([2, 0, 1]) - return res - - def _in_proj_v(self, value): - - value = value.transpose([1,2,0])#(1, 2, 0) - value = paddle.unsqueeze(value, axis=2) - res = self.conv1(value) - res = paddle.squeeze(res, axis=2) - res = res.transpose([2, 0, 1]) - return res - - class MultiheadAttentionOptim(nn.Layer): r"""Allows the model to jointly attend to information @@ -362,4 +154,4 @@ class MultiheadAttentionOptim(nn.Layer): res = self.conv3(value) res = paddle.squeeze(res, axis=2) res = res.transpose([2, 0, 1]) - return res \ No newline at end of file + return res From 8227ad1b5032fd14d2526d9b3676bb31bd3bff69 Mon Sep 17 00:00:00 2001 From: topduke <784990967@qq.com> Date: Tue, 17 Aug 2021 21:53:25 +0800 Subject: [PATCH 09/18] Delete multiheadAttention.py --- .../modeling/backbones/multiheadAttention.py | 157 ------------------ 1 file changed, 157 deletions(-) delete mode 100755 ppocr/modeling/backbones/multiheadAttention.py diff --git a/ppocr/modeling/backbones/multiheadAttention.py b/ppocr/modeling/backbones/multiheadAttention.py deleted file mode 100755 index 6aba81de..00000000 --- a/ppocr/modeling/backbones/multiheadAttention.py +++ /dev/null @@ -1,157 +0,0 @@ -import paddle -from paddle import nn -import paddle.nn.functional as F -from paddle.nn import Linear -from paddle.nn.initializer import XavierUniform as xavier_uniform_ -from paddle.nn.initializer import Constant as constant_ -from paddle.nn.initializer import XavierNormal as xavier_normal_ - -zeros_ = constant_(value=0.) -ones_ = constant_(value=1.) - - -class MultiheadAttentionOptim(nn.Layer): - r"""Allows the model to jointly attend to information - from different representation subspaces. - See reference: Attention Is All You Need - - .. math:: - \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O - \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) - - Args: - embed_dim: total dimension of the model - num_heads: parallel attention layers, or heads - - Examples:: - - >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value) - """ - - def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False): - super(MultiheadAttentionOptim, self).__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" - self.scaling = self.head_dim ** -0.5 - - self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias) - - self._reset_parameters() - - self.conv1 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) - self.conv2 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) - self.conv3 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) - - def _reset_parameters(self): - - - xavier_uniform_(self.out_proj.weight) - - - def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, - need_weights=True, static_kv=False, attn_mask=None): - """ - Inputs of forward function - query: [target length, batch size, embed dim] - key: [sequence length, batch size, embed dim] - value: [sequence length, batch size, embed dim] - key_padding_mask: if True, mask padding based on batch size - incremental_state: if provided, previous time steps are cashed - need_weights: output attn_output_weights - static_kv: key and value are static - - Outputs of forward function - attn_output: [target length, batch size, embed dim] - attn_output_weights: [batch size, target length, sequence length] - """ - - - tgt_len, bsz, embed_dim = query.shape - assert embed_dim == self.embed_dim - assert list(query.shape) == [tgt_len, bsz, embed_dim] - assert key.shape == value.shape - - q = self._in_proj_q(query) - k = self._in_proj_k(key) - v = self._in_proj_v(value) - q *= self.scaling - - - q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) - k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) - v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) - - - src_len = k.shape[1] - - if key_padding_mask is not None: - assert key_padding_mask.shape[0] == bsz - assert key_padding_mask.shape[1] == src_len - - - attn_output_weights = paddle.bmm(q, k.transpose([0,2,1])) - assert list(attn_output_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] - - if attn_mask is not None: - attn_mask = attn_mask.unsqueeze(0) - attn_output_weights += attn_mask - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) - key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32') - - y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf') - - y = paddle.where(key==0.,key, y) - - attn_output_weights += y - attn_output_weights = attn_output_weights.reshape([bsz*self.num_heads, tgt_len, src_len]) - - attn_output_weights = F.softmax( - attn_output_weights.astype('float32'), axis=-1, - dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype) - attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training) - - attn_output = paddle.bmm(attn_output_weights, v) - assert list(attn_output.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] - attn_output = attn_output.transpose([1, 0,2]).reshape([tgt_len, bsz, embed_dim]) - attn_output = self.out_proj(attn_output) - - if need_weights: - # average attention weights over heads - attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) - attn_output_weights = attn_output_weights.sum(axis=1) / self.num_heads - else: - attn_output_weights = None - - return attn_output, attn_output_weights - - - def _in_proj_q(self, query): - query = query.transpose([1, 2, 0]) - query = paddle.unsqueeze(query, axis=2) - res = self.conv1(query) - res = paddle.squeeze(res, axis=2) - res = res.transpose([2, 0, 1]) - return res - - def _in_proj_k(self, key): - - key = key.transpose([1, 2, 0]) - key = paddle.unsqueeze(key, axis=2) - res = self.conv2(key) - res = paddle.squeeze(res, axis=2) - res = res.transpose([2, 0, 1]) - return res - - def _in_proj_v(self, value): - - value = value.transpose([1,2,0])#(1, 2, 0) - value = paddle.unsqueeze(value, axis=2) - res = self.conv3(value) - res = paddle.squeeze(res, axis=2) - res = res.transpose([2, 0, 1]) - return res From 0486bc37a326979b7fe4bb3f505e7b3a6babe662 Mon Sep 17 00:00:00 2001 From: topduke <784990967@qq.com> Date: Wed, 18 Aug 2021 19:53:52 +0800 Subject: [PATCH 10/18] modify beam size --- configs/rec/rec_mtb_nrtr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/rec/rec_mtb_nrtr.yml b/configs/rec/rec_mtb_nrtr.yml index c832b917..171ac7e3 100644 --- a/configs/rec/rec_mtb_nrtr.yml +++ b/configs/rec/rec_mtb_nrtr.yml @@ -46,7 +46,7 @@ Architecture: name: TransformerOptim d_model: 512 num_encoder_layers: 6 - beam_size: -1 # When Beam size is greater than 0, it means to use beam search when evaluation. + beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation. Loss: From a11e219970bd6762bfba2c6a0497c97ffaa6a094 Mon Sep 17 00:00:00 2001 From: topduke <784990967@qq.com> Date: Wed, 18 Aug 2021 19:54:57 +0800 Subject: [PATCH 11/18] delete blank line --- ppocr/losses/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 7f4ab152..eed5a46e 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -44,7 +44,6 @@ from .table_att_loss import TableAttentionLoss def build_loss(config): support_dict = [ 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', - 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss' ] From 55b76dcaa55fb69bf54dcac1f8d124deb15a5423 Mon Sep 17 00:00:00 2001 From: Topdu <784990967@qq.com> Date: Thu, 19 Aug 2021 09:31:02 +0000 Subject: [PATCH 12/18] delete blank lines and modify forward_train --- configs/rec/rec_mtb_nrtr.yml | 2 +- ppocr/modeling/backbones/__init__.py | 5 +- ppocr/modeling/backbones/rec_nrtr_mtb.py | 34 +- ppocr/modeling/heads/__init__.py | 7 +- ppocr/modeling/heads/multiheadAttention.py | 113 +++-- ppocr/modeling/heads/rec_nrtr_optim_head.py | 504 +++++++++++--------- tools/program.py | 11 +- 7 files changed, 388 insertions(+), 288 deletions(-) diff --git a/configs/rec/rec_mtb_nrtr.yml b/configs/rec/rec_mtb_nrtr.yml index 171ac7e3..c89de02b 100644 --- a/configs/rec/rec_mtb_nrtr.yml +++ b/configs/rec/rec_mtb_nrtr.yml @@ -46,7 +46,7 @@ Architecture: name: TransformerOptim d_model: 512 num_encoder_layers: 6 - beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation. + beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation. Loss: diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 618b827d..f8ca7e40 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -27,8 +27,9 @@ def build_backbone(config, model_type): from .rec_resnet_fpn import ResNetFPN from .rec_mv1_enhance import MobileNetV1Enhance from .rec_nrtr_mtb import MTB - from .rec_swin import SwinTransformer - support_dict = ['MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'SwinTransformer'] + support_dict = [ + 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB' + ] elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet support_dict = ["ResNet"] diff --git a/ppocr/modeling/backbones/rec_nrtr_mtb.py b/ppocr/modeling/backbones/rec_nrtr_mtb.py index 26a0dc7f..04b5c9bb 100644 --- a/ppocr/modeling/backbones/rec_nrtr_mtb.py +++ b/ppocr/modeling/backbones/rec_nrtr_mtb.py @@ -1,5 +1,20 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from paddle import nn + class MTB(nn.Layer): def __init__(self, cnn_num, in_channels): super(MTB, self).__init__() @@ -8,17 +23,20 @@ class MTB(nn.Layer): self.cnn_num = cnn_num if self.cnn_num == 2: for i in range(self.cnn_num): - self.block.add_sublayer('conv_{}'.format(i), nn.Conv2D( - in_channels = in_channels if i == 0 else 32*(2**(i-1)), - out_channels = 32*(2**i), - kernel_size = 3, - stride = 2, - padding=1)) + self.block.add_sublayer( + 'conv_{}'.format(i), + nn.Conv2D( + in_channels=in_channels + if i == 0 else 32 * (2**(i - 1)), + out_channels=32 * (2**i), + kernel_size=3, + stride=2, + padding=1)) self.block.add_sublayer('relu_{}'.format(i), nn.ReLU()) - self.block.add_sublayer('bn_{}'.format(i), nn.BatchNorm2D(32*(2**i))) + self.block.add_sublayer('bn_{}'.format(i), + nn.BatchNorm2D(32 * (2**i))) def forward(self, images): - x = self.block(images) if self.cnn_num == 2: # (b, w, h, c) diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 63951cd5..11fd4b26 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -27,14 +27,13 @@ def build_head(config): from .rec_att_head import AttentionHead from .rec_srn_head import SRNHead from .rec_nrtr_optim_head import TransformerOptim - + # cls head from .cls_head import ClsHead support_dict = [ 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', - - 'SRNHead', 'PGHead', 'TransformerOptim', 'TableAttentionHead'] - + 'SRNHead', 'PGHead', 'TransformerOptim', 'TableAttentionHead' + ] #table head from .table_att_head import TableAttentionHead diff --git a/ppocr/modeling/heads/multiheadAttention.py b/ppocr/modeling/heads/multiheadAttention.py index 6aba81de..4be37025 100755 --- a/ppocr/modeling/heads/multiheadAttention.py +++ b/ppocr/modeling/heads/multiheadAttention.py @@ -1,3 +1,17 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle from paddle import nn import paddle.nn.functional as F @@ -11,7 +25,7 @@ ones_ = constant_(value=1.) class MultiheadAttentionOptim(nn.Layer): - r"""Allows the model to jointly attend to information + """Allows the model to jointly attend to information from different representation subspaces. See reference: Attention Is All You Need @@ -23,37 +37,43 @@ class MultiheadAttentionOptim(nn.Layer): embed_dim: total dimension of the model num_heads: parallel attention layers, or heads - Examples:: - - >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value) """ - def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False): + def __init__(self, + embed_dim, + num_heads, + dropout=0., + bias=True, + add_bias_kv=False, + add_zero_attn=False): super(MultiheadAttentionOptim, self).__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" - self.scaling = self.head_dim ** -0.5 - + self.scaling = self.head_dim**-0.5 self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias) - self._reset_parameters() - - self.conv1 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) - self.conv2 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) - self.conv3 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + self.conv1 = paddle.nn.Conv2D( + in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + self.conv2 = paddle.nn.Conv2D( + in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + self.conv3 = paddle.nn.Conv2D( + in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) def _reset_parameters(self): - - xavier_uniform_(self.out_proj.weight) - - def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, - need_weights=True, static_kv=False, attn_mask=None): + def forward(self, + query, + key, + value, + key_padding_mask=None, + incremental_state=None, + need_weights=True, + static_kv=False, + attn_mask=None): """ Inputs of forward function query: [target length, batch size, embed dim] @@ -68,8 +88,6 @@ class MultiheadAttentionOptim(nn.Layer): attn_output: [target length, batch size, embed dim] attn_output_weights: [batch size, target length, sequence length] """ - - tgt_len, bsz, embed_dim = query.shape assert embed_dim == self.embed_dim assert list(query.shape) == [tgt_len, bsz, embed_dim] @@ -80,11 +98,12 @@ class MultiheadAttentionOptim(nn.Layer): v = self._in_proj_v(value) q *= self.scaling - - q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) - k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) - v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) - + q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose( + [1, 0, 2]) + k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose( + [1, 0, 2]) + v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose( + [1, 0, 2]) src_len = k.shape[1] @@ -92,44 +111,48 @@ class MultiheadAttentionOptim(nn.Layer): assert key_padding_mask.shape[0] == bsz assert key_padding_mask.shape[1] == src_len - - attn_output_weights = paddle.bmm(q, k.transpose([0,2,1])) - assert list(attn_output_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] + attn_output_weights = paddle.bmm(q, k.transpose([0, 2, 1])) + assert list(attn_output_weights. + shape) == [bsz * self.num_heads, tgt_len, src_len] if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) attn_output_weights += attn_mask if key_padding_mask is not None: - attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) + attn_output_weights = attn_output_weights.reshape( + [bsz, self.num_heads, tgt_len, src_len]) key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32') - y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf') - - y = paddle.where(key==0.,key, y) - + y = paddle.where(key == 0., key, y) attn_output_weights += y - attn_output_weights = attn_output_weights.reshape([bsz*self.num_heads, tgt_len, src_len]) + attn_output_weights = attn_output_weights.reshape( + [bsz * self.num_heads, tgt_len, src_len]) attn_output_weights = F.softmax( - attn_output_weights.astype('float32'), axis=-1, - dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype) - attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training) + attn_output_weights.astype('float32'), + axis=-1, + dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 + else attn_output_weights.dtype) + attn_output_weights = F.dropout( + attn_output_weights, p=self.dropout, training=self.training) attn_output = paddle.bmm(attn_output_weights, v) - assert list(attn_output.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] - attn_output = attn_output.transpose([1, 0,2]).reshape([tgt_len, bsz, embed_dim]) + assert list(attn_output. + shape) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn_output = attn_output.transpose([1, 0, 2]).reshape( + [tgt_len, bsz, embed_dim]) attn_output = self.out_proj(attn_output) if need_weights: # average attention weights over heads - attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) - attn_output_weights = attn_output_weights.sum(axis=1) / self.num_heads + attn_output_weights = attn_output_weights.reshape( + [bsz, self.num_heads, tgt_len, src_len]) + attn_output_weights = attn_output_weights.sum( + axis=1) / self.num_heads else: attn_output_weights = None - return attn_output, attn_output_weights - def _in_proj_q(self, query): query = query.transpose([1, 2, 0]) query = paddle.unsqueeze(query, axis=2) @@ -139,7 +162,6 @@ class MultiheadAttentionOptim(nn.Layer): return res def _in_proj_k(self, key): - key = key.transpose([1, 2, 0]) key = paddle.unsqueeze(key, axis=2) res = self.conv2(key) @@ -148,8 +170,7 @@ class MultiheadAttentionOptim(nn.Layer): return res def _in_proj_v(self, value): - - value = value.transpose([1,2,0])#(1, 2, 0) + value = value.transpose([1, 2, 0]) #(1, 2, 0) value = paddle.unsqueeze(value, axis=2) res = self.conv3(value) res = paddle.squeeze(res, axis=2) diff --git a/ppocr/modeling/heads/rec_nrtr_optim_head.py b/ppocr/modeling/heads/rec_nrtr_optim_head.py index b9a5100a..98f212d0 100644 --- a/ppocr/modeling/heads/rec_nrtr_optim_head.py +++ b/ppocr/modeling/heads/rec_nrtr_optim_head.py @@ -1,7 +1,21 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import paddle import copy -from paddle import nn +from paddle import nn import paddle.nn.functional as F from paddle.nn import LayerList from paddle.nn.initializer import XavierNormal as xavier_uniform_ @@ -14,8 +28,9 @@ from paddle.nn.initializer import XavierNormal as xavier_normal_ zeros_ = constant_(value=0.) ones_ = constant_(value=1.) + class TransformerOptim(nn.Layer): - r"""A transformer model. User is able to modify the attributes as needed. The architechture + """A transformer model. User is able to modify the attributes as needed. The architechture is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information @@ -31,39 +46,50 @@ class TransformerOptim(nn.Layer): custom_encoder: custom encoder (default=None). custom_decoder: custom decoder (default=None). - Examples:: - >>> transformer_model = nn.Transformer(src_vocab, tgt_vocab) - >>> transformer_model = nn.Transformer(src_vocab, tgt_vocab, nhead=16, num_encoder_layers=12) """ - def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, beam_size=0, - num_decoder_layers=6, dim_feedforward=1024, attention_dropout_rate=0.0, residual_dropout_rate=0.1, - custom_encoder=None, custom_decoder=None,in_channels=0,out_channels=0,dst_vocab_size=99,scale_embedding=True): + def __init__(self, + d_model=512, + nhead=8, + num_encoder_layers=6, + beam_size=0, + num_decoder_layers=6, + dim_feedforward=1024, + attention_dropout_rate=0.0, + residual_dropout_rate=0.1, + custom_encoder=None, + custom_decoder=None, + in_channels=0, + out_channels=0, + dst_vocab_size=99, + scale_embedding=True): super(TransformerOptim, self).__init__() self.embedding = Embeddings( d_model=d_model, vocab=dst_vocab_size, padding_idx=0, - scale_embedding=scale_embedding - ) + scale_embedding=scale_embedding) self.positional_encoding = PositionalEncoding( dropout=residual_dropout_rate, - dim=d_model, - ) + dim=d_model, ) if custom_encoder is not None: self.encoder = custom_encoder else: - if num_encoder_layers > 0 : - encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, attention_dropout_rate, residual_dropout_rate) - - self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers) + if num_encoder_layers > 0: + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, attention_dropout_rate, + residual_dropout_rate) + self.encoder = TransformerEncoder(encoder_layer, + num_encoder_layers) else: self.encoder = None if custom_decoder is not None: self.decoder = custom_decoder else: - decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, attention_dropout_rate, residual_dropout_rate) + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, attention_dropout_rate, + residual_dropout_rate) self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers) self._reset_parameters() @@ -71,201 +97,205 @@ class TransformerOptim(nn.Layer): self.d_model = d_model self.nhead = nhead self.tgt_word_prj = nn.Linear(d_model, dst_vocab_size, bias_attr=False) - w0 = np.random.normal(0.0, d_model**-0.5,(d_model, dst_vocab_size)).astype(np.float32) + w0 = np.random.normal(0.0, d_model**-0.5, + (d_model, dst_vocab_size)).astype(np.float32) self.tgt_word_prj.weight.set_value(w0) self.apply(self._init_weights) - def _init_weights(self, m): - + if isinstance(m, nn.Conv2D): xavier_normal_(m.weight) if m.bias is not None: zeros_(m.bias) - def forward_train(self,src,tgt): - tgt = tgt[:, :-1] + def forward_train(self, src, tgt): + tgt = tgt[:, :-1] - - - tgt_key_padding_mask = self.generate_padding_mask(tgt) - tgt = self.embedding(tgt).transpose([1, 0, 2]) - tgt = self.positional_encoding(tgt) - tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0]) + tgt_key_padding_mask = self.generate_padding_mask(tgt) + tgt = self.embedding(tgt).transpose([1, 0, 2]) + tgt = self.positional_encoding(tgt) + tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0]) - if self.encoder is not None : - src = self.positional_encoding(src.transpose([1, 0, 2])) - memory = self.encoder(src) - else: - memory = src.squeeze(2).transpose([2, 0, 1]) - output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=None, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=None) - output = output.transpose([1, 0, 2]) - logit = self.tgt_word_prj(output) - return logit - - def forward(self, src, tgt=None): - r"""Take in and process masked source/target sequences. + if self.encoder is not None: + src = self.positional_encoding(src.transpose([1, 0, 2])) + memory = self.encoder(src) + else: + memory = src.squeeze(2).transpose([2, 0, 1]) + output = self.decoder( + tgt, + memory, + tgt_mask=tgt_mask, + memory_mask=None, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=None) + output = output.transpose([1, 0, 2]) + logit = self.tgt_word_prj(output) + return logit + def forward(self, src, targets=None): + """Take in and process masked source/target sequences. Args: src: the sequence to the encoder (required). tgt: the sequence to the decoder (required). - src_mask: the additive mask for the src sequence (optional). - tgt_mask: the additive mask for the tgt sequence (optional). - memory_mask: the additive mask for the encoder output (optional). - src_key_padding_mask: the ByteTensor mask for src keys per batch (optional). - tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional). - memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional). - Shape: - src: :math:`(S, N, E)`. - tgt: :math:`(T, N, E)`. - - src_mask: :math:`(S, S)`. - - tgt_mask: :math:`(T, T)`. - - memory_mask: :math:`(T, S)`. - - src_key_padding_mask: :math:`(N, S)`. - - tgt_key_padding_mask: :math:`(N, T)`. - - memory_key_padding_mask: :math:`(N, S)`. - - Note: [src/tgt/memory]_mask should be filled with - float('-inf') for the masked positions and float(0.0) else. These masks - ensure that predictions for position i depend only on the unmasked positions - j and are applied identically for each sequence in a batch. - [src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions - that should be masked with float('-inf') and False values will be unchanged. - This mask ensures that no information will be taken from position i if - it is masked, and has a separate mask for each sequence in a batch. - - - output: :math:`(T, N, E)`. - - Note: Due to the multi-head attention architecture in the transformer model, - the output sequence length of a transformer is same as the input sequence - (i.e. target) length of the decode. - - where S is the source sequence length, T is the target sequence length, N is the - batch size, E is the feature number - Examples: - >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) + >>> output = transformer_model(src, tgt) """ - if tgt is not None: + + if self.training: + max_len = targets[1].max() + tgt = targets[0][:, :2 + max_len] return self.forward_train(src, tgt) else: - if self.beam_size > 0 : + if self.beam_size > 0: return self.forward_beam(src) else: return self.forward_test(src) def forward_test(self, src): bs = src.shape[0] - if self.encoder is not None : + if self.encoder is not None: src = self.positional_encoding(src.transpose([1, 0, 2])) memory = self.encoder(src) else: memory = src.squeeze(2).transpose([2, 0, 1]) - dec_seq = paddle.full((bs,1), 2, dtype=paddle.int64) + dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64) for len_dec_seq in range(1, 25): src_enc = memory.clone() tgt_key_padding_mask = self.generate_padding_mask(dec_seq) dec_seq_embed = self.embedding(dec_seq).transpose([1, 0, 2]) dec_seq_embed = self.positional_encoding(dec_seq_embed) - tgt_mask = self.generate_square_subsequent_mask(dec_seq_embed.shape[0]) - output = self.decoder(dec_seq_embed, src_enc, tgt_mask=tgt_mask, memory_mask=None, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=None) + tgt_mask = self.generate_square_subsequent_mask(dec_seq_embed.shape[ + 0]) + output = self.decoder( + dec_seq_embed, + src_enc, + tgt_mask=tgt_mask, + memory_mask=None, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=None) dec_output = output.transpose([1, 0, 2]) - - dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h + + dec_output = dec_output[:, + -1, :] # Pick the last step: (bh * bm) * d_h word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1) word_prob = word_prob.reshape([1, bs, -1]) preds_idx = word_prob.argmax(axis=2) - - if paddle.equal_all(preds_idx[-1],paddle.full(preds_idx[-1].shape,3,dtype='int64')): + + if paddle.equal_all( + preds_idx[-1], + paddle.full( + preds_idx[-1].shape, 3, dtype='int64')): break preds_prob = word_prob.max(axis=2) - dec_seq = paddle.concat([dec_seq,preds_idx.reshape([-1,1])],axis=1) + dec_seq = paddle.concat( + [dec_seq, preds_idx.reshape([-1, 1])], axis=1) - return dec_seq + return dec_seq - def forward_beam(self,images): - + def forward_beam(self, images): ''' Translation work in one batch ''' def get_inst_idx_to_tensor_position_map(inst_idx_list): ''' Indicate the position of an instance in a tensor. ''' - return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)} + return { + inst_idx: tensor_position + for tensor_position, inst_idx in enumerate(inst_idx_list) + } - def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): + def collect_active_part(beamed_tensor, curr_active_inst_idx, + n_prev_active_inst, n_bm): ''' Collect tensor parts associated to active instances. ''' _, *d_hs = beamed_tensor.shape n_curr_active_inst = len(curr_active_inst_idx) new_shape = (n_curr_active_inst * n_bm, *d_hs) - beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])#contiguous() - beamed_tensor = beamed_tensor.index_select(paddle.to_tensor(curr_active_inst_idx),axis=0) + beamed_tensor = beamed_tensor.reshape( + [n_prev_active_inst, -1]) #contiguous() + beamed_tensor = beamed_tensor.index_select( + paddle.to_tensor(curr_active_inst_idx), axis=0) beamed_tensor = beamed_tensor.reshape([*new_shape]) return beamed_tensor - - def collate_active_info( - src_enc, inst_idx_to_position_map, active_inst_idx_list): + def collate_active_info(src_enc, inst_idx_to_position_map, + active_inst_idx_list): # Sentences which are still active are collected, # so the decoder will not run on completed sentences. - + n_prev_active_inst = len(inst_idx_to_position_map) - active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list] + active_inst_idx = [ + inst_idx_to_position_map[k] for k in active_inst_idx_list + ] active_inst_idx = paddle.to_tensor(active_inst_idx, dtype='int64') - active_src_enc = collect_active_part(src_enc.transpose([1, 0, 2]), active_inst_idx, n_prev_active_inst, n_bm).transpose([1, 0, 2]) - active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) + active_src_enc = collect_active_part( + src_enc.transpose([1, 0, 2]), active_inst_idx, + n_prev_active_inst, n_bm).transpose([1, 0, 2]) + active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( + active_inst_idx_list) return active_src_enc, active_inst_idx_to_position_map - def beam_decode_step( - inst_dec_beams, len_dec_seq, enc_output, inst_idx_to_position_map, n_bm, memory_key_padding_mask): + def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output, + inst_idx_to_position_map, n_bm, + memory_key_padding_mask): ''' Decode and update beam status, and then return active beam idx ''' def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): - dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done] + dec_partial_seq = [ + b.get_current_state() for b in inst_dec_beams if not b.done + ] dec_partial_seq = paddle.stack(dec_partial_seq) - + dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq]) return dec_partial_seq - def prepare_beam_memory_key_padding_mask(inst_dec_beams, memory_key_padding_mask, n_bm): + def prepare_beam_memory_key_padding_mask( + inst_dec_beams, memory_key_padding_mask, n_bm): keep = [] for idx in (memory_key_padding_mask): if not inst_dec_beams[idx].done: keep.append(idx) - memory_key_padding_mask = memory_key_padding_mask[paddle.to_tensor(keep)] + memory_key_padding_mask = memory_key_padding_mask[ + paddle.to_tensor(keep)] len_s = memory_key_padding_mask.shape[-1] n_inst = memory_key_padding_mask.shape[0] - memory_key_padding_mask = paddle.concat([memory_key_padding_mask for i in range(n_bm)],axis=1) - memory_key_padding_mask = memory_key_padding_mask.reshape([n_inst * n_bm, len_s])#repeat(1, n_bm) + memory_key_padding_mask = paddle.concat( + [memory_key_padding_mask for i in range(n_bm)], axis=1) + memory_key_padding_mask = memory_key_padding_mask.reshape( + [n_inst * n_bm, len_s]) #repeat(1, n_bm) return memory_key_padding_mask - def predict_word(dec_seq, enc_output, n_active_inst, n_bm, memory_key_padding_mask): + def predict_word(dec_seq, enc_output, n_active_inst, n_bm, + memory_key_padding_mask): tgt_key_padding_mask = self.generate_padding_mask(dec_seq) dec_seq = self.embedding(dec_seq).transpose([1, 0, 2]) dec_seq = self.positional_encoding(dec_seq) - tgt_mask = self.generate_square_subsequent_mask(dec_seq.shape[0]) + tgt_mask = self.generate_square_subsequent_mask(dec_seq.shape[ + 0]) dec_output = self.decoder( - dec_seq, enc_output, + dec_seq, + enc_output, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, ).transpose([1, 0, 2]) - dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h + dec_output = dec_output[:, + -1, :] # Pick the last step: (bh * bm) * d_h word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1) word_prob = word_prob.reshape([n_active_inst, n_bm, -1]) return word_prob - def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): + def collect_active_inst_idx_list(inst_beams, word_prob, + inst_idx_to_position_map): active_inst_idx_list = [] for inst_idx, inst_position in inst_idx_to_position_map.items(): - is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position]) + is_inst_complete = inst_beams[inst_idx].advance(word_prob[ + inst_position]) if not is_inst_complete: active_inst_idx_list += [inst_idx] @@ -274,7 +304,8 @@ class TransformerOptim(nn.Layer): n_active_inst = len(inst_idx_to_position_map) dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) memory_key_padding_mask = None - word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm, memory_key_padding_mask) + word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm, + memory_key_padding_mask) # Update the beam with predicted word prob information and collect incomplete instances active_inst_idx_list = collect_active_inst_idx_list( inst_dec_beams, word_prob, inst_idx_to_position_map) @@ -285,14 +316,17 @@ class TransformerOptim(nn.Layer): for inst_idx in range(len(inst_dec_beams)): scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() all_scores += [scores[:n_best]] - hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]] + hyps = [ + inst_dec_beams[inst_idx].get_hypothesis(i) + for i in tail_idxs[:n_best] + ] all_hyp += [hyps] return all_hyp, all_scores with paddle.no_grad(): #-- Encode - - if self.encoder is not None : + + if self.encoder is not None: src = self.positional_encoding(images.transpose([1, 0, 2])) src_enc = self.encoder(src).transpose([1, 0, 2]) else: @@ -301,45 +335,53 @@ class TransformerOptim(nn.Layer): #-- Repeat data for beam search n_bm = self.beam_size n_inst, len_s, d_h = src_enc.shape - src_enc = paddle.concat([src_enc for i in range(n_bm)],axis=1) - src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose([1, 0, 2])#repeat(1, n_bm, 1) + src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1) + src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose( + [1, 0, 2]) #repeat(1, n_bm, 1) #-- Prepare beams inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)] #-- Bookkeeping for active or not active_inst_idx_list = list(range(n_inst)) - inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) + inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( + active_inst_idx_list) #-- Decode for len_dec_seq in range(1, 25): src_enc_copy = src_enc.clone() active_inst_idx_list = beam_decode_step( - inst_dec_beams, len_dec_seq, src_enc_copy, inst_idx_to_position_map, n_bm, None) + inst_dec_beams, len_dec_seq, src_enc_copy, + inst_idx_to_position_map, n_bm, None) if not active_inst_idx_list: break # all instances have finished their path to src_enc, inst_idx_to_position_map = collate_active_info( - src_enc_copy, inst_idx_to_position_map, active_inst_idx_list) - batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, 1) + src_enc_copy, inst_idx_to_position_map, + active_inst_idx_list) + batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, + 1) result_hyp = [] for bs_hyp in batch_hyp: - bs_hyp_pad =bs_hyp[0]+[3]*(25-len(bs_hyp[0])) + bs_hyp_pad = bs_hyp[0] + [3] * (25 - len(bs_hyp[0])) result_hyp.append(bs_hyp_pad) - return paddle.to_tensor(np.array(result_hyp),dtype=paddle.int64) + return paddle.to_tensor(np.array(result_hyp), dtype=paddle.int64) def generate_square_subsequent_mask(self, sz): - r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). + """Generate a square mask for the sequence. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). """ - mask = paddle.zeros([sz, sz],dtype='float32') - mask_inf = paddle.triu(paddle.full(shape=[sz,sz], dtype='float32', fill_value='-inf'),diagonal=1) - mask = mask+mask_inf + mask = paddle.zeros([sz, sz], dtype='float32') + mask_inf = paddle.triu( + paddle.full( + shape=[sz, sz], dtype='float32', fill_value='-inf'), + diagonal=1) + mask = mask + mask_inf return mask def generate_padding_mask(self, x): - padding_mask = x.equal(paddle.to_tensor(0,dtype=x.dtype)) + padding_mask = x.equal(paddle.to_tensor(0, dtype=x.dtype)) return padding_mask def _reset_parameters(self): - r"""Initiate parameters in the transformer model.""" + """Initiate parameters in the transformer model.""" for p in self.parameters(): if p.dim() > 1: @@ -347,16 +389,11 @@ class TransformerOptim(nn.Layer): class TransformerEncoder(nn.Layer): - r"""TransformerEncoder is a stack of N encoder layers - + """TransformerEncoder is a stack of N encoder layers Args: encoder_layer: an instance of the TransformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). norm: the layer normalization component (optional). - - Examples:: - >>> encoder_layer = nn.TransformerEncoderLayer(d_model, nhead) - >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers) """ def __init__(self, encoder_layer, num_layers): @@ -364,50 +401,46 @@ class TransformerEncoder(nn.Layer): self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers - def forward(self, src): - r"""Pass the input through the endocder layers in turn. - + """Pass the input through the endocder layers in turn. Args: src: the sequnce to the encoder (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - see the docs in Transformer class. """ output = src for i in range(self.num_layers): - output = self.layers[i](output, src_mask=None, + output = self.layers[i](output, + src_mask=None, src_key_padding_mask=None) return output class TransformerDecoder(nn.Layer): - r"""TransformerDecoder is a stack of N decoder layers + """TransformerDecoder is a stack of N decoder layers Args: decoder_layer: an instance of the TransformerDecoderLayer() class (required). num_layers: the number of sub-decoder-layers in the decoder (required). norm: the layer normalization component (optional). - Examples:: - >>> decoder_layer = nn.TransformerDecoderLayer(d_model, nhead) - >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers) """ def __init__(self, decoder_layer, num_layers): super(TransformerDecoder, self).__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers - - def forward(self, tgt, memory, tgt_mask=None, - memory_mask=None, tgt_key_padding_mask=None, + def forward(self, + tgt, + memory, + tgt_mask=None, + memory_mask=None, + tgt_key_padding_mask=None, memory_key_padding_mask=None): - r"""Pass the inputs (and mask) through the decoder layer in turn. + """Pass the inputs (and mask) through the decoder layer in turn. Args: tgt: the sequence to the decoder (required). @@ -416,21 +449,22 @@ class TransformerDecoder(nn.Layer): memory_mask: the mask for the memory sequence (optional). tgt_key_padding_mask: the mask for the tgt keys per batch (optional). memory_key_padding_mask: the mask for the memory keys per batch (optional). - - Shape: - see the docs in Transformer class. """ output = tgt for i in range(self.num_layers): - output = self.layers[i](output, memory, tgt_mask=tgt_mask, - memory_mask=memory_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask) + output = self.layers[i]( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) return output + class TransformerEncoderLayer(nn.Layer): - r"""TransformerEncoderLayer is made up of self-attn and feedforward network. + """TransformerEncoderLayer is made up of self-attn and feedforward network. This standard encoder layer is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in @@ -443,16 +477,26 @@ class TransformerEncoderLayer(nn.Layer): dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). - Examples:: - >>> encoder_layer = nn.TransformerEncoderLayer(d_model, nhead) """ - def __init__(self, d_model, nhead, dim_feedforward=2048, attention_dropout_rate=0.0, residual_dropout_rate=0.1): + def __init__(self, + d_model, + nhead, + dim_feedforward=2048, + attention_dropout_rate=0.0, + residual_dropout_rate=0.1): super(TransformerEncoderLayer, self).__init__() - self.self_attn = MultiheadAttentionOptim(d_model, nhead, dropout=attention_dropout_rate) + self.self_attn = MultiheadAttentionOptim( + d_model, nhead, dropout=attention_dropout_rate) - self.conv1 = Conv2D(in_channels=d_model, out_channels=dim_feedforward, kernel_size=(1, 1)) - self.conv2 = Conv2D(in_channels=dim_feedforward, out_channels=d_model, kernel_size=(1, 1)) + self.conv1 = Conv2D( + in_channels=d_model, + out_channels=dim_feedforward, + kernel_size=(1, 1)) + self.conv2 = Conv2D( + in_channels=dim_feedforward, + out_channels=d_model, + kernel_size=(1, 1)) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) @@ -460,18 +504,18 @@ class TransformerEncoderLayer(nn.Layer): self.dropout2 = Dropout(residual_dropout_rate) def forward(self, src, src_mask=None, src_key_padding_mask=None): - r"""Pass the input through the endocder layer. - + """Pass the input through the endocder layer. Args: src: the sequnce to the encoder layer (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - see the docs in Transformer class. """ - src2 = self.self_attn(src, src, src, attn_mask=src_mask, - key_padding_mask=src_key_padding_mask)[0] + src2 = self.self_attn( + src, + src, + src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] src = src + self.dropout1(src2) src = self.norm1(src) @@ -487,8 +531,9 @@ class TransformerEncoderLayer(nn.Layer): src = self.norm2(src) return src + class TransformerDecoderLayer(nn.Layer): - r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. + """TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. This standard decoder layer is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in @@ -501,17 +546,28 @@ class TransformerDecoderLayer(nn.Layer): dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). - Examples:: - >>> decoder_layer = nn.TransformerDecoderLayer(d_model, nhead) """ - def __init__(self, d_model, nhead, dim_feedforward=2048, attention_dropout_rate=0.0, residual_dropout_rate=0.1): + def __init__(self, + d_model, + nhead, + dim_feedforward=2048, + attention_dropout_rate=0.0, + residual_dropout_rate=0.1): super(TransformerDecoderLayer, self).__init__() - self.self_attn = MultiheadAttentionOptim(d_model, nhead, dropout=attention_dropout_rate) - self.multihead_attn = MultiheadAttentionOptim(d_model, nhead, dropout=attention_dropout_rate) + self.self_attn = MultiheadAttentionOptim( + d_model, nhead, dropout=attention_dropout_rate) + self.multihead_attn = MultiheadAttentionOptim( + d_model, nhead, dropout=attention_dropout_rate) - self.conv1 = Conv2D(in_channels=d_model, out_channels=dim_feedforward, kernel_size=(1, 1)) - self.conv2 = Conv2D(in_channels=dim_feedforward, out_channels=d_model, kernel_size=(1, 1)) + self.conv1 = Conv2D( + in_channels=d_model, + out_channels=dim_feedforward, + kernel_size=(1, 1)) + self.conv2 = Conv2D( + in_channels=dim_feedforward, + out_channels=d_model, + kernel_size=(1, 1)) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) @@ -520,9 +576,14 @@ class TransformerDecoderLayer(nn.Layer): self.dropout2 = Dropout(residual_dropout_rate) self.dropout3 = Dropout(residual_dropout_rate) - def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, - tgt_key_padding_mask=None, memory_key_padding_mask=None): - r"""Pass the inputs (and mask) through the decoder layer. + def forward(self, + tgt, + memory, + tgt_mask=None, + memory_mask=None, + tgt_key_padding_mask=None, + memory_key_padding_mask=None): + """Pass the inputs (and mask) through the decoder layer. Args: tgt: the sequence to the decoder layer (required). @@ -532,15 +593,21 @@ class TransformerDecoderLayer(nn.Layer): tgt_key_padding_mask: the mask for the tgt keys per batch (optional). memory_key_padding_mask: the mask for the memory keys per batch (optional). - Shape: - see the docs in Transformer class. """ - tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask)[0] + tgt2 = self.self_attn( + tgt, + tgt, + tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) - tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask)[0] + tgt2 = self.multihead_attn( + tgt, + memory, + memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) @@ -562,9 +629,8 @@ def _get_clones(module, N): return LayerList([copy.deepcopy(module) for i in range(N)]) - class PositionalEncoding(nn.Layer): - r"""Inject some information about the relative or absolute position of the tokens + """Inject some information about the relative or absolute position of the tokens in the sequence. The positional encodings have the same dimension as the embeddings, so that the two can be summed. Here, we use sine and cosine functions of different frequencies. @@ -586,7 +652,9 @@ class PositionalEncoding(nn.Layer): pe = paddle.zeros([max_len, dim]) position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1) - div_term = paddle.exp(paddle.arange(0, dim, 2).astype('float32') * (-math.log(10000.0) / dim)) + div_term = paddle.exp( + paddle.arange(0, dim, 2).astype('float32') * + (-math.log(10000.0) / dim)) pe[:, 0::2] = paddle.sin(position * div_term) pe[:, 1::2] = paddle.cos(position * div_term) pe = pe.unsqueeze(0) @@ -594,7 +662,7 @@ class PositionalEncoding(nn.Layer): self.register_buffer('pe', pe) def forward(self, x): - r"""Inputs of forward function + """Inputs of forward function Args: x: the sequence fed to the positional encoder model (required). Shape: @@ -608,7 +676,7 @@ class PositionalEncoding(nn.Layer): class PositionalEncoding_2d(nn.Layer): - r"""Inject some information about the relative or absolute position of the tokens + """Inject some information about the relative or absolute position of the tokens in the sequence. The positional encodings have the same dimension as the embeddings, so that the two can be summed. Here, we use sine and cosine functions of different frequencies. @@ -630,7 +698,9 @@ class PositionalEncoding_2d(nn.Layer): pe = paddle.zeros([max_len, dim]) position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1) - div_term = paddle.exp(paddle.arange(0, dim, 2).astype('float32') * (-math.log(10000.0) / dim)) + div_term = paddle.exp( + paddle.arange(0, dim, 2).astype('float32') * + (-math.log(10000.0) / dim)) pe[:, 0::2] = paddle.sin(position * div_term) pe[:, 1::2] = paddle.cos(position * div_term) pe = pe.unsqueeze(0).transpose([1, 0, 2]) @@ -644,7 +714,7 @@ class PositionalEncoding_2d(nn.Layer): self.linear2.weight.data.fill_(1.) def forward(self, x): - r"""Inputs of forward function + """Inputs of forward function Args: x: the sequence fed to the positional encoder model (required). Shape: @@ -666,7 +736,9 @@ class PositionalEncoding_2d(nn.Layer): h_pe = h_pe.unsqueeze(3) x = x + w_pe + h_pe - x = x.reshape([x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).transpose([2,0,1]) + x = x.reshape( + [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).transpose( + [2, 0, 1]) return self.dropout(x) @@ -675,8 +747,9 @@ class Embeddings(nn.Layer): def __init__(self, d_model, vocab, padding_idx, scale_embedding): super(Embeddings, self).__init__() self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx) - w0 = np.random.normal(0.0, d_model**-0.5,(vocab, d_model)).astype(np.float32) - self.embedding.weight.set_value(w0) + w0 = np.random.normal(0.0, d_model**-0.5, + (vocab, d_model)).astype(np.float32) + self.embedding.weight.set_value(w0) self.d_model = d_model self.scale_embedding = scale_embedding @@ -687,9 +760,6 @@ class Embeddings(nn.Layer): return self.embedding(x) - - - class Beam(): ''' Beam search ''' @@ -698,12 +768,12 @@ class Beam(): self.size = size self._done = False # The score for each translation on the beam. - self.scores = paddle.zeros((size,), dtype=paddle.float32) + self.scores = paddle.zeros((size, ), dtype=paddle.float32) self.all_scores = [] # The backpointers at each time-step. self.prev_ks = [] # The outputs at each time-step. - self.next_ys = [paddle.full((size,), 0, dtype=paddle.int64)] + self.next_ys = [paddle.full((size, ), 0, dtype=paddle.int64)] self.next_ys[0][0] = 2 def get_current_state(self): @@ -729,28 +799,26 @@ class Beam(): beam_lk = word_prob[0] flat_beam_lk = beam_lk.reshape([-1]) - best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort + best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, + True) # 1st sort self.all_scores.append(self.scores) self.scores = best_scores - # bestScoresId is flattened as a (beam x word) array, # so we need to calculate which word and beam each score came from prev_k = best_scores_id // num_words self.prev_ks.append(prev_k) - - self.next_ys.append(best_scores_id - prev_k * num_words) - + self.next_ys.append(best_scores_id - prev_k * num_words) # End condition is when top-of-beam is EOS. - if self.next_ys[-1][0] == 3 : + if self.next_ys[-1][0] == 3: self._done = True self.all_scores.append(self.scores) - return self._done def sort_scores(self): "Sort the scores." - return self.scores, paddle.to_tensor([i for i in range(self.scores.shape[0])],dtype='int32') + return self.scores, paddle.to_tensor( + [i for i in range(self.scores.shape[0])], dtype='int32') def get_the_best_score_and_idx(self): "Get the score of the best in the beam." @@ -759,7 +827,6 @@ class Beam(): def get_tentative_hypothesis(self): "Get the decoded sequence for the current timestep." - if len(self.next_ys) == 1: dec_seq = self.next_ys[0].unsqueeze(1) else: @@ -767,13 +834,12 @@ class Beam(): hyps = [self.get_hypothesis(k) for k in keys] hyps = [[2] + h for h in hyps] dec_seq = paddle.to_tensor(hyps, dtype='int64') - return dec_seq def get_hypothesis(self, k): """ Walk back to construct the full hypothesis. """ hyp = [] for j in range(len(self.prev_ks) - 1, -1, -1): - hyp.append(self.next_ys[j+1][k]) + hyp.append(self.next_ys[j + 1][k]) k = self.prev_ks[j][k] return list(map(lambda x: x.item(), hyp[::-1])) diff --git a/tools/program.py b/tools/program.py index aa5f9388..60a5e482 100755 --- a/tools/program.py +++ b/tools/program.py @@ -189,9 +189,9 @@ def train(config, use_nrtr = config['Architecture']['algorithm'] == "NRTR" - try: + try: model_type = config['Architecture']['model_type'] - except: + except: model_type = None if 'start_epoch' in best_model_dict: @@ -216,11 +216,8 @@ def train(config, images = batch[0] if use_srn: model_average = True - if use_srn or model_type == 'table': + if use_srn or model_type == 'table' or use_nrtr: preds = model(images, data=batch[1:]) - elif use_nrtr: - max_len = batch[2].max() - preds = model(images, batch[1][:,:2+max_len]) else: preds = model(images) loss = loss_class(preds, batch) @@ -405,9 +402,7 @@ def preprocess(is_train=False): alg = config['Architecture']['algorithm'] assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', - 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn' - ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' From c6359258958bc878cfd50e67f978ab4691cf4153 Mon Sep 17 00:00:00 2001 From: topduke <784990967@qq.com> Date: Thu, 19 Aug 2021 17:34:38 +0800 Subject: [PATCH 13/18] Update rec_metric.py --- ppocr/metrics/rec_metric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index e4b65a50..3e82fe75 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -48,8 +48,8 @@ class RecMetric(object): 'norm_edit_dis': 0, } """ - acc = 1.0 * self.correct_num / (self.all_num) - norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num) + acc = 1.0 * self.correct_num / self.all_num + norm_edit_dis = 1 - self.norm_edit_dis / self.all_num self.reset() return {'acc': acc, 'norm_edit_dis': norm_edit_dis} From c8094e6575c3327be63f415ca274805631850ff8 Mon Sep 17 00:00:00 2001 From: topduke <784990967@qq.com> Date: Thu, 19 Aug 2021 19:08:23 +0800 Subject: [PATCH 14/18] Update rec_nrtr_optim_head.py --- ppocr/modeling/heads/rec_nrtr_optim_head.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ppocr/modeling/heads/rec_nrtr_optim_head.py b/ppocr/modeling/heads/rec_nrtr_optim_head.py index 98f212d0..63473c11 100644 --- a/ppocr/modeling/heads/rec_nrtr_optim_head.py +++ b/ppocr/modeling/heads/rec_nrtr_optim_head.py @@ -216,7 +216,7 @@ class TransformerOptim(nn.Layer): new_shape = (n_curr_active_inst * n_bm, *d_hs) beamed_tensor = beamed_tensor.reshape( - [n_prev_active_inst, -1]) #contiguous() + [n_prev_active_inst, -1]) beamed_tensor = beamed_tensor.index_select( paddle.to_tensor(curr_active_inst_idx), axis=0) beamed_tensor = beamed_tensor.reshape([*new_shape]) @@ -337,7 +337,7 @@ class TransformerOptim(nn.Layer): n_inst, len_s, d_h = src_enc.shape src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1) src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose( - [1, 0, 2]) #repeat(1, n_bm, 1) + [1, 0, 2]) #-- Prepare beams inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)] From 88f7c59e8fbe82ef9d5fbce03f91952cbaa59d27 Mon Sep 17 00:00:00 2001 From: topduke <784990967@qq.com> Date: Mon, 23 Aug 2021 11:42:58 +0800 Subject: [PATCH 15/18] Update __init__.py --- ppocr/postprocess/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index f1829e3e..61593987 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -24,10 +24,8 @@ __all__ = ['build_post_process'] from .db_postprocess import DBPostProcess, DistillationDBPostProcess from .east_postprocess import EASTPostProcess from .sast_postprocess import SASTPostProcess - from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode, \ TableLabelDecode - from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess From 533b15c82096067cc5e7c0ab5f463e7bbac5c224 Mon Sep 17 00:00:00 2001 From: topduke <784990967@qq.com> Date: Mon, 23 Aug 2021 11:43:45 +0800 Subject: [PATCH 16/18] Update program.py --- tools/program.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tools/program.py b/tools/program.py index 60a5e482..5dd0cbfa 100755 --- a/tools/program.py +++ b/tools/program.py @@ -186,7 +186,6 @@ def train(config, model.train() use_srn = config['Architecture']['algorithm'] == "SRN" - use_nrtr = config['Architecture']['algorithm'] == "NRTR" try: From 73058cc082ab2b4afadd8461d9663797b214088a Mon Sep 17 00:00:00 2001 From: topduke <784990967@qq.com> Date: Mon, 23 Aug 2021 11:45:10 +0800 Subject: [PATCH 17/18] Update program.py --- tools/program.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tools/program.py b/tools/program.py index 5dd0cbfa..e7742a8f 100755 --- a/tools/program.py +++ b/tools/program.py @@ -364,7 +364,6 @@ def eval(model, break images = batch[0] start = time.time() - if use_srn or model_type == 'table': preds = model(images, data=batch[1:]) else: From 2bf8ad9b7d686e8366ead5733313e13e0755a4c4 Mon Sep 17 00:00:00 2001 From: Topdu <784990967@qq.com> Date: Tue, 24 Aug 2021 07:46:43 +0000 Subject: [PATCH 18/18] modify transformeroptim, resize --- configs/rec/rec_mtb_nrtr.yml | 8 +++-- ppocr/data/imaug/__init__.py | 2 +- ppocr/data/imaug/rec_img_aug.py | 31 +++++++------------ ppocr/losses/rec_nrtr_loss.py | 22 +++++-------- ppocr/modeling/heads/__init__.py | 4 +-- ppocr/modeling/heads/multiheadAttention.py | 4 +-- ...ec_nrtr_optim_head.py => rec_nrtr_head.py} | 15 +++++---- 7 files changed, 35 insertions(+), 51 deletions(-) rename ppocr/modeling/heads/{rec_nrtr_optim_head.py => rec_nrtr_head.py} (98%) diff --git a/configs/rec/rec_mtb_nrtr.yml b/configs/rec/rec_mtb_nrtr.yml index c89de02b..635c392d 100644 --- a/configs/rec/rec_mtb_nrtr.yml +++ b/configs/rec/rec_mtb_nrtr.yml @@ -43,7 +43,7 @@ Architecture: name: MTB cnn_num: 2 Head: - name: TransformerOptim + name: Transformer d_model: 512 num_encoder_layers: 6 beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation. @@ -69,8 +69,9 @@ Train: img_mode: BGR channel_first: False - NRTRLabelEncode: # Class handling label - - PILResize: + - NRTRRecResizeImg: image_shape: [100, 32] + resize_type: PIL # PIL or OpenCV - KeepKeys: keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order loader: @@ -88,8 +89,9 @@ Eval: img_mode: BGR channel_first: False - NRTRLabelEncode: # Class handling label - - PILResize: + - NRTRRecResizeImg: image_shape: [100, 32] + resize_type: PIL # PIL or OpenCV - KeepKeys: keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order loader: diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 5c384c1d..4418d075 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap from .make_shrink_map import MakeShrinkMap from .random_crop_data import EastRandomCropData, PSERandomCrop -from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, PILResize, CVResize +from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg from .randaugment import RandAugment from .copy_paste import CopyPaste from .operators import * diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 13a5c71d..e914d384 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -42,30 +42,21 @@ class ClsResizeImg(object): data['image'] = norm_img return data -class PILResize(object): - def __init__(self, image_shape, **kwargs): + +class NRTRRecResizeImg(object): + def __init__(self, image_shape, resize_type, **kwargs): self.image_shape = image_shape + self.resize_type = resize_type def __call__(self, data): img = data['image'] - image_pil = Image.fromarray(np.uint8(img)) - norm_img = image_pil.resize(self.image_shape, Image.ANTIALIAS) - norm_img = np.array(norm_img) - norm_img = np.expand_dims(norm_img, -1) - norm_img = norm_img.transpose((2, 0, 1)) - data['image'] = norm_img.astype(np.float32) / 128. - 1. - return data - - -class CVResize(object): - def __init__(self, image_shape, **kwargs): - self.image_shape = image_shape - - def __call__(self, data): - img = data['image'] - #print(img) - norm_img = cv2.resize(img,self.image_shape) - norm_img = np.expand_dims(norm_img, -1) + if self.resize_type == 'PIL': + image_pil = Image.fromarray(np.uint8(img)) + img = image_pil.resize(self.image_shape, Image.ANTIALIAS) + img = np.array(img) + if self.resize_type == 'OpenCV': + img = cv2.resize(img, self.image_shape) + norm_img = np.expand_dims(img, -1) norm_img = norm_img.transpose((2, 0, 1)) data['image'] = norm_img.astype(np.float32) / 128. - 1. return data diff --git a/ppocr/losses/rec_nrtr_loss.py b/ppocr/losses/rec_nrtr_loss.py index 915f506d..41714dd2 100644 --- a/ppocr/losses/rec_nrtr_loss.py +++ b/ppocr/losses/rec_nrtr_loss.py @@ -3,34 +3,26 @@ from paddle import nn import paddle.nn.functional as F -def cal_performance(pred, tgt): - - pred = pred.max(1)[1] - tgt = tgt.contiguous().view(-1) - non_pad_mask = tgt.ne(0) - n_correct = pred.eq(tgt) - n_correct = n_correct.masked_select(non_pad_mask).sum().item() - return n_correct - - class NRTRLoss(nn.Layer): - def __init__(self,smoothing=True, **kwargs): + def __init__(self, smoothing=True, **kwargs): super(NRTRLoss, self).__init__() - self.loss_func = nn.CrossEntropyLoss(reduction='mean',ignore_index=0) + self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0) self.smoothing = smoothing def forward(self, pred, batch): pred = pred.reshape([-1, pred.shape[2]]) max_len = batch[2].max() - tgt = batch[1][:,1:2+max_len] - tgt = tgt.reshape([-1] ) + tgt = batch[1][:, 1:2 + max_len] + tgt = tgt.reshape([-1]) if self.smoothing: eps = 0.1 n_class = pred.shape[1] one_hot = F.one_hot(tgt, pred.shape[1]) one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) log_prb = F.log_softmax(pred, axis=1) - non_pad_mask = paddle.not_equal(tgt, paddle.zeros(tgt.shape,dtype='int64')) + non_pad_mask = paddle.not_equal( + tgt, paddle.zeros( + tgt.shape, dtype='int64')) loss = -(one_hot * log_prb).sum(axis=1) loss = loss.masked_select(non_pad_mask).mean() else: diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 11fd4b26..572ec4aa 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -26,13 +26,13 @@ def build_head(config): from .rec_ctc_head import CTCHead from .rec_att_head import AttentionHead from .rec_srn_head import SRNHead - from .rec_nrtr_optim_head import TransformerOptim + from .rec_nrtr_head import Transformer # cls head from .cls_head import ClsHead support_dict = [ 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', - 'SRNHead', 'PGHead', 'TransformerOptim', 'TableAttentionHead' + 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead' ] #table head diff --git a/ppocr/modeling/heads/multiheadAttention.py b/ppocr/modeling/heads/multiheadAttention.py index 4be37025..651d4f57 100755 --- a/ppocr/modeling/heads/multiheadAttention.py +++ b/ppocr/modeling/heads/multiheadAttention.py @@ -24,7 +24,7 @@ zeros_ = constant_(value=0.) ones_ = constant_(value=1.) -class MultiheadAttentionOptim(nn.Layer): +class MultiheadAttention(nn.Layer): """Allows the model to jointly attend to information from different representation subspaces. See reference: Attention Is All You Need @@ -46,7 +46,7 @@ class MultiheadAttentionOptim(nn.Layer): bias=True, add_bias_kv=False, add_zero_attn=False): - super(MultiheadAttentionOptim, self).__init__() + super(MultiheadAttention, self).__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout diff --git a/ppocr/modeling/heads/rec_nrtr_optim_head.py b/ppocr/modeling/heads/rec_nrtr_head.py similarity index 98% rename from ppocr/modeling/heads/rec_nrtr_optim_head.py rename to ppocr/modeling/heads/rec_nrtr_head.py index 63473c11..05dba677 100644 --- a/ppocr/modeling/heads/rec_nrtr_optim_head.py +++ b/ppocr/modeling/heads/rec_nrtr_head.py @@ -21,7 +21,7 @@ from paddle.nn import LayerList from paddle.nn.initializer import XavierNormal as xavier_uniform_ from paddle.nn import Dropout, Linear, LayerNorm, Conv2D import numpy as np -from ppocr.modeling.heads.multiheadAttention import MultiheadAttentionOptim +from ppocr.modeling.heads.multiheadAttention import MultiheadAttention from paddle.nn.initializer import Constant as constant_ from paddle.nn.initializer import XavierNormal as xavier_normal_ @@ -29,7 +29,7 @@ zeros_ = constant_(value=0.) ones_ = constant_(value=1.) -class TransformerOptim(nn.Layer): +class Transformer(nn.Layer): """A transformer model. User is able to modify the attributes as needed. The architechture is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and @@ -63,7 +63,7 @@ class TransformerOptim(nn.Layer): out_channels=0, dst_vocab_size=99, scale_embedding=True): - super(TransformerOptim, self).__init__() + super(Transformer, self).__init__() self.embedding = Embeddings( d_model=d_model, vocab=dst_vocab_size, @@ -215,8 +215,7 @@ class TransformerOptim(nn.Layer): n_curr_active_inst = len(curr_active_inst_idx) new_shape = (n_curr_active_inst * n_bm, *d_hs) - beamed_tensor = beamed_tensor.reshape( - [n_prev_active_inst, -1]) + beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1]) beamed_tensor = beamed_tensor.index_select( paddle.to_tensor(curr_active_inst_idx), axis=0) beamed_tensor = beamed_tensor.reshape([*new_shape]) @@ -486,7 +485,7 @@ class TransformerEncoderLayer(nn.Layer): attention_dropout_rate=0.0, residual_dropout_rate=0.1): super(TransformerEncoderLayer, self).__init__() - self.self_attn = MultiheadAttentionOptim( + self.self_attn = MultiheadAttention( d_model, nhead, dropout=attention_dropout_rate) self.conv1 = Conv2D( @@ -555,9 +554,9 @@ class TransformerDecoderLayer(nn.Layer): attention_dropout_rate=0.0, residual_dropout_rate=0.1): super(TransformerDecoderLayer, self).__init__() - self.self_attn = MultiheadAttentionOptim( + self.self_attn = MultiheadAttention( d_model, nhead, dropout=attention_dropout_rate) - self.multihead_attn = MultiheadAttentionOptim( + self.multihead_attn = MultiheadAttention( d_model, nhead, dropout=attention_dropout_rate) self.conv1 = Conv2D(