update code
This commit is contained in:
parent
aa7e9ac34e
commit
29e2fed39c
|
@ -208,14 +208,14 @@ def build_export(config, main_prog, startup_prog):
|
|||
with fluid.unique_name.guard():
|
||||
func_infor = config['Architecture']['function']
|
||||
model = create_module(func_infor)(params=config)
|
||||
loss_type = config['Global']['loss_type']
|
||||
if loss_type == "srn":
|
||||
algorithm = config['Global']['algorithm']
|
||||
if algorithm == "SRN":
|
||||
image, others, outputs = model(mode='export')
|
||||
else:
|
||||
image, outputs = model(mode='export')
|
||||
fetches_var_name = sorted([name for name in outputs.keys()])
|
||||
fetches_var = [outputs[name] for name in fetches_var_name]
|
||||
if loss_type == "srn":
|
||||
if algorithm == "SRN":
|
||||
others_var_names = sorted([name for name in others.keys()])
|
||||
feeded_var_names = [image.name] + others_var_names
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue