fix distillation model export and pred save (#3869)
This commit is contained in:
parent
073a591cee
commit
03b7daa5be
|
@ -93,6 +93,9 @@ def main():
|
||||||
for key in config["Architecture"]["Models"]:
|
for key in config["Architecture"]["Models"]:
|
||||||
config["Architecture"]["Models"][key]["Head"][
|
config["Architecture"]["Models"][key]["Head"][
|
||||||
"out_channels"] = char_num
|
"out_channels"] = char_num
|
||||||
|
# just one final tensor needs to to exported for inference
|
||||||
|
config["Architecture"]["Models"][key][
|
||||||
|
"return_all_feats"] = False
|
||||||
else: # base rec model
|
else: # base rec model
|
||||||
config["Architecture"]["Head"]["out_channels"] = char_num
|
config["Architecture"]["Head"]["out_channels"] = char_num
|
||||||
model = build_model(config["Architecture"])
|
model = build_model(config["Architecture"])
|
||||||
|
|
|
@ -121,7 +121,7 @@ def main():
|
||||||
if len(post_result[key][0]) >= 2:
|
if len(post_result[key][0]) >= 2:
|
||||||
rec_info[key] = {
|
rec_info[key] = {
|
||||||
"label": post_result[key][0][0],
|
"label": post_result[key][0][0],
|
||||||
"score": post_result[key][0][1],
|
"score": float(post_result[key][0][1]),
|
||||||
}
|
}
|
||||||
info = json.dumps(rec_info)
|
info = json.dumps(rec_info)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue