delete data_num_per_epoch
This commit is contained in:
parent
e23c4de5d8
commit
b8ba703548
|
@ -27,17 +27,13 @@ class SimpleDataSet(Dataset):
|
||||||
global_config = config['Global']
|
global_config = config['Global']
|
||||||
dataset_config = config[mode]['dataset']
|
dataset_config = config[mode]['dataset']
|
||||||
loader_config = config[mode]['loader']
|
loader_config = config[mode]['loader']
|
||||||
if 'data_num_per_epoch' in loader_config.keys():
|
|
||||||
data_num_per_epoch = loader_config['data_num_per_epoch']
|
|
||||||
else:
|
|
||||||
data_num_per_epoch = None
|
|
||||||
|
|
||||||
self.delimiter = dataset_config.get('delimiter', '\t')
|
self.delimiter = dataset_config.get('delimiter', '\t')
|
||||||
label_file_list = dataset_config.pop('label_file_list')
|
label_file_list = dataset_config.pop('label_file_list')
|
||||||
data_source_num = len(label_file_list)
|
data_source_num = len(label_file_list)
|
||||||
ratio_list = dataset_config.get("ratio_list", [1.0])
|
ratio_list = dataset_config.get("ratio_list", [1.0])
|
||||||
if isinstance(ratio_list, (float, int)):
|
if isinstance(ratio_list, (float, int)):
|
||||||
ratio_list = [float(ratio_list)] * len(data_source_num)
|
ratio_list = [float(ratio_list)] * int(data_source_num)
|
||||||
|
|
||||||
assert len(
|
assert len(
|
||||||
ratio_list
|
ratio_list
|
||||||
|
@ -46,34 +42,26 @@ class SimpleDataSet(Dataset):
|
||||||
self.do_shuffle = loader_config['shuffle']
|
self.do_shuffle = loader_config['shuffle']
|
||||||
|
|
||||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list,
|
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
||||||
data_num_per_epoch)
|
|
||||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||||
if mode.lower() == "train":
|
if mode.lower() == "train":
|
||||||
self.shuffle_data_random()
|
self.shuffle_data_random()
|
||||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||||
|
|
||||||
def _sample_dataset(self, datas, sample_ratio, data_num_per_epoch=None):
|
def _sample_dataset(self, datas, sample_ratio):
|
||||||
sample_num = round(len(datas) * sample_ratio)
|
sample_num = round(len(datas) * sample_ratio)
|
||||||
|
|
||||||
if data_num_per_epoch is not None:
|
|
||||||
sample_num = int(data_num_per_epoch * sample_ratio)
|
|
||||||
|
|
||||||
nums, rem = int(sample_num // len(datas)), int(sample_num % len(datas))
|
nums, rem = int(sample_num // len(datas)), int(sample_num % len(datas))
|
||||||
return list(datas) * nums + random.sample(datas, rem)
|
return list(datas) * nums + random.sample(datas, rem)
|
||||||
|
|
||||||
def get_image_info_list(self,
|
def get_image_info_list(self, file_list, ratio_list):
|
||||||
file_list,
|
|
||||||
ratio_list,
|
|
||||||
data_num_per_epoch=None):
|
|
||||||
if isinstance(file_list, str):
|
if isinstance(file_list, str):
|
||||||
file_list = [file_list]
|
file_list = [file_list]
|
||||||
data_lines = []
|
data_lines = []
|
||||||
for idx, file in enumerate(file_list):
|
for idx, file in enumerate(file_list):
|
||||||
with open(file, "rb") as f:
|
with open(file, "rb") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
lines = self._sample_dataset(lines, ratio_list[idx],
|
lines = self._sample_dataset(lines, ratio_list[idx])
|
||||||
data_num_per_epoch)
|
|
||||||
data_lines.extend(lines)
|
data_lines.extend(lines)
|
||||||
return data_lines
|
return data_lines
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue