fix rec distillation (#3994)
* fix rec distillation * add dist cfg * fix yaml
This commit is contained in:
parent
51f4a2c375
commit
7dc56191f5
|
@ -4,7 +4,7 @@ Global:
|
||||||
epoch_num: 800
|
epoch_num: 800
|
||||||
log_smooth_window: 20
|
log_smooth_window: 20
|
||||||
print_batch_step: 10
|
print_batch_step: 10
|
||||||
save_model_dir: ./output/rec_chinese_lite_distillation_v2.1
|
save_model_dir: ./output/rec_mobile_pp-OCRv2
|
||||||
save_epoch_step: 3
|
save_epoch_step: 3
|
||||||
eval_batch_step: [0, 2000]
|
eval_batch_step: [0, 2000]
|
||||||
cal_metric_during_train: true
|
cal_metric_during_train: true
|
||||||
|
@ -19,7 +19,7 @@ Global:
|
||||||
infer_mode: false
|
infer_mode: false
|
||||||
use_space_char: true
|
use_space_char: true
|
||||||
distributed: true
|
distributed: true
|
||||||
save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt
|
save_res_path: ./output/rec/predicts_mobile_pp-OCRv2.txt
|
||||||
|
|
||||||
|
|
||||||
Optimizer:
|
Optimizer:
|
||||||
|
@ -35,79 +35,32 @@ Optimizer:
|
||||||
name: L2
|
name: L2
|
||||||
factor: 2.0e-05
|
factor: 2.0e-05
|
||||||
|
|
||||||
Architecture:
|
|
||||||
model_type: &model_type "rec"
|
|
||||||
name: DistillationModel
|
|
||||||
algorithm: Distillation
|
|
||||||
Models:
|
|
||||||
Teacher:
|
|
||||||
pretrained:
|
|
||||||
freeze_params: false
|
|
||||||
return_all_feats: true
|
|
||||||
model_type: *model_type
|
|
||||||
algorithm: CRNN
|
|
||||||
Transform:
|
|
||||||
Backbone:
|
|
||||||
name: MobileNetV1Enhance
|
|
||||||
scale: 0.5
|
|
||||||
Neck:
|
|
||||||
name: SequenceEncoder
|
|
||||||
encoder_type: rnn
|
|
||||||
hidden_size: 64
|
|
||||||
Head:
|
|
||||||
name: CTCHead
|
|
||||||
mid_channels: 96
|
|
||||||
fc_decay: 0.00002
|
|
||||||
Student:
|
|
||||||
pretrained:
|
|
||||||
freeze_params: false
|
|
||||||
return_all_feats: true
|
|
||||||
model_type: *model_type
|
|
||||||
algorithm: CRNN
|
|
||||||
Transform:
|
|
||||||
Backbone:
|
|
||||||
name: MobileNetV1Enhance
|
|
||||||
scale: 0.5
|
|
||||||
Neck:
|
|
||||||
name: SequenceEncoder
|
|
||||||
encoder_type: rnn
|
|
||||||
hidden_size: 64
|
|
||||||
Head:
|
|
||||||
name: CTCHead
|
|
||||||
mid_channels: 96
|
|
||||||
fc_decay: 0.00002
|
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: rec
|
||||||
|
algorithm: CRNN
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: MobileNetV1Enhance
|
||||||
|
scale: 0.5
|
||||||
|
Neck:
|
||||||
|
name: SequenceEncoder
|
||||||
|
encoder_type: rnn
|
||||||
|
hidden_size: 64
|
||||||
|
Head:
|
||||||
|
name: CTCHead
|
||||||
|
mid_channels: 96
|
||||||
|
fc_decay: 0.00002
|
||||||
|
|
||||||
Loss:
|
Loss:
|
||||||
name: CombinedLoss
|
name: CTCLoss
|
||||||
loss_config_list:
|
|
||||||
- DistillationCTCLoss:
|
|
||||||
weight: 1.0
|
|
||||||
model_name_list: ["Student", "Teacher"]
|
|
||||||
key: head_out
|
|
||||||
- DistillationDMLLoss:
|
|
||||||
weight: 1.0
|
|
||||||
act: "softmax"
|
|
||||||
model_name_pairs:
|
|
||||||
- ["Student", "Teacher"]
|
|
||||||
key: head_out
|
|
||||||
- DistillationDistanceLoss:
|
|
||||||
weight: 1.0
|
|
||||||
mode: "l2"
|
|
||||||
model_name_pairs:
|
|
||||||
- ["Student", "Teacher"]
|
|
||||||
key: backbone_out
|
|
||||||
|
|
||||||
PostProcess:
|
PostProcess:
|
||||||
name: DistillationCTCLabelDecode
|
name: CTCLabelDecode
|
||||||
model_name: ["Student", "Teacher"]
|
|
||||||
key: head_out
|
|
||||||
|
|
||||||
Metric:
|
Metric:
|
||||||
name: DistillationMetric
|
name: RecMetric
|
||||||
base_metric_name: RecMetric
|
|
||||||
main_indicator: acc
|
main_indicator: acc
|
||||||
key: "Student"
|
|
||||||
|
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
|
@ -132,7 +85,6 @@ Train:
|
||||||
shuffle: true
|
shuffle: true
|
||||||
batch_size_per_card: 128
|
batch_size_per_card: 128
|
||||||
drop_last: true
|
drop_last: true
|
||||||
num_sections: 1
|
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
|
|
|
@ -0,0 +1,160 @@
|
||||||
|
Global:
|
||||||
|
debug: false
|
||||||
|
use_gpu: true
|
||||||
|
epoch_num: 800
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 10
|
||||||
|
save_model_dir: ./output/rec_pp-OCRv2_distillation
|
||||||
|
save_epoch_step: 3
|
||||||
|
eval_batch_step: [0, 2000]
|
||||||
|
cal_metric_during_train: true
|
||||||
|
pretrained_model:
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: false
|
||||||
|
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||||
|
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
|
||||||
|
character_type: ch
|
||||||
|
max_text_length: 25
|
||||||
|
infer_mode: false
|
||||||
|
use_space_char: true
|
||||||
|
distributed: true
|
||||||
|
save_res_path: ./output/rec/predicts_pp-OCRv2_distillation.txt
|
||||||
|
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
lr:
|
||||||
|
name: Piecewise
|
||||||
|
decay_epochs : [700, 800]
|
||||||
|
values : [0.001, 0.0001]
|
||||||
|
warmup_epoch: 5
|
||||||
|
regularizer:
|
||||||
|
name: L2
|
||||||
|
factor: 2.0e-05
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: &model_type "rec"
|
||||||
|
name: DistillationModel
|
||||||
|
algorithm: Distillation
|
||||||
|
Models:
|
||||||
|
Teacher:
|
||||||
|
pretrained:
|
||||||
|
freeze_params: false
|
||||||
|
return_all_feats: true
|
||||||
|
model_type: *model_type
|
||||||
|
algorithm: CRNN
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: MobileNetV1Enhance
|
||||||
|
scale: 0.5
|
||||||
|
Neck:
|
||||||
|
name: SequenceEncoder
|
||||||
|
encoder_type: rnn
|
||||||
|
hidden_size: 64
|
||||||
|
Head:
|
||||||
|
name: CTCHead
|
||||||
|
mid_channels: 96
|
||||||
|
fc_decay: 0.00002
|
||||||
|
Student:
|
||||||
|
pretrained:
|
||||||
|
freeze_params: false
|
||||||
|
return_all_feats: true
|
||||||
|
model_type: *model_type
|
||||||
|
algorithm: CRNN
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: MobileNetV1Enhance
|
||||||
|
scale: 0.5
|
||||||
|
Neck:
|
||||||
|
name: SequenceEncoder
|
||||||
|
encoder_type: rnn
|
||||||
|
hidden_size: 64
|
||||||
|
Head:
|
||||||
|
name: CTCHead
|
||||||
|
mid_channels: 96
|
||||||
|
fc_decay: 0.00002
|
||||||
|
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: CombinedLoss
|
||||||
|
loss_config_list:
|
||||||
|
- DistillationCTCLoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_name_list: ["Student", "Teacher"]
|
||||||
|
key: head_out
|
||||||
|
- DistillationDMLLoss:
|
||||||
|
weight: 1.0
|
||||||
|
act: "softmax"
|
||||||
|
use_log: true
|
||||||
|
model_name_pairs:
|
||||||
|
- ["Student", "Teacher"]
|
||||||
|
key: head_out
|
||||||
|
- DistillationDistanceLoss:
|
||||||
|
weight: 1.0
|
||||||
|
mode: "l2"
|
||||||
|
model_name_pairs:
|
||||||
|
- ["Student", "Teacher"]
|
||||||
|
key: backbone_out
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: DistillationCTCLabelDecode
|
||||||
|
model_name: ["Student", "Teacher"]
|
||||||
|
key: head_out
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: DistillationMetric
|
||||||
|
base_metric_name: RecMetric
|
||||||
|
main_indicator: acc
|
||||||
|
key: "Student"
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/
|
||||||
|
label_file_list:
|
||||||
|
- ./train_data/train_list.txt
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: false
|
||||||
|
- RecAug:
|
||||||
|
- CTCLabelEncode:
|
||||||
|
- RecResizeImg:
|
||||||
|
image_shape: [3, 32, 320]
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys:
|
||||||
|
- image
|
||||||
|
- label
|
||||||
|
- length
|
||||||
|
loader:
|
||||||
|
shuffle: true
|
||||||
|
batch_size_per_card: 128
|
||||||
|
drop_last: true
|
||||||
|
num_sections: 1
|
||||||
|
num_workers: 8
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data
|
||||||
|
label_file_list:
|
||||||
|
- ./train_data/val_list.txt
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: false
|
||||||
|
- CTCLabelEncode:
|
||||||
|
- RecResizeImg:
|
||||||
|
image_shape: [3, 32, 320]
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys:
|
||||||
|
- image
|
||||||
|
- label
|
||||||
|
- length
|
||||||
|
loader:
|
||||||
|
shuffle: false
|
||||||
|
drop_last: false
|
||||||
|
batch_size_per_card: 128
|
||||||
|
num_workers: 8
|
|
@ -39,7 +39,7 @@ PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要
|
||||||
|
|
||||||
### 2.1 识别配置文件解析
|
### 2.1 识别配置文件解析
|
||||||
|
|
||||||
配置文件在[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml)。
|
配置文件在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)。
|
||||||
|
|
||||||
#### 2.1.1 模型结构
|
#### 2.1.1 模型结构
|
||||||
|
|
||||||
|
@ -246,6 +246,39 @@ Metric:
|
||||||
关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)。
|
关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)。
|
||||||
|
|
||||||
|
|
||||||
|
#### 2.1.5 蒸馏模型微调
|
||||||
|
|
||||||
|
对蒸馏得到的识别蒸馏进行微调有2种方式。
|
||||||
|
|
||||||
|
(1)基于知识蒸馏的微调:这种情况比较简单,下载预训练模型,在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)中配置好预训练模型路径以及自己的数据路径,即可进行模型微调训练。
|
||||||
|
|
||||||
|
(2)微调时不使用知识蒸馏:这种情况,需要首先将预训练模型中的学生模型参数提取出来,具体步骤如下。
|
||||||
|
|
||||||
|
* 首先下载预训练模型并解压。
|
||||||
|
```shell
|
||||||
|
# 下面预训练模型并解压
|
||||||
|
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar
|
||||||
|
tar -xf ch_PP-OCRv2_rec_train.tar
|
||||||
|
```
|
||||||
|
|
||||||
|
* 然后使用python,对其中的学生模型参数进行提取
|
||||||
|
|
||||||
|
```python
|
||||||
|
import paddle
|
||||||
|
# 加载预训练模型
|
||||||
|
all_params = paddle.load("ch_PP-OCRv2_rec_train/best_accuracy.pdparams")
|
||||||
|
# 查看权重参数的keys
|
||||||
|
print(all_params.keys())
|
||||||
|
# 学生模型的权重提取
|
||||||
|
s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}
|
||||||
|
# 查看学生模型权重参数的keys
|
||||||
|
print(s_params.keys())
|
||||||
|
# 保存
|
||||||
|
paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams")
|
||||||
|
```
|
||||||
|
|
||||||
|
转化完成之后,使用[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml),修改预训练模型的路径(为导出的`student.pdparams`模型路径)以及自己的数据路径,即可进行模型微调。
|
||||||
|
|
||||||
### 2.2 检测配置文件解析
|
### 2.2 检测配置文件解析
|
||||||
|
|
||||||
* coming soon!
|
* coming soon!
|
||||||
|
|
|
@ -56,31 +56,34 @@ class CELoss(nn.Layer):
|
||||||
|
|
||||||
class KLJSLoss(object):
|
class KLJSLoss(object):
|
||||||
def __init__(self, mode='kl'):
|
def __init__(self, mode='kl'):
|
||||||
assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
|
assert mode in ['kl', 'js', 'KL', 'JS'
|
||||||
|
], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
|
||||||
def __call__(self, p1, p2, reduction="mean"):
|
def __call__(self, p1, p2, reduction="mean"):
|
||||||
|
|
||||||
loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5))
|
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
|
||||||
|
|
||||||
if self.mode.lower() == "js":
|
if self.mode.lower() == "js":
|
||||||
loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5))
|
loss += paddle.multiply(
|
||||||
|
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
|
||||||
loss *= 0.5
|
loss *= 0.5
|
||||||
if reduction == "mean":
|
if reduction == "mean":
|
||||||
loss = paddle.mean(loss, axis=[1,2])
|
loss = paddle.mean(loss, axis=[1, 2])
|
||||||
elif reduction=="none" or reduction is None:
|
elif reduction == "none" or reduction is None:
|
||||||
return loss
|
return loss
|
||||||
else:
|
else:
|
||||||
loss = paddle.sum(loss, axis=[1,2])
|
loss = paddle.sum(loss, axis=[1, 2])
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class DMLLoss(nn.Layer):
|
class DMLLoss(nn.Layer):
|
||||||
"""
|
"""
|
||||||
DMLLoss
|
DMLLoss
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, act=None):
|
def __init__(self, act=None, use_log=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if act is not None:
|
if act is not None:
|
||||||
assert act in ["softmax", "sigmoid"]
|
assert act in ["softmax", "sigmoid"]
|
||||||
|
@ -91,19 +94,23 @@ class DMLLoss(nn.Layer):
|
||||||
else:
|
else:
|
||||||
self.act = None
|
self.act = None
|
||||||
|
|
||||||
|
self.use_log = use_log
|
||||||
|
|
||||||
self.jskl_loss = KLJSLoss(mode="js")
|
self.jskl_loss = KLJSLoss(mode="js")
|
||||||
|
|
||||||
def forward(self, out1, out2):
|
def forward(self, out1, out2):
|
||||||
if self.act is not None:
|
if self.act is not None:
|
||||||
out1 = self.act(out1)
|
out1 = self.act(out1)
|
||||||
out2 = self.act(out2)
|
out2 = self.act(out2)
|
||||||
if len(out1.shape) < 2:
|
if self.use_log:
|
||||||
|
# for recognition distillation, log is needed for feature map
|
||||||
log_out1 = paddle.log(out1)
|
log_out1 = paddle.log(out1)
|
||||||
log_out2 = paddle.log(out2)
|
log_out2 = paddle.log(out2)
|
||||||
loss = (F.kl_div(
|
loss = (F.kl_div(
|
||||||
log_out1, out2, reduction='batchmean') + F.kl_div(
|
log_out1, out2, reduction='batchmean') + F.kl_div(
|
||||||
log_out2, out1, reduction='batchmean')) / 2.0
|
log_out2, out1, reduction='batchmean')) / 2.0
|
||||||
else:
|
else:
|
||||||
|
# for detection distillation log is not needed
|
||||||
loss = self.jskl_loss(out1, out2)
|
loss = self.jskl_loss(out1, out2)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
|
@ -49,11 +49,15 @@ class CombinedLoss(nn.Layer):
|
||||||
loss = loss_func(input, batch, **kargs)
|
loss = loss_func(input, batch, **kargs)
|
||||||
if isinstance(loss, paddle.Tensor):
|
if isinstance(loss, paddle.Tensor):
|
||||||
loss = {"loss_{}_{}".format(str(loss), idx): loss}
|
loss = {"loss_{}_{}".format(str(loss), idx): loss}
|
||||||
|
|
||||||
weight = self.loss_weight[idx]
|
weight = self.loss_weight[idx]
|
||||||
for key in loss.keys():
|
|
||||||
if key == "loss":
|
loss = {key: loss[key] * weight for key in loss}
|
||||||
loss_all += loss[key] * weight
|
|
||||||
else:
|
if "loss" in loss:
|
||||||
loss_dict["{}_{}".format(key, idx)] = loss[key]
|
loss_all += loss["loss"]
|
||||||
|
else:
|
||||||
|
loss_all += paddle.add_n(list(loss.values()))
|
||||||
|
loss_dict.update(loss)
|
||||||
loss_dict["loss"] = loss_all
|
loss_dict["loss"] = loss_all
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
|
@ -44,10 +44,11 @@ class DistillationDMLLoss(DMLLoss):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_name_pairs=[],
|
model_name_pairs=[],
|
||||||
act=None,
|
act=None,
|
||||||
|
use_log=False,
|
||||||
key=None,
|
key=None,
|
||||||
maps_name=None,
|
maps_name=None,
|
||||||
name="dml"):
|
name="dml"):
|
||||||
super().__init__(act=act)
|
super().__init__(act=act, use_log=use_log)
|
||||||
assert isinstance(model_name_pairs, list)
|
assert isinstance(model_name_pairs, list)
|
||||||
self.key = key
|
self.key = key
|
||||||
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
||||||
|
@ -57,7 +58,8 @@ class DistillationDMLLoss(DMLLoss):
|
||||||
def _check_model_name_pairs(self, model_name_pairs):
|
def _check_model_name_pairs(self, model_name_pairs):
|
||||||
if not isinstance(model_name_pairs, list):
|
if not isinstance(model_name_pairs, list):
|
||||||
return []
|
return []
|
||||||
elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str):
|
elif isinstance(model_name_pairs[0], list) and isinstance(
|
||||||
|
model_name_pairs[0][0], str):
|
||||||
return model_name_pairs
|
return model_name_pairs
|
||||||
else:
|
else:
|
||||||
return [model_name_pairs]
|
return [model_name_pairs]
|
||||||
|
@ -112,8 +114,8 @@ class DistillationDMLLoss(DMLLoss):
|
||||||
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
||||||
0], pair[1], map_name, idx)] = loss[key]
|
0], pair[1], map_name, idx)] = loss[key]
|
||||||
else:
|
else:
|
||||||
loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c],
|
loss_dict["{}_{}_{}".format(self.name, self.maps_name[
|
||||||
idx)] = loss
|
_c], idx)] = loss
|
||||||
|
|
||||||
loss_dict = _sum_loss(loss_dict)
|
loss_dict = _sum_loss(loss_dict)
|
||||||
|
|
||||||
|
|
|
@ -108,14 +108,15 @@ def load_dygraph_params(config, model, logger, optimizer):
|
||||||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||||
new_state_dict[k1] = params[k2]
|
new_state_dict[k1] = params[k2]
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||||
)
|
)
|
||||||
model.set_state_dict(new_state_dict)
|
model.set_state_dict(new_state_dict)
|
||||||
logger.info(f"loaded pretrained_model successful from {pm}")
|
logger.info(f"loaded pretrained_model successful from {pm}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def load_pretrained_params(model, path):
|
def load_pretrained_params(model, path):
|
||||||
if path is None:
|
if path is None:
|
||||||
return False
|
return False
|
||||||
|
@ -138,6 +139,7 @@ def load_pretrained_params(model, path):
|
||||||
print(f"load pretrain successful from {path}")
|
print(f"load pretrain successful from {path}")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def save_model(model,
|
def save_model(model,
|
||||||
optimizer,
|
optimizer,
|
||||||
model_path,
|
model_path,
|
||||||
|
|
Loading…
Reference in New Issue