diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index 8f8fcb4d..0906167f 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -69,6 +69,36 @@ class SimpleDataSet(Dataset): random.shuffle(self.data_lines) return + def get_ext_data(self): + ext_data_num = 0 + for op in self.ops: + if hasattr(op, 'ext_data_num'): + ext_data_num = getattr(op, 'ext_data_num') + break + load_data_ops = self.ops[:2] + ext_data = [] + + while len(ext_data) < ext_data_num: + file_idx = self.data_idx_order_list[np.random.randint(self.__len__( + ))] + data_line = self.data_lines[file_idx] + data_line = data_line.decode('utf-8') + substr = data_line.strip("\n").split(self.delimiter) + file_name = substr[0] + label = substr[1] + img_path = os.path.join(self.data_dir, file_name) + data = {'img_path': img_path, 'label': label} + if not os.path.exists(img_path): + continue + with open(data['img_path'], 'rb') as f: + img = f.read() + data['image'] = img + data = transform(data, load_data_ops) + if data is None: + continue + ext_data.append(data) + return ext_data + def __getitem__(self, idx): file_idx = self.data_idx_order_list[idx] data_line = self.data_lines[file_idx] @@ -84,6 +114,7 @@ class SimpleDataSet(Dataset): with open(data['img_path'], 'rb') as f: img = f.read() data['image'] = img + data['ext_data'] = self.get_ext_data() outs = transform(data, self.ops) except Exception as e: self.logger.error(