commit
d40302e6d6
|
@ -14,7 +14,6 @@
|
|||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import traceback
|
||||
from paddle.io import Dataset
|
||||
|
||||
from .imaug import transform, create_operators
|
||||
|
@ -46,7 +45,6 @@ class SimpleDataSet(Dataset):
|
|||
self.seed = seed
|
||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
||||
self.check_data()
|
||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||
if self.mode == "train" and self.do_shuffle:
|
||||
self.shuffle_data_random()
|
||||
|
@ -103,18 +101,25 @@ class SimpleDataSet(Dataset):
|
|||
|
||||
def __getitem__(self, idx):
|
||||
file_idx = self.data_idx_order_list[idx]
|
||||
data = self.data_lines[file_idx]
|
||||
data_line = self.data_lines[file_idx]
|
||||
try:
|
||||
data_line = data_line.decode('utf-8')
|
||||
substr = data_line.strip("\n").split(self.delimiter)
|
||||
file_name = substr[0]
|
||||
label = substr[1]
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
data = {'img_path': img_path, 'label': label}
|
||||
if not os.path.exists(img_path):
|
||||
raise Exception("{} does not exist!".format(img_path))
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
data['ext_data'] = self.get_ext_data()
|
||||
outs = transform(data, self.ops)
|
||||
except:
|
||||
error_meg = traceback.format_exc()
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"When parsing file {} and label {}, error happened with msg: {}".format(
|
||||
data['img_path'],data['label'], error_meg))
|
||||
"When parsing line {}, error happened with msg: {}".format(
|
||||
data_line, e))
|
||||
outs = None
|
||||
if outs is None:
|
||||
# during evaluation, we should fix the idx to get same results for many times of evaluation.
|
||||
|
@ -125,17 +130,3 @@ class SimpleDataSet(Dataset):
|
|||
|
||||
def __len__(self):
|
||||
return len(self.data_idx_order_list)
|
||||
|
||||
def check_data(self):
|
||||
new_data_lines = []
|
||||
for data_line in self.data_lines:
|
||||
data_line = data_line.decode('utf-8')
|
||||
substr = data_line.strip("\n").strip("\r").split(self.delimiter)
|
||||
file_name = substr[0]
|
||||
label = substr[1]
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
if os.path.exists(img_path):
|
||||
new_data_lines.append({'img_path': img_path, 'label': label})
|
||||
else:
|
||||
self.logger.info("{} does not exist!".format(img_path))
|
||||
self.data_lines = new_data_lines
|
|
@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
|
|||
pretrained = model_config.pop("pretrained")
|
||||
model = BaseModel(model_config)
|
||||
if pretrained is not None:
|
||||
model = load_pretrained_params(model, pretrained)
|
||||
load_pretrained_params(model, pretrained)
|
||||
if freeze_params:
|
||||
for param in model.parameters():
|
||||
param.trainable = False
|
||||
|
|
Loading…
Reference in New Issue