Merge pull request #1807 from littletomatodonkey/dyg/fix_seed
fix data replication for multi-cards sampling
This commit is contained in:
commit
a27a43ec0f
|
@ -51,7 +51,7 @@ signal.signal(signal.SIGINT, term_mp)
|
|||
signal.signal(signal.SIGTERM, term_mp)
|
||||
|
||||
|
||||
def build_dataloader(config, mode, device, logger):
|
||||
def build_dataloader(config, mode, device, logger, seed=None):
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
support_dict = ['SimpleDataSet', 'LMDBDateSet']
|
||||
|
@ -61,7 +61,7 @@ def build_dataloader(config, mode, device, logger):
|
|||
assert mode in ['Train', 'Eval', 'Test'
|
||||
], "Mode should be Train, Eval or Test."
|
||||
|
||||
dataset = eval(module_name)(config, mode, logger)
|
||||
dataset = eval(module_name)(config, mode, logger, seed)
|
||||
loader_config = config[mode]['loader']
|
||||
batch_size = loader_config['batch_size_per_card']
|
||||
drop_last = loader_config['drop_last']
|
||||
|
|
|
@ -21,7 +21,7 @@ from .imaug import transform, create_operators
|
|||
|
||||
|
||||
class LMDBDateSet(Dataset):
|
||||
def __init__(self, config, mode, logger):
|
||||
def __init__(self, config, mode, logger, seed=None):
|
||||
super(LMDBDateSet, self).__init__()
|
||||
|
||||
global_config = config['Global']
|
||||
|
|
|
@ -20,7 +20,7 @@ from .imaug import transform, create_operators
|
|||
|
||||
|
||||
class SimpleDataSet(Dataset):
|
||||
def __init__(self, config, mode, logger):
|
||||
def __init__(self, config, mode, logger, seed=None):
|
||||
super(SimpleDataSet, self).__init__()
|
||||
self.logger = logger
|
||||
|
||||
|
@ -41,6 +41,7 @@ class SimpleDataSet(Dataset):
|
|||
self.data_dir = dataset_config['data_dir']
|
||||
self.do_shuffle = loader_config['shuffle']
|
||||
|
||||
self.seed = seed
|
||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||
|
@ -55,6 +56,7 @@ class SimpleDataSet(Dataset):
|
|||
for idx, file in enumerate(file_list):
|
||||
with open(file, "rb") as f:
|
||||
lines = f.readlines()
|
||||
random.seed(self.seed)
|
||||
lines = random.sample(lines,
|
||||
round(len(lines) * ratio_list[idx]))
|
||||
data_lines.extend(lines)
|
||||
|
@ -62,6 +64,7 @@ class SimpleDataSet(Dataset):
|
|||
|
||||
def shuffle_data_random(self):
|
||||
if self.do_shuffle:
|
||||
random.seed(self.seed)
|
||||
random.shuffle(self.data_lines)
|
||||
return
|
||||
|
||||
|
|
|
@ -182,8 +182,8 @@ def train(config,
|
|||
start_epoch = 1
|
||||
|
||||
for epoch in range(start_epoch, epoch_num + 1):
|
||||
if epoch > 0:
|
||||
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
||||
train_dataloader = build_dataloader(
|
||||
config, 'Train', device, logger, seed=epoch)
|
||||
train_batch_cost = 0.0
|
||||
train_reader_cost = 0.0
|
||||
batch_sum = 0
|
||||
|
|
Loading…
Reference in New Issue