diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py index 4e1ff0ae..ea27a865 100644 --- a/ppocr/data/__init__.py +++ b/ppocr/data/__init__.py @@ -51,7 +51,7 @@ signal.signal(signal.SIGINT, term_mp) signal.signal(signal.SIGTERM, term_mp) -def build_dataloader(config, mode, device, logger): +def build_dataloader(config, mode, device, logger, seed=None): config = copy.deepcopy(config) support_dict = ['SimpleDataSet', 'LMDBDateSet'] @@ -61,7 +61,7 @@ def build_dataloader(config, mode, device, logger): assert mode in ['Train', 'Eval', 'Test' ], "Mode should be Train, Eval or Test." - dataset = eval(module_name)(config, mode, logger) + dataset = eval(module_name)(config, mode, logger, seed) loader_config = config[mode]['loader'] batch_size = loader_config['batch_size_per_card'] drop_last = loader_config['drop_last'] diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py index e7bb6dd3..bd0630f6 100644 --- a/ppocr/data/lmdb_dataset.py +++ b/ppocr/data/lmdb_dataset.py @@ -21,7 +21,7 @@ from .imaug import transform, create_operators class LMDBDateSet(Dataset): - def __init__(self, config, mode, logger): + def __init__(self, config, mode, logger, seed=None): super(LMDBDateSet, self).__init__() global_config = config['Global'] diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index ab17dd1a..d2a86b0f 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -20,7 +20,7 @@ from .imaug import transform, create_operators class SimpleDataSet(Dataset): - def __init__(self, config, mode, logger): + def __init__(self, config, mode, logger, seed=None): super(SimpleDataSet, self).__init__() self.logger = logger @@ -41,6 +41,7 @@ class SimpleDataSet(Dataset): self.data_dir = dataset_config['data_dir'] self.do_shuffle = loader_config['shuffle'] + self.seed = seed logger.info("Initialize indexs of datasets:%s" % label_file_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_idx_order_list = list(range(len(self.data_lines))) @@ -55,6 +56,7 @@ class SimpleDataSet(Dataset): for idx, file in enumerate(file_list): with open(file, "rb") as f: lines = f.readlines() + random.seed(self.seed) lines = random.sample(lines, round(len(lines) * ratio_list[idx])) data_lines.extend(lines) @@ -62,6 +64,7 @@ class SimpleDataSet(Dataset): def shuffle_data_random(self): if self.do_shuffle: + random.seed(self.seed) random.shuffle(self.data_lines) return diff --git a/tools/program.py b/tools/program.py index c2915426..cbca715a 100755 --- a/tools/program.py +++ b/tools/program.py @@ -182,8 +182,8 @@ def train(config, start_epoch = 1 for epoch in range(start_epoch, epoch_num + 1): - if epoch > 0: - train_dataloader = build_dataloader(config, 'Train', device, logger) + train_dataloader = build_dataloader( + config, 'Train', device, logger, seed=epoch) train_batch_cost = 0.0 train_reader_cost = 0.0 batch_sum = 0