Merge remote-tracking branch 'origin/dygraph' into dygraph
This commit is contained in:
commit
d9c5148fdc
|
@ -66,6 +66,7 @@ class StdTextDrawer(object):
|
|||
corpus_list.append(corpus[0:i])
|
||||
text_input_list.append(text_input)
|
||||
corpus = corpus[i:]
|
||||
i = 0
|
||||
break
|
||||
draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font)
|
||||
char_x += char_size
|
||||
|
@ -78,7 +79,6 @@ class StdTextDrawer(object):
|
|||
|
||||
corpus_list.append(corpus[0:i])
|
||||
text_input_list.append(text_input)
|
||||
corpus = corpus[i:]
|
||||
break
|
||||
|
||||
return corpus_list, text_input_list
|
||||
|
|
|
@ -17,7 +17,7 @@ Global:
|
|||
character_type: ch
|
||||
max_text_length: 25
|
||||
infer_mode: false
|
||||
use_space_char: false
|
||||
use_space_char: true
|
||||
distributed: true
|
||||
save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt
|
||||
|
||||
|
@ -27,48 +27,29 @@ Optimizer:
|
|||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.0005
|
||||
name: Piecewise
|
||||
decay_epochs : [700, 800]
|
||||
values : [0.001, 0.0001]
|
||||
warmup_epoch: 5
|
||||
regularizer:
|
||||
name: L2
|
||||
factor: 1.0e-05
|
||||
factor: 2.0e-05
|
||||
|
||||
Architecture:
|
||||
model_type: &model_type "rec"
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Student:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: rec
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: small
|
||||
small_stride: [1, 2, 2, 2]
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00001
|
||||
Teacher:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: rec
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
model_name: small
|
||||
small_stride: [1, 2, 2, 2]
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
|
@ -76,7 +57,25 @@ Architecture:
|
|||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00001
|
||||
fc_decay: 0.00002
|
||||
Student:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
|
||||
|
||||
Loss:
|
||||
|
|
|
@ -37,6 +37,17 @@ from paddleslim.dygraph.quant import QAT
|
|||
from ppocr.data import build_dataloader
|
||||
|
||||
|
||||
def export_single_model(quanter, model, infer_shape, save_path, logger):
|
||||
quanter.save_quantized_model(
|
||||
model,
|
||||
save_path,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + infer_shape, dtype='float32')
|
||||
])
|
||||
logger.info('inference QAT model is saved to {}'.format(save_path))
|
||||
|
||||
|
||||
def main():
|
||||
############################################################################################################
|
||||
# 1. quantization configs
|
||||
|
@ -76,7 +87,14 @@ def main():
|
|||
# for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
if config['Architecture']["algorithm"] in ["Distillation",
|
||||
]: # distillation model
|
||||
for key in config['Architecture']["Models"]:
|
||||
config['Architecture']["Models"][key]["Head"][
|
||||
'out_channels'] = char_num
|
||||
else: # base rec model
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
# get QAT model
|
||||
|
@ -92,25 +110,30 @@ def main():
|
|||
# build dataloader
|
||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
model_type = config['Architecture']['model_type']
|
||||
# start eval
|
||||
metirc = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class)
|
||||
eval_class, model_type, use_srn)
|
||||
|
||||
logger.info('metric eval ***************')
|
||||
for k, v in metirc.items():
|
||||
for k, v in metric.items():
|
||||
logger.info('{}:{}'.format(k, v))
|
||||
|
||||
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
|
||||
infer_shape = [3, 32, 100] if config['Architecture'][
|
||||
'model_type'] != "det" else [3, 640, 640]
|
||||
|
||||
quanter.save_quantized_model(
|
||||
model,
|
||||
save_path,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + infer_shape, dtype='float32')
|
||||
])
|
||||
logger.info('inference QAT model is saved to {}'.format(save_path))
|
||||
save_path = config["Global"]["save_inference_dir"]
|
||||
|
||||
arch_config = config["Architecture"]
|
||||
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
|
||||
for idx, name in enumerate(model.model_name_list):
|
||||
sub_model_save_path = os.path.join(save_path, name, "inference")
|
||||
export_single_model(quanter, model.model_list[idx], infer_shape,
|
||||
sub_model_save_path, logger)
|
||||
else:
|
||||
save_path = os.path.join(save_path, "inference")
|
||||
export_single_model(quanter, model, infer_shape, save_path, logger)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -109,9 +109,18 @@ def main(config, device, logger, vdl_writer):
|
|||
# for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
if config['Architecture']["algorithm"] in ["Distillation",
|
||||
]: # distillation model
|
||||
for key in config['Architecture']["Models"]:
|
||||
config['Architecture']["Models"][key]["Head"][
|
||||
'out_channels'] = char_num
|
||||
else: # base rec model
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
quanter = QAT(config=quant_config, act_preprocess=PACT)
|
||||
quanter.quantize(model)
|
||||
|
||||
if config['Global']['distributed']:
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
|
@ -132,8 +141,6 @@ def main(config, device, logger, vdl_writer):
|
|||
|
||||
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
|
||||
format(len(train_dataloader), len(valid_dataloader)))
|
||||
quanter = QAT(config=quant_config, act_preprocess=PACT)
|
||||
quanter.quantize(model)
|
||||
|
||||
# start train
|
||||
program.train(config, train_dataloader, valid_dataloader, device, model,
|
||||
|
|
|
@ -0,0 +1,251 @@
|
|||
# 知识蒸馏
|
||||
|
||||
|
||||
## 1. 简介
|
||||
|
||||
### 1.1 知识蒸馏介绍
|
||||
|
||||
近年来,深度神经网络在计算机视觉、自然语言处理等领域被验证是一种极其有效的解决问题的方法。通过构建合适的神经网络,加以训练,最终网络模型的性能指标基本上都会超过传统算法。
|
||||
|
||||
在数据量足够大的情况下,通过合理构建网络模型的方式增加其参数量,可以显著改善模型性能,但是这又带来了模型复杂度急剧提升的问题。大模型在实际场景中使用的成本较高。
|
||||
|
||||
深度神经网络一般有较多的参数冗余,目前有几种主要的方法对模型进行压缩,减小其参数量。如裁剪、量化、知识蒸馏等,其中知识蒸馏是指使用教师模型(teacher model)去指导学生模型(student model)学习特定任务,保证小模型在参数量不变的情况下,得到比较大的性能提升。
|
||||
|
||||
此外,在知识蒸馏任务中,也衍生出了互学习的模型训练方法,论文[Deep Mutual Learning](https://arxiv.org/abs/1706.00384)中指出,使用两个完全相同的模型在训练的过程中互相监督,可以达到比单个模型训练更好的效果。
|
||||
|
||||
### 1.2 PaddleOCR知识蒸馏简介
|
||||
|
||||
无论是大模型蒸馏小模型,还是小模型之间互相学习,更新参数,他们本质上是都是不同模型之间输出或者特征图(feature map)之间的相互监督,区别仅在于 (1) 模型是否需要固定参数。(2) 模型是否需要加载预训练模型。
|
||||
|
||||
对于大模型蒸馏小模型的情况,大模型一般需要加载预训练模型并固定参数;对于小模型之间互相蒸馏的情况,小模型一般都不加载预训练模型,参数也都是可学习的状态。
|
||||
|
||||
在知识蒸馏任务中,不只有2个模型之间进行蒸馏的情况,多个模型之间互相学习的情况也非常普遍。因此在知识蒸馏代码框架中,也有必要支持该种类别的蒸馏方法。
|
||||
|
||||
PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要的特点:
|
||||
- 支持任意网络的互相学习,不要求子网络结构完全一致或者具有预训练模型;同时子网络数量也没有任何限制,只需要在配置文件中添加即可。
|
||||
- 支持loss函数通过配置文件任意配置,不仅可以使用某种loss,也可以使用多种loss的组合
|
||||
- 支持知识蒸馏训练、预测、评估与导出等所有模型相关的环境,方便使用与部署。
|
||||
|
||||
|
||||
通过知识蒸馏,在中英文通用文字识别任务中,不增加任何预测耗时的情况下,可以给模型带来3%以上的精度提升,结合学习率调整策略以及模型结构微调策略,最终提升提升超过5%。
|
||||
|
||||
|
||||
|
||||
## 2. 配置文件解析
|
||||
|
||||
在知识蒸馏训练的过程中,数据预处理、优化器、学习率、全局的一些属性没有任何变化。模型结构、损失函数、后处理、指标计算等模块的配置文件需要进行微调。
|
||||
|
||||
下面以识别与检测的知识蒸馏配置文件为例,对知识蒸馏的训练与配置进行解析。
|
||||
|
||||
### 2.1 识别配置文件解析
|
||||
|
||||
配置文件在[rec_chinese_lite_train_distillation_v2.1.yml](../../configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml)。
|
||||
|
||||
#### 2.1.1 模型结构
|
||||
|
||||
知识蒸馏任务中,模型结构配置如下所示。
|
||||
|
||||
```yaml
|
||||
Architecture:
|
||||
model_type: &model_type "rec" # 模型类别,rec、det等,每个子网络的的模型类别都与
|
||||
name: DistillationModel # 结构名称,蒸馏任务中,为DistillationModel,用于构建对应的结构
|
||||
algorithm: Distillation # 算法名称
|
||||
Models: # 模型,包含子网络的配置信息
|
||||
Teacher: # 子网络名称,至少需要包含`pretrained`与`freeze_params`信息,其他的参数为子网络的构造参数
|
||||
pretrained: # 该子网络是否需要加载预训练模型
|
||||
freeze_params: false # 是否需要固定参数
|
||||
return_all_feats: true # 子网络的参数,表示是否需要返回所有的features,如果为False,则只返回最后的输出
|
||||
model_type: *model_type # 模型类别
|
||||
algorithm: CRNN # 子网络的算法名称,该子网络剩余参与均为构造参数,与普通的模型训练配置一致
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
Student: # 另外一个子网络,这里给的是DML的蒸馏示例,两个子网络结构相同,均需要学习参数
|
||||
pretrained: # 下面的组网参数同上
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
```
|
||||
|
||||
当然,这里如果希望添加更多的子网络进行训练,也可以按照`Student`与`Teacher`的添加方式,在配置文件中添加相应的字段。比如说如果希望有3个模型互相监督,共同训练,那么`Architecture`可以写为如下格式。
|
||||
|
||||
```yaml
|
||||
Architecture:
|
||||
model_type: &model_type "rec"
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Teacher:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
Student:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
Student2: # 知识蒸馏任务中引入的新的子网络,其他部分与上述配置相同
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
```
|
||||
|
||||
最终该模型训练时,包含3个子网络:`Teacher`, `Student`, `Student2`。
|
||||
|
||||
蒸馏模型`DistillationModel`类的具体实现代码可以参考[distillation_model.py](../../ppocr/modeling/architectures/distillation_model.py)。
|
||||
|
||||
最终模型`forward`输出为一个字典,key为所有的子网络名称,例如这里为`Student`与`Teacher`,value为对应子网络的输出,可以为`Tensor`(只返回该网络的最后一层)和`dict`(也返回了中间的特征信息)。
|
||||
|
||||
在识别任务中,为了添加更多损失函数,保证蒸馏方法的可扩展性,将每个子网络的输出保存为`dict`,其中包含子模块输出。以该识别模型为例,每个子网络的输出结果均为`dict`,key包含`backbone_out`,`neck_out`, `head_out`,`value`为对应模块的tensor,最终对于上述配置文件,`DistillationModel`的输出格式如下。
|
||||
|
||||
```json
|
||||
{
|
||||
"Teacher": {
|
||||
"backbone_out": tensor,
|
||||
"neck_out": tensor,
|
||||
"head_out": tensor,
|
||||
},
|
||||
"Student": {
|
||||
"backbone_out": tensor,
|
||||
"neck_out": tensor,
|
||||
"head_out": tensor,
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 2.1.2 损失函数
|
||||
|
||||
知识蒸馏任务中,损失函数配置如下所示。
|
||||
|
||||
```yaml
|
||||
Loss:
|
||||
name: CombinedLoss # 损失函数名称,基于改名称,构建用于损失函数的类
|
||||
loss_config_list: # 损失函数配置文件列表,为CombinedLoss的必备函数
|
||||
- DistillationCTCLoss: # 基于蒸馏的CTC损失函数,继承自标准的CTC loss
|
||||
weight: 1.0 # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
|
||||
model_name_list: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss
|
||||
key: head_out # 取子网络输出dict中,该key对应的tensor
|
||||
- DistillationDMLLoss: # 蒸馏的DML损失函数,继承自标准的DMLLoss
|
||||
weight: 1.0 # 权重
|
||||
act: "softmax" # 激活函数,对输入使用激活函数处理,可以为softmax, sigmoid或者为None,默认为None
|
||||
model_name_pairs: # 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充
|
||||
- ["Student", "Teacher"]
|
||||
key: head_out # 取子网络输出dict中,该key对应的tensor
|
||||
- DistillationDistanceLoss: # 蒸馏的距离损失函数
|
||||
weight: 1.0 # 权重
|
||||
mode: "l2" # 距离计算方法,目前支持l1, l2, smooth_l1
|
||||
model_name_pairs: # 用于计算distance loss的子网络名称对
|
||||
- ["Student", "Teacher"]
|
||||
key: backbone_out # 取子网络输出dict中,该key对应的tensor
|
||||
```
|
||||
|
||||
上述损失函数中,所有的蒸馏损失函数均继承自标准的损失函数类,主要功能为: 对蒸馏模型的输出进行解析,找到用于计算损失的中间节点(tensor),再使用标准的损失函数类去计算。
|
||||
|
||||
以上述配置为例,最终蒸馏训练的损失函数包含下面3个部分。
|
||||
|
||||
- `Student`和`Teacher`的最终输出(`head_out`)与gt的CTC loss,权重为1。在这里因为2个子网络都需要更新参数,因此2者都需要计算与g的loss。
|
||||
- `Student`和`Teacher`的最终输出(`head_out`)之间的DML loss,权重为1。
|
||||
- `Student`和`Teacher`的骨干网络输出(`backbone_out`)之间的l2 loss,权重为1。
|
||||
|
||||
关于`CombinedLoss`更加具体的实现可以参考: [combined_loss.py](../../ppocr/losses/combined_loss.py#L23)。关于`DistillationCTCLoss`等蒸馏损失函数更加具体的实现可以参考[distillation_loss.py](../../ppocr/losses/distillation_loss.py)。
|
||||
|
||||
|
||||
#### 2.1.3 后处理
|
||||
|
||||
知识蒸馏任务中,后处理配置如下所示。
|
||||
|
||||
```yaml
|
||||
PostProcess:
|
||||
name: DistillationCTCLabelDecode # 蒸馏任务的CTC解码后处理,继承自标准的CTCLabelDecode类
|
||||
model_name: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,进行解码
|
||||
key: head_out # 取子网络输出dict中,该key对应的tensor
|
||||
```
|
||||
|
||||
以上述配置为例,最终会同时计算`Student`和`Teahcer` 2个子网络的CTC解码输出,返回一个`dict`,`key`为用于处理的子网络名称,`value`为用于处理的子网络列表。
|
||||
|
||||
关于`DistillationCTCLabelDecode`更加具体的实现可以参考: [rec_postprocess.py](../../ppocr/postprocess/rec_postprocess.py#L128)
|
||||
|
||||
|
||||
#### 2.1.4 指标计算
|
||||
|
||||
知识蒸馏任务中,指标计算配置如下所示。
|
||||
|
||||
```yaml
|
||||
Metric:
|
||||
name: DistillationMetric # 蒸馏任务的CTC解码后处理,继承自标准的CTCLabelDecode类
|
||||
base_metric_name: RecMetric # 指标计算的基类,对于模型的输出,会基于该类,计算指标
|
||||
main_indicator: acc # 指标的名称
|
||||
key: "Student" # 选取该子网络的 main_indicator 作为作为保存保存best model的判断标准
|
||||
```
|
||||
|
||||
以上述配置为例,最终会使用`Student`子网络的acc指标作为保存best model的判断指标,同时,日志中也会打印出所有子网络的acc指标。
|
||||
|
||||
关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)。
|
||||
|
||||
|
||||
### 2.2 检测配置文件解析
|
||||
|
||||
* coming soon!
|
|
@ -23,6 +23,7 @@ from .random_crop_data import EastRandomCropData, PSERandomCrop
|
|||
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
from .operators import *
|
||||
from .label_ops import *
|
||||
|
||||
|
|
|
@ -0,0 +1,166 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import cv2
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
from ppocr.data.imaug.iaa_augment import IaaAugment
|
||||
from ppocr.data.imaug.random_crop_data import is_poly_outside_rect
|
||||
from tools.infer.utility import get_rotate_crop_image
|
||||
|
||||
|
||||
class CopyPaste(object):
|
||||
def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs):
|
||||
self.ext_data_num = 1
|
||||
self.objects_paste_ratio = objects_paste_ratio
|
||||
self.limit_paste = limit_paste
|
||||
augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}]
|
||||
self.aug = IaaAugment(augmenter_args)
|
||||
|
||||
def __call__(self, data):
|
||||
src_img = data['image']
|
||||
src_polys = data['polys'].tolist()
|
||||
src_ignores = data['ignore_tags'].tolist()
|
||||
ext_data = data['ext_data'][0]
|
||||
ext_image = ext_data['image']
|
||||
ext_polys = ext_data['polys']
|
||||
ext_ignores = ext_data['ignore_tags']
|
||||
|
||||
indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
|
||||
select_num = max(
|
||||
1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
|
||||
|
||||
random.shuffle(indexs)
|
||||
select_idxs = indexs[:select_num]
|
||||
select_polys = ext_polys[select_idxs]
|
||||
select_ignores = ext_ignores[select_idxs]
|
||||
|
||||
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
|
||||
ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
|
||||
src_img = Image.fromarray(src_img).convert('RGBA')
|
||||
for poly, tag in zip(select_polys, select_ignores):
|
||||
box_img = get_rotate_crop_image(ext_image, poly)
|
||||
|
||||
src_img, box = self.paste_img(src_img, box_img, src_polys)
|
||||
if box is not None:
|
||||
src_polys.append(box)
|
||||
src_ignores.append(tag)
|
||||
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
|
||||
h, w = src_img.shape[:2]
|
||||
src_polys = np.array(src_polys)
|
||||
src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w)
|
||||
src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
|
||||
data['image'] = src_img
|
||||
data['polys'] = src_polys
|
||||
data['ignore_tags'] = np.array(src_ignores)
|
||||
return data
|
||||
|
||||
def paste_img(self, src_img, box_img, src_polys):
|
||||
box_img_pil = Image.fromarray(box_img).convert('RGBA')
|
||||
src_w, src_h = src_img.size
|
||||
box_w, box_h = box_img_pil.size
|
||||
|
||||
angle = np.random.randint(0, 360)
|
||||
box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]])
|
||||
box = rotate_bbox(box_img, box, angle)[0]
|
||||
box_img_pil = box_img_pil.rotate(angle, expand=1)
|
||||
box_w, box_h = box_img_pil.width, box_img_pil.height
|
||||
if src_w - box_w < 0 or src_h - box_h < 0:
|
||||
return src_img, None
|
||||
|
||||
paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w,
|
||||
src_h - box_h)
|
||||
if paste_x is None:
|
||||
return src_img, None
|
||||
box[:, 0] += paste_x
|
||||
box[:, 1] += paste_y
|
||||
r, g, b, A = box_img_pil.split()
|
||||
src_img.paste(box_img_pil, (paste_x, paste_y), mask=A)
|
||||
|
||||
return src_img, box
|
||||
|
||||
def select_coord(self, src_polys, box, endx, endy):
|
||||
if self.limit_paste:
|
||||
xmin, ymin, xmax, ymax = box[:, 0].min(), box[:, 1].min(
|
||||
), box[:, 0].max(), box[:, 1].max()
|
||||
for _ in range(50):
|
||||
paste_x = random.randint(0, endx)
|
||||
paste_y = random.randint(0, endy)
|
||||
xmin1 = xmin + paste_x
|
||||
xmax1 = xmax + paste_x
|
||||
ymin1 = ymin + paste_y
|
||||
ymax1 = ymax + paste_y
|
||||
|
||||
num_poly_in_rect = 0
|
||||
for poly in src_polys:
|
||||
if not is_poly_outside_rect(poly, xmin1, ymin1,
|
||||
xmax1 - xmin1, ymax1 - ymin1):
|
||||
num_poly_in_rect += 1
|
||||
break
|
||||
if num_poly_in_rect == 0:
|
||||
return paste_x, paste_y
|
||||
return None, None
|
||||
else:
|
||||
paste_x = random.randint(0, endx)
|
||||
paste_y = random.randint(0, endy)
|
||||
return paste_x, paste_y
|
||||
|
||||
|
||||
def get_union(pD, pG):
|
||||
return Polygon(pD).union(Polygon(pG)).area
|
||||
|
||||
|
||||
def get_intersection_over_union(pD, pG):
|
||||
return get_intersection(pD, pG) / get_union(pD, pG)
|
||||
|
||||
|
||||
def get_intersection(pD, pG):
|
||||
return Polygon(pD).intersection(Polygon(pG)).area
|
||||
|
||||
|
||||
def rotate_bbox(img, text_polys, angle, scale=1):
|
||||
"""
|
||||
from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
|
||||
Args:
|
||||
img: np.ndarray
|
||||
text_polys: np.ndarray N*4*2
|
||||
angle: int
|
||||
scale: int
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
w = img.shape[1]
|
||||
h = img.shape[0]
|
||||
|
||||
rangle = np.deg2rad(angle)
|
||||
nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
|
||||
nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
|
||||
rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
|
||||
rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
|
||||
rot_mat[0, 2] += rot_move[0]
|
||||
rot_mat[1, 2] += rot_move[1]
|
||||
|
||||
# ---------------------- rotate box ----------------------
|
||||
rot_text_polys = list()
|
||||
for bbox in text_polys:
|
||||
point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
|
||||
point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
|
||||
point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
|
||||
point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
|
||||
rot_text_polys.append([point1, point2, point3, point4])
|
||||
return np.array(rot_text_polys, dtype=np.float32)
|
|
@ -14,6 +14,7 @@
|
|||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import traceback
|
||||
from paddle.io import Dataset
|
||||
|
||||
from .imaug import transform, create_operators
|
||||
|
@ -69,6 +70,36 @@ class SimpleDataSet(Dataset):
|
|||
random.shuffle(self.data_lines)
|
||||
return
|
||||
|
||||
def get_ext_data(self):
|
||||
ext_data_num = 0
|
||||
for op in self.ops:
|
||||
if hasattr(op, 'ext_data_num'):
|
||||
ext_data_num = getattr(op, 'ext_data_num')
|
||||
break
|
||||
load_data_ops = self.ops[:2]
|
||||
ext_data = []
|
||||
|
||||
while len(ext_data) < ext_data_num:
|
||||
file_idx = self.data_idx_order_list[np.random.randint(self.__len__(
|
||||
))]
|
||||
data_line = self.data_lines[file_idx]
|
||||
data_line = data_line.decode('utf-8')
|
||||
substr = data_line.strip("\n").split(self.delimiter)
|
||||
file_name = substr[0]
|
||||
label = substr[1]
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
data = {'img_path': img_path, 'label': label}
|
||||
if not os.path.exists(img_path):
|
||||
continue
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
data = transform(data, load_data_ops)
|
||||
if data is None:
|
||||
continue
|
||||
ext_data.append(data)
|
||||
return ext_data
|
||||
|
||||
def __getitem__(self, idx):
|
||||
file_idx = self.data_idx_order_list[idx]
|
||||
data_line = self.data_lines[file_idx]
|
||||
|
@ -84,11 +115,13 @@ class SimpleDataSet(Dataset):
|
|||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
data['ext_data'] = self.get_ext_data()
|
||||
outs = transform(data, self.ops)
|
||||
except Exception as e:
|
||||
except:
|
||||
error_meg = traceback.format_exc()
|
||||
self.logger.error(
|
||||
"When parsing line {}, error happened with msg: {}".format(
|
||||
data_line, e))
|
||||
data_line, error_meg))
|
||||
outs = None
|
||||
if outs is None:
|
||||
# during evaluation, we should fix the idx to get same results for many times of evaluation.
|
||||
|
|
|
@ -12,33 +12,36 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = ['build_backbone']
|
||||
__all__ = ["build_backbone"]
|
||||
|
||||
|
||||
def build_backbone(config, model_type):
|
||||
if model_type == 'det':
|
||||
if model_type == "det":
|
||||
from .det_mobilenet_v3 import MobileNetV3
|
||||
from .det_resnet_vd import ResNet
|
||||
from .det_resnet_vd_sast import ResNet_SAST
|
||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST']
|
||||
elif model_type == 'rec' or model_type == 'cls':
|
||||
support_dict = ["MobileNetV3", "ResNet", "ResNet_SAST"]
|
||||
elif model_type == "rec" or model_type == "cls":
|
||||
from .rec_mobilenet_v3 import MobileNetV3
|
||||
from .rec_resnet_vd import ResNet
|
||||
from .rec_resnet_fpn import ResNetFPN
|
||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
|
||||
elif model_type == 'e2e':
|
||||
from .rec_mv1_enhance import MobileNetV1Enhance
|
||||
support_dict = [
|
||||
"MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN"
|
||||
]
|
||||
elif model_type == "e2e":
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
support_dict = ['ResNet']
|
||||
support_dict = ["ResNet"]
|
||||
elif model_type == "table":
|
||||
from .table_resnet_vd import ResNet
|
||||
from .table_mobilenet_v3 import MobileNetV3
|
||||
support_dict = ['ResNet', 'MobileNetV3']
|
||||
support_dict = ["ResNet", "MobileNetV3"]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
module_name = config.pop('name')
|
||||
module_name = config.pop("name")
|
||||
assert module_name in support_dict, Exception(
|
||||
'when model typs is {}, backbone only support {}'.format(model_type,
|
||||
"when model typs is {}, backbone only support {}".format(model_type,
|
||||
support_dict))
|
||||
module_class = eval(module_name)(**config)
|
||||
return module_class
|
||||
|
|
|
@ -0,0 +1,256 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
|
||||
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
|
||||
from paddle.nn.initializer import KaimingNormal
|
||||
import math
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import ParamAttr, reshape, transpose, concat, split
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
|
||||
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
|
||||
from paddle.nn.initializer import KaimingNormal
|
||||
import math
|
||||
from paddle.nn.functional import hardswish, hardsigmoid
|
||||
from paddle.regularizer import L2Decay
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
filter_size,
|
||||
num_filters,
|
||||
stride,
|
||||
padding,
|
||||
channels=None,
|
||||
num_groups=1,
|
||||
act='hard_swish'):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self._conv = Conv2D(
|
||||
in_channels=num_channels,
|
||||
out_channels=num_filters,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=num_groups,
|
||||
weight_attr=ParamAttr(initializer=KaimingNormal()),
|
||||
bias_attr=False)
|
||||
|
||||
self._batch_norm = BatchNorm(
|
||||
num_filters,
|
||||
act=act,
|
||||
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
||||
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
|
||||
|
||||
class DepthwiseSeparable(nn.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters1,
|
||||
num_filters2,
|
||||
num_groups,
|
||||
stride,
|
||||
scale,
|
||||
dw_size=3,
|
||||
padding=1,
|
||||
use_se=False):
|
||||
super(DepthwiseSeparable, self).__init__()
|
||||
self.use_se = use_se
|
||||
self._depthwise_conv = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=int(num_filters1 * scale),
|
||||
filter_size=dw_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
num_groups=int(num_groups * scale))
|
||||
if use_se:
|
||||
self._se = SEModule(int(num_filters1 * scale))
|
||||
self._pointwise_conv = ConvBNLayer(
|
||||
num_channels=int(num_filters1 * scale),
|
||||
filter_size=1,
|
||||
num_filters=int(num_filters2 * scale),
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._depthwise_conv(inputs)
|
||||
if self.use_se:
|
||||
y = self._se(y)
|
||||
y = self._pointwise_conv(y)
|
||||
return y
|
||||
|
||||
|
||||
class MobileNetV1Enhance(nn.Layer):
|
||||
def __init__(self, in_channels=3, scale=0.5, **kwargs):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.block_list = []
|
||||
|
||||
self.conv1 = ConvBNLayer(
|
||||
num_channels=3,
|
||||
filter_size=3,
|
||||
channels=3,
|
||||
num_filters=int(32 * scale),
|
||||
stride=2,
|
||||
padding=1)
|
||||
|
||||
conv2_1 = DepthwiseSeparable(
|
||||
num_channels=int(32 * scale),
|
||||
num_filters1=32,
|
||||
num_filters2=64,
|
||||
num_groups=32,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv2_1)
|
||||
|
||||
conv2_2 = DepthwiseSeparable(
|
||||
num_channels=int(64 * scale),
|
||||
num_filters1=64,
|
||||
num_filters2=128,
|
||||
num_groups=64,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv2_2)
|
||||
|
||||
conv3_1 = DepthwiseSeparable(
|
||||
num_channels=int(128 * scale),
|
||||
num_filters1=128,
|
||||
num_filters2=128,
|
||||
num_groups=128,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv3_1)
|
||||
|
||||
conv3_2 = DepthwiseSeparable(
|
||||
num_channels=int(128 * scale),
|
||||
num_filters1=128,
|
||||
num_filters2=256,
|
||||
num_groups=128,
|
||||
stride=(2, 1),
|
||||
scale=scale)
|
||||
self.block_list.append(conv3_2)
|
||||
|
||||
conv4_1 = DepthwiseSeparable(
|
||||
num_channels=int(256 * scale),
|
||||
num_filters1=256,
|
||||
num_filters2=256,
|
||||
num_groups=256,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv4_1)
|
||||
|
||||
conv4_2 = DepthwiseSeparable(
|
||||
num_channels=int(256 * scale),
|
||||
num_filters1=256,
|
||||
num_filters2=512,
|
||||
num_groups=256,
|
||||
stride=(2, 1),
|
||||
scale=scale)
|
||||
self.block_list.append(conv4_2)
|
||||
|
||||
for _ in range(5):
|
||||
conv5 = DepthwiseSeparable(
|
||||
num_channels=int(512 * scale),
|
||||
num_filters1=512,
|
||||
num_filters2=512,
|
||||
num_groups=512,
|
||||
stride=1,
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
scale=scale,
|
||||
use_se=False)
|
||||
self.block_list.append(conv5)
|
||||
|
||||
conv5_6 = DepthwiseSeparable(
|
||||
num_channels=int(512 * scale),
|
||||
num_filters1=512,
|
||||
num_filters2=1024,
|
||||
num_groups=512,
|
||||
stride=(2, 1),
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
scale=scale,
|
||||
use_se=True)
|
||||
self.block_list.append(conv5_6)
|
||||
|
||||
conv6 = DepthwiseSeparable(
|
||||
num_channels=int(1024 * scale),
|
||||
num_filters1=1024,
|
||||
num_filters2=1024,
|
||||
num_groups=1024,
|
||||
stride=1,
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
use_se=True,
|
||||
scale=scale)
|
||||
self.block_list.append(conv6)
|
||||
|
||||
self.block_list = nn.Sequential(*self.block_list)
|
||||
|
||||
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||
self.out_channels = int(1024 * scale)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv1(inputs)
|
||||
y = self.block_list(y)
|
||||
y = self.pool(y)
|
||||
return y
|
||||
|
||||
|
||||
class SEModule(nn.Layer):
|
||||
def __init__(self, channel, reduction=4):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = AdaptiveAvgPool2D(1)
|
||||
self.conv1 = Conv2D(
|
||||
in_channels=channel,
|
||||
out_channels=channel // reduction,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(),
|
||||
bias_attr=ParamAttr())
|
||||
self.conv2 = Conv2D(
|
||||
in_channels=channel // reduction,
|
||||
out_channels=channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(),
|
||||
bias_attr=ParamAttr())
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.avg_pool(inputs)
|
||||
outputs = self.conv1(outputs)
|
||||
outputs = F.relu(outputs)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = hardsigmoid(outputs)
|
||||
return paddle.multiply(x=inputs, y=outputs)
|
|
@ -230,15 +230,8 @@ class GridGenerator(nn.Layer):
|
|||
def build_inv_delta_C_paddle(self, C):
|
||||
""" Return inv_delta_C which is needed to calculate T """
|
||||
F = self.F
|
||||
hat_C = paddle.zeros((F, F), dtype='float64') # F x F
|
||||
for i in range(0, F):
|
||||
for j in range(i, F):
|
||||
if i == j:
|
||||
hat_C[i, j] = 1
|
||||
else:
|
||||
r = paddle.norm(C[i] - C[j])
|
||||
hat_C[i, j] = r
|
||||
hat_C[j, i] = r
|
||||
hat_eye = paddle.eye(F, dtype='float64') # F x F
|
||||
hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
|
||||
hat_C = (hat_C**2) * paddle.log(hat_C)
|
||||
delta_C = paddle.concat( # F+3 x F+3
|
||||
[
|
||||
|
|
|
@ -25,7 +25,7 @@ import paddle
|
|||
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
|
||||
__all__ = ['init_model', 'save_model', 'load_dygraph_params']
|
||||
|
||||
|
||||
def _mkdir_if_not_exist(path, logger):
|
||||
|
@ -89,6 +89,34 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
|
|||
return best_model_dict
|
||||
|
||||
|
||||
def load_dygraph_params(config, model, logger, optimizer):
|
||||
ckp = config['Global']['checkpoints']
|
||||
if ckp and os.path.exists(ckp + ".pdparams"):
|
||||
pre_best_model_dict = init_model(config, model, optimizer)
|
||||
return pre_best_model_dict
|
||||
else:
|
||||
pm = config['Global']['pretrained_model']
|
||||
if pm is None:
|
||||
return {}
|
||||
if not os.path.exists(pm) and not os.path.exists(pm + ".pdparams"):
|
||||
logger.info(f"The pretrained_model {pm} does not exists!")
|
||||
return {}
|
||||
pm = pm if pm.endswith('.pdparams') else pm + '.pdparams'
|
||||
params = paddle.load(pm)
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||
new_state_dict[k1] = params[k2]
|
||||
else:
|
||||
logger.info(
|
||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||
)
|
||||
model.set_state_dict(new_state_dict)
|
||||
logger.info(f"loaded pretrained_model successful from {pm}")
|
||||
return {}
|
||||
|
||||
|
||||
def save_model(model,
|
||||
optimizer,
|
||||
model_path,
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
model_name:ocr_det
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
Global.auto_cast:False
|
||||
Global.epoch_num:10
|
||||
Global.save_model_dir:./output/
|
||||
Global.save_inference_dir:./output/
|
||||
Train.loader.batch_size_per_card:
|
||||
Global.use_gpu
|
||||
Global.pretrained_model
|
||||
|
||||
trainer:norm|pact
|
||||
norm_train:tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
quant_train:deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/det_mv3_db_v2.0_train/best_accuracy
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
|
||||
eval:tools/eval.py -c configs/det/det_mv3_db.yml -o
|
||||
|
||||
norm_export:tools/export_model.py -c configs/det/det_mv3_db.yml -o
|
||||
quant_export:deploy/slim/quantization/export_model.py -c configs/det/det_mv3_db.yml -o
|
||||
fpgm_export:deploy/slim/prune/export_prune_model.py
|
||||
distill_export:null
|
||||
|
||||
inference:tools/infer/predict_det.py
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:True|False
|
||||
--cpu_threads:1|6
|
||||
--rec_batch_num:1
|
||||
--use_tensorrt:True|False
|
||||
--precision:fp32|fp16|int8
|
||||
--det_model_dir
|
||||
--image_dir
|
||||
--save_log_path
|
||||
|
|
@ -0,0 +1,138 @@
|
|||
#!/bin/bash
|
||||
FILENAME=$1
|
||||
# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer']
|
||||
MODE=$2
|
||||
|
||||
dataline=$(cat ${FILENAME})
|
||||
|
||||
# parser params
|
||||
IFS=$'\n'
|
||||
lines=(${dataline})
|
||||
function func_parser_key(){
|
||||
strs=$1
|
||||
IFS=":"
|
||||
array=(${strs})
|
||||
tmp=${array[0]}
|
||||
echo ${tmp}
|
||||
}
|
||||
function func_parser_value(){
|
||||
strs=$1
|
||||
IFS=":"
|
||||
array=(${strs})
|
||||
tmp=${array[1]}
|
||||
echo ${tmp}
|
||||
}
|
||||
IFS=$'\n'
|
||||
# The training params
|
||||
model_name=$(func_parser_value "${lines[0]}")
|
||||
train_model_list=$(func_parser_value "${lines[0]}")
|
||||
trainer_list=$(func_parser_value "${lines[10]}")
|
||||
|
||||
# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer']
|
||||
MODE=$2
|
||||
# prepare pretrained weights and dataset
|
||||
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
|
||||
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar
|
||||
cd pretrain_models && tar xf det_mv3_db_v2.0_train.tar && cd ../
|
||||
|
||||
if [ ${MODE} = "lite_train_infer" ];then
|
||||
# pretrain lite train data
|
||||
rm -rf ./train_data/icdar2015
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar
|
||||
cd ./train_data/ && tar xf icdar2015_lite.tar
|
||||
ln -s ./icdar2015_lite ./icdar2015
|
||||
cd ../
|
||||
epoch=10
|
||||
eval_batch_step=10
|
||||
elif [ ${MODE} = "whole_train_infer" ];then
|
||||
rm -rf ./train_data/icdar2015
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar
|
||||
cd ./train_data/ && tar xf icdar2015.tar && cd ../
|
||||
epoch=500
|
||||
eval_batch_step=200
|
||||
elif [ ${MODE} = "whole_infer" ];then
|
||||
rm -rf ./train_data/icdar2015
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_infer.tar
|
||||
cd ./train_data/ && tar xf icdar2015_infer.tar
|
||||
ln -s ./icdar2015_infer ./icdar2015
|
||||
cd ../
|
||||
epoch=10
|
||||
eval_batch_step=10
|
||||
else
|
||||
rm -rf ./train_data/icdar2015
|
||||
wget -nc -P ./train_data https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
|
||||
if [ ${model_name} = "ocr_det" ]; then
|
||||
eval_model_name="ch_ppocr_mobile_v2.0_det_train"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar
|
||||
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
|
||||
else
|
||||
eval_model_name="ch_ppocr_mobile_v2.0_rec_train"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar
|
||||
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
IFS='|'
|
||||
for train_model in ${train_model_list[*]}; do
|
||||
if [ ${train_model} = "ocr_det" ];then
|
||||
model_name="ocr_det"
|
||||
yml_file="configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
|
||||
cd ./inference && tar xf ch_det_data_50.tar && cd ../
|
||||
img_dir="./inference/ch_det_data_50/all-sum-510"
|
||||
data_dir=./inference/ch_det_data_50/
|
||||
data_label_file=[./inference/ch_det_data_50/test_gt_50.txt]
|
||||
elif [ ${train_model} = "ocr_rec" ];then
|
||||
model_name="ocr_rec"
|
||||
yml_file="configs/rec/rec_mv3_none_bilstm_ctc.yml"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_rec_data_200.tar
|
||||
cd ./inference && tar xf ch_rec_data_200.tar && cd ../
|
||||
img_dir="./inference/ch_rec_data_200/"
|
||||
fi
|
||||
|
||||
# eval
|
||||
for slim_trainer in ${trainer_list[*]}; do
|
||||
if [ ${slim_trainer} = "norm" ]; then
|
||||
if [ ${model_name} = "ocr_det" ]; then
|
||||
eval_model_name="ch_ppocr_mobile_v2.0_det_train"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar
|
||||
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
|
||||
else
|
||||
eval_model_name="ch_ppocr_mobile_v2.0_rec_train"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar
|
||||
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
|
||||
fi
|
||||
elif [ ${slim_trainer} = "pact" ]; then
|
||||
if [ ${model_name} = "ocr_det" ]; then
|
||||
eval_model_name="ch_ppocr_mobile_v2.0_det_quant_train"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_quant_train.tar
|
||||
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
|
||||
else
|
||||
eval_model_name="ch_ppocr_mobile_v2.0_rec_quant_train"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_rec_quant_train.tar
|
||||
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
|
||||
fi
|
||||
elif [ ${slim_trainer} = "distill" ]; then
|
||||
if [ ${model_name} = "ocr_det" ]; then
|
||||
eval_model_name="ch_ppocr_mobile_v2.0_det_distill_train"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_distill_train.tar
|
||||
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
|
||||
else
|
||||
eval_model_name="ch_ppocr_mobile_v2.0_rec_distill_train"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_rec_distill_train.tar
|
||||
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
|
||||
fi
|
||||
elif [ ${slim_trainer} = "fpgm" ]; then
|
||||
if [ ${model_name} = "ocr_det" ]; then
|
||||
eval_model_name="ch_ppocr_mobile_v2.0_det_prune_train"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_train.tar
|
||||
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
|
||||
else
|
||||
eval_model_name="ch_ppocr_mobile_v2.0_rec_prune_train"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_rec_prune_train.tar
|
||||
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
|
||||
fi
|
||||
fi
|
||||
done
|
||||
done
|
|
@ -0,0 +1,221 @@
|
|||
#!/bin/bash
|
||||
FILENAME=$1
|
||||
# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer']
|
||||
MODE=$2
|
||||
|
||||
dataline=$(cat ${FILENAME})
|
||||
|
||||
# parser params
|
||||
IFS=$'\n'
|
||||
lines=(${dataline})
|
||||
function func_parser_key(){
|
||||
strs=$1
|
||||
IFS=":"
|
||||
array=(${strs})
|
||||
tmp=${array[0]}
|
||||
echo ${tmp}
|
||||
}
|
||||
function func_parser_value(){
|
||||
strs=$1
|
||||
IFS=":"
|
||||
array=(${strs})
|
||||
tmp=${array[1]}
|
||||
echo ${tmp}
|
||||
}
|
||||
function status_check(){
|
||||
last_status=$1 # the exit code
|
||||
run_command=$2
|
||||
run_log=$3
|
||||
if [ $last_status -eq 0 ]; then
|
||||
echo -e "\033[33m Run successfully with command - ${run_command}! \033[0m" | tee -a ${run_log}
|
||||
else
|
||||
echo -e "\033[33m Run failed with command - ${run_command}! \033[0m" | tee -a ${run_log}
|
||||
fi
|
||||
}
|
||||
|
||||
IFS=$'\n'
|
||||
# The training params
|
||||
model_name=$(func_parser_value "${lines[0]}")
|
||||
python=$(func_parser_value "${lines[1]}")
|
||||
gpu_list=$(func_parser_value "${lines[2]}")
|
||||
autocast_list=$(func_parser_value "${lines[3]}")
|
||||
autocast_key=$(func_parser_key "${lines[3]}")
|
||||
epoch_key=$(func_parser_key "${lines[4]}")
|
||||
save_model_key=$(func_parser_key "${lines[5]}")
|
||||
save_infer_key=$(func_parser_key "${lines[6]}")
|
||||
train_batch_key=$(func_parser_key "${lines[7]}")
|
||||
train_use_gpu_key=$(func_parser_key "${lines[8]}")
|
||||
pretrain_model_key=$(func_parser_key "${lines[9]}")
|
||||
|
||||
trainer_list=$(func_parser_value "${lines[10]}")
|
||||
norm_trainer=$(func_parser_value "${lines[11]}")
|
||||
pact_trainer=$(func_parser_value "${lines[12]}")
|
||||
fpgm_trainer=$(func_parser_value "${lines[13]}")
|
||||
distill_trainer=$(func_parser_value "${lines[14]}")
|
||||
|
||||
eval_py=$(func_parser_value "${lines[15]}")
|
||||
norm_export=$(func_parser_value "${lines[16]}")
|
||||
pact_export=$(func_parser_value "${lines[17]}")
|
||||
fpgm_export=$(func_parser_value "${lines[18]}")
|
||||
distill_export=$(func_parser_value "${lines[19]}")
|
||||
|
||||
inference_py=$(func_parser_value "${lines[20]}")
|
||||
use_gpu_key=$(func_parser_key "${lines[21]}")
|
||||
use_gpu_list=$(func_parser_value "${lines[21]}")
|
||||
use_mkldnn_key=$(func_parser_key "${lines[22]}")
|
||||
use_mkldnn_list=$(func_parser_value "${lines[22]}")
|
||||
cpu_threads_key=$(func_parser_key "${lines[23]}")
|
||||
cpu_threads_list=$(func_parser_value "${lines[23]}")
|
||||
batch_size_key=$(func_parser_key "${lines[24]}")
|
||||
batch_size_list=$(func_parser_value "${lines[24]}")
|
||||
use_trt_key=$(func_parser_key "${lines[25]}")
|
||||
use_trt_list=$(func_parser_value "${lines[25]}")
|
||||
precision_key=$(func_parser_key "${lines[26]}")
|
||||
precision_list=$(func_parser_value "${lines[26]}")
|
||||
model_dir_key=$(func_parser_key "${lines[27]}")
|
||||
image_dir_key=$(func_parser_key "${lines[28]}")
|
||||
save_log_key=$(func_parser_key "${lines[29]}")
|
||||
|
||||
LOG_PATH="./test/output"
|
||||
mkdir -p ${LOG_PATH}
|
||||
status_log="${LOG_PATH}/results.log"
|
||||
|
||||
if [ ${MODE} = "lite_train_infer" ]; then
|
||||
export infer_img_dir="./train_data/icdar2015/text_localization/ch4_test_images/"
|
||||
export epoch_num=10
|
||||
elif [ ${MODE} = "whole_infer" ]; then
|
||||
export infer_img_dir="./train_data/icdar2015/text_localization/ch4_test_images/"
|
||||
export epoch_num=10
|
||||
elif [ ${MODE} = "whole_train_infer" ]; then
|
||||
export infer_img_dir="./train_data/icdar2015/text_localization/ch4_test_images/"
|
||||
export epoch_num=300
|
||||
else
|
||||
export infer_img_dir="./inference/ch_det_data_50/all-sum-510"
|
||||
export infer_model_dir="./inference/ch_ppocr_mobile_v2.0_det_train/best_accuracy"
|
||||
fi
|
||||
|
||||
|
||||
function func_inference(){
|
||||
IFS='|'
|
||||
_python=$1
|
||||
_script=$2
|
||||
_model_dir=$3
|
||||
_log_path=$4
|
||||
_img_dir=$5
|
||||
|
||||
# inference
|
||||
for use_gpu in ${use_gpu_list[*]}; do
|
||||
if [ ${use_gpu} = "False" ]; then
|
||||
for use_mkldnn in ${use_mkldnn_list[*]}; do
|
||||
for threads in ${cpu_threads_list[*]}; do
|
||||
for batch_size in ${batch_size_list[*]}; do
|
||||
_save_log_path="${_log_path}/infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_${batch_size}"
|
||||
command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_mkldnn_key}=${use_mkldnn} ${cpu_threads_key}=${threads} ${model_dir_key}=${_model_dir} ${batch_size_key}=${batch_size} ${image_dir_key}=${_img_dir} ${save_log_key}=${_save_log_path} --benchmark=True"
|
||||
eval $command
|
||||
status_check $? "${command}" "${status_log}"
|
||||
done
|
||||
done
|
||||
done
|
||||
else
|
||||
for use_trt in ${use_trt_list[*]}; do
|
||||
for precision in ${precision_list[*]}; do
|
||||
if [ ${use_trt} = "False" ] && [ ${precision} != "fp32" ]; then
|
||||
continue
|
||||
fi
|
||||
for batch_size in ${batch_size_list[*]}; do
|
||||
_save_log_path="${_log_path}/infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}"
|
||||
command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_trt_key}=${use_trt} ${precision_key}=${precision} ${model_dir_key}=${_model_dir} ${batch_size_key}=${batch_size} ${image_dir_key}=${_img_dir} ${save_log_key}=${_save_log_path} --benchmark=True"
|
||||
eval $command
|
||||
status_check $? "${command}" "${status_log}"
|
||||
done
|
||||
done
|
||||
done
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
if [ ${MODE} != "infer" ]; then
|
||||
|
||||
IFS="|"
|
||||
for gpu in ${gpu_list[*]}; do
|
||||
train_use_gpu=True
|
||||
if [ ${gpu} = "-1" ];then
|
||||
train_use_gpu=False
|
||||
env=""
|
||||
elif [ ${#gpu} -le 1 ];then
|
||||
env="export CUDA_VISIBLE_DEVICES=${gpu}"
|
||||
elif [ ${#gpu} -le 15 ];then
|
||||
IFS=","
|
||||
array=(${gpu})
|
||||
env="export CUDA_VISIBLE_DEVICES=${array[0]}"
|
||||
IFS="|"
|
||||
else
|
||||
IFS=";"
|
||||
array=(${gpu})
|
||||
ips=${array[0]}
|
||||
gpu=${array[1]}
|
||||
IFS="|"
|
||||
fi
|
||||
for autocast in ${autocast_list[*]}; do
|
||||
for trainer in ${trainer_list[*]}; do
|
||||
if [ ${trainer} = "pact" ]; then
|
||||
run_train=${pact_trainer}
|
||||
run_export=${pact_export}
|
||||
elif [ ${trainer} = "fpgm" ]; then
|
||||
run_train=${fpgm_trainer}
|
||||
run_export=${fpgm_export}
|
||||
elif [ ${trainer} = "distill" ]; then
|
||||
run_train=${distill_trainer}
|
||||
run_export=${distill_export}
|
||||
else
|
||||
run_train=${norm_trainer}
|
||||
run_export=${norm_export}
|
||||
fi
|
||||
|
||||
if [ ${run_train} = "null" ]; then
|
||||
continue
|
||||
fi
|
||||
if [ ${run_export} = "null" ]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
save_log="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}"
|
||||
if [ ${#gpu} -le 2 ];then # epoch_num #TODO
|
||||
cmd="${python} ${run_train} ${train_use_gpu_key}=${train_use_gpu} ${autocast_key}=${autocast} ${epoch_key}=${epoch_num} ${save_model_key}=${save_log} "
|
||||
elif [ ${#gpu} -le 15 ];then
|
||||
cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${autocast_key}=${autocast} ${epoch_key}=${epoch_num} ${save_model_key}=${save_log}"
|
||||
else
|
||||
cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${autocast_key}=${autocast} ${epoch_key}=${epoch_num} ${save_model_key}=${save_log}"
|
||||
fi
|
||||
# run train
|
||||
eval $cmd
|
||||
status_check $? "${cmd}" "${status_log}"
|
||||
|
||||
# run eval
|
||||
eval_cmd="${python} ${eval_py} ${save_model_key}=${save_log} ${pretrain_model_key}=${save_log}/latest"
|
||||
eval $eval_cmd
|
||||
status_check $? "${eval_cmd}" "${status_log}"
|
||||
|
||||
# run export model
|
||||
save_infer_path="${save_log}"
|
||||
export_cmd="${python} ${run_export} ${save_model_key}=${save_log} ${pretrain_model_key}=${save_log}/latest ${save_infer_key}=${save_infer_path}"
|
||||
eval $export_cmd
|
||||
status_check $? "${export_cmd}" "${status_log}"
|
||||
|
||||
#run inference
|
||||
save_infer_path="${save_log}"
|
||||
func_inference "${python}" "${inference_py}" "${save_infer_path}" "${LOG_PATH}" "${infer_img_dir}"
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
else
|
||||
save_infer_path="${LOG_PATH}/${MODE}"
|
||||
run_export=${norm_export}
|
||||
export_cmd="${python} ${run_export} ${save_model_key}=${save_infer_path} ${pretrain_model_key}=${infer_model_dir} ${save_infer_key}=${save_infer_path}"
|
||||
eval $export_cmd
|
||||
status_check $? "${export_cmd}" "${status_log}"
|
||||
|
||||
#run inference
|
||||
func_inference "${python}" "${inference_py}" "${save_infer_path}" "${LOG_PATH}" "${infer_img_dir}"
|
||||
fi
|
|
@ -5,5 +5,5 @@ recursive-include ppocr/utils *.txt utility.py logging.py network.py
|
|||
recursive-include ppocr/data/ *.py
|
||||
recursive-include ppocr/postprocess *.py
|
||||
recursive-include tools/infer *.py
|
||||
recursive-include ppstructure *.py
|
||||
recursive-include test1 *.py
|
||||
|
||||
|
|
|
@ -146,23 +146,3 @@ def main():
|
|||
logger.info(item['res'])
|
||||
save_res(result, save_folder, img_name)
|
||||
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
table_engine = PaddleStructure(show_log=True)
|
||||
|
||||
img_path = '../test/test_imgs/PMC1173095_006_00.png'
|
||||
img = cv2.imread(img_path)
|
||||
result = table_engine(img)
|
||||
save_res(result, '/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/output/table',
|
||||
os.path.basename(img_path).split('.')[0])
|
||||
|
||||
for line in result:
|
||||
print(line)
|
||||
|
||||
from PIL import Image
|
||||
|
||||
font_path = '../doc/fonts/simfang.ttf'
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
im_show = draw_result(image, result, font_path=font_path)
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
|
@ -20,8 +20,6 @@ import shutil
|
|||
with open('../requirements.txt', encoding="utf-8-sig") as f:
|
||||
requirements = f.readlines()
|
||||
requirements.append('tqdm')
|
||||
requirements.append('layoutparser')
|
||||
requirements.append('iopath')
|
||||
|
||||
|
||||
def readme():
|
||||
|
|
|
@ -1,232 +0,0 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
|
||||
import paddle
|
||||
import paddle.inference as paddle_infer
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
class PaddleInferBenchmark(object):
|
||||
def __init__(self,
|
||||
config,
|
||||
model_info: dict={},
|
||||
data_info: dict={},
|
||||
perf_info: dict={},
|
||||
resource_info: dict={},
|
||||
save_log_path: str="",
|
||||
**kwargs):
|
||||
"""
|
||||
Construct PaddleInferBenchmark Class to format logs.
|
||||
args:
|
||||
config(paddle.inference.Config): paddle inference config
|
||||
model_info(dict): basic model info
|
||||
{'model_name': 'resnet50'
|
||||
'precision': 'fp32'}
|
||||
data_info(dict): input data info
|
||||
{'batch_size': 1
|
||||
'shape': '3,224,224'
|
||||
'data_num': 1000}
|
||||
perf_info(dict): performance result
|
||||
{'preprocess_time_s': 1.0
|
||||
'inference_time_s': 2.0
|
||||
'postprocess_time_s': 1.0
|
||||
'total_time_s': 4.0}
|
||||
resource_info(dict):
|
||||
cpu and gpu resources
|
||||
{'cpu_rss': 100
|
||||
'gpu_rss': 100
|
||||
'gpu_util': 60}
|
||||
"""
|
||||
# PaddleInferBenchmark Log Version
|
||||
self.log_version = 1.0
|
||||
|
||||
# Paddle Version
|
||||
self.paddle_version = paddle.__version__
|
||||
self.paddle_commit = paddle.__git_commit__
|
||||
paddle_infer_info = paddle_infer.get_version()
|
||||
self.paddle_branch = paddle_infer_info.strip().split(': ')[-1]
|
||||
|
||||
# model info
|
||||
self.model_info = model_info
|
||||
|
||||
# data info
|
||||
self.data_info = data_info
|
||||
|
||||
# perf info
|
||||
self.perf_info = perf_info
|
||||
|
||||
try:
|
||||
self.model_name = model_info['model_name']
|
||||
self.precision = model_info['precision']
|
||||
|
||||
self.batch_size = data_info['batch_size']
|
||||
self.shape = data_info['shape']
|
||||
self.data_num = data_info['data_num']
|
||||
|
||||
self.preprocess_time_s = round(perf_info['preprocess_time_s'], 4)
|
||||
self.inference_time_s = round(perf_info['inference_time_s'], 4)
|
||||
self.postprocess_time_s = round(perf_info['postprocess_time_s'], 4)
|
||||
self.total_time_s = round(perf_info['total_time_s'], 4)
|
||||
except:
|
||||
self.print_help()
|
||||
raise ValueError(
|
||||
"Set argument wrong, please check input argument and its type")
|
||||
|
||||
# conf info
|
||||
self.config_status = self.parse_config(config)
|
||||
self.save_log_path = save_log_path
|
||||
# mem info
|
||||
if isinstance(resource_info, dict):
|
||||
self.cpu_rss_mb = int(resource_info.get('cpu_rss_mb', 0))
|
||||
self.gpu_rss_mb = int(resource_info.get('gpu_rss_mb', 0))
|
||||
self.gpu_util = round(resource_info.get('gpu_util', 0), 2)
|
||||
else:
|
||||
self.cpu_rss_mb = 0
|
||||
self.gpu_rss_mb = 0
|
||||
self.gpu_util = 0
|
||||
|
||||
# init benchmark logger
|
||||
self.benchmark_logger()
|
||||
|
||||
def benchmark_logger(self):
|
||||
"""
|
||||
benchmark logger
|
||||
"""
|
||||
# Init logger
|
||||
FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
log_output = f"{self.save_log_path}/{self.model_name}.log"
|
||||
Path(f"{self.save_log_path}").mkdir(parents=True, exist_ok=True)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format=FORMAT,
|
||||
handlers=[
|
||||
logging.FileHandler(
|
||||
filename=log_output, mode='w'),
|
||||
logging.StreamHandler(),
|
||||
])
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger.info(
|
||||
f"Paddle Inference benchmark log will be saved to {log_output}")
|
||||
|
||||
def parse_config(self, config) -> dict:
|
||||
"""
|
||||
parse paddle predictor config
|
||||
args:
|
||||
config(paddle.inference.Config): paddle inference config
|
||||
return:
|
||||
config_status(dict): dict style config info
|
||||
"""
|
||||
config_status = {}
|
||||
config_status['runtime_device'] = "gpu" if config.use_gpu() else "cpu"
|
||||
config_status['ir_optim'] = config.ir_optim()
|
||||
config_status['enable_tensorrt'] = config.tensorrt_engine_enabled()
|
||||
config_status['precision'] = self.precision
|
||||
config_status['enable_mkldnn'] = config.mkldnn_enabled()
|
||||
config_status[
|
||||
'cpu_math_library_num_threads'] = config.cpu_math_library_num_threads(
|
||||
)
|
||||
return config_status
|
||||
|
||||
def report(self, identifier=None):
|
||||
"""
|
||||
print log report
|
||||
args:
|
||||
identifier(string): identify log
|
||||
"""
|
||||
if identifier:
|
||||
identifier = f"[{identifier}]"
|
||||
else:
|
||||
identifier = ""
|
||||
|
||||
self.logger.info("\n")
|
||||
self.logger.info(
|
||||
"---------------------- Paddle info ----------------------")
|
||||
self.logger.info(f"{identifier} paddle_version: {self.paddle_version}")
|
||||
self.logger.info(f"{identifier} paddle_commit: {self.paddle_commit}")
|
||||
self.logger.info(f"{identifier} paddle_branch: {self.paddle_branch}")
|
||||
self.logger.info(f"{identifier} log_api_version: {self.log_version}")
|
||||
self.logger.info(
|
||||
"----------------------- Conf info -----------------------")
|
||||
self.logger.info(
|
||||
f"{identifier} runtime_device: {self.config_status['runtime_device']}"
|
||||
)
|
||||
self.logger.info(
|
||||
f"{identifier} ir_optim: {self.config_status['ir_optim']}")
|
||||
self.logger.info(f"{identifier} enable_memory_optim: {True}")
|
||||
self.logger.info(
|
||||
f"{identifier} enable_tensorrt: {self.config_status['enable_tensorrt']}"
|
||||
)
|
||||
self.logger.info(
|
||||
f"{identifier} enable_mkldnn: {self.config_status['enable_mkldnn']}")
|
||||
self.logger.info(
|
||||
f"{identifier} cpu_math_library_num_threads: {self.config_status['cpu_math_library_num_threads']}"
|
||||
)
|
||||
self.logger.info(
|
||||
"----------------------- Model info ----------------------")
|
||||
self.logger.info(f"{identifier} model_name: {self.model_name}")
|
||||
self.logger.info(f"{identifier} precision: {self.precision}")
|
||||
self.logger.info(
|
||||
"----------------------- Data info -----------------------")
|
||||
self.logger.info(f"{identifier} batch_size: {self.batch_size}")
|
||||
self.logger.info(f"{identifier} input_shape: {self.shape}")
|
||||
self.logger.info(f"{identifier} data_num: {self.data_num}")
|
||||
self.logger.info(
|
||||
"----------------------- Perf info -----------------------")
|
||||
self.logger.info(
|
||||
f"{identifier} cpu_rss(MB): {self.cpu_rss_mb}, gpu_rss(MB): {self.gpu_rss_mb}, gpu_util: {self.gpu_util}%"
|
||||
)
|
||||
self.logger.info(
|
||||
f"{identifier} total time spent(s): {self.total_time_s}")
|
||||
self.logger.info(
|
||||
f"{identifier} preprocess_time(ms): {round(self.preprocess_time_s*1000, 1)}, inference_time(ms): {round(self.inference_time_s*1000, 1)}, postprocess_time(ms): {round(self.postprocess_time_s*1000, 1)}"
|
||||
)
|
||||
|
||||
def print_help(self):
|
||||
"""
|
||||
print function help
|
||||
"""
|
||||
print("""Usage:
|
||||
==== Print inference benchmark logs. ====
|
||||
config = paddle.inference.Config()
|
||||
model_info = {'model_name': 'resnet50'
|
||||
'precision': 'fp32'}
|
||||
data_info = {'batch_size': 1
|
||||
'shape': '3,224,224'
|
||||
'data_num': 1000}
|
||||
perf_info = {'preprocess_time_s': 1.0
|
||||
'inference_time_s': 2.0
|
||||
'postprocess_time_s': 1.0
|
||||
'total_time_s': 4.0}
|
||||
resource_info = {'cpu_rss_mb': 100
|
||||
'gpu_rss_mb': 100
|
||||
'gpu_util': 60}
|
||||
log = PaddleInferBenchmark(config, model_info, data_info, perf_info, resource_info)
|
||||
log('Test')
|
||||
""")
|
||||
|
||||
def __call__(self, identifier=None):
|
||||
"""
|
||||
__call__
|
||||
args:
|
||||
identifier(string): identify log
|
||||
"""
|
||||
self.report(identifier)
|
|
@ -48,8 +48,6 @@ class TextClassifier(object):
|
|||
self.predictor, self.input_tensor, self.output_tensors, _ = \
|
||||
utility.create_predictor(args, 'cls', logger)
|
||||
|
||||
self.cls_times = utility.Timer()
|
||||
|
||||
def resize_norm_img(self, img):
|
||||
imgC, imgH, imgW = self.cls_image_shape
|
||||
h = img.shape[0]
|
||||
|
@ -85,35 +83,28 @@ class TextClassifier(object):
|
|||
cls_res = [['', 0.0]] * img_num
|
||||
batch_num = self.cls_batch_num
|
||||
elapse = 0
|
||||
self.cls_times.total_time.start()
|
||||
for beg_img_no in range(0, img_num, batch_num):
|
||||
|
||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||
norm_img_batch = []
|
||||
max_wh_ratio = 0
|
||||
starttime = time.time()
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
h, w = img_list[indices[ino]].shape[0:2]
|
||||
wh_ratio = w * 1.0 / h
|
||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||
self.cls_times.preprocess_time.start()
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]])
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
norm_img_batch = np.concatenate(norm_img_batch)
|
||||
norm_img_batch = norm_img_batch.copy()
|
||||
starttime = time.time()
|
||||
self.cls_times.preprocess_time.end()
|
||||
self.cls_times.inference_time.start()
|
||||
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
self.predictor.run()
|
||||
prob_out = self.output_tensors[0].copy_to_cpu()
|
||||
self.cls_times.inference_time.end()
|
||||
self.cls_times.postprocess_time.start()
|
||||
self.predictor.try_shrink_memory()
|
||||
cls_result = self.postprocess_op(prob_out)
|
||||
self.cls_times.postprocess_time.end()
|
||||
elapse += time.time() - starttime
|
||||
for rno in range(len(cls_result)):
|
||||
label, score = cls_result[rno]
|
||||
|
@ -121,9 +112,7 @@ class TextClassifier(object):
|
|||
if '180' in label and score > self.cls_thresh:
|
||||
img_list[indices[beg_img_no + rno]] = cv2.rotate(
|
||||
img_list[indices[beg_img_no + rno]], 1)
|
||||
self.cls_times.total_time.end()
|
||||
self.cls_times.img_num += img_num
|
||||
elapse = self.cls_times.total_time.value()
|
||||
elapse = time.time() - starttime
|
||||
return img_list, cls_res, elapse
|
||||
|
||||
|
||||
|
|
|
@ -31,8 +31,6 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
|||
from ppocr.data import create_operators, transform
|
||||
from ppocr.postprocess import build_post_process
|
||||
|
||||
# import tools.infer.benchmark_utils as benchmark_utils
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
|
@ -100,6 +98,24 @@ class TextDetector(object):
|
|||
self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
|
||||
args, 'det', logger)
|
||||
|
||||
if args.benchmark:
|
||||
import auto_log
|
||||
pid = os.getpid()
|
||||
self.autolog = auto_log.AutoLogger(
|
||||
model_name="det",
|
||||
model_precision=args.precision,
|
||||
batch_size=1,
|
||||
data_shape="dynamic",
|
||||
save_path=args.save_log_path,
|
||||
inference_config=self.config,
|
||||
pids=pid,
|
||||
process_name=None,
|
||||
gpu_ids=0,
|
||||
time_keys=[
|
||||
'preprocess_time', 'inference_time', 'postprocess_time'
|
||||
],
|
||||
warmup=10)
|
||||
|
||||
def order_points_clockwise(self, pts):
|
||||
"""
|
||||
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
|
||||
|
@ -158,6 +174,10 @@ class TextDetector(object):
|
|||
data = {'image': img}
|
||||
|
||||
st = time.time()
|
||||
|
||||
if args.benchmark:
|
||||
self.autolog.times.start()
|
||||
|
||||
data = transform(data, self.preprocess_op)
|
||||
img, shape_list = data
|
||||
if img is None:
|
||||
|
@ -166,12 +186,17 @@ class TextDetector(object):
|
|||
shape_list = np.expand_dims(shape_list, axis=0)
|
||||
img = img.copy()
|
||||
|
||||
if args.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
|
||||
self.input_tensor.copy_from_cpu(img)
|
||||
self.predictor.run()
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
if args.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
|
||||
preds = {}
|
||||
if self.det_algorithm == "EAST":
|
||||
|
@ -187,7 +212,7 @@ class TextDetector(object):
|
|||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.predictor.try_shrink_memory()
|
||||
#self.predictor.try_shrink_memory()
|
||||
post_result = self.postprocess_op(preds, shape_list)
|
||||
dt_boxes = post_result[0]['points']
|
||||
if self.det_algorithm == "SAST" and self.det_sast_polygon:
|
||||
|
@ -195,6 +220,8 @@ class TextDetector(object):
|
|||
else:
|
||||
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
||||
|
||||
if args.benchmark:
|
||||
self.autolog.times.end(stamp=True)
|
||||
et = time.time()
|
||||
return dt_boxes, et - st
|
||||
|
||||
|
@ -212,8 +239,6 @@ if __name__ == "__main__":
|
|||
for i in range(10):
|
||||
res = text_detector(img)
|
||||
|
||||
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
|
||||
|
||||
if not os.path.exists(draw_img_save):
|
||||
os.makedirs(draw_img_save)
|
||||
for image_file in image_file_list:
|
||||
|
@ -237,3 +262,6 @@ if __name__ == "__main__":
|
|||
"det_res_{}".format(img_name_pure))
|
||||
cv2.imwrite(img_path, src_im)
|
||||
logger.info("The visualized image saved in {}".format(img_path))
|
||||
|
||||
if args.benchmark:
|
||||
text_detector.autolog.report()
|
||||
|
|
|
@ -28,7 +28,6 @@ import traceback
|
|||
import paddle
|
||||
|
||||
import tools.infer.utility as utility
|
||||
import tools.infer.benchmark_utils as benchmark_utils
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
|
@ -66,8 +65,6 @@ class TextRecognizer(object):
|
|||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
utility.create_predictor(args, 'rec', logger)
|
||||
|
||||
self.rec_times = utility.Timer()
|
||||
|
||||
def resize_norm_img(self, img, max_wh_ratio):
|
||||
imgC, imgH, imgW = self.rec_image_shape
|
||||
assert imgC == img.shape[2]
|
||||
|
@ -168,14 +165,13 @@ class TextRecognizer(object):
|
|||
width_list.append(img.shape[1] / float(img.shape[0]))
|
||||
# Sorting can speed up the recognition process
|
||||
indices = np.argsort(np.array(width_list))
|
||||
self.rec_times.total_time.start()
|
||||
rec_res = [['', 0.0]] * img_num
|
||||
batch_num = self.rec_batch_num
|
||||
st = time.time()
|
||||
for beg_img_no in range(0, img_num, batch_num):
|
||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||
norm_img_batch = []
|
||||
max_wh_ratio = 0
|
||||
self.rec_times.preprocess_time.start()
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
h, w = img_list[indices[ino]].shape[0:2]
|
||||
wh_ratio = w * 1.0 / h
|
||||
|
@ -216,23 +212,18 @@ class TextRecognizer(object):
|
|||
gsrm_slf_attn_bias1_list,
|
||||
gsrm_slf_attn_bias2_list,
|
||||
]
|
||||
self.rec_times.preprocess_time.end()
|
||||
self.rec_times.inference_time.start()
|
||||
input_names = self.predictor.get_input_names()
|
||||
for i in range(len(input_names)):
|
||||
input_tensor = self.predictor.get_input_handle(input_names[
|
||||
i])
|
||||
input_tensor.copy_from_cpu(inputs[i])
|
||||
self.predictor.run()
|
||||
self.rec_times.inference_time.end()
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
preds = {"predict": outputs[2]}
|
||||
else:
|
||||
self.rec_times.preprocess_time.end()
|
||||
self.rec_times.inference_time.start()
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
self.predictor.run()
|
||||
|
||||
|
@ -241,15 +232,11 @@ class TextRecognizer(object):
|
|||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
preds = outputs[0]
|
||||
self.rec_times.inference_time.end()
|
||||
self.rec_times.postprocess_time.start()
|
||||
rec_result = self.postprocess_op(preds)
|
||||
for rno in range(len(rec_result)):
|
||||
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
||||
self.rec_times.postprocess_time.end()
|
||||
self.rec_times.img_num += int(norm_img_batch.shape[0])
|
||||
self.rec_times.total_time.end()
|
||||
return rec_res, self.rec_times.total_time.value()
|
||||
|
||||
return rec_res, time.time() - st
|
||||
|
||||
|
||||
def main(args):
|
||||
|
@ -278,12 +265,6 @@ def main(args):
|
|||
img_list.append(img)
|
||||
try:
|
||||
rec_res, _ = text_recognizer(img_list)
|
||||
if args.benchmark:
|
||||
cm, gm, gu = utility.get_current_memory_mb(0)
|
||||
cpu_mem += cm
|
||||
gpu_mem += gm
|
||||
gpu_util += gu
|
||||
count += 1
|
||||
|
||||
except Exception as E:
|
||||
logger.info(traceback.format_exc())
|
||||
|
@ -292,38 +273,6 @@ def main(args):
|
|||
for ino in range(len(img_list)):
|
||||
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
|
||||
rec_res[ino]))
|
||||
if args.benchmark:
|
||||
mems = {
|
||||
'cpu_rss_mb': cpu_mem / count,
|
||||
'gpu_rss_mb': gpu_mem / count,
|
||||
'gpu_util': gpu_util * 100 / count
|
||||
}
|
||||
else:
|
||||
mems = None
|
||||
logger.info("The predict time about recognizer module is as follows: ")
|
||||
rec_time_dict = text_recognizer.rec_times.report(average=True)
|
||||
rec_model_name = args.rec_model_dir
|
||||
|
||||
if args.benchmark:
|
||||
# construct log information
|
||||
model_info = {
|
||||
'model_name': args.rec_model_dir.split('/')[-1],
|
||||
'precision': args.precision
|
||||
}
|
||||
data_info = {
|
||||
'batch_size': args.rec_batch_num,
|
||||
'shape': 'dynamic_shape',
|
||||
'data_num': rec_time_dict['img_num']
|
||||
}
|
||||
perf_info = {
|
||||
'preprocess_time_s': rec_time_dict['preprocess_time'],
|
||||
'inference_time_s': rec_time_dict['inference_time'],
|
||||
'postprocess_time_s': rec_time_dict['postprocess_time'],
|
||||
'total_time_s': rec_time_dict['total_time']
|
||||
}
|
||||
benchmark_log = benchmark_utils.PaddleInferBenchmark(
|
||||
text_recognizer.config, model_info, data_info, perf_info, mems)
|
||||
benchmark_log("Rec")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -33,8 +33,7 @@ import tools.infer.predict_det as predict_det
|
|||
import tools.infer.predict_cls as predict_cls
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from ppocr.utils.logging import get_logger
|
||||
from tools.infer.utility import draw_ocr_box_txt, get_current_memory_mb
|
||||
import tools.infer.benchmark_utils as benchmark_utils
|
||||
from tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
|
@ -50,39 +49,6 @@ class TextSystem(object):
|
|||
if self.use_angle_cls:
|
||||
self.text_classifier = predict_cls.TextClassifier(args)
|
||||
|
||||
def get_rotate_crop_image(self, img, points):
|
||||
'''
|
||||
img_height, img_width = img.shape[0:2]
|
||||
left = int(np.min(points[:, 0]))
|
||||
right = int(np.max(points[:, 0]))
|
||||
top = int(np.min(points[:, 1]))
|
||||
bottom = int(np.max(points[:, 1]))
|
||||
img_crop = img[top:bottom, left:right, :].copy()
|
||||
points[:, 0] = points[:, 0] - left
|
||||
points[:, 1] = points[:, 1] - top
|
||||
'''
|
||||
img_crop_width = int(
|
||||
max(
|
||||
np.linalg.norm(points[0] - points[1]),
|
||||
np.linalg.norm(points[2] - points[3])))
|
||||
img_crop_height = int(
|
||||
max(
|
||||
np.linalg.norm(points[0] - points[3]),
|
||||
np.linalg.norm(points[1] - points[2])))
|
||||
pts_std = np.float32([[0, 0], [img_crop_width, 0],
|
||||
[img_crop_width, img_crop_height],
|
||||
[0, img_crop_height]])
|
||||
M = cv2.getPerspectiveTransform(points, pts_std)
|
||||
dst_img = cv2.warpPerspective(
|
||||
img,
|
||||
M, (img_crop_width, img_crop_height),
|
||||
borderMode=cv2.BORDER_REPLICATE,
|
||||
flags=cv2.INTER_CUBIC)
|
||||
dst_img_height, dst_img_width = dst_img.shape[0:2]
|
||||
if dst_img_height * 1.0 / dst_img_width >= 1.5:
|
||||
dst_img = np.rot90(dst_img)
|
||||
return dst_img
|
||||
|
||||
def print_draw_crop_rec_res(self, img_crop_list, rec_res):
|
||||
bbox_num = len(img_crop_list)
|
||||
for bno in range(bbox_num):
|
||||
|
@ -103,7 +69,7 @@ class TextSystem(object):
|
|||
|
||||
for bno in range(len(dt_boxes)):
|
||||
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||||
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
|
||||
img_crop = get_rotate_crop_image(ori_im, tmp_box)
|
||||
img_crop_list.append(img_crop)
|
||||
if self.use_angle_cls and cls:
|
||||
img_crop_list, angle_list, elapse = self.text_classifier(
|
||||
|
@ -158,7 +124,7 @@ def main(args):
|
|||
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
|
||||
for i in range(10):
|
||||
res = text_sys(img)
|
||||
|
||||
|
||||
total_time = 0
|
||||
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
|
||||
_st = time.time()
|
||||
|
@ -175,12 +141,6 @@ def main(args):
|
|||
dt_boxes, rec_res = text_sys(img)
|
||||
elapse = time.time() - starttime
|
||||
total_time += elapse
|
||||
if args.benchmark and idx % 20 == 0:
|
||||
cm, gm, gu = get_current_memory_mb(0)
|
||||
cpu_mem += cm
|
||||
gpu_mem += gm
|
||||
gpu_util += gu
|
||||
count += 1
|
||||
|
||||
logger.info(
|
||||
str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse))
|
||||
|
@ -215,61 +175,6 @@ def main(args):
|
|||
logger.info("\nThe predict total time is {}".format(total_time))
|
||||
|
||||
img_num = text_sys.text_detector.det_times.img_num
|
||||
if args.benchmark:
|
||||
mems = {
|
||||
'cpu_rss_mb': cpu_mem / count,
|
||||
'gpu_rss_mb': gpu_mem / count,
|
||||
'gpu_util': gpu_util * 100 / count
|
||||
}
|
||||
else:
|
||||
mems = None
|
||||
det_time_dict = text_sys.text_detector.det_times.report(average=True)
|
||||
rec_time_dict = text_sys.text_recognizer.rec_times.report(average=True)
|
||||
det_model_name = args.det_model_dir
|
||||
rec_model_name = args.rec_model_dir
|
||||
|
||||
# construct det log information
|
||||
model_info = {
|
||||
'model_name': args.det_model_dir.split('/')[-1],
|
||||
'precision': args.precision
|
||||
}
|
||||
data_info = {
|
||||
'batch_size': 1,
|
||||
'shape': 'dynamic_shape',
|
||||
'data_num': det_time_dict['img_num']
|
||||
}
|
||||
perf_info = {
|
||||
'preprocess_time_s': det_time_dict['preprocess_time'],
|
||||
'inference_time_s': det_time_dict['inference_time'],
|
||||
'postprocess_time_s': det_time_dict['postprocess_time'],
|
||||
'total_time_s': det_time_dict['total_time']
|
||||
}
|
||||
|
||||
benchmark_log = benchmark_utils.PaddleInferBenchmark(
|
||||
text_sys.text_detector.config, model_info, data_info, perf_info, mems,
|
||||
args.save_log_path)
|
||||
benchmark_log("Det")
|
||||
|
||||
# construct rec log information
|
||||
model_info = {
|
||||
'model_name': args.rec_model_dir.split('/')[-1],
|
||||
'precision': args.precision
|
||||
}
|
||||
data_info = {
|
||||
'batch_size': args.rec_batch_num,
|
||||
'shape': 'dynamic_shape',
|
||||
'data_num': rec_time_dict['img_num']
|
||||
}
|
||||
perf_info = {
|
||||
'preprocess_time_s': rec_time_dict['preprocess_time'],
|
||||
'inference_time_s': rec_time_dict['inference_time'],
|
||||
'postprocess_time_s': rec_time_dict['postprocess_time'],
|
||||
'total_time_s': rec_time_dict['total_time']
|
||||
}
|
||||
benchmark_log = benchmark_utils.PaddleInferBenchmark(
|
||||
text_sys.text_recognizer.config, model_info, data_info, perf_info, mems,
|
||||
args.save_log_path)
|
||||
benchmark_log("Rec")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -37,6 +37,7 @@ def init_args():
|
|||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||
parser.add_argument("--min_subgraph_size", type=int, default=3)
|
||||
parser.add_argument("--precision", type=str, default="fp32")
|
||||
parser.add_argument("--gpu_mem", type=int, default=500)
|
||||
|
||||
|
@ -124,76 +125,6 @@ def parse_args():
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
class Times(object):
|
||||
def __init__(self):
|
||||
self.time = 0.
|
||||
self.st = 0.
|
||||
self.et = 0.
|
||||
|
||||
def start(self):
|
||||
self.st = time.time()
|
||||
|
||||
def end(self, accumulative=True):
|
||||
self.et = time.time()
|
||||
if accumulative:
|
||||
self.time += self.et - self.st
|
||||
else:
|
||||
self.time = self.et - self.st
|
||||
|
||||
def reset(self):
|
||||
self.time = 0.
|
||||
self.st = 0.
|
||||
self.et = 0.
|
||||
|
||||
def value(self):
|
||||
return round(self.time, 4)
|
||||
|
||||
|
||||
class Timer(Times):
|
||||
def __init__(self):
|
||||
super(Timer, self).__init__()
|
||||
self.total_time = Times()
|
||||
self.preprocess_time = Times()
|
||||
self.inference_time = Times()
|
||||
self.postprocess_time = Times()
|
||||
self.img_num = 0
|
||||
|
||||
def info(self, average=False):
|
||||
logger.info("----------------------- Perf info -----------------------")
|
||||
logger.info("total_time: {}, img_num: {}".format(self.total_time.value(
|
||||
), self.img_num))
|
||||
preprocess_time = round(self.preprocess_time.value() / self.img_num,
|
||||
4) if average else self.preprocess_time.value()
|
||||
postprocess_time = round(
|
||||
self.postprocess_time.value() / self.img_num,
|
||||
4) if average else self.postprocess_time.value()
|
||||
inference_time = round(self.inference_time.value() / self.img_num,
|
||||
4) if average else self.inference_time.value()
|
||||
|
||||
average_latency = self.total_time.value() / self.img_num
|
||||
logger.info("average_latency(ms): {:.2f}, QPS: {:2f}".format(
|
||||
average_latency * 1000, 1 / average_latency))
|
||||
logger.info(
|
||||
"preprocess_latency(ms): {:.2f}, inference_latency(ms): {:.2f}, postprocess_latency(ms): {:.2f}".
|
||||
format(preprocess_time * 1000, inference_time * 1000,
|
||||
postprocess_time * 1000))
|
||||
|
||||
def report(self, average=False):
|
||||
dic = {}
|
||||
dic['preprocess_time'] = round(
|
||||
self.preprocess_time.value() / self.img_num,
|
||||
4) if average else self.preprocess_time.value()
|
||||
dic['postprocess_time'] = round(
|
||||
self.postprocess_time.value() / self.img_num,
|
||||
4) if average else self.postprocess_time.value()
|
||||
dic['inference_time'] = round(
|
||||
self.inference_time.value() / self.img_num,
|
||||
4) if average else self.inference_time.value()
|
||||
dic['img_num'] = self.img_num
|
||||
dic['total_time'] = round(self.total_time.value(), 4)
|
||||
return dic
|
||||
|
||||
|
||||
def create_predictor(args, mode, logger):
|
||||
if mode == "det":
|
||||
model_dir = args.det_model_dir
|
||||
|
@ -212,11 +143,10 @@ def create_predictor(args, mode, logger):
|
|||
model_file_path = model_dir + "/inference.pdmodel"
|
||||
params_file_path = model_dir + "/inference.pdiparams"
|
||||
if not os.path.exists(model_file_path):
|
||||
logger.info("not find model file path {}".format(model_file_path))
|
||||
sys.exit(0)
|
||||
raise ValueError("not find model file path {}".format(model_file_path))
|
||||
if not os.path.exists(params_file_path):
|
||||
logger.info("not find params file path {}".format(params_file_path))
|
||||
sys.exit(0)
|
||||
raise ValueError("not find params file path {}".format(
|
||||
params_file_path))
|
||||
|
||||
config = inference.Config(model_file_path, params_file_path)
|
||||
|
||||
|
@ -236,14 +166,17 @@ def create_predictor(args, mode, logger):
|
|||
config.enable_tensorrt_engine(
|
||||
precision_mode=inference.PrecisionType.Float32,
|
||||
max_batch_size=args.max_batch_size,
|
||||
min_subgraph_size=3) # skip the minmum trt subgraph
|
||||
if mode == "det" and "mobile" in model_file_path:
|
||||
min_subgraph_size=args.min_subgraph_size)
|
||||
# skip the minmum trt subgraph
|
||||
if mode == "det":
|
||||
min_input_shape = {
|
||||
"x": [1, 3, 50, 50],
|
||||
"conv2d_92.tmp_0": [1, 96, 20, 20],
|
||||
"conv2d_91.tmp_0": [1, 96, 10, 10],
|
||||
"conv2d_59.tmp_0": [1, 96, 20, 20],
|
||||
"nearest_interp_v2_1.tmp_0": [1, 96, 10, 10],
|
||||
"nearest_interp_v2_2.tmp_0": [1, 96, 20, 20],
|
||||
"conv2d_124.tmp_0": [1, 96, 20, 20],
|
||||
"nearest_interp_v2_3.tmp_0": [1, 24, 20, 20],
|
||||
"nearest_interp_v2_4.tmp_0": [1, 24, 20, 20],
|
||||
"nearest_interp_v2_5.tmp_0": [1, 24, 20, 20],
|
||||
|
@ -254,7 +187,9 @@ def create_predictor(args, mode, logger):
|
|||
"x": [1, 3, 2000, 2000],
|
||||
"conv2d_92.tmp_0": [1, 96, 400, 400],
|
||||
"conv2d_91.tmp_0": [1, 96, 200, 200],
|
||||
"conv2d_59.tmp_0": [1, 96, 400, 400],
|
||||
"nearest_interp_v2_1.tmp_0": [1, 96, 200, 200],
|
||||
"conv2d_124.tmp_0": [1, 256, 400, 400],
|
||||
"nearest_interp_v2_2.tmp_0": [1, 96, 400, 400],
|
||||
"nearest_interp_v2_3.tmp_0": [1, 24, 400, 400],
|
||||
"nearest_interp_v2_4.tmp_0": [1, 24, 400, 400],
|
||||
|
@ -266,39 +201,16 @@ def create_predictor(args, mode, logger):
|
|||
"x": [1, 3, 640, 640],
|
||||
"conv2d_92.tmp_0": [1, 96, 160, 160],
|
||||
"conv2d_91.tmp_0": [1, 96, 80, 80],
|
||||
"conv2d_59.tmp_0": [1, 96, 160, 160],
|
||||
"nearest_interp_v2_1.tmp_0": [1, 96, 80, 80],
|
||||
"nearest_interp_v2_2.tmp_0": [1, 96, 160, 160],
|
||||
"conv2d_124.tmp_0": [1, 256, 160, 160],
|
||||
"nearest_interp_v2_3.tmp_0": [1, 24, 160, 160],
|
||||
"nearest_interp_v2_4.tmp_0": [1, 24, 160, 160],
|
||||
"nearest_interp_v2_5.tmp_0": [1, 24, 160, 160],
|
||||
"elementwise_add_7": [1, 56, 40, 40],
|
||||
"nearest_interp_v2_0.tmp_0": [1, 96, 40, 40]
|
||||
}
|
||||
if mode == "det" and "server" in model_file_path:
|
||||
min_input_shape = {
|
||||
"x": [1, 3, 50, 50],
|
||||
"conv2d_59.tmp_0": [1, 96, 20, 20],
|
||||
"nearest_interp_v2_2.tmp_0": [1, 96, 20, 20],
|
||||
"nearest_interp_v2_3.tmp_0": [1, 24, 20, 20],
|
||||
"nearest_interp_v2_4.tmp_0": [1, 24, 20, 20],
|
||||
"nearest_interp_v2_5.tmp_0": [1, 24, 20, 20]
|
||||
}
|
||||
max_input_shape = {
|
||||
"x": [1, 3, 2000, 2000],
|
||||
"conv2d_59.tmp_0": [1, 96, 400, 400],
|
||||
"nearest_interp_v2_2.tmp_0": [1, 96, 400, 400],
|
||||
"nearest_interp_v2_3.tmp_0": [1, 24, 400, 400],
|
||||
"nearest_interp_v2_4.tmp_0": [1, 24, 400, 400],
|
||||
"nearest_interp_v2_5.tmp_0": [1, 24, 400, 400]
|
||||
}
|
||||
opt_input_shape = {
|
||||
"x": [1, 3, 640, 640],
|
||||
"conv2d_59.tmp_0": [1, 96, 160, 160],
|
||||
"nearest_interp_v2_2.tmp_0": [1, 96, 160, 160],
|
||||
"nearest_interp_v2_3.tmp_0": [1, 24, 160, 160],
|
||||
"nearest_interp_v2_4.tmp_0": [1, 24, 160, 160],
|
||||
"nearest_interp_v2_5.tmp_0": [1, 24, 160, 160]
|
||||
}
|
||||
elif mode == "rec":
|
||||
min_input_shape = {"x": [args.rec_batch_num, 3, 32, 10]}
|
||||
max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]}
|
||||
|
@ -328,11 +240,11 @@ def create_predictor(args, mode, logger):
|
|||
|
||||
# enable memory optim
|
||||
config.enable_memory_optim()
|
||||
config.disable_glog_info()
|
||||
#config.disable_glog_info()
|
||||
|
||||
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
|
||||
if mode == 'table':
|
||||
config.delete_pass("fc_fuse_pass") # not supported for table
|
||||
config.delete_pass("fc_fuse_pass") # not supported for table
|
||||
config.switch_use_feed_fetch_ops(False)
|
||||
config.switch_ir_optim(True)
|
||||
|
||||
|
@ -597,29 +509,39 @@ def draw_boxes(image, boxes, scores=None, drop_score=0.5):
|
|||
return image
|
||||
|
||||
|
||||
def get_current_memory_mb(gpu_id=None):
|
||||
"""
|
||||
It is used to Obtain the memory usage of the CPU and GPU during the running of the program.
|
||||
And this function Current program is time-consuming.
|
||||
"""
|
||||
import pynvml
|
||||
import psutil
|
||||
import GPUtil
|
||||
pid = os.getpid()
|
||||
p = psutil.Process(pid)
|
||||
info = p.memory_full_info()
|
||||
cpu_mem = info.uss / 1024. / 1024.
|
||||
gpu_mem = 0
|
||||
gpu_percent = 0
|
||||
if gpu_id is not None:
|
||||
GPUs = GPUtil.getGPUs()
|
||||
gpu_load = GPUs[gpu_id].load
|
||||
gpu_percent = gpu_load
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
||||
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
gpu_mem = meminfo.used / 1024. / 1024.
|
||||
return round(cpu_mem, 4), round(gpu_mem, 4), round(gpu_percent, 4)
|
||||
def get_rotate_crop_image(img, points):
|
||||
'''
|
||||
img_height, img_width = img.shape[0:2]
|
||||
left = int(np.min(points[:, 0]))
|
||||
right = int(np.max(points[:, 0]))
|
||||
top = int(np.min(points[:, 1]))
|
||||
bottom = int(np.max(points[:, 1]))
|
||||
img_crop = img[top:bottom, left:right, :].copy()
|
||||
points[:, 0] = points[:, 0] - left
|
||||
points[:, 1] = points[:, 1] - top
|
||||
'''
|
||||
assert len(points) == 4, "shape of points must be 4*2"
|
||||
img_crop_width = int(
|
||||
max(
|
||||
np.linalg.norm(points[0] - points[1]),
|
||||
np.linalg.norm(points[2] - points[3])))
|
||||
img_crop_height = int(
|
||||
max(
|
||||
np.linalg.norm(points[0] - points[3]),
|
||||
np.linalg.norm(points[1] - points[2])))
|
||||
pts_std = np.float32([[0, 0], [img_crop_width, 0],
|
||||
[img_crop_width, img_crop_height],
|
||||
[0, img_crop_height]])
|
||||
M = cv2.getPerspectiveTransform(points, pts_std)
|
||||
dst_img = cv2.warpPerspective(
|
||||
img,
|
||||
M, (img_crop_width, img_crop_height),
|
||||
borderMode=cv2.BORDER_REPLICATE,
|
||||
flags=cv2.INTER_CUBIC)
|
||||
dst_img_height, dst_img_width = dst_img.shape[0:2]
|
||||
if dst_img_height * 1.0 / dst_img_width >= 1.5:
|
||||
dst_img = np.rot90(dst_img)
|
||||
return dst_img
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -35,7 +35,7 @@ from ppocr.losses import build_loss
|
|||
from ppocr.optimizer import build_optimizer
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.metrics import build_metric
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import init_model, load_dygraph_params
|
||||
import tools.program as program
|
||||
|
||||
dist.get_world_size()
|
||||
|
@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
|
|||
# build metric
|
||||
eval_class = build_metric(config['Metric'])
|
||||
# load pretrain model
|
||||
pre_best_model_dict = init_model(config, model, optimizer)
|
||||
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
|
||||
|
||||
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
|
||||
if valid_dataloader is not None:
|
||||
|
|
Loading…
Reference in New Issue