Merge pull request #520 from littletomatodonkey/fix_mv3
add stride config of backbone and fix load ckp
This commit is contained in:
commit
bad9f6cd74
|
@ -247,10 +247,12 @@ class SimpleReader(object):
|
|||
print("multiprocess is not fully compatible with Windows."
|
||||
"num_workers will be 1.")
|
||||
self.num_workers = 1
|
||||
if self.batch_size * get_device_num() > img_num:
|
||||
if self.batch_size * get_device_num(
|
||||
) * self.num_workers > img_num:
|
||||
raise Exception(
|
||||
"The number of the whole data ({}) is smaller than the batch_size * devices_num ({})".
|
||||
format(img_num, self.batch_size * get_device_num()))
|
||||
"The number of the whole data ({}) is smaller than the batch_size * devices_num * num_workers ({})".
|
||||
format(img_num, self.batch_size * get_device_num() *
|
||||
self.num_workers))
|
||||
for img_id in range(process_id, img_num, self.num_workers):
|
||||
label_infor = label_infor_list[img_id_list[img_id]]
|
||||
substr = label_infor.decode('utf-8').strip("\n").split("\t")
|
||||
|
|
|
@ -31,16 +31,28 @@ __all__ = [
|
|||
|
||||
class MobileNetV3():
|
||||
def __init__(self, params):
|
||||
self.scale = params['scale']
|
||||
model_name = params['model_name']
|
||||
self.scale = params.get("scale", 0.5)
|
||||
model_name = params.get("model_name", "small")
|
||||
large_stride = params.get("large_stride", [1, 2, 2, 2])
|
||||
small_stride = params.get("small_stride", [2, 2, 2, 2])
|
||||
|
||||
assert isinstance(large_stride, list), "large_stride type must " \
|
||||
"be list but got {}".format(type(large_stride))
|
||||
assert isinstance(small_stride, list), "small_stride type must " \
|
||||
"be list but got {}".format(type(small_stride))
|
||||
assert len(large_stride) == 4, "large_stride length must be " \
|
||||
"4 but got {}".format(len(large_stride))
|
||||
assert len(small_stride) == 4, "small_stride length must be " \
|
||||
"4 but got {}".format(len(small_stride))
|
||||
|
||||
self.inplanes = 16
|
||||
if model_name == "large":
|
||||
self.cfg = [
|
||||
# k, exp, c, se, nl, s,
|
||||
[3, 16, 16, False, 'relu', 1],
|
||||
[3, 64, 24, False, 'relu', (2, 1)],
|
||||
[3, 16, 16, False, 'relu', large_stride[0]],
|
||||
[3, 64, 24, False, 'relu', (large_stride[1], 1)],
|
||||
[3, 72, 24, False, 'relu', 1],
|
||||
[5, 72, 40, True, 'relu', (2, 1)],
|
||||
[5, 72, 40, True, 'relu', (large_stride[2], 1)],
|
||||
[5, 120, 40, True, 'relu', 1],
|
||||
[5, 120, 40, True, 'relu', 1],
|
||||
[3, 240, 80, False, 'hard_swish', 1],
|
||||
|
@ -49,7 +61,7 @@ class MobileNetV3():
|
|||
[3, 184, 80, False, 'hard_swish', 1],
|
||||
[3, 480, 112, True, 'hard_swish', 1],
|
||||
[3, 672, 112, True, 'hard_swish', 1],
|
||||
[5, 672, 160, True, 'hard_swish', (2, 1)],
|
||||
[5, 672, 160, True, 'hard_swish', (large_stride[3], 1)],
|
||||
[5, 960, 160, True, 'hard_swish', 1],
|
||||
[5, 960, 160, True, 'hard_swish', 1],
|
||||
]
|
||||
|
@ -58,15 +70,15 @@ class MobileNetV3():
|
|||
elif model_name == "small":
|
||||
self.cfg = [
|
||||
# k, exp, c, se, nl, s,
|
||||
[3, 16, 16, True, 'relu', (2, 1)],
|
||||
[3, 72, 24, False, 'relu', (2, 1)],
|
||||
[3, 16, 16, True, 'relu', (small_stride[0], 1)],
|
||||
[3, 72, 24, False, 'relu', (small_stride[1], 1)],
|
||||
[3, 88, 24, False, 'relu', 1],
|
||||
[5, 96, 40, True, 'hard_swish', (2, 1)],
|
||||
[5, 96, 40, True, 'hard_swish', (small_stride[2], 1)],
|
||||
[5, 240, 40, True, 'hard_swish', 1],
|
||||
[5, 240, 40, True, 'hard_swish', 1],
|
||||
[5, 120, 48, True, 'hard_swish', 1],
|
||||
[5, 144, 48, True, 'hard_swish', 1],
|
||||
[5, 288, 96, True, 'hard_swish', (2, 1)],
|
||||
[5, 288, 96, True, 'hard_swish', (small_stride[3], 1)],
|
||||
[5, 576, 96, True, 'hard_swish', 1],
|
||||
[5, 576, 96, True, 'hard_swish', 1],
|
||||
]
|
||||
|
|
|
@ -32,6 +32,7 @@ class CTCPredict(object):
|
|||
self.char_num = params['char_num']
|
||||
self.encoder = SequenceEncoder(params)
|
||||
self.encoder_type = params['encoder_type']
|
||||
self.fc_decay = params.get("fc_decay", 0.0004)
|
||||
|
||||
def __call__(self, inputs, labels=None, mode=None):
|
||||
encoder_features = self.encoder(inputs)
|
||||
|
@ -39,7 +40,7 @@ class CTCPredict(object):
|
|||
encoder_features = fluid.layers.concat(encoder_features, axis=1)
|
||||
name = "ctc_fc"
|
||||
para_attr, bias_attr = get_para_bias_attr(
|
||||
l2_decay=0.0004, k=encoder_features.shape[1], name=name)
|
||||
l2_decay=self.fc_decay, k=encoder_features.shape[1], name=name)
|
||||
predict = fluid.layers.fc(input=encoder_features,
|
||||
size=self.char_num + 1,
|
||||
param_attr=para_attr,
|
||||
|
|
|
@ -114,10 +114,10 @@ def init_model(config, program, exe):
|
|||
fluid.load(program, path, exe)
|
||||
logger.info("Finish initing model from {}".format(path))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Model checkpoints {} does not exists,"
|
||||
"check if you lost the file prefix.".format(checkpoints + '.pdparams'))
|
||||
|
||||
raise ValueError("Model checkpoints {} does not exists,"
|
||||
"check if you lost the file prefix.".format(
|
||||
checkpoints + '.pdparams'))
|
||||
else:
|
||||
pretrain_weights = config['Global'].get('pretrain_weights')
|
||||
if pretrain_weights:
|
||||
path = pretrain_weights
|
||||
|
|
Loading…
Reference in New Issue