polish srn anno
This commit is contained in:
parent
234bb38c8a
commit
fa12cf0b6d
|
@ -28,6 +28,13 @@ gradient_clip = 10
|
||||||
|
|
||||||
|
|
||||||
class SRNPredict(object):
|
class SRNPredict(object):
|
||||||
|
"""
|
||||||
|
SRN:
|
||||||
|
see arxiv: https://arxiv.org/abs/2003.12294
|
||||||
|
args:
|
||||||
|
params(dict): the super parameters for network build
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, params):
|
def __init__(self, params):
|
||||||
super(SRNPredict, self).__init__()
|
super(SRNPredict, self).__init__()
|
||||||
self.char_num = params['char_num']
|
self.char_num = params['char_num']
|
||||||
|
@ -39,7 +46,15 @@ class SRNPredict(object):
|
||||||
self.hidden_dims = params['hidden_dims']
|
self.hidden_dims = params['hidden_dims']
|
||||||
|
|
||||||
def pvam(self, inputs, others):
|
def pvam(self, inputs, others):
|
||||||
|
"""
|
||||||
|
Parallel visual attention module model
|
||||||
|
|
||||||
|
args:
|
||||||
|
inputs(variable): Feature map extracted from backbone network
|
||||||
|
others(list): Other location information variables
|
||||||
|
|
||||||
|
return: pvam_features
|
||||||
|
"""
|
||||||
b, c, h, w = inputs.shape
|
b, c, h, w = inputs.shape
|
||||||
conv_features = fluid.layers.reshape(x=inputs, shape=[-1, c, h * w])
|
conv_features = fluid.layers.reshape(x=inputs, shape=[-1, c, h * w])
|
||||||
conv_features = fluid.layers.transpose(x=conv_features, perm=[0, 2, 1])
|
conv_features = fluid.layers.transpose(x=conv_features, perm=[0, 2, 1])
|
||||||
|
@ -98,6 +113,15 @@ class SRNPredict(object):
|
||||||
return pvam_features
|
return pvam_features
|
||||||
|
|
||||||
def gsrm(self, pvam_features, others):
|
def gsrm(self, pvam_features, others):
|
||||||
|
"""
|
||||||
|
Global Semantic Reasonging Module
|
||||||
|
|
||||||
|
args:
|
||||||
|
pvam_features(variable): Feature map extracted from pvam
|
||||||
|
others(list): Other location information variables
|
||||||
|
|
||||||
|
return: gsrm_features, word_out, gsrm_out
|
||||||
|
"""
|
||||||
|
|
||||||
#===== GSRM Visual-to-semantic embedding block =====
|
#===== GSRM Visual-to-semantic embedding block =====
|
||||||
b, t, c = pvam_features.shape
|
b, t, c = pvam_features.shape
|
||||||
|
@ -190,7 +214,15 @@ class SRNPredict(object):
|
||||||
return gsrm_features, word_out, gsrm_out
|
return gsrm_features, word_out, gsrm_out
|
||||||
|
|
||||||
def vsfd(self, pvam_features, gsrm_features):
|
def vsfd(self, pvam_features, gsrm_features):
|
||||||
|
"""
|
||||||
|
Visual-Semantic Fusion Decoder Module
|
||||||
|
|
||||||
|
args:
|
||||||
|
pvam_features(variable): Feature map extracted from pvam
|
||||||
|
gsrm_features(list): Feature map extracted from gsrm
|
||||||
|
|
||||||
|
return: fc_out
|
||||||
|
"""
|
||||||
#===== Visual-Semantic Fusion Decoder Module =====
|
#===== Visual-Semantic Fusion Decoder Module =====
|
||||||
b, t, c1 = pvam_features.shape
|
b, t, c1 = pvam_features.shape
|
||||||
b, t, c2 = gsrm_features.shape
|
b, t, c2 = gsrm_features.shape
|
||||||
|
|
Loading…
Reference in New Issue