all ready
This commit is contained in:
parent
297871d4be
commit
93670ab5a2
|
@ -3,7 +3,7 @@ Global:
|
|||
epoch_num: 72
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 5
|
||||
save_model_dir: ./output/rec/srn
|
||||
save_model_dir: ./output/rec/srn_new
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [0, 5000]
|
||||
|
@ -25,8 +25,10 @@ Global:
|
|||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
clip_norm: 10.0
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.0001
|
||||
|
||||
Architecture:
|
||||
|
@ -58,7 +60,6 @@ Train:
|
|||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/srn_train_data_duiqi
|
||||
#label_file_list: ["./train_data/ic15_data/1.txt"]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
|
@ -77,7 +78,7 @@ Train:
|
|||
loader:
|
||||
shuffle: False
|
||||
batch_size_per_card: 64
|
||||
drop_last: True
|
||||
drop_last: False
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
|
|
|
@ -359,6 +359,7 @@ class PrepareDecoder(nn.Layer):
|
|||
self.emb0 = paddle.nn.Embedding(
|
||||
num_embeddings=src_vocab_size,
|
||||
embedding_dim=self.src_emb_dim,
|
||||
padding_idx=bos_idx,
|
||||
weight_attr=paddle.ParamAttr(
|
||||
name=word_emb_param_name,
|
||||
initializer=nn.initializer.Normal(0., src_emb_dim**-0.5)))
|
||||
|
|
|
@ -182,14 +182,15 @@ class SRNLabelDecode(BaseRecLabelDecode):
|
|||
|
||||
preds_prob = np.reshape(preds_prob, [-1, 25])
|
||||
|
||||
text = self.decode(preds_idx, preds_prob)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
||||
|
||||
if label is None:
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
return text
|
||||
label = self.decode(label, is_remove_duplicate=False)
|
||||
label = self.decode(label, is_remove_duplicate=True)
|
||||
return text, label
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=True):
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
|
|
|
@ -242,6 +242,12 @@ def train(config,
|
|||
# eval
|
||||
if global_step > start_eval_step and \
|
||||
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
||||
model_average = paddle.optimizer.ModelAverage(
|
||||
0.15,
|
||||
parameters=model.parameters(),
|
||||
min_average_window=10000,
|
||||
max_average_window=15625)
|
||||
model_average.apply()
|
||||
cur_metirc = eval(model, valid_dataloader, post_process_class,
|
||||
eval_class)
|
||||
cur_metirc_str = 'cur metirc, {}'.format(', '.join(
|
||||
|
@ -277,6 +283,7 @@ def train(config,
|
|||
best_model_dict[main_indicator],
|
||||
global_step)
|
||||
global_step += 1
|
||||
optimizer.clear_grad()
|
||||
batch_start = time.time()
|
||||
if dist.get_rank() == 0:
|
||||
save_model(
|
||||
|
|
Loading…
Reference in New Issue