polish srn anno
This commit is contained in:
parent
234bb38c8a
commit
fa12cf0b6d
|
@ -28,6 +28,13 @@ gradient_clip = 10
|
|||
|
||||
|
||||
class SRNPredict(object):
|
||||
"""
|
||||
SRN:
|
||||
see arxiv: https://arxiv.org/abs/2003.12294
|
||||
args:
|
||||
params(dict): the super parameters for network build
|
||||
"""
|
||||
|
||||
def __init__(self, params):
|
||||
super(SRNPredict, self).__init__()
|
||||
self.char_num = params['char_num']
|
||||
|
@ -39,7 +46,15 @@ class SRNPredict(object):
|
|||
self.hidden_dims = params['hidden_dims']
|
||||
|
||||
def pvam(self, inputs, others):
|
||||
"""
|
||||
Parallel visual attention module model
|
||||
|
||||
args:
|
||||
inputs(variable): Feature map extracted from backbone network
|
||||
others(list): Other location information variables
|
||||
|
||||
return: pvam_features
|
||||
"""
|
||||
b, c, h, w = inputs.shape
|
||||
conv_features = fluid.layers.reshape(x=inputs, shape=[-1, c, h * w])
|
||||
conv_features = fluid.layers.transpose(x=conv_features, perm=[0, 2, 1])
|
||||
|
@ -98,6 +113,15 @@ class SRNPredict(object):
|
|||
return pvam_features
|
||||
|
||||
def gsrm(self, pvam_features, others):
|
||||
"""
|
||||
Global Semantic Reasonging Module
|
||||
|
||||
args:
|
||||
pvam_features(variable): Feature map extracted from pvam
|
||||
others(list): Other location information variables
|
||||
|
||||
return: gsrm_features, word_out, gsrm_out
|
||||
"""
|
||||
|
||||
#===== GSRM Visual-to-semantic embedding block =====
|
||||
b, t, c = pvam_features.shape
|
||||
|
@ -190,7 +214,15 @@ class SRNPredict(object):
|
|||
return gsrm_features, word_out, gsrm_out
|
||||
|
||||
def vsfd(self, pvam_features, gsrm_features):
|
||||
"""
|
||||
Visual-Semantic Fusion Decoder Module
|
||||
|
||||
args:
|
||||
pvam_features(variable): Feature map extracted from pvam
|
||||
gsrm_features(list): Feature map extracted from gsrm
|
||||
|
||||
return: fc_out
|
||||
"""
|
||||
#===== Visual-Semantic Fusion Decoder Module =====
|
||||
b, t, c1 = pvam_features.shape
|
||||
b, t, c2 = gsrm_features.shape
|
||||
|
|
Loading…
Reference in New Issue