update code

This commit is contained in:
tink2123 2020-09-03 18:59:44 +08:00
parent aa7e9ac34e
commit 29e2fed39c
1 changed files with 3 additions and 3 deletions

View File

@ -208,14 +208,14 @@ def build_export(config, main_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
func_infor = config['Architecture']['function'] func_infor = config['Architecture']['function']
model = create_module(func_infor)(params=config) model = create_module(func_infor)(params=config)
loss_type = config['Global']['loss_type'] algorithm = config['Global']['algorithm']
if loss_type == "srn": if algorithm == "SRN":
image, others, outputs = model(mode='export') image, others, outputs = model(mode='export')
else: else:
image, outputs = model(mode='export') image, outputs = model(mode='export')
fetches_var_name = sorted([name for name in outputs.keys()]) fetches_var_name = sorted([name for name in outputs.keys()])
fetches_var = [outputs[name] for name in fetches_var_name] fetches_var = [outputs[name] for name in fetches_var_name]
if loss_type == "srn": if algorithm == "SRN":
others_var_names = sorted([name for name in others.keys()]) others_var_names = sorted([name for name in others.keys()])
feeded_var_names = [image.name] + others_var_names feeded_var_names = [image.name] + others_var_names
else: else: