add embedding and fix eval when distillation (#3112)
This commit is contained in:
parent
56ab75b0ea
commit
0938c44a0a
|
@ -52,9 +52,10 @@ Architecture:
|
|||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 48
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00001
|
||||
Teacher:
|
||||
pretrained:
|
||||
|
@ -71,9 +72,10 @@ Architecture:
|
|||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 48
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00001
|
||||
|
||||
|
||||
|
|
|
@ -33,19 +33,47 @@ def get_para_bias_attr(l2_decay, k):
|
|||
|
||||
|
||||
class CTCHead(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
fc_decay=0.0004,
|
||||
mid_channels=None,
|
||||
**kwargs):
|
||||
super(CTCHead, self).__init__()
|
||||
weight_attr, bias_attr = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=in_channels)
|
||||
self.fc = nn.Linear(
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=bias_attr)
|
||||
if mid_channels is None:
|
||||
weight_attr, bias_attr = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=in_channels)
|
||||
self.fc = nn.Linear(
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=bias_attr)
|
||||
else:
|
||||
weight_attr1, bias_attr1 = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=in_channels)
|
||||
self.fc1 = nn.Linear(
|
||||
in_channels,
|
||||
mid_channels,
|
||||
weight_attr=weight_attr1,
|
||||
bias_attr=bias_attr1)
|
||||
|
||||
weight_attr2, bias_attr2 = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=mid_channels)
|
||||
self.fc2 = nn.Linear(
|
||||
mid_channels,
|
||||
out_channels,
|
||||
weight_attr=weight_attr2,
|
||||
bias_attr=bias_attr2)
|
||||
self.out_channels = out_channels
|
||||
self.mid_channels = mid_channels
|
||||
|
||||
def forward(self, x, labels=None):
|
||||
predicts = self.fc(x)
|
||||
if self.mid_channels is None:
|
||||
predicts = self.fc(x)
|
||||
else:
|
||||
predicts = self.fc1(x)
|
||||
predicts = self.fc2(predicts)
|
||||
|
||||
if not self.training:
|
||||
predicts = F.softmax(predicts, axis=2)
|
||||
return predicts
|
||||
|
|
|
@ -44,8 +44,15 @@ def main():
|
|||
# build model
|
||||
# for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
config['Architecture']["Head"]['out_channels'] = len(
|
||||
getattr(post_process_class, 'character'))
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
if config['Architecture']["algorithm"] in ["Distillation",
|
||||
]: # distillation model
|
||||
for key in config['Architecture']["Models"]:
|
||||
config['Architecture']["Models"][key]["Head"][
|
||||
'out_channels'] = char_num
|
||||
else: # base rec model
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
|
||||
|
|
Loading…
Reference in New Issue