Merge pull request #2347 from littletomatodonkey/dyg/fix_pre_rec

fix eval res vary for different times
This commit is contained in:
Double_V 2021-03-31 09:56:03 +08:00 committed by GitHub
commit d6ee6bdb48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 8 deletions

View File

@ -23,6 +23,7 @@ class SimpleDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None): def __init__(self, config, mode, logger, seed=None):
super(SimpleDataSet, self).__init__() super(SimpleDataSet, self).__init__()
self.logger = logger self.logger = logger
self.mode = mode.lower()
global_config = config['Global'] global_config = config['Global']
dataset_config = config[mode]['dataset'] dataset_config = config[mode]['dataset']
@ -45,7 +46,7 @@ class SimpleDataSet(Dataset):
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines))) self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train": if self.mode == "train" and self.do_shuffle:
self.shuffle_data_random() self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
@ -56,6 +57,7 @@ class SimpleDataSet(Dataset):
for idx, file in enumerate(file_list): for idx, file in enumerate(file_list):
with open(file, "rb") as f: with open(file, "rb") as f:
lines = f.readlines() lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed) random.seed(self.seed)
lines = random.sample(lines, lines = random.sample(lines,
round(len(lines) * ratio_list[idx])) round(len(lines) * ratio_list[idx]))
@ -63,7 +65,6 @@ class SimpleDataSet(Dataset):
return data_lines return data_lines
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed) random.seed(self.seed)
random.shuffle(self.data_lines) random.shuffle(self.data_lines)
return return
@ -90,7 +91,10 @@ class SimpleDataSet(Dataset):
data_line, e)) data_line, e))
outs = None outs = None
if outs is None: if outs is None:
return self.__getitem__(np.random.randint(self.__len__())) # during evaluation, we should fix the idx to get same results for many times of evaluation.
rnd_idx = np.random.randint(self.__len__(
)) if self.mode == "train" else (idx + 1) % self.__len__()
return self.__getitem__(rnd_idx)
return outs return outs
def __len__(self): def __len__(self):