update code
This commit is contained in:
parent
aa7e9ac34e
commit
29e2fed39c
|
@ -208,14 +208,14 @@ def build_export(config, main_prog, startup_prog):
|
||||||
with fluid.unique_name.guard():
|
with fluid.unique_name.guard():
|
||||||
func_infor = config['Architecture']['function']
|
func_infor = config['Architecture']['function']
|
||||||
model = create_module(func_infor)(params=config)
|
model = create_module(func_infor)(params=config)
|
||||||
loss_type = config['Global']['loss_type']
|
algorithm = config['Global']['algorithm']
|
||||||
if loss_type == "srn":
|
if algorithm == "SRN":
|
||||||
image, others, outputs = model(mode='export')
|
image, others, outputs = model(mode='export')
|
||||||
else:
|
else:
|
||||||
image, outputs = model(mode='export')
|
image, outputs = model(mode='export')
|
||||||
fetches_var_name = sorted([name for name in outputs.keys()])
|
fetches_var_name = sorted([name for name in outputs.keys()])
|
||||||
fetches_var = [outputs[name] for name in fetches_var_name]
|
fetches_var = [outputs[name] for name in fetches_var_name]
|
||||||
if loss_type == "srn":
|
if algorithm == "SRN":
|
||||||
others_var_names = sorted([name for name in others.keys()])
|
others_var_names = sorted([name for name in others.keys()])
|
||||||
feeded_var_names = [image.name] + others_var_names
|
feeded_var_names = [image.name] + others_var_names
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue