diff --git a/README.md b/README.md index 4814fbb9..001251e4 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,11 @@ English | [简体中文](README_cn.md) PaddleOCR aims to create rich, leading, and practical OCR tools that help users train better models and apply them into practice. **Recent updates** +- 2020.8.16 Release text detection algorithm [SAST](https://arxiv.org/abs/1908.05498) and text recognition algorithm [SRN](https://arxiv.org/abs/2003.12294) - 2020.7.23, Release the playback and PPT of live class on BiliBili station, PaddleOCR Introduction, [address](https://aistudio.baidu.com/aistudio/course/introduce/1519) - 2020.7.15, Add mobile App demo , support both iOS and Android ( based on easyedge and Paddle Lite) - 2020.7.15, Improve the deployment ability, add the C + + inference , serving deployment. In addtion, the benchmarks of the ultra-lightweight OCR model are provided. - 2020.7.15, Add several related datasets, data annotation and synthesis tools. -- 2020.7.9 Add a new model to support recognize the character "space". -- 2020.7.9 Add the data augument and learning rate decay strategies during training. - [more](./doc/doc_en/update_en.md) ## Features @@ -91,7 +90,7 @@ Mobile DEMO experience (based on EasyEdge and Paddle-Lite, supports iOS and Andr PaddleOCR open source text detection algorithms list: - [x] EAST([paper](https://arxiv.org/abs/1704.03155)) - [x] DB([paper](https://arxiv.org/abs/1911.08947)) -- [ ] SAST([paper](https://arxiv.org/abs/1908.05498))(Baidu Self-Research, comming soon) +- [x] SAST([paper](https://arxiv.org/abs/1908.05498))(Baidu Self-Research) On the ICDAR2015 dataset, the text detection result is as follows: @@ -101,6 +100,7 @@ On the ICDAR2015 dataset, the text detection result is as follows: |EAST|MobileNetV3|81.67%|79.83%|80.74%|[Download link](https://paddleocr.bj.bcebos.com/det_mv3_east.tar)| |DB|ResNet50_vd|83.79%|80.65%|82.19%|[Download link](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)| |DB|MobileNetV3|75.92%|73.18%|74.53%|[Download link](https://paddleocr.bj.bcebos.com/det_mv3_db.tar)| +|SAST|ResNet50_vd|92.18%|82.96%|87.33%|[Download link](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_icdar2015.tar)| For use of [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/datasets_en.md#1-icdar2019-lsvt) street view dataset with a total of 3w training data,the related configuration and pre-trained models for text detection task are as follows: |Model|Backbone|Configuration file|Pre-trained model| @@ -120,7 +120,7 @@ PaddleOCR open-source text recognition algorithms list: - [x] Rosetta([paper](https://arxiv.org/abs/1910.05085)) - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html)) - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1)) -- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))(Baidu Self-Research, comming soon) +- [x] SRN([paper](https://arxiv.org/abs/2003.12294))(Baidu Self-Research) Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: @@ -134,8 +134,14 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_ctc.tar)| |RARE|Resnet34_vd|84.90%|rec_r34_vd_tps_bilstm_attn|[Download link](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_attn.tar)| |RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|[Download link](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar)| +|SRN|Resnet50_vd_fpn|88.33%|rec_r50fpn_vd_none_srn|[Download link](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)| + +**Note:** SRN model uses data expansion method to expand the two training sets mentioned above, and the expanded data can be downloaded from [Baidu Drive](todo). + +The average accuracy of the two-stage training in the original paper is 89.74%, and that of one stage training in paddleocr is 88.33%. Both pre-trained weights can be downloaded [here](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar). We use [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/datasets_en.md#1-icdar2019-lsvt) dataset and cropout 30w traning data from original photos by using position groundtruth and make some calibration needed. In addition, based on the LSVT corpus, 500w synthetic data is generated to train the model. The related configuration and pre-trained models are as follows: + |Model|Backbone|Configuration file|Pre-trained model| |-|-|-|-| |ultra-lightweight OCR model|MobileNetV3|rec_chinese_lite_train.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar)|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar) & [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance.tar)| diff --git a/README_cn.md b/README_cn.md index ebfc4b1d..45627083 100644 --- a/README_cn.md +++ b/README_cn.md @@ -4,12 +4,11 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力使用者训练出更好的模型,并应用落地。 **近期更新** +- 2020.8.16 开源文本检测算法[SAST](https://arxiv.org/abs/1908.05498)和文本识别算法[SRN](https://arxiv.org/abs/2003.12294) - 2020.7.23 发布7月21日B站直播课回放和PPT,PaddleOCR开源大礼包全面解读,[获取地址](https://aistudio.baidu.com/aistudio/course/introduce/1519) - 2020.7.15 添加基于EasyEdge和Paddle-Lite的移动端DEMO,支持iOS和Android系统 - 2020.7.15 完善预测部署,添加基于C++预测引擎推理、服务化部署和端侧部署方案,以及超轻量级中文OCR模型预测耗时Benchmark - 2020.7.15 整理OCR相关数据集、常用数据标注以及合成工具 -- 2020.7.9 添加支持空格的识别模型,识别效果,预测及训练方式请参考快速开始和文本识别训练相关文档 -- 2020.7.9 添加数据增强、学习率衰减策略,具体参考[配置文件](./doc/doc_ch/config.md) - [more](./doc/doc_ch/update.md) @@ -93,7 +92,7 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力 PaddleOCR开源的文本检测算法列表: - [x] EAST([paper](https://arxiv.org/abs/1704.03155)) - [x] DB([paper](https://arxiv.org/abs/1911.08947)) -- [ ] SAST([paper](https://arxiv.org/abs/1908.05498))(百度自研, coming soon) +- [x] SAST([paper](https://arxiv.org/abs/1908.05498))(百度自研) 在ICDAR2015文本检测公开数据集上,算法效果如下: @@ -103,8 +102,10 @@ PaddleOCR开源的文本检测算法列表: |EAST|MobileNetV3|81.67%|79.83%|80.74%|[下载链接](https://paddleocr.bj.bcebos.com/det_mv3_east.tar)| |DB|ResNet50_vd|83.79%|80.65%|82.19%|[下载链接](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)| |DB|MobileNetV3|75.92%|73.18%|74.53%|[下载链接](https://paddleocr.bj.bcebos.com/det_mv3_db.tar)| +|SAST|ResNet50_vd|92.18%|82.96%|87.33%|[下载链接](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_icdar2015.tar)| 使用[LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/datasets.md#1icdar2019-lsvt)街景数据集共3w张数据,训练中文检测模型的相关配置和预训练文件如下: + |模型|骨干网络|配置文件|预训练模型| |-|-|-|-| |超轻量中文模型|MobileNetV3|det_mv3_db.yml|[下载链接](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar)| @@ -124,9 +125,6 @@ PaddleOCR开源的文本识别算法列表: - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1)) - [x] SRN([paper](https://arxiv.org/abs/2003.12294))(百度自研) -*备注:* SRN模型使用了数据扰动方法对上述提到对两个训练集进行增广,增广后的数据可以在[百度网盘](todo)上下载。 -原始论文使用两阶段训练平均精度为89.74%,PaddleOCR中使用one-stage训练,平均精度为88.33%。两种预训练权重均在[下载链接](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)中。 - 参考[DTRB](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: |模型|骨干网络|Avg Accuracy|模型存储命名|下载链接| @@ -141,6 +139,9 @@ PaddleOCR开源的文本识别算法列表: |RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|[下载链接](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar)| |SRN|Resnet50_vd_fpn|88.33%|rec_r50fpn_vd_none_srn|[下载链接](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)| +**说明:** SRN模型使用了数据扰动方法对上述提到对两个训练集进行增广,增广后的数据可以在[百度网盘](todo)上下载。 +原始论文使用两阶段训练平均精度为89.74%,PaddleOCR中使用one-stage训练,平均精度为88.33%。两种预训练权重均在[下载链接](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)中。 + 使用[LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/datasets.md#1icdar2019-lsvt)街景数据集根据真值将图crop出来30w数据,进行位置校准。此外基于LSVT语料生成500w合成数据训练中文模型,相关配置和预训练文件如下: |模型|骨干网络|配置文件|预训练模型| diff --git a/doc/doc_ch/update.md b/doc/doc_ch/update.md index e22d05b4..1cd77885 100644 --- a/doc/doc_ch/update.md +++ b/doc/doc_ch/update.md @@ -1,4 +1,5 @@ # 更新 +- 2020.8.16 开源文本检测算法[SAST](https://arxiv.org/abs/1908.05498)和文本识别算法[SRN](https://arxiv.org/abs/2003.12294) - 2020.7.23 发布7月21日B站直播课回放和PPT,PaddleOCR开源大礼包全面解读,[获取地址](https://aistudio.baidu.com/aistudio/course/introduce/1519) - 2020.7.15 添加基于EasyEdge和Paddle-Lite的移动端DEMO,支持iOS和Android系统 - 2020.7.15 完善预测部署,添加基于C++预测引擎推理、服务化部署和端侧部署方案,以及超轻量级中文OCR模型预测耗时Benchmark diff --git a/doc/doc_en/update_en.md b/doc/doc_en/update_en.md index ef02d9db..dc839d89 100644 --- a/doc/doc_en/update_en.md +++ b/doc/doc_en/update_en.md @@ -1,4 +1,5 @@ # RECENT UPDATES +- 2020.8.16 Release text detection algorithm [SAST](https://arxiv.org/abs/1908.05498) and text recognition algorithm [SRN](https://arxiv.org/abs/2003.12294) - 2020.7.23, Release the playback and PPT of live class on BiliBili station, PaddleOCR Introduction, [address](https://aistudio.baidu.com/aistudio/course/introduce/1519) - 2020.7.15, Add mobile App demo , support both iOS and Android ( based on easyedge and Paddle Lite) - 2020.7.15, Improve the deployment ability, add the C + + inference , serving deployment. In addtion, the benchmarks of the ultra-lightweight Chinese OCR model are provided. diff --git a/ppocr/data/det/dataset_traversal.py b/ppocr/data/det/dataset_traversal.py index 9ab19175..bd055c82 100644 --- a/ppocr/data/det/dataset_traversal.py +++ b/ppocr/data/det/dataset_traversal.py @@ -73,7 +73,7 @@ class TrainReader(object): data_size_list.append(len(image_files)) fetch_record_list.append(0) - image_batch, poly_batch = [], [] + image_batch = [] # get a batch of img_fns and poly_fns for i in range(0, len(batch_size_list)): bs = batch_size_list[i] diff --git a/ppocr/data/det/sast_process.py b/ppocr/data/det/sast_process.py index 1ce1dc0b..74a84846 100644 --- a/ppocr/data/det/sast_process.py +++ b/ppocr/data/det/sast_process.py @@ -593,38 +593,6 @@ class SASTProcessTrain(object): return np.array(quad_list) - def rotate_im_poly(self, im, text_polys): - """ - rotate image with 90 / 180 / 270 degre - """ - im_w, im_h = im.shape[1], im.shape[0] - dst_im = im.copy() - dst_polys = [] - rand_degree_ratio = np.random.rand() - rand_degree_cnt = 1 - #if rand_degree_ratio > 0.333 and rand_degree_ratio < 0.666: - # rand_degree_cnt = 2 - #elif rand_degree_ratio > 0.666: - if rand_degree_ratio > 0.5: - rand_degree_cnt = 3 - for i in range(rand_degree_cnt): - dst_im = np.rot90(dst_im) - rot_degree = -90 * rand_degree_cnt - rot_angle = rot_degree * math.pi / 180.0 - n_poly = text_polys.shape[0] - cx, cy = 0.5 * im_w, 0.5 * im_h - ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0] - for i in range(n_poly): - wordBB = text_polys[i] - poly = [] - for j in range(4):#16->4 - sx, sy = wordBB[j][0], wordBB[j][1] - dx = math.cos(rot_angle) * (sx - cx) - math.sin(rot_angle) * (sy - cy) + ncx - dy = math.sin(rot_angle) * (sx - cx) + math.cos(rot_angle) * (sy - cy) + ncy - poly.append([dx, dy]) - dst_polys.append(poly) - return dst_im, np.array(dst_polys, dtype=np.float32) - def extract_polys(self, poly_txt_path): """ Read text_polys, txt_tags, txts from give txt file. @@ -653,39 +621,13 @@ class SASTProcessTrain(object): return None if text_polys.shape[0] == 0: return None - # #add rotate cases - # if np.random.rand() < 0.5: - # im, text_polys = self.rotate_im_poly(im, text_polys) + h, w, _ = im.shape - # text_polys, text_tags = self.check_and_validate_polys(text_polys, - # text_tags, h, w) text_polys, text_tags, hv_tags = self.check_and_validate_polys(text_polys, text_tags, (h, w)) if text_polys.shape[0] == 0: return None - # # random scale this image - # rd_scale = np.random.choice(self.random_scale) - # im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) - # text_polys *= rd_scale - # if np.random.rand() < self.background_ratio: - # outs = self.crop_background_infor(im, text_polys, text_tags, - # text_strs) - # else: - # outs = self.crop_foreground_infor(im, text_polys, text_tags, - # text_strs) - - # if outs is None: - # return None - # im, score_map, geo_map, training_mask = outs - # score_map = score_map[np.newaxis, ::4, ::4].astype(np.float32) - # geo_map = np.swapaxes(geo_map, 1, 2) - # geo_map = np.swapaxes(geo_map, 1, 0) - # geo_map = geo_map[:, ::4, ::4].astype(np.float32) - # training_mask = training_mask[np.newaxis, ::4, ::4] - # training_mask = training_mask.astype(np.float32) - # return im, score_map, geo_map, training_mask - #set aspect ratio and keep area fix asp_scales = np.arange(1.0, 1.55, 0.1) asp_scale = np.random.choice(asp_scales) @@ -781,28 +723,9 @@ class SASTProcessTrain(object): im_padded[:, :, 0] /= (255.0 * 0.225) im_padded = im_padded.transpose((2, 0, 1)) - # images.append(im_padded[::-1, :, :]) - # tcl_maps.append(score_map[np.newaxis, :, :]) - # border_maps.append(border_map.transpose((2, 0, 1))) - # training_masks.append(training_mask[np.newaxis, :, :]) - # tvos.append(tvo_map.transpose((2, 0, 1))) - # tcos.append(tco_map.transpose((2, 0, 1))) - - # # After a batch should begin - # if len(images) == batch_size: - # yield np.array(images, dtype=np.float32), \ - # np.array(tcl_maps, dtype=np.float32), \ - # np.array(tvos, dtype=np.float32), \ - # np.array(tcos, dtype=np.float32), \ - # np.array(border_maps, dtype=np.float32), \ - # np.array(training_masks, dtype=np.float32), \ - - # images, tcl_maps, border_maps, training_masks = [], [], [], [] - # tvos, tcos = [], [] - - # return im_padded, score_map, border_map, training_mask, tvo_map, tco_map return im_padded[::-1, :, :], score_map[np.newaxis, :, :], border_map.transpose((2, 0, 1)), training_mask[np.newaxis, :, :], tvo_map.transpose((2, 0, 1)), tco_map.transpose((2, 0, 1)) + class SASTProcessTest(object): """ SAST process function for test @@ -814,46 +737,6 @@ class SASTProcessTest(object): else: self.max_side_len = 2400 - # def resize_image(self, im): - # """ - # resize image to a size multiple of 32 which is required by the network - # :param im: the resized image - # :param max_side_len: limit of max image size to avoid out of memory in gpu - # :return: the resized image and the resize ratio - # """ - # max_side_len = self.max_side_len - # h, w, _ = im.shape - - # resize_w = w - # resize_h = h - - # # limit the max side - # if max(resize_h, resize_w) > max_side_len: - # if resize_h > resize_w: - # ratio = float(max_side_len) / resize_h - # else: - # ratio = float(max_side_len) / resize_w - # else: - # ratio = 1. - # resize_h = int(resize_h * ratio) - # resize_w = int(resize_w * ratio) - # if resize_h % 32 == 0: - # resize_h = resize_h - # elif resize_h // 32 <= 1: - # resize_h = 32 - # else: - # resize_h = (resize_h // 32 - 1) * 32 - # if resize_w % 32 == 0: - # resize_w = resize_w - # elif resize_w // 32 <= 1: - # resize_w = 32 - # else: - # resize_w = (resize_w // 32 - 1) * 32 - # im = cv2.resize(im, (int(resize_w), int(resize_h))) - # ratio_h = resize_h / float(h) - # ratio_w = resize_w / float(w) - # return im, (ratio_h, ratio_w) - def resize_image(self, im): """ resize image to a size multiple of max_stride which is required by the network diff --git a/ppocr/modeling/architectures/det_model.py b/ppocr/modeling/architectures/det_model.py index d2b82a55..54d3a479 100644 --- a/ppocr/modeling/architectures/det_model.py +++ b/ppocr/modeling/architectures/det_model.py @@ -105,7 +105,6 @@ class DetModel(object): input_mask = fluid.layers.data( name='mask', shape=[1, 128, 128], dtype='float32') input_tvo = fluid.layers.data( - # name='tvo', shape=[5, 128, 128], dtype='float32') name='tvo', shape=[9, 128, 128], dtype='float32') input_tco = fluid.layers.data( name='tco', shape=[3, 128, 128], dtype='float32') diff --git a/ppocr/modeling/heads/det_sast_head.py b/ppocr/modeling/heads/det_sast_head.py index d198d71a..b5e19b84 100644 --- a/ppocr/modeling/heads/det_sast_head.py +++ b/ppocr/modeling/heads/det_sast_head.py @@ -24,7 +24,7 @@ from collections import OrderedDict class SASTHead(object): """ SAST: - see arxiv: https:// + see arxiv: https://arxiv.org/abs/1908.05498 args: params(dict): the super parameters for network build """ @@ -89,7 +89,7 @@ class SASTHead(object): g[i] = fluid.layers.relu(g[i]) g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i], filter_size=3, stride=1, act='relu', name='fpn_down_g%d_1'%i) g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i+1], filter_size=3, stride=2, act=None, name='fpn_down_g%d_2'%i) - print("g[{}] shape: {}".format(i, g[i].shape)) + # print("g[{}] shape: {}".format(i, g[i].shape)) g[2] = fluid.layers.elementwise_add(x=g[1], y=h[2]) g[2] = fluid.layers.relu(g[2]) g[2] = conv_bn_layer(input=g[2], num_filters=num_outputs[2], @@ -106,14 +106,14 @@ class SASTHead(object): f_score = conv_bn_layer(input=f_score, num_filters=128, filter_size=1, stride=1, act='relu', name='f_score3') f_score = conv_bn_layer(input=f_score, num_filters=1, filter_size=3, stride=1, name='f_score4') f_score = fluid.layers.sigmoid(f_score) - print("f_score shape: {}".format(f_score.shape)) + # print("f_score shape: {}".format(f_score.shape)) #f_boder f_border = conv_bn_layer(input=f_common, num_filters=64, filter_size=1, stride=1, act='relu', name='f_border1') f_border = conv_bn_layer(input=f_border, num_filters=64, filter_size=3, stride=1, act='relu', name='f_border2') f_border = conv_bn_layer(input=f_border, num_filters=128, filter_size=1, stride=1, act='relu', name='f_border3') f_border = conv_bn_layer(input=f_border, num_filters=4, filter_size=3, stride=1, name='f_border4') - print("f_border shape: {}".format(f_border.shape)) + # print("f_border shape: {}".format(f_border.shape)) return f_score, f_border @@ -124,14 +124,14 @@ class SASTHead(object): f_tvo = conv_bn_layer(input=f_tvo, num_filters=64, filter_size=3, stride=1, act='relu', name='f_tvo2') f_tvo = conv_bn_layer(input=f_tvo, num_filters=128, filter_size=1, stride=1, act='relu', name='f_tvo3') f_tvo = conv_bn_layer(input=f_tvo, num_filters=8, filter_size=3, stride=1, name='f_tvo4') - print("f_tvo shape: {}".format(f_tvo.shape)) + # print("f_tvo shape: {}".format(f_tvo.shape)) #f_tco f_tco = conv_bn_layer(input=f_common, num_filters=64, filter_size=1, stride=1, act='relu', name='f_tco1') f_tco = conv_bn_layer(input=f_tco, num_filters=64, filter_size=3, stride=1, act='relu', name='f_tco2') f_tco = conv_bn_layer(input=f_tco, num_filters=128, filter_size=1, stride=1, act='relu', name='f_tco3') f_tco = conv_bn_layer(input=f_tco, num_filters=2, filter_size=3, stride=1, name='f_tco4') - print("f_tco shape: {}".format(f_tco.shape)) + # print("f_tco shape: {}".format(f_tco.shape)) return f_tvo, f_tco @@ -161,7 +161,7 @@ class SASTHead(object): #weighted sum fh_weight = fluid.layers.matmul(fh_attn, fh_g) fh_weight = fluid.layers.reshape(fh_weight, [f_shape[0], f_shape[2], f_shape[3], 128]) - print("fh_weight: {}".format(fh_weight.shape)) + # print("fh_weight: {}".format(fh_weight.shape)) fh_weight = fluid.layers.transpose(fh_weight, [0, 3, 1, 2]) fh_weight = conv_bn_layer(input=fh_weight, num_filters=128, filter_size=1, stride=1, name='fh_weight') #short cut @@ -187,7 +187,7 @@ class SASTHead(object): #weighted sum fv_weight = fluid.layers.matmul(fv_attn, fv_g) fv_weight = fluid.layers.reshape(fv_weight, [f_shape[0], f_shape[3], f_shape[2], 128]) - print("fv_weight: {}".format(fv_weight.shape)) + # print("fv_weight: {}".format(fv_weight.shape)) fv_weight = fluid.layers.transpose(fv_weight, [0, 3, 2, 1]) fv_weight = conv_bn_layer(input=fv_weight, num_filters=128, filter_size=1, stride=1, name='fv_weight') #short cut @@ -199,22 +199,22 @@ class SASTHead(object): return f_attn def __call__(self, blocks, with_cab=False): - for k, v in blocks.items(): - print(k, v.shape) + # for k, v in blocks.items(): + # print(k, v.shape) #down fpn f_down = self.FPN_Down_Fusion(blocks) - print("f_down shape: {}".format(f_down.shape)) + # print("f_down shape: {}".format(f_down.shape)) #up fpn f_up = self.FPN_Up_Fusion(blocks) - print("f_up shape: {}".format(f_up.shape)) + # print("f_up shape: {}".format(f_up.shape)) #fusion f_common = fluid.layers.elementwise_add(x=f_down, y=f_up) f_common = fluid.layers.relu(f_common) - print("f_common: {}".format(f_common.shape)) + # print("f_common: {}".format(f_common.shape)) if self.with_cab: - print('enhence f_common with CAB.') + # print('enhence f_common with CAB.') f_common = self.cross_attention(f_common) f_score, f_border= self.SAST_Header1(f_common)