This commit is contained in:
Jethong 2021-04-22 20:32:40 +08:00
parent 2365c70e4a
commit 9a8f5d081b
3 changed files with 11 additions and 11 deletions

View File

@ -72,13 +72,13 @@ Train:
dataset:
name: PGDataSet
data_dir: ./train_data/total_text/train
label_file_list: [./train_data/total_text/train/]
label_file_list: [./train_data/total_text/train/total_text.txt]
ratio_list: [1.0]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- E2ELabelEncode_train:
- E2ELabelEncodeTrain:
- PGProcessTrain:
batch_size: 14 # same as loader: batch_size_per_card
min_crop_size: 24
@ -96,12 +96,12 @@ Eval:
dataset:
name: PGDataSet
data_dir: ./train_data/total_text/test
label_file_list: [./train_data/total_text/test/]
label_file_list: [./train_data/total_text/test/total_text.txt]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- E2ELabelEncode_test:
- E2ELabelEncodeTest:
- E2EResizeForTest:
max_side_len: 768
- NormalizeImage:
@ -111,7 +111,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'image', 'shape', 'polys', 'texts', 'tags', 'img_id']
keep_keys: [ 'image', 'shape', 'polys', 'texts', 'ignore_tags', 'img_id']
loader:
shuffle: False
drop_last: False

View File

@ -187,14 +187,14 @@ class CTCLabelEncode(BaseRecLabelEncode):
return dict_character
class E2ELabelEncode_test(BaseRecLabelEncode):
class E2ELabelEncodeTest(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='EN',
use_space_char=False,
**kwargs):
super(E2ELabelEncode_test,
super(E2ELabelEncodeTest,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
@ -217,7 +217,7 @@ class E2ELabelEncode_test(BaseRecLabelEncode):
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool)
data['polys'] = boxes
data['tags'] = txt_tags
data['ignore_tags'] = txt_tags
temp_texts = []
for text in txts:
text = text.lower()
@ -231,7 +231,7 @@ class E2ELabelEncode_test(BaseRecLabelEncode):
return data
class E2ELabelEncode_train(object):
class E2ELabelEncodeTrain(object):
def __init__(self, **kwargs):
pass
@ -255,7 +255,7 @@ class E2ELabelEncode_train(object):
data['polys'] = boxes
data['texts'] = txts
data['tags'] = txt_tags
data['ignore_tags'] = txt_tags
return data

View File

@ -750,7 +750,7 @@ class PGProcessTrain(object):
input_size = 512
im = data['image']
text_polys = data['polys']
text_tags = data['tags']
text_tags = data['ignore_tags']
text_strs = data['texts']
h, w, _ = im.shape
text_polys, text_tags, hv_tags = self.check_and_validate_polys(