diff --git a/ppocr/modeling/heads/rec_srn_all_head.py b/ppocr/modeling/heads/rec_srn_all_head.py index e1bb955d..c2a70115 100755 --- a/ppocr/modeling/heads/rec_srn_all_head.py +++ b/ppocr/modeling/heads/rec_srn_all_head.py @@ -28,6 +28,13 @@ gradient_clip = 10 class SRNPredict(object): + """ + SRN: + see arxiv: https://arxiv.org/abs/2003.12294 + args: + params(dict): the super parameters for network build + """ + def __init__(self, params): super(SRNPredict, self).__init__() self.char_num = params['char_num'] @@ -39,7 +46,15 @@ class SRNPredict(object): self.hidden_dims = params['hidden_dims'] def pvam(self, inputs, others): + """ + Parallel visual attention module model + args: + inputs(variable): Feature map extracted from backbone network + others(list): Other location information variables + + return: pvam_features + """ b, c, h, w = inputs.shape conv_features = fluid.layers.reshape(x=inputs, shape=[-1, c, h * w]) conv_features = fluid.layers.transpose(x=conv_features, perm=[0, 2, 1]) @@ -98,6 +113,15 @@ class SRNPredict(object): return pvam_features def gsrm(self, pvam_features, others): + """ + Global Semantic Reasonging Module + + args: + pvam_features(variable): Feature map extracted from pvam + others(list): Other location information variables + + return: gsrm_features, word_out, gsrm_out + """ #===== GSRM Visual-to-semantic embedding block ===== b, t, c = pvam_features.shape @@ -190,7 +214,15 @@ class SRNPredict(object): return gsrm_features, word_out, gsrm_out def vsfd(self, pvam_features, gsrm_features): + """ + Visual-Semantic Fusion Decoder Module + args: + pvam_features(variable): Feature map extracted from pvam + gsrm_features(list): Feature map extracted from gsrm + + return: fc_out + """ #===== Visual-Semantic Fusion Decoder Module ===== b, t, c1 = pvam_features.shape b, t, c2 = gsrm_features.shape