delete data_num_per_epoch

This commit is contained in:
LDOUBLEV 2020-12-10 10:19:39 +08:00
parent e23c4de5d8
commit b8ba703548
1 changed files with 5 additions and 17 deletions

View File

@ -27,17 +27,13 @@ class SimpleDataSet(Dataset):
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
if 'data_num_per_epoch' in loader_config.keys():
data_num_per_epoch = loader_config['data_num_per_epoch']
else:
data_num_per_epoch = None
self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * len(data_source_num)
ratio_list = [float(ratio_list)] * int(data_source_num)
assert len(
ratio_list
@ -46,34 +42,26 @@ class SimpleDataSet(Dataset):
self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list,
data_num_per_epoch)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
def _sample_dataset(self, datas, sample_ratio, data_num_per_epoch=None):
def _sample_dataset(self, datas, sample_ratio):
sample_num = round(len(datas) * sample_ratio)
if data_num_per_epoch is not None:
sample_num = int(data_num_per_epoch * sample_ratio)
nums, rem = int(sample_num // len(datas)), int(sample_num % len(datas))
return list(datas) * nums + random.sample(datas, rem)
def get_image_info_list(self,
file_list,
ratio_list,
data_num_per_epoch=None):
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
lines = self._sample_dataset(lines, ratio_list[idx],
data_num_per_epoch)
lines = self._sample_dataset(lines, ratio_list[idx])
data_lines.extend(lines)
return data_lines