mv model_average to incubate
This commit is contained in:
parent
93670ab5a2
commit
ed2f0de95e
|
@ -42,6 +42,6 @@ class SRNLoss(nn.Layer):
|
||||||
cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1])
|
cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1])
|
||||||
cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1])
|
cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1])
|
||||||
|
|
||||||
sum_cost = cost_word + cost_vsfd * 2.0 + cost_gsrm * 0.15
|
sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15
|
||||||
|
|
||||||
return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd}
|
return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd}
|
||||||
|
|
|
@ -182,12 +182,12 @@ class SRNLabelDecode(BaseRecLabelDecode):
|
||||||
|
|
||||||
preds_prob = np.reshape(preds_prob, [-1, 25])
|
preds_prob = np.reshape(preds_prob, [-1, 25])
|
||||||
|
|
||||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
text = self.decode(preds_idx, preds_prob)
|
||||||
|
|
||||||
if label is None:
|
if label is None:
|
||||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
return text
|
return text
|
||||||
label = self.decode(label, is_remove_duplicate=True)
|
label = self.decode(label)
|
||||||
return text, label
|
return text, label
|
||||||
|
|
||||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||||
|
|
|
@ -174,6 +174,7 @@ def train(config,
|
||||||
best_model_dict = {main_indicator: 0}
|
best_model_dict = {main_indicator: 0}
|
||||||
best_model_dict.update(pre_best_model_dict)
|
best_model_dict.update(pre_best_model_dict)
|
||||||
train_stats = TrainingStats(log_smooth_window, ['lr'])
|
train_stats = TrainingStats(log_smooth_window, ['lr'])
|
||||||
|
model_average = False
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
if 'start_epoch' in best_model_dict:
|
if 'start_epoch' in best_model_dict:
|
||||||
|
@ -197,6 +198,7 @@ def train(config,
|
||||||
if config['Architecture']['algorithm'] == "SRN":
|
if config['Architecture']['algorithm'] == "SRN":
|
||||||
others = batch[-4:]
|
others = batch[-4:]
|
||||||
preds = model(images, others)
|
preds = model(images, others)
|
||||||
|
model_average = True
|
||||||
else:
|
else:
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
loss = loss_class(preds, batch)
|
loss = loss_class(preds, batch)
|
||||||
|
@ -242,12 +244,13 @@ def train(config,
|
||||||
# eval
|
# eval
|
||||||
if global_step > start_eval_step and \
|
if global_step > start_eval_step and \
|
||||||
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
||||||
model_average = paddle.optimizer.ModelAverage(
|
if model_average:
|
||||||
0.15,
|
Model_Average = paddle.incubate.optimizer.ModelAverage(
|
||||||
parameters=model.parameters(),
|
0.15,
|
||||||
min_average_window=10000,
|
parameters=model.parameters(),
|
||||||
max_average_window=15625)
|
min_average_window=10000,
|
||||||
model_average.apply()
|
max_average_window=15625)
|
||||||
|
Model_Average.apply()
|
||||||
cur_metirc = eval(model, valid_dataloader, post_process_class,
|
cur_metirc = eval(model, valid_dataloader, post_process_class,
|
||||||
eval_class)
|
eval_class)
|
||||||
cur_metirc_str = 'cur metirc, {}'.format(', '.join(
|
cur_metirc_str = 'cur metirc, {}'.format(', '.join(
|
||||||
|
|
Loading…
Reference in New Issue