fix seed typo
This commit is contained in:
commit
d5e6df05ca
|
@ -0,0 +1,126 @@
|
|||
Global:
|
||||
debug: false
|
||||
use_gpu: true
|
||||
epoch_num: 800
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec_mobile_pp-OCRv2_enhanced_ctc_loss
|
||||
save_epoch_step: 3
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: true
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: false
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
|
||||
character_type: ch
|
||||
max_text_length: 25
|
||||
infer_mode: false
|
||||
use_space_char: true
|
||||
distributed: true
|
||||
save_res_path: ./output/rec/predicts_mobile_pp-OCRv2_enhanced_ctc_loss.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs : [700, 800]
|
||||
values : [0.001, 0.0001]
|
||||
warmup_epoch: 5
|
||||
regularizer:
|
||||
name: L2
|
||||
factor: 2.0e-05
|
||||
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
return_feats: true
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- CTCLoss:
|
||||
use_focal_loss: false
|
||||
weight: 1.0
|
||||
- CenterLoss:
|
||||
weight: 0.05
|
||||
num_classes: 6625
|
||||
feat_dim: 96
|
||||
init_center: false
|
||||
center_file_path: "./train_center.pkl"
|
||||
# you can also try to add ace loss on your own dataset
|
||||
# - ACELoss:
|
||||
# weight: 0.1
|
||||
|
||||
PostProcess:
|
||||
name: CTCLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/
|
||||
label_file_list:
|
||||
- ./train_data/train_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- RecAug:
|
||||
- CTCLabelEncode:
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 320]
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label
|
||||
- length
|
||||
- label_ace
|
||||
loader:
|
||||
shuffle: true
|
||||
batch_size_per_card: 128
|
||||
drop_last: true
|
||||
num_workers: 8
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data
|
||||
label_file_list:
|
||||
- ./train_data/val_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- CTCLabelEncode:
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 320]
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label
|
||||
- length
|
||||
loader:
|
||||
shuffle: false
|
||||
drop_last: false
|
||||
batch_size_per_card: 128
|
||||
num_workers: 8
|
|
@ -215,6 +215,11 @@ class CTCLabelEncode(BaseRecLabelEncode):
|
|||
data['length'] = np.array(len(text))
|
||||
text = text + [0] * (self.max_text_len - len(text))
|
||||
data['label'] = np.array(text)
|
||||
|
||||
label = [0] * len(self.character)
|
||||
for x in text:
|
||||
label[x] += 1
|
||||
data['label_ace'] = np.array(label)
|
||||
return data
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
|
|
|
@ -52,7 +52,6 @@ def build_loss(config):
|
|||
'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
|
||||
'TableAttentionLoss', 'SARLoss', 'AsterLoss'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('loss only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
|
||||
class ACELoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(
|
||||
weight=None,
|
||||
ignore_index=0,
|
||||
reduction='none',
|
||||
soft_label=True,
|
||||
axis=-1)
|
||||
|
||||
def __call__(self, predicts, batch):
|
||||
if isinstance(predicts, (list, tuple)):
|
||||
predicts = predicts[-1]
|
||||
B, N = predicts.shape[:2]
|
||||
div = paddle.to_tensor([N]).astype('float32')
|
||||
|
||||
predicts = nn.functional.softmax(predicts, axis=-1)
|
||||
aggregation_preds = paddle.sum(predicts, axis=1)
|
||||
aggregation_preds = paddle.divide(aggregation_preds, div)
|
||||
|
||||
length = batch[2].astype("float32")
|
||||
batch = batch[3].astype("float32")
|
||||
batch[:, 0] = paddle.subtract(div, length)
|
||||
|
||||
batch = paddle.divide(batch, div)
|
||||
|
||||
loss = self.loss_func(aggregation_preds, batch)
|
||||
|
||||
return {"loss_ace": loss}
|
|
@ -0,0 +1,89 @@
|
|||
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class CenterLoss(nn.Layer):
|
||||
"""
|
||||
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes=6625,
|
||||
feat_dim=96,
|
||||
init_center=False,
|
||||
center_file_path=None):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.feat_dim = feat_dim
|
||||
self.centers = paddle.randn(
|
||||
shape=[self.num_classes, self.feat_dim]).astype(
|
||||
"float64") #random center
|
||||
|
||||
if init_center:
|
||||
assert os.path.exists(
|
||||
center_file_path
|
||||
), f"center path({center_file_path}) must exist when init_center is set as True."
|
||||
with open(center_file_path, 'rb') as f:
|
||||
char_dict = pickle.load(f)
|
||||
for key in char_dict.keys():
|
||||
self.centers[key] = paddle.to_tensor(char_dict[key])
|
||||
|
||||
def __call__(self, predicts, batch):
|
||||
assert isinstance(predicts, (list, tuple))
|
||||
features, predicts = predicts
|
||||
|
||||
feats_reshape = paddle.reshape(
|
||||
features, [-1, features.shape[-1]]).astype("float64")
|
||||
label = paddle.argmax(predicts, axis=2)
|
||||
label = paddle.reshape(label, [label.shape[0] * label.shape[1]])
|
||||
|
||||
batch_size = feats_reshape.shape[0]
|
||||
|
||||
#calc feat * feat
|
||||
dist1 = paddle.sum(paddle.square(feats_reshape), axis=1, keepdim=True)
|
||||
dist1 = paddle.expand(dist1, [batch_size, self.num_classes])
|
||||
|
||||
#dist2 of centers
|
||||
dist2 = paddle.sum(paddle.square(self.centers), axis=1,
|
||||
keepdim=True) #num_classes
|
||||
dist2 = paddle.expand(dist2,
|
||||
[self.num_classes, batch_size]).astype("float64")
|
||||
dist2 = paddle.transpose(dist2, [1, 0])
|
||||
|
||||
#first x * x + y * y
|
||||
distmat = paddle.add(dist1, dist2)
|
||||
tmp = paddle.matmul(feats_reshape,
|
||||
paddle.transpose(self.centers, [1, 0]))
|
||||
distmat = distmat - 2.0 * tmp
|
||||
|
||||
#generate the mask
|
||||
classes = paddle.arange(self.num_classes).astype("int64")
|
||||
label = paddle.expand(
|
||||
paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
|
||||
mask = paddle.equal(
|
||||
paddle.expand(classes, [batch_size, self.num_classes]),
|
||||
label).astype("float64") #get mask
|
||||
dist = paddle.multiply(distmat, mask)
|
||||
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
|
||||
return {'loss_center': loss}
|
|
@ -15,6 +15,10 @@
|
|||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
from .center_loss import CenterLoss
|
||||
from .ace_loss import ACELoss
|
||||
|
||||
from .distillation_loss import DistillationCTCLoss
|
||||
from .distillation_loss import DistillationDMLLoss
|
||||
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
|
||||
|
|
|
@ -112,7 +112,7 @@ class DistillationDMLLoss(DMLLoss):
|
|||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
||||
0], pair[1], map_name, idx)] = loss[key]
|
||||
0], pair[1], self.maps_name, idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}_{}".format(self.name, self.maps_name[
|
||||
_c], idx)] = loss
|
||||
|
|
|
@ -21,16 +21,24 @@ from paddle import nn
|
|||
|
||||
|
||||
class CTCLoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, use_focal_loss=False, **kwargs):
|
||||
super(CTCLoss, self).__init__()
|
||||
self.loss_func = nn.CTCLoss(blank=0, reduction='none')
|
||||
self.use_focal_loss = use_focal_loss
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
if isinstance(predicts, (list, tuple)):
|
||||
predicts = predicts[-1]
|
||||
predicts = predicts.transpose((1, 0, 2))
|
||||
N, B, _ = predicts.shape
|
||||
preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
|
||||
labels = batch[1].astype("int32")
|
||||
label_lengths = batch[2].astype('int64')
|
||||
loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
|
||||
if self.use_focal_loss:
|
||||
weight = paddle.exp(-loss)
|
||||
weight = paddle.subtract(paddle.to_tensor([1.0]), weight)
|
||||
weight = paddle.square(weight) * self.focal_loss_alpha
|
||||
loss = paddle.multiply(loss, weight)
|
||||
loss = loss.mean() # sum
|
||||
return {'loss': loss}
|
||||
|
|
|
@ -38,6 +38,7 @@ class CTCHead(nn.Layer):
|
|||
out_channels,
|
||||
fc_decay=0.0004,
|
||||
mid_channels=None,
|
||||
return_feats=False,
|
||||
**kwargs):
|
||||
super(CTCHead, self).__init__()
|
||||
if mid_channels is None:
|
||||
|
@ -66,14 +67,22 @@ class CTCHead(nn.Layer):
|
|||
bias_attr=bias_attr2)
|
||||
self.out_channels = out_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.return_feats = return_feats
|
||||
|
||||
def forward(self, x, targets=None):
|
||||
if self.mid_channels is None:
|
||||
predicts = self.fc(x)
|
||||
else:
|
||||
predicts = self.fc1(x)
|
||||
predicts = self.fc2(predicts)
|
||||
|
||||
x = self.fc1(x)
|
||||
predicts = self.fc2(x)
|
||||
|
||||
if self.return_feats:
|
||||
result = (x, predicts)
|
||||
else:
|
||||
result = predicts
|
||||
|
||||
if not self.training:
|
||||
predicts = F.softmax(predicts, axis=2)
|
||||
return predicts
|
||||
result = predicts
|
||||
|
||||
return result
|
||||
|
|
|
@ -18,6 +18,7 @@ from __future__ import print_function
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import copy
|
||||
import platform
|
||||
|
||||
__all__ = ['build_post_process']
|
||||
|
||||
|
@ -28,7 +29,10 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
|
|||
TableLabelDecode, NRTRLabelDecode, SARLabelDecode , SEEDLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
from .pse_postprocess import PSEPostProcess
|
||||
|
||||
if platform.system() != "Windows":
|
||||
# pse is not support in Windows
|
||||
from .pse_postprocess import PSEPostProcess
|
||||
|
||||
|
||||
def build_post_process(config, global_config=None):
|
||||
|
|
|
@ -111,6 +111,8 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
|||
character_type, use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[-1]
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
|
|
|
@ -49,6 +49,12 @@ def export_single_model(model, arch_config, save_path, logger):
|
|||
]
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "SAR":
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 48, 160], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
else:
|
||||
infer_shape = [3, -1, -1]
|
||||
if arch_config["model_type"] == "rec":
|
||||
|
|
|
@ -68,6 +68,13 @@ class TextRecognizer(object):
|
|||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
elif self.rec_algorithm == "SAR":
|
||||
postprocess_params = {
|
||||
'name': 'SARLabelDecode',
|
||||
"character_type": args.rec_char_type,
|
||||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
utility.create_predictor(args, 'rec', logger)
|
||||
|
@ -194,6 +201,41 @@ class TextRecognizer(object):
|
|||
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2)
|
||||
|
||||
def resize_norm_img_sar(self, img, image_shape,
|
||||
width_downsample_ratio=0.25):
|
||||
imgC, imgH, imgW_min, imgW_max = image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
valid_ratio = 1.0
|
||||
# make sure new_width is an integral multiple of width_divisor.
|
||||
width_divisor = int(1 / width_downsample_ratio)
|
||||
# resize
|
||||
ratio = w / float(h)
|
||||
resize_w = math.ceil(imgH * ratio)
|
||||
if resize_w % width_divisor != 0:
|
||||
resize_w = round(resize_w / width_divisor) * width_divisor
|
||||
if imgW_min is not None:
|
||||
resize_w = max(imgW_min, resize_w)
|
||||
if imgW_max is not None:
|
||||
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
|
||||
resize_w = min(imgW_max, resize_w)
|
||||
resized_image = cv2.resize(img, (resize_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
# norm
|
||||
if image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
else:
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
resize_shape = resized_image.shape
|
||||
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
|
||||
padding_im[:, :, 0:resize_w] = resized_image
|
||||
pad_shape = padding_im.shape
|
||||
|
||||
return padding_im, resize_shape, pad_shape, valid_ratio
|
||||
|
||||
def __call__(self, img_list):
|
||||
img_num = len(img_list)
|
||||
# Calculate the aspect ratio of all text bars
|
||||
|
@ -216,11 +258,19 @@ class TextRecognizer(object):
|
|||
wh_ratio = w * 1.0 / h
|
||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
if self.rec_algorithm != "SRN":
|
||||
if self.rec_algorithm != "SRN" and self.rec_algorithm != "SAR":
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
||||
max_wh_ratio)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
elif self.rec_algorithm == "SAR":
|
||||
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
|
||||
img_list[indices[ino]], self.rec_image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
valid_ratio = np.expand_dims(valid_ratio, axis=0)
|
||||
valid_ratios = []
|
||||
valid_ratios.append(valid_ratio)
|
||||
norm_img_batch.append(norm_img)
|
||||
else:
|
||||
norm_img = self.process_image_srn(
|
||||
img_list[indices[ino]], self.rec_image_shape, 8, 25)
|
||||
|
@ -266,6 +316,25 @@ class TextRecognizer(object):
|
|||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
preds = {"predict": outputs[2]}
|
||||
elif self.rec_algorithm == "SAR":
|
||||
valid_ratios = np.concatenate(valid_ratios)
|
||||
inputs = [
|
||||
norm_img_batch,
|
||||
valid_ratios,
|
||||
]
|
||||
input_names = self.predictor.get_input_names()
|
||||
for i in range(len(input_names)):
|
||||
input_tensor = self.predictor.get_input_handle(input_names[
|
||||
i])
|
||||
input_tensor.copy_from_cpu(inputs[i])
|
||||
self.predictor.run()
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
preds = outputs[0]
|
||||
else:
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
self.predictor.run()
|
||||
|
|
|
@ -394,20 +394,6 @@ def preprocess(is_train=False):
|
|||
config = load_config(FLAGS.config)
|
||||
merge_config(FLAGS.opt)
|
||||
|
||||
# check if set use_gpu=True in paddlepaddle cpu version
|
||||
use_gpu = config['Global']['use_gpu']
|
||||
check_gpu(use_gpu)
|
||||
|
||||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED']
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
device = paddle.set_device(device)
|
||||
|
||||
config['Global']['distributed'] = dist.get_world_size() != 1
|
||||
if is_train:
|
||||
# save_config
|
||||
save_model_dir = config['Global']['save_model_dir']
|
||||
|
@ -419,6 +405,27 @@ def preprocess(is_train=False):
|
|||
else:
|
||||
log_file = None
|
||||
logger = get_logger(name='root', log_file=log_file)
|
||||
|
||||
# check if set use_gpu=True in paddlepaddle cpu version
|
||||
use_gpu = config['Global']['use_gpu']
|
||||
check_gpu(use_gpu)
|
||||
|
||||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED']
|
||||
windows_not_support_list = ['PSE']
|
||||
if platform.system() == "Windows" and alg in windows_not_support_list:
|
||||
logger.warning('{} is not support in Windows now'.format(
|
||||
windows_not_support_list))
|
||||
sys.exit()
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
device = paddle.set_device(device)
|
||||
|
||||
config['Global']['distributed'] = dist.get_world_size() != 1
|
||||
|
||||
if config['Global']['use_visualdl']:
|
||||
from visualdl import LogWriter
|
||||
save_model_dir = config['Global']['save_model_dir']
|
||||
|
|
Loading…
Reference in New Issue