Merge pull request #1807 from littletomatodonkey/dyg/fix_seed
fix data replication for multi-cards sampling
This commit is contained in:
commit
a27a43ec0f
|
@ -51,7 +51,7 @@ signal.signal(signal.SIGINT, term_mp)
|
||||||
signal.signal(signal.SIGTERM, term_mp)
|
signal.signal(signal.SIGTERM, term_mp)
|
||||||
|
|
||||||
|
|
||||||
def build_dataloader(config, mode, device, logger):
|
def build_dataloader(config, mode, device, logger, seed=None):
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
|
|
||||||
support_dict = ['SimpleDataSet', 'LMDBDateSet']
|
support_dict = ['SimpleDataSet', 'LMDBDateSet']
|
||||||
|
@ -61,7 +61,7 @@ def build_dataloader(config, mode, device, logger):
|
||||||
assert mode in ['Train', 'Eval', 'Test'
|
assert mode in ['Train', 'Eval', 'Test'
|
||||||
], "Mode should be Train, Eval or Test."
|
], "Mode should be Train, Eval or Test."
|
||||||
|
|
||||||
dataset = eval(module_name)(config, mode, logger)
|
dataset = eval(module_name)(config, mode, logger, seed)
|
||||||
loader_config = config[mode]['loader']
|
loader_config = config[mode]['loader']
|
||||||
batch_size = loader_config['batch_size_per_card']
|
batch_size = loader_config['batch_size_per_card']
|
||||||
drop_last = loader_config['drop_last']
|
drop_last = loader_config['drop_last']
|
||||||
|
|
|
@ -21,7 +21,7 @@ from .imaug import transform, create_operators
|
||||||
|
|
||||||
|
|
||||||
class LMDBDateSet(Dataset):
|
class LMDBDateSet(Dataset):
|
||||||
def __init__(self, config, mode, logger):
|
def __init__(self, config, mode, logger, seed=None):
|
||||||
super(LMDBDateSet, self).__init__()
|
super(LMDBDateSet, self).__init__()
|
||||||
|
|
||||||
global_config = config['Global']
|
global_config = config['Global']
|
||||||
|
|
|
@ -20,7 +20,7 @@ from .imaug import transform, create_operators
|
||||||
|
|
||||||
|
|
||||||
class SimpleDataSet(Dataset):
|
class SimpleDataSet(Dataset):
|
||||||
def __init__(self, config, mode, logger):
|
def __init__(self, config, mode, logger, seed=None):
|
||||||
super(SimpleDataSet, self).__init__()
|
super(SimpleDataSet, self).__init__()
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
|
@ -41,6 +41,7 @@ class SimpleDataSet(Dataset):
|
||||||
self.data_dir = dataset_config['data_dir']
|
self.data_dir = dataset_config['data_dir']
|
||||||
self.do_shuffle = loader_config['shuffle']
|
self.do_shuffle = loader_config['shuffle']
|
||||||
|
|
||||||
|
self.seed = seed
|
||||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
||||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||||
|
@ -55,6 +56,7 @@ class SimpleDataSet(Dataset):
|
||||||
for idx, file in enumerate(file_list):
|
for idx, file in enumerate(file_list):
|
||||||
with open(file, "rb") as f:
|
with open(file, "rb") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
random.seed(self.seed)
|
||||||
lines = random.sample(lines,
|
lines = random.sample(lines,
|
||||||
round(len(lines) * ratio_list[idx]))
|
round(len(lines) * ratio_list[idx]))
|
||||||
data_lines.extend(lines)
|
data_lines.extend(lines)
|
||||||
|
@ -62,6 +64,7 @@ class SimpleDataSet(Dataset):
|
||||||
|
|
||||||
def shuffle_data_random(self):
|
def shuffle_data_random(self):
|
||||||
if self.do_shuffle:
|
if self.do_shuffle:
|
||||||
|
random.seed(self.seed)
|
||||||
random.shuffle(self.data_lines)
|
random.shuffle(self.data_lines)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -182,8 +182,8 @@ def train(config,
|
||||||
start_epoch = 1
|
start_epoch = 1
|
||||||
|
|
||||||
for epoch in range(start_epoch, epoch_num + 1):
|
for epoch in range(start_epoch, epoch_num + 1):
|
||||||
if epoch > 0:
|
train_dataloader = build_dataloader(
|
||||||
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
config, 'Train', device, logger, seed=epoch)
|
||||||
train_batch_cost = 0.0
|
train_batch_cost = 0.0
|
||||||
train_reader_cost = 0.0
|
train_reader_cost = 0.0
|
||||||
batch_sum = 0
|
batch_sum = 0
|
||||||
|
|
Loading…
Reference in New Issue