Merge pull request #1807 from littletomatodonkey/dyg/fix_seed

fix data replication for multi-cards sampling
This commit is contained in:
zhoujun 2021-01-26 12:45:23 +08:00 committed by GitHub
commit a27a43ec0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 9 additions and 6 deletions

View File

@ -51,7 +51,7 @@ signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp) signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger): def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDateSet'] support_dict = ['SimpleDataSet', 'LMDBDateSet']
@ -61,7 +61,7 @@ def build_dataloader(config, mode, device, logger):
assert mode in ['Train', 'Eval', 'Test' assert mode in ['Train', 'Eval', 'Test'
], "Mode should be Train, Eval or Test." ], "Mode should be Train, Eval or Test."
dataset = eval(module_name)(config, mode, logger) dataset = eval(module_name)(config, mode, logger, seed)
loader_config = config[mode]['loader'] loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card'] batch_size = loader_config['batch_size_per_card']
drop_last = loader_config['drop_last'] drop_last = loader_config['drop_last']

View File

@ -21,7 +21,7 @@ from .imaug import transform, create_operators
class LMDBDateSet(Dataset): class LMDBDateSet(Dataset):
def __init__(self, config, mode, logger): def __init__(self, config, mode, logger, seed=None):
super(LMDBDateSet, self).__init__() super(LMDBDateSet, self).__init__()
global_config = config['Global'] global_config = config['Global']

View File

@ -20,7 +20,7 @@ from .imaug import transform, create_operators
class SimpleDataSet(Dataset): class SimpleDataSet(Dataset):
def __init__(self, config, mode, logger): def __init__(self, config, mode, logger, seed=None):
super(SimpleDataSet, self).__init__() super(SimpleDataSet, self).__init__()
self.logger = logger self.logger = logger
@ -41,6 +41,7 @@ class SimpleDataSet(Dataset):
self.data_dir = dataset_config['data_dir'] self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines))) self.data_idx_order_list = list(range(len(self.data_lines)))
@ -55,6 +56,7 @@ class SimpleDataSet(Dataset):
for idx, file in enumerate(file_list): for idx, file in enumerate(file_list):
with open(file, "rb") as f: with open(file, "rb") as f:
lines = f.readlines() lines = f.readlines()
random.seed(self.seed)
lines = random.sample(lines, lines = random.sample(lines,
round(len(lines) * ratio_list[idx])) round(len(lines) * ratio_list[idx]))
data_lines.extend(lines) data_lines.extend(lines)
@ -62,6 +64,7 @@ class SimpleDataSet(Dataset):
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle: if self.do_shuffle:
random.seed(self.seed)
random.shuffle(self.data_lines) random.shuffle(self.data_lines)
return return

View File

@ -182,8 +182,8 @@ def train(config,
start_epoch = 1 start_epoch = 1
for epoch in range(start_epoch, epoch_num + 1): for epoch in range(start_epoch, epoch_num + 1):
if epoch > 0: train_dataloader = build_dataloader(
train_dataloader = build_dataloader(config, 'Train', device, logger) config, 'Train', device, logger, seed=epoch)
train_batch_cost = 0.0 train_batch_cost = 0.0
train_reader_cost = 0.0 train_reader_cost = 0.0
batch_sum = 0 batch_sum = 0