Merge pull request #2347 from littletomatodonkey/dyg/fix_pre_rec
fix eval res vary for different times
This commit is contained in:
commit
d6ee6bdb48
|
@ -23,6 +23,7 @@ class SimpleDataSet(Dataset):
|
||||||
def __init__(self, config, mode, logger, seed=None):
|
def __init__(self, config, mode, logger, seed=None):
|
||||||
super(SimpleDataSet, self).__init__()
|
super(SimpleDataSet, self).__init__()
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
self.mode = mode.lower()
|
||||||
|
|
||||||
global_config = config['Global']
|
global_config = config['Global']
|
||||||
dataset_config = config[mode]['dataset']
|
dataset_config = config[mode]['dataset']
|
||||||
|
@ -45,7 +46,7 @@ class SimpleDataSet(Dataset):
|
||||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
||||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||||
if mode.lower() == "train":
|
if self.mode == "train" and self.do_shuffle:
|
||||||
self.shuffle_data_random()
|
self.shuffle_data_random()
|
||||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||||
|
|
||||||
|
@ -56,6 +57,7 @@ class SimpleDataSet(Dataset):
|
||||||
for idx, file in enumerate(file_list):
|
for idx, file in enumerate(file_list):
|
||||||
with open(file, "rb") as f:
|
with open(file, "rb") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
if self.mode == "train" or ratio_list[idx] < 1.0:
|
||||||
random.seed(self.seed)
|
random.seed(self.seed)
|
||||||
lines = random.sample(lines,
|
lines = random.sample(lines,
|
||||||
round(len(lines) * ratio_list[idx]))
|
round(len(lines) * ratio_list[idx]))
|
||||||
|
@ -63,7 +65,6 @@ class SimpleDataSet(Dataset):
|
||||||
return data_lines
|
return data_lines
|
||||||
|
|
||||||
def shuffle_data_random(self):
|
def shuffle_data_random(self):
|
||||||
if self.do_shuffle:
|
|
||||||
random.seed(self.seed)
|
random.seed(self.seed)
|
||||||
random.shuffle(self.data_lines)
|
random.shuffle(self.data_lines)
|
||||||
return
|
return
|
||||||
|
@ -90,7 +91,10 @@ class SimpleDataSet(Dataset):
|
||||||
data_line, e))
|
data_line, e))
|
||||||
outs = None
|
outs = None
|
||||||
if outs is None:
|
if outs is None:
|
||||||
return self.__getitem__(np.random.randint(self.__len__()))
|
# during evaluation, we should fix the idx to get same results for many times of evaluation.
|
||||||
|
rnd_idx = np.random.randint(self.__len__(
|
||||||
|
)) if self.mode == "train" else (idx + 1) % self.__len__()
|
||||||
|
return self.__getitem__(rnd_idx)
|
||||||
return outs
|
return outs
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
Loading…
Reference in New Issue