update for enable static
This commit is contained in:
parent
e66b53ca3e
commit
aa3266b192
|
@ -24,6 +24,7 @@ sys.path.append(os.path.join(__dir__, '..', '..', '..'))
|
|||
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
||||
|
||||
import program
|
||||
import paddle
|
||||
from paddle import fluid
|
||||
from ppocr.utils.utility import initial_logger
|
||||
logger = initial_logger()
|
||||
|
@ -32,6 +33,12 @@ from paddleslim.prune import load_model
|
|||
|
||||
|
||||
def main():
|
||||
# Run code with static graph mode.
|
||||
try:
|
||||
paddle.enable_static()
|
||||
except:
|
||||
pass
|
||||
|
||||
startup_prog, eval_program, place, config, _ = program.preprocess()
|
||||
|
||||
feeded_var_names, target_vars, fetches_var_name = program.build_export(
|
||||
|
|
|
@ -50,7 +50,12 @@ skip_list = [
|
|||
|
||||
|
||||
def main():
|
||||
paddle.enable_static()
|
||||
# Run code with static graph mode.
|
||||
try:
|
||||
paddle.enable_static()
|
||||
except:
|
||||
pass
|
||||
|
||||
config = program.load_config(FLAGS.config)
|
||||
program.merge_config(FLAGS.opt)
|
||||
logger.info(config)
|
||||
|
|
|
@ -25,6 +25,7 @@ sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
|||
|
||||
import json
|
||||
import cv2
|
||||
import paddle
|
||||
from paddle import fluid
|
||||
import paddleslim as slim
|
||||
from copy import deepcopy
|
||||
|
@ -60,6 +61,12 @@ def eval_function(eval_args, mode='eval'):
|
|||
|
||||
|
||||
def main():
|
||||
# Run code with static graph mode.
|
||||
try:
|
||||
paddle.enable_static()
|
||||
except:
|
||||
pass
|
||||
|
||||
config = program.load_config(FLAGS.config)
|
||||
program.merge_config(FLAGS.opt)
|
||||
logger.info(config)
|
||||
|
|
|
@ -77,7 +77,12 @@ def main():
|
|||
# The decay coefficient of moving average, default is 0.9
|
||||
'moving_rate': 0.9,
|
||||
}
|
||||
paddle.enable_static()
|
||||
# Run code with static graph mode.
|
||||
try:
|
||||
paddle.enable_static()
|
||||
except:
|
||||
pass
|
||||
|
||||
startup_prog, eval_program, place, config, alg_type = program.preprocess()
|
||||
|
||||
feeded_var_names, target_vars, fetches_var_name = program.build_export(
|
||||
|
|
|
@ -85,7 +85,12 @@ def get_optimizer():
|
|||
|
||||
|
||||
def main():
|
||||
paddle.enable_static()
|
||||
# Run code with static graph mode.
|
||||
try:
|
||||
paddle.enable_static()
|
||||
except:
|
||||
pass
|
||||
|
||||
train_build_outputs = program.build(
|
||||
config, train_program, startup_program, mode='train')
|
||||
train_loader = train_build_outputs[0]
|
||||
|
|
Loading…
Reference in New Issue