delete data_num_per_epoch

This commit is contained in:
LDOUBLEV 2020-12-10 10:19:39 +08:00
parent e23c4de5d8
commit b8ba703548
1 changed files with 5 additions and 17 deletions

View File

@ -27,17 +27,13 @@ class SimpleDataSet(Dataset):
global_config = config['Global'] global_config = config['Global']
dataset_config = config[mode]['dataset'] dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader'] loader_config = config[mode]['loader']
if 'data_num_per_epoch' in loader_config.keys():
data_num_per_epoch = loader_config['data_num_per_epoch']
else:
data_num_per_epoch = None
self.delimiter = dataset_config.get('delimiter', '\t') self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list') label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list) data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0]) ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)): if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * len(data_source_num) ratio_list = [float(ratio_list)] * int(data_source_num)
assert len( assert len(
ratio_list ratio_list
@ -46,34 +42,26 @@ class SimpleDataSet(Dataset):
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list, self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
data_num_per_epoch)
self.data_idx_order_list = list(range(len(self.data_lines))) self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train": if mode.lower() == "train":
self.shuffle_data_random() self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
def _sample_dataset(self, datas, sample_ratio, data_num_per_epoch=None): def _sample_dataset(self, datas, sample_ratio):
sample_num = round(len(datas) * sample_ratio) sample_num = round(len(datas) * sample_ratio)
if data_num_per_epoch is not None:
sample_num = int(data_num_per_epoch * sample_ratio)
nums, rem = int(sample_num // len(datas)), int(sample_num % len(datas)) nums, rem = int(sample_num // len(datas)), int(sample_num % len(datas))
return list(datas) * nums + random.sample(datas, rem) return list(datas) * nums + random.sample(datas, rem)
def get_image_info_list(self, def get_image_info_list(self, file_list, ratio_list):
file_list,
ratio_list,
data_num_per_epoch=None):
if isinstance(file_list, str): if isinstance(file_list, str):
file_list = [file_list] file_list = [file_list]
data_lines = [] data_lines = []
for idx, file in enumerate(file_list): for idx, file in enumerate(file_list):
with open(file, "rb") as f: with open(file, "rb") as f:
lines = f.readlines() lines = f.readlines()
lines = self._sample_dataset(lines, ratio_list[idx], lines = self._sample_dataset(lines, ratio_list[idx])
data_num_per_epoch)
data_lines.extend(lines) data_lines.extend(lines)
return data_lines return data_lines