Merge pull request #1368 from WenmuZhou/dygraph_rc
[Dygraph] add config file of DB, CRNN,Rosetta, StarNet
This commit is contained in:
commit
5b2f6b7524
|
@ -45,9 +45,7 @@ Optimizer:
|
|||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
# name: Cosine
|
||||
learning_rate: 0.001
|
||||
# warmup_epoch: 0
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/det_rc/det_r50_vd/
|
||||
save_epoch_step: 1200
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [5000,4000]
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
load_static_weights: True
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 50
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 256
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
Loss:
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
learning_rate: 0.001
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: DBPostProcess
|
||||
thresh: 0.3
|
||||
box_thresh: 0.7
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DetMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [0.5]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||
- EastRandomCropData:
|
||||
size: [640, 640]
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 16
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
image_shape: [736, 1280]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 8
|
|
@ -5,7 +5,7 @@ Global:
|
|||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/mv3_none_bilstm_ctc/
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
# evaluation is run every 2000 iterations
|
||||
eval_batch_step: [0, 2000]
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
cal_metric_during_train: True
|
||||
|
@ -13,7 +13,7 @@ Global:
|
|||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
character_type: en
|
||||
|
@ -21,7 +21,6 @@ Global:
|
|||
infer_mode: False
|
||||
use_space_char: False
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 72
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/mv3_none_none_ctc/
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 2000 iterations
|
||||
eval_batch_step: [0, 2000]
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
character_type: en
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
learning_rate: 0.0005
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: Rosetta
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: reshape
|
||||
Head:
|
||||
name: CTCHead
|
||||
fc_decay: 0.0004
|
||||
|
||||
Loss:
|
||||
name: CTCLoss
|
||||
|
||||
PostProcess:
|
||||
name: CTCLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- CTCLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 100]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
batch_size_per_card: 256
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- CTCLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 100]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 256
|
||||
num_workers: 8
|
|
@ -0,0 +1,100 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 72
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/mv3_tps_bilstm_ctc/
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 2000 iterations
|
||||
eval_batch_step: [0, 2000]
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
character_type: en
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
learning_rate: 0.0005
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: STARNet
|
||||
Transform:
|
||||
name: TPS
|
||||
num_fiducial: 20
|
||||
loc_lr: 0.1
|
||||
model_name: small
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 96
|
||||
Head:
|
||||
name: CTCHead
|
||||
fc_decay: 0.0004
|
||||
|
||||
Loss:
|
||||
name: CTCLoss
|
||||
|
||||
PostProcess:
|
||||
name: CTCLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- CTCLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 100]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
batch_size_per_card: 256
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- CTCLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 100]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 256
|
||||
num_workers: 4
|
|
@ -5,7 +5,7 @@ Global:
|
|||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/r34_vd_none_bilstm_ctc/
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
# evaluation is run every 2000 iterations
|
||||
eval_batch_step: [0, 2000]
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
cal_metric_during_train: True
|
||||
|
@ -13,7 +13,7 @@ Global:
|
|||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
character_type: en
|
||||
|
@ -21,7 +21,6 @@ Global:
|
|||
infer_mode: False
|
||||
use_space_char: False
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
|
@ -71,7 +70,7 @@ Train:
|
|||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
shuffle: True
|
||||
batch_size_per_card: 256
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 72
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/r34_vd_none_none_ctc/
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 2000 iterations
|
||||
eval_batch_step: [0, 2000]
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
character_type: en
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
learning_rate: 0.0005
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: Rosetta
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 34
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: reshape
|
||||
Head:
|
||||
name: CTCHead
|
||||
fc_decay: 0.0004
|
||||
|
||||
Loss:
|
||||
name: CTCLoss
|
||||
|
||||
PostProcess:
|
||||
name: CTCLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- CTCLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 100]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 256
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- CTCLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 100]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 256
|
||||
num_workers: 4
|
|
@ -5,7 +5,7 @@ Global:
|
|||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/r34_vd_tps_bilstm_ctc/
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
# evaluation is run every 2000 iterations
|
||||
eval_batch_step: [0, 2000]
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
cal_metric_during_train: True
|
||||
|
@ -13,7 +13,7 @@ Global:
|
|||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
character_type: en
|
||||
|
@ -21,7 +21,6 @@ Global:
|
|||
infer_mode: False
|
||||
use_space_char: False
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
|
@ -34,7 +33,7 @@ Optimizer:
|
|||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: CRNN
|
||||
algorithm: STARNet
|
||||
Transform:
|
||||
name: TPS
|
||||
num_fiducial: 20
|
||||
|
|
|
@ -1,214 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import string
|
||||
import re
|
||||
from .check import check_config_params
|
||||
import sys
|
||||
|
||||
|
||||
class CharacterOps(object):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, config):
|
||||
self.character_type = config['character_type']
|
||||
self.loss_type = config['loss_type']
|
||||
self.max_text_len = config['max_text_length']
|
||||
if self.character_type == "en":
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
elif self.character_type == "ch":
|
||||
character_dict_path = config['character_dict_path']
|
||||
add_space = False
|
||||
if 'use_space_char' in config:
|
||||
add_space = config['use_space_char']
|
||||
self.character_str = ""
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
self.character_str += line
|
||||
if add_space:
|
||||
self.character_str += " "
|
||||
dict_character = list(self.character_str)
|
||||
elif self.character_type == "en_sensitive":
|
||||
# same with ASTER setting (use 94 char).
|
||||
self.character_str = string.printable[:-6]
|
||||
dict_character = list(self.character_str)
|
||||
else:
|
||||
self.character_str = None
|
||||
assert self.character_str is not None, \
|
||||
"Nonsupport type of the character: {}".format(self.character_str)
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
if self.loss_type == "attention":
|
||||
dict_character = [self.beg_str, self.end_str] + dict_character
|
||||
elif self.loss_type == "srn":
|
||||
dict_character = dict_character + [self.beg_str, self.end_str]
|
||||
self.dict = {}
|
||||
for i, char in enumerate(dict_character):
|
||||
self.dict[char] = i
|
||||
self.character = dict_character
|
||||
|
||||
def encode(self, text):
|
||||
"""convert text-label into text-index.
|
||||
input:
|
||||
text: text labels of each image. [batch_size]
|
||||
|
||||
output:
|
||||
text: concatenated text index for CTCLoss.
|
||||
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
|
||||
length: length of each text. [batch_size]
|
||||
"""
|
||||
if self.character_type == "en":
|
||||
text = text.lower()
|
||||
|
||||
text_list = []
|
||||
for char in text:
|
||||
if char not in self.dict:
|
||||
continue
|
||||
text_list.append(self.dict[char])
|
||||
text = np.array(text_list)
|
||||
return text
|
||||
|
||||
def decode(self, text_index, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
char_list = []
|
||||
char_num = self.get_char_num()
|
||||
|
||||
if self.loss_type == "attention":
|
||||
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||
end_idx = self.get_beg_end_flag_idx("end")
|
||||
ignored_tokens = [beg_idx, end_idx]
|
||||
else:
|
||||
ignored_tokens = [char_num]
|
||||
|
||||
for idx in range(len(text_index)):
|
||||
if text_index[idx] in ignored_tokens:
|
||||
continue
|
||||
if is_remove_duplicate:
|
||||
if idx > 0 and text_index[idx - 1] == text_index[idx]:
|
||||
continue
|
||||
char_list.append(self.character[int(text_index[idx])])
|
||||
text = ''.join(char_list)
|
||||
return text
|
||||
|
||||
def get_char_num(self):
|
||||
return len(self.character)
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end):
|
||||
if self.loss_type == "attention":
|
||||
if beg_or_end == "beg":
|
||||
idx = np.array(self.dict[self.beg_str])
|
||||
elif beg_or_end == "end":
|
||||
idx = np.array(self.dict[self.end_str])
|
||||
else:
|
||||
assert False, "Unsupport type %s in get_beg_end_flag_idx"\
|
||||
% beg_or_end
|
||||
return idx
|
||||
else:
|
||||
err = "error in get_beg_end_flag_idx when using the loss %s"\
|
||||
% (self.loss_type)
|
||||
assert False, err
|
||||
|
||||
|
||||
def cal_predicts_accuracy(char_ops,
|
||||
preds,
|
||||
preds_lod,
|
||||
labels,
|
||||
labels_lod,
|
||||
is_remove_duplicate=False):
|
||||
acc_num = 0
|
||||
img_num = 0
|
||||
for ino in range(len(labels_lod) - 1):
|
||||
beg_no = preds_lod[ino]
|
||||
end_no = preds_lod[ino + 1]
|
||||
preds_text = preds[beg_no:end_no].reshape(-1)
|
||||
preds_text = char_ops.decode(preds_text, is_remove_duplicate)
|
||||
|
||||
beg_no = labels_lod[ino]
|
||||
end_no = labels_lod[ino + 1]
|
||||
labels_text = labels[beg_no:end_no].reshape(-1)
|
||||
labels_text = char_ops.decode(labels_text, is_remove_duplicate)
|
||||
img_num += 1
|
||||
|
||||
if preds_text == labels_text:
|
||||
acc_num += 1
|
||||
acc = acc_num * 1.0 / img_num
|
||||
return acc, acc_num, img_num
|
||||
|
||||
|
||||
def cal_predicts_accuracy_srn(char_ops,
|
||||
preds,
|
||||
labels,
|
||||
max_text_len,
|
||||
is_debug=False):
|
||||
acc_num = 0
|
||||
img_num = 0
|
||||
|
||||
char_num = char_ops.get_char_num()
|
||||
|
||||
total_len = preds.shape[0]
|
||||
img_num = int(total_len / max_text_len)
|
||||
for i in range(img_num):
|
||||
cur_label = []
|
||||
cur_pred = []
|
||||
for j in range(max_text_len):
|
||||
if labels[j + i * max_text_len] != int(char_num-1): #0
|
||||
cur_label.append(labels[j + i * max_text_len][0])
|
||||
else:
|
||||
break
|
||||
|
||||
for j in range(max_text_len + 1):
|
||||
if j < len(cur_label) and preds[j + i * max_text_len][
|
||||
0] != cur_label[j]:
|
||||
break
|
||||
elif j == len(cur_label) and j == max_text_len:
|
||||
acc_num += 1
|
||||
break
|
||||
elif j == len(cur_label) and preds[j + i * max_text_len][0] == int(char_num-1):
|
||||
acc_num += 1
|
||||
break
|
||||
acc = acc_num * 1.0 / img_num
|
||||
return acc, acc_num, img_num
|
||||
|
||||
|
||||
def convert_rec_attention_infer_res(preds):
|
||||
img_num = preds.shape[0]
|
||||
target_lod = [0]
|
||||
convert_ids = []
|
||||
for ino in range(img_num):
|
||||
end_pos = np.where(preds[ino, :] == 1)[0]
|
||||
if len(end_pos) <= 1:
|
||||
text_list = preds[ino, 1:]
|
||||
else:
|
||||
text_list = preds[ino, 1:end_pos[1]]
|
||||
target_lod.append(target_lod[ino] + len(text_list))
|
||||
convert_ids = convert_ids + list(text_list)
|
||||
convert_ids = np.array(convert_ids)
|
||||
convert_ids = convert_ids.reshape((-1, 1))
|
||||
return convert_ids, target_lod
|
||||
|
||||
|
||||
def convert_rec_label_to_lod(ori_labels):
|
||||
img_num = len(ori_labels)
|
||||
target_lod = [0]
|
||||
convert_ids = []
|
||||
for ino in range(img_num):
|
||||
target_lod.append(target_lod[ino] + len(ori_labels[ino]))
|
||||
convert_ids = convert_ids + list(ori_labels[ino])
|
||||
convert_ids = np.array(convert_ids)
|
||||
convert_ids = convert_ids.reshape((-1, 1))
|
||||
return convert_ids, target_lod
|
|
@ -1,31 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import sys
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_config_params(config, config_name, params):
|
||||
for param in params:
|
||||
if param not in config:
|
||||
err = "param %s didn't find in %s!" % (param, config_name)
|
||||
assert False, err
|
||||
return
|
Loading…
Reference in New Issue