delete data_num_per_epoch
This commit is contained in:
parent
e23c4de5d8
commit
b8ba703548
|
@ -27,17 +27,13 @@ class SimpleDataSet(Dataset):
|
|||
global_config = config['Global']
|
||||
dataset_config = config[mode]['dataset']
|
||||
loader_config = config[mode]['loader']
|
||||
if 'data_num_per_epoch' in loader_config.keys():
|
||||
data_num_per_epoch = loader_config['data_num_per_epoch']
|
||||
else:
|
||||
data_num_per_epoch = None
|
||||
|
||||
self.delimiter = dataset_config.get('delimiter', '\t')
|
||||
label_file_list = dataset_config.pop('label_file_list')
|
||||
data_source_num = len(label_file_list)
|
||||
ratio_list = dataset_config.get("ratio_list", [1.0])
|
||||
if isinstance(ratio_list, (float, int)):
|
||||
ratio_list = [float(ratio_list)] * len(data_source_num)
|
||||
ratio_list = [float(ratio_list)] * int(data_source_num)
|
||||
|
||||
assert len(
|
||||
ratio_list
|
||||
|
@ -46,34 +42,26 @@ class SimpleDataSet(Dataset):
|
|||
self.do_shuffle = loader_config['shuffle']
|
||||
|
||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list,
|
||||
data_num_per_epoch)
|
||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||
if mode.lower() == "train":
|
||||
self.shuffle_data_random()
|
||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||
|
||||
def _sample_dataset(self, datas, sample_ratio, data_num_per_epoch=None):
|
||||
def _sample_dataset(self, datas, sample_ratio):
|
||||
sample_num = round(len(datas) * sample_ratio)
|
||||
|
||||
if data_num_per_epoch is not None:
|
||||
sample_num = int(data_num_per_epoch * sample_ratio)
|
||||
|
||||
nums, rem = int(sample_num // len(datas)), int(sample_num % len(datas))
|
||||
return list(datas) * nums + random.sample(datas, rem)
|
||||
|
||||
def get_image_info_list(self,
|
||||
file_list,
|
||||
ratio_list,
|
||||
data_num_per_epoch=None):
|
||||
def get_image_info_list(self, file_list, ratio_list):
|
||||
if isinstance(file_list, str):
|
||||
file_list = [file_list]
|
||||
data_lines = []
|
||||
for idx, file in enumerate(file_list):
|
||||
with open(file, "rb") as f:
|
||||
lines = f.readlines()
|
||||
lines = self._sample_dataset(lines, ratio_list[idx],
|
||||
data_num_per_epoch)
|
||||
lines = self._sample_dataset(lines, ratio_list[idx])
|
||||
data_lines.extend(lines)
|
||||
return data_lines
|
||||
|
||||
|
|
Loading…
Reference in New Issue