add config
This commit is contained in:
parent
40bf3b1053
commit
48898ac357
|
@ -0,0 +1,194 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: true
|
||||||
|
epoch_num: 1200
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 2
|
||||||
|
save_model_dir: ./output/ch_db_mv3/
|
||||||
|
save_epoch_step: 1200
|
||||||
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
|
eval_batch_step: [3000, 2000]
|
||||||
|
cal_metric_during_train: False
|
||||||
|
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: doc/imgs_en/img_10.jpg
|
||||||
|
save_res_path: ./output/det_db/predicts_db.txt
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
name: DistillationModel
|
||||||
|
algorithm: Distillation
|
||||||
|
Models:
|
||||||
|
Student:
|
||||||
|
pretrained:
|
||||||
|
freeze_params: false
|
||||||
|
return_all_feats: false
|
||||||
|
model_type: det
|
||||||
|
algorithm: DB
|
||||||
|
Backbone:
|
||||||
|
name: MobileNetV3
|
||||||
|
scale: 0.5
|
||||||
|
model_name: large
|
||||||
|
disable_se: True
|
||||||
|
Neck:
|
||||||
|
name: DBFPN
|
||||||
|
out_channels: 96
|
||||||
|
Head:
|
||||||
|
name: DBHead
|
||||||
|
k: 50
|
||||||
|
Student2:
|
||||||
|
pretrained:
|
||||||
|
freeze_params: false
|
||||||
|
return_all_feats: false
|
||||||
|
model_type: det
|
||||||
|
algorithm: DB
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: MobileNetV3
|
||||||
|
scale: 0.5
|
||||||
|
model_name: large
|
||||||
|
disable_se: True
|
||||||
|
Neck:
|
||||||
|
name: DBFPN
|
||||||
|
out_channels: 96
|
||||||
|
Head:
|
||||||
|
name: DBHead
|
||||||
|
k: 50
|
||||||
|
Teacher:
|
||||||
|
model_type: det
|
||||||
|
algorithm: DB
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: ResNet
|
||||||
|
layers: 18
|
||||||
|
Neck:
|
||||||
|
name: DBFPN
|
||||||
|
out_channels: 256
|
||||||
|
Head:
|
||||||
|
name: DBHead
|
||||||
|
k: 50
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: CombinedLoss
|
||||||
|
loss_config_list:
|
||||||
|
- DistillationDilaDBLoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_name_list: ["Student", "Student2", "Teacher"]
|
||||||
|
key: maps
|
||||||
|
balance_loss: true
|
||||||
|
main_loss_type: DiceLoss
|
||||||
|
alpha: 5
|
||||||
|
beta: 10
|
||||||
|
ohem_ratio: 3
|
||||||
|
- DistillationDMLLoss:
|
||||||
|
maps_name: ["thrink_maps"]
|
||||||
|
weight: 1.0
|
||||||
|
act: "softmax"
|
||||||
|
model_name_pairs: ["Student", "Student2"]
|
||||||
|
key: maps
|
||||||
|
- DistillationDBLoss:
|
||||||
|
model_name_list: ["Student", "Teacher"]
|
||||||
|
key: maps
|
||||||
|
name: DBLoss
|
||||||
|
balance_loss: true
|
||||||
|
main_loss_type: DiceLoss
|
||||||
|
alpha: 5
|
||||||
|
beta: 10
|
||||||
|
ohem_ratio: 3
|
||||||
|
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
lr:
|
||||||
|
name: Cosine
|
||||||
|
learning_rate: 0.001
|
||||||
|
warmup_epoch: 2
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
factor: 0
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: DistillationCTDBPostProcessCLabelDecode
|
||||||
|
model_name: ["Student", "Student2"]
|
||||||
|
key: head_out
|
||||||
|
thresh: 0.3
|
||||||
|
box_thresh: 0.6
|
||||||
|
max_candidates: 1000
|
||||||
|
unclip_ratio: 1.5
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: DistillationMetric
|
||||||
|
base_metric_name: DetMetric
|
||||||
|
main_indicator: hmean
|
||||||
|
key: "Student"
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/icdar2015/text_localization/
|
||||||
|
label_file_list:
|
||||||
|
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||||
|
ratio_list: [1.0]
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- DetLabelEncode: # Class handling label
|
||||||
|
- IaaAugment:
|
||||||
|
augmenter_args:
|
||||||
|
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||||
|
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||||
|
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||||
|
- EastRandomCropData:
|
||||||
|
size: [960, 960]
|
||||||
|
max_tries: 50
|
||||||
|
keep_ratio: true
|
||||||
|
- MakeBorderMap:
|
||||||
|
shrink_ratio: 0.4
|
||||||
|
thresh_min: 0.3
|
||||||
|
thresh_max: 0.7
|
||||||
|
- MakeShrinkMap:
|
||||||
|
shrink_ratio: 0.4
|
||||||
|
min_text_size: 8
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: 'hwc'
|
||||||
|
- ToCHWImage:
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 8
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/icdar2015/text_localization/
|
||||||
|
label_file_list:
|
||||||
|
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- DetLabelEncode: # Class handling label
|
||||||
|
- DetResizeForTest:
|
||||||
|
# image_shape: [736, 1280]
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: 'hwc'
|
||||||
|
- ToCHWImage:
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 1 # must be 1
|
||||||
|
num_workers: 2
|
|
@ -132,6 +132,96 @@ class DistillationCTCLoss(CTCLoss):
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationDBLoss(DBLoss):
|
||||||
|
def __init__(self,
|
||||||
|
model_name_list=[],
|
||||||
|
balance_loss=True,
|
||||||
|
main_loss_type='DiceLoss',
|
||||||
|
alpha=5,
|
||||||
|
beta=10,
|
||||||
|
ohem_ratio=3,
|
||||||
|
eps=1e-6,
|
||||||
|
name="db_loss",
|
||||||
|
**kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.model_name_list = model_name_list
|
||||||
|
self.name = name
|
||||||
|
self.key = None
|
||||||
|
|
||||||
|
def forward(self, preicts, batch):
|
||||||
|
loss_dict = {}
|
||||||
|
for idx, model_name in enumerate(self.model_name_list):
|
||||||
|
out = predicts[model_name]
|
||||||
|
if self.key is not None:
|
||||||
|
out = out[self.key]
|
||||||
|
loss = super().forward(out, batch)
|
||||||
|
|
||||||
|
if isinstance(loss, dict):
|
||||||
|
for key in loss.keys():
|
||||||
|
if key == "loss":
|
||||||
|
continue
|
||||||
|
name = "{}_{}_{}".format(self.name, model_name, key)
|
||||||
|
loss_dict[name] = loss[key]
|
||||||
|
else:
|
||||||
|
loss_dict["{}_{}".format(self.name, model_name)] = loss
|
||||||
|
|
||||||
|
loss_dict = _sum_loss(loss_dict)
|
||||||
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationDilaDBLoss(DBLoss):
|
||||||
|
def __init__(self,
|
||||||
|
model_name_pairs=[],
|
||||||
|
balance_loss=True,
|
||||||
|
main_loss_type='DiceLoss',
|
||||||
|
alpha=5,
|
||||||
|
beta=10,
|
||||||
|
ohem_ratio=3,
|
||||||
|
eps=1e-6,
|
||||||
|
name="dila_dbloss"):
|
||||||
|
super().__init__()
|
||||||
|
self.model_name_pairs = model_name_pairs
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
loss_dict = dict()
|
||||||
|
for idx, pair in enumerate(self.model_name_pairs):
|
||||||
|
stu_outs = predicts[pair[0]]
|
||||||
|
tch_outs = predicts[pair[1]]
|
||||||
|
if self.key is not None:
|
||||||
|
stu_preds = stu_outs[self.key]
|
||||||
|
tch_preds = tch_outs[self.key]
|
||||||
|
|
||||||
|
stu_shrink_maps = stu_preds[:, 0, :, :]
|
||||||
|
stu_binary_maps = stu_preds[:, 2, :, :]
|
||||||
|
|
||||||
|
# dilation to teacher prediction
|
||||||
|
dilation_w = np.array([[1, 1], [1, 1]])
|
||||||
|
th_shrink_maps = tch_preds[:, 0, :, :]
|
||||||
|
th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
|
||||||
|
dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
|
||||||
|
for i in range(th_shrink_maps.shape[0]):
|
||||||
|
dilate_maps[i] = cv2.dilate(
|
||||||
|
th_shrink_maps[i, :, :].astype(np.uint8), dilation_w)
|
||||||
|
th_shrink_maps = paddle.to_tensor(dilate_maps)
|
||||||
|
|
||||||
|
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[
|
||||||
|
1:]
|
||||||
|
|
||||||
|
# calculate the shrink map loss
|
||||||
|
bce_loss = self.alpha * self.bce_loss(
|
||||||
|
stu_shrink_maps, th_shrink_maps, label_shrink_mask)
|
||||||
|
loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps,
|
||||||
|
label_shrink_mask)
|
||||||
|
|
||||||
|
# k = f"{self.name}_{pair[0]}_{pair[1]}"
|
||||||
|
k = "{}_{}_{}".format(self.name, pair[0], pair[1])
|
||||||
|
loss_dict[k] = bce_loss + loss_binary_maps
|
||||||
|
|
||||||
|
loss_dict = _sum_loss(loss_dict)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class DistillationDistanceLoss(DistanceLoss):
|
class DistillationDistanceLoss(DistanceLoss):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -79,7 +79,7 @@ class BaseModel(nn.Layer):
|
||||||
x = self.neck(x)
|
x = self.neck(x)
|
||||||
y["neck_out"] = x
|
y["neck_out"] = x
|
||||||
x = self.head(x, targets=data)
|
x = self.head(x, targets=data)
|
||||||
if type(x) is dict:
|
if isinstance(x, dict):
|
||||||
y.update(x)
|
y.update(x)
|
||||||
else:
|
else:
|
||||||
y["head_out"] = x
|
y["head_out"] = x
|
||||||
|
|
Loading…
Reference in New Issue