Merge branch 'reborn' of https://github.com/iclementine/Parakeet into reborn
This commit is contained in:
commit
f255eee029
|
@ -12,10 +12,10 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from parakeet.models.clarinet import *
|
#from parakeet.models.clarinet import *
|
||||||
from parakeet.models.waveflow import *
|
from parakeet.models.waveflow import *
|
||||||
from parakeet.models.wavenet import *
|
#from parakeet.models.wavenet import *
|
||||||
|
|
||||||
from parakeet.models.transformer_tts import *
|
from parakeet.models.transformer_tts import *
|
||||||
from parakeet.models.deepvoice3 import *
|
#from parakeet.models.deepvoice3 import *
|
||||||
# from parakeet.models.fastspeech import *
|
# from parakeet.models.fastspeech import *
|
||||||
|
|
|
@ -273,12 +273,14 @@ class MLPPreNet(nn.Layer):
|
||||||
super(MLPPreNet, self).__init__()
|
super(MLPPreNet, self).__init__()
|
||||||
self.lin1 = nn.Linear(d_input, d_hidden)
|
self.lin1 = nn.Linear(d_input, d_hidden)
|
||||||
self.lin2 = nn.Linear(d_hidden, d_hidden)
|
self.lin2 = nn.Linear(d_hidden, d_hidden)
|
||||||
|
self.lin3 = nn.Linear(d_hidden, d_hidden)
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
|
||||||
def forward(self, x, dropout):
|
def forward(self, x, dropout):
|
||||||
l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training)
|
l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training)
|
||||||
l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training)
|
l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training)
|
||||||
return l2
|
l3 = self.lin3(l2)
|
||||||
|
return l3
|
||||||
|
|
||||||
# NOTE: not used in
|
# NOTE: not used in
|
||||||
class CNNPreNet(nn.Layer):
|
class CNNPreNet(nn.Layer):
|
||||||
|
@ -317,6 +319,7 @@ class CNNPostNet(nn.Layer):
|
||||||
Conv1dBatchNorm(c_in, c_out, kernel_size,
|
Conv1dBatchNorm(c_in, c_out, kernel_size,
|
||||||
weight_attr=I.XavierUniform(),
|
weight_attr=I.XavierUniform(),
|
||||||
padding=padding))
|
padding=padding))
|
||||||
|
self.last_bn = nn.BatchNorm1D(d_output)
|
||||||
# for a layer that ends with a normalization layer that is targeted to
|
# for a layer that ends with a normalization layer that is targeted to
|
||||||
# output a non zero-central output, it may take a long time to
|
# output a non zero-central output, it may take a long time to
|
||||||
# train the scale and bias
|
# train the scale and bias
|
||||||
|
@ -328,7 +331,7 @@ class CNNPostNet(nn.Layer):
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
if i != (len(self.convs) - 1):
|
if i != (len(self.convs) - 1):
|
||||||
x = F.tanh(x)
|
x = F.tanh(x)
|
||||||
x = x_in + x
|
x = self.last_bn(x_in + x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -491,7 +494,7 @@ class TransformerTTS(nn.Layer):
|
||||||
decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1)
|
decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1)
|
||||||
|
|
||||||
# stop condition: (if any ouput frame of the output multiframes hits the stop condition)
|
# stop condition: (if any ouput frame of the output multiframes hits the stop condition)
|
||||||
if paddle.any(paddle.argmax(stop_logits[0, :, :], axis=-1) == self.stop_prob_index):
|
if paddle.any(paddle.argmax(stop_logits[0, -self.r:, :], axis=-1) == self.stop_prob_index):
|
||||||
if verbose:
|
if verbose:
|
||||||
print("Hits stop condition.")
|
print("Hits stop condition.")
|
||||||
break
|
break
|
||||||
|
@ -526,6 +529,34 @@ class TransformerTTSLoss(nn.Layer):
|
||||||
stop_loss = L.masked_softmax_with_cross_entropy(
|
stop_loss = L.masked_softmax_with_cross_entropy(
|
||||||
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
|
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
|
||||||
|
|
||||||
|
loss = mel_loss1 + mel_loss2 + stop_loss
|
||||||
|
losses = dict(
|
||||||
|
loss=loss, # total loss
|
||||||
|
mel_loss1=mel_loss1, # ouput mel loss
|
||||||
|
mel_loss2=mel_loss2, # intermediate mel loss
|
||||||
|
stop_loss=stop_loss # stop prob loss
|
||||||
|
)
|
||||||
|
return losses
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveTransformerTTSLoss(nn.Layer):
|
||||||
|
def __init__(self):
|
||||||
|
super(AdaptiveTransformerTTSLoss, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs):
|
||||||
|
mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype)
|
||||||
|
mask1 = paddle.unsqueeze(mask, -1)
|
||||||
|
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
|
||||||
|
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
|
||||||
|
|
||||||
|
batch_size, mel_len = mask.shape
|
||||||
|
valid_lengths = mask.sum(-1).astype("int64")
|
||||||
|
last_position = F.one_hot(valid_lengths - 1, num_classes=mel_len)
|
||||||
|
stop_loss_scale = valid_lengths.sum() / batch_size - 1
|
||||||
|
mask2 = mask + last_position.scale(stop_loss_scale - 1).astype(mask.dtype)
|
||||||
|
stop_loss = L.masked_softmax_with_cross_entropy(
|
||||||
|
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
|
||||||
|
|
||||||
loss = mel_loss1 + mel_loss2 + stop_loss
|
loss = mel_loss1 + mel_loss2 + stop_loss
|
||||||
losses = dict(
|
losses = dict(
|
||||||
loss=loss, # total loss
|
loss=loss, # total loss
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import numba
|
||||||
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
|
@ -12,7 +14,7 @@ def weighted_mean(input, weight):
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: shape(1,), weighted mean tensor with the same dtype as input.
|
Tensor: shape(1,), weighted mean tensor with the same dtype as input.
|
||||||
"""
|
"""
|
||||||
weight = paddle.cast(weight, input.dtype)
|
weight = paddle.cast(weight, input.dtype)
|
||||||
return paddle.mean(input * weight)
|
return paddle.mean(input * weight)
|
||||||
|
|
||||||
def masked_l1_loss(prediction, target, mask):
|
def masked_l1_loss(prediction, target, mask):
|
||||||
|
@ -22,3 +24,32 @@ def masked_l1_loss(prediction, target, mask):
|
||||||
def masked_softmax_with_cross_entropy(logits, label, mask, axis=-1):
|
def masked_softmax_with_cross_entropy(logits, label, mask, axis=-1):
|
||||||
ce = F.softmax_with_cross_entropy(logits, label, axis=axis)
|
ce = F.softmax_with_cross_entropy(logits, label, axis=axis)
|
||||||
return weighted_mean(ce, mask)
|
return weighted_mean(ce, mask)
|
||||||
|
|
||||||
|
def diagonal_loss(attentions, input_lengths, target_lengths, g=0.2, multihead=False):
|
||||||
|
"""A metric to evaluate how diagonal a attention distribution is."""
|
||||||
|
W = guided_attentions(input_lengths, target_lengths, g)
|
||||||
|
W_tensor = paddle.to_tensor(W)
|
||||||
|
if not multihead:
|
||||||
|
return paddle.mean(attentions * W_tensor)
|
||||||
|
else:
|
||||||
|
return paddle.mean(attentions * paddle.unsqueeze(W_tensor, 1))
|
||||||
|
|
||||||
|
@numba.jit(nopython=True)
|
||||||
|
def guided_attention(N, max_N, T, max_T, g):
|
||||||
|
W = np.zeros((max_T, max_N), dtype=np.float32)
|
||||||
|
for t in range(T):
|
||||||
|
for n in range(N):
|
||||||
|
W[t, n] = 1 - np.exp(-(n / N - t / T)**2 / (2 * g * g))
|
||||||
|
# (T_dec, T_enc)
|
||||||
|
return W
|
||||||
|
|
||||||
|
def guided_attentions(input_lengths, target_lengths, g=0.2):
|
||||||
|
B = len(input_lengths)
|
||||||
|
max_input_len = input_lengths.max()
|
||||||
|
max_target_len = target_lengths.max()
|
||||||
|
W = np.zeros((B, max_target_len, max_input_len), dtype=np.float32)
|
||||||
|
for b in range(B):
|
||||||
|
W[b] = guided_attention(input_lengths[b], max_input_len,
|
||||||
|
target_lengths[b], max_target_len, g)
|
||||||
|
# (B, T_dec, T_enc)
|
||||||
|
return W
|
|
@ -1,202 +0,0 @@
|
||||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import math
|
|
||||||
import numpy as np
|
|
||||||
import paddle.fluid as fluid
|
|
||||||
import paddle.fluid.dygraph as dg
|
|
||||||
import paddle.fluid.layers as layers
|
|
||||||
|
|
||||||
|
|
||||||
class Linear(dg.Layer):
|
|
||||||
def __init__(self,
|
|
||||||
in_features,
|
|
||||||
out_features,
|
|
||||||
is_bias=True,
|
|
||||||
dtype="float32"):
|
|
||||||
super(Linear, self).__init__()
|
|
||||||
self.in_features = in_features
|
|
||||||
self.out_features = out_features
|
|
||||||
self.dtype = dtype
|
|
||||||
self.weight = fluid.ParamAttr(
|
|
||||||
initializer=fluid.initializer.XavierInitializer())
|
|
||||||
self.bias = is_bias
|
|
||||||
|
|
||||||
if is_bias is not False:
|
|
||||||
k = math.sqrt(1.0 / in_features)
|
|
||||||
self.bias = fluid.ParamAttr(initializer=fluid.initializer.Uniform(
|
|
||||||
low=-k, high=k))
|
|
||||||
|
|
||||||
self.linear = dg.Linear(
|
|
||||||
in_features,
|
|
||||||
out_features,
|
|
||||||
param_attr=self.weight,
|
|
||||||
bias_attr=self.bias, )
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.linear(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ScaledDotProductAttention(dg.Layer):
|
|
||||||
def __init__(self, d_key):
|
|
||||||
"""Scaled dot product attention module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
d_key (int): the dim of key in multihead attention.
|
|
||||||
"""
|
|
||||||
super(ScaledDotProductAttention, self).__init__()
|
|
||||||
|
|
||||||
self.d_key = d_key
|
|
||||||
|
|
||||||
# please attention this mask is diff from pytorch
|
|
||||||
def forward(self,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
query,
|
|
||||||
mask=None,
|
|
||||||
query_mask=None,
|
|
||||||
dropout=0.1):
|
|
||||||
"""
|
|
||||||
Compute scaled dot product attention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key (Variable): shape(B, T, C), dtype float32, the input key of scaled dot product attention.
|
|
||||||
value (Variable): shape(B, T, C), dtype float32, the input value of scaled dot product attention.
|
|
||||||
query (Variable): shape(B, T, C), dtype float32, the input query of scaled dot product attention.
|
|
||||||
mask (Variable, optional): shape(B, T_q, T_k), dtype float32, the mask of key. Defaults to None.
|
|
||||||
query_mask (Variable, optional): shape(B, T_q, T_q), dtype float32, the mask of query. Defaults to None.
|
|
||||||
dropout (float32, optional): the probability of dropout. Defaults to 0.1.
|
|
||||||
Returns:
|
|
||||||
result (Variable): shape(B, T, C), the result of mutihead attention.
|
|
||||||
attention (Variable): shape(n_head * B, T, C), the attention of key.
|
|
||||||
"""
|
|
||||||
# Compute attention score
|
|
||||||
attention = layers.matmul(
|
|
||||||
query, key, transpose_y=True, alpha=self.d_key
|
|
||||||
**-0.5) #transpose the last dim in y
|
|
||||||
|
|
||||||
# Mask key to ignore padding
|
|
||||||
if mask is not None:
|
|
||||||
attention = attention + mask
|
|
||||||
attention = layers.softmax(attention, use_cudnn=True)
|
|
||||||
attention = layers.dropout(
|
|
||||||
attention, dropout, dropout_implementation='upscale_in_train')
|
|
||||||
|
|
||||||
# Mask query to ignore padding
|
|
||||||
if query_mask is not None:
|
|
||||||
attention = attention * query_mask
|
|
||||||
|
|
||||||
result = layers.matmul(attention, value)
|
|
||||||
return result, attention
|
|
||||||
|
|
||||||
|
|
||||||
class MultiheadAttention(dg.Layer):
|
|
||||||
def __init__(self,
|
|
||||||
num_hidden,
|
|
||||||
d_k,
|
|
||||||
d_q,
|
|
||||||
num_head=4,
|
|
||||||
is_bias=False,
|
|
||||||
dropout=0.1,
|
|
||||||
is_concat=True):
|
|
||||||
"""Multihead Attention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_hidden (int): the number of hidden layer in network.
|
|
||||||
d_k (int): the dim of key in multihead attention.
|
|
||||||
d_q (int): the dim of query in multihead attention.
|
|
||||||
num_head (int, optional): the head number of multihead attention. Defaults to 4.
|
|
||||||
is_bias (bool, optional): whether have bias in linear layers. Default to False.
|
|
||||||
dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1.
|
|
||||||
is_concat (bool, optional): whether concat query and result. Default to True.
|
|
||||||
"""
|
|
||||||
super(MultiheadAttention, self).__init__()
|
|
||||||
self.num_hidden = num_hidden
|
|
||||||
self.num_head = num_head
|
|
||||||
self.d_k = d_k
|
|
||||||
self.d_q = d_q
|
|
||||||
self.dropout = dropout
|
|
||||||
self.is_concat = is_concat
|
|
||||||
|
|
||||||
self.key = Linear(num_hidden, num_head * d_k, is_bias=is_bias)
|
|
||||||
self.value = Linear(num_hidden, num_head * d_k, is_bias=is_bias)
|
|
||||||
self.query = Linear(num_hidden, num_head * d_q, is_bias=is_bias)
|
|
||||||
|
|
||||||
self.scal_attn = ScaledDotProductAttention(d_k)
|
|
||||||
|
|
||||||
if self.is_concat:
|
|
||||||
self.fc = Linear(num_head * d_q * 2, num_hidden)
|
|
||||||
else:
|
|
||||||
self.fc = Linear(num_head * d_q, num_hidden)
|
|
||||||
|
|
||||||
self.layer_norm = dg.LayerNorm(num_hidden)
|
|
||||||
|
|
||||||
def forward(self, key, value, query_input, mask=None, query_mask=None):
|
|
||||||
"""
|
|
||||||
Compute attention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key (Variable): shape(B, T, C), dtype float32, the input key of attention.
|
|
||||||
value (Variable): shape(B, T, C), dtype float32, the input value of attention.
|
|
||||||
query_input (Variable): shape(B, T, C), dtype float32, the input query of attention.
|
|
||||||
mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of key. Defaults to None.
|
|
||||||
query_mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of query. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
result (Variable): shape(B, T, C), the result of mutihead attention.
|
|
||||||
attention (Variable): shape(num_head * B, T, C), the attention of key and query.
|
|
||||||
"""
|
|
||||||
|
|
||||||
batch_size = key.shape[0]
|
|
||||||
seq_len_key = key.shape[1]
|
|
||||||
seq_len_query = query_input.shape[1]
|
|
||||||
|
|
||||||
# Make multihead attention
|
|
||||||
key = layers.reshape(
|
|
||||||
self.key(key), [batch_size, seq_len_key, self.num_head, self.d_k])
|
|
||||||
value = layers.reshape(
|
|
||||||
self.value(value),
|
|
||||||
[batch_size, seq_len_key, self.num_head, self.d_k])
|
|
||||||
query = layers.reshape(
|
|
||||||
self.query(query_input),
|
|
||||||
[batch_size, seq_len_query, self.num_head, self.d_q])
|
|
||||||
|
|
||||||
key = layers.reshape(
|
|
||||||
layers.transpose(key, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k])
|
|
||||||
value = layers.reshape(
|
|
||||||
layers.transpose(value, [2, 0, 1, 3]),
|
|
||||||
[-1, seq_len_key, self.d_k])
|
|
||||||
query = layers.reshape(
|
|
||||||
layers.transpose(query, [2, 0, 1, 3]),
|
|
||||||
[-1, seq_len_query, self.d_q])
|
|
||||||
|
|
||||||
result, attention = self.scal_attn(
|
|
||||||
key, value, query, mask=mask, query_mask=query_mask)
|
|
||||||
|
|
||||||
# concat all multihead result
|
|
||||||
result = layers.reshape(
|
|
||||||
result, [self.num_head, batch_size, seq_len_query, self.d_q])
|
|
||||||
result = layers.reshape(
|
|
||||||
layers.transpose(result, [1, 2, 0, 3]),
|
|
||||||
[batch_size, seq_len_query, -1])
|
|
||||||
if self.is_concat:
|
|
||||||
result = layers.concat([query_input, result], axis=-1)
|
|
||||||
result = layers.dropout(
|
|
||||||
self.fc(result),
|
|
||||||
self.dropout,
|
|
||||||
dropout_implementation='upscale_in_train')
|
|
||||||
result = result + query_input
|
|
||||||
|
|
||||||
result = self.layer_norm(result)
|
|
||||||
return result, attention
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def default_argument_parser():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
# data and outpu
|
||||||
|
parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.")
|
||||||
|
parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.")
|
||||||
|
parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and log. If not provided, a directory is created in runs/ to save outputs.")
|
||||||
|
|
||||||
|
# load from saved checkpoint
|
||||||
|
parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load")
|
||||||
|
|
||||||
|
# running
|
||||||
|
parser.add_argument("--device", type=str, choices=["cpu", "gpu"], help="device type to use, cpu and gpu are supported.")
|
||||||
|
parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.")
|
||||||
|
|
||||||
|
# overwrite extra config and default config
|
||||||
|
parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
|
||||||
|
|
||||||
|
return parser
|
|
@ -0,0 +1,137 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from parakeet.utils import mp_tools
|
||||||
|
|
||||||
|
|
||||||
|
def _load_latest_checkpoint(checkpoint_dir):
|
||||||
|
"""Get the iteration number corresponding to the latest saved checkpoint
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: the latest iteration number.
|
||||||
|
"""
|
||||||
|
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
|
||||||
|
# Create checkpoint index file if not exist.
|
||||||
|
if (not os.path.isfile(checkpoint_record)):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Fetch the latest checkpoint index.
|
||||||
|
with open(checkpoint_record, "r") as handle:
|
||||||
|
latest_checkpoint = handle.readline().split()[-1]
|
||||||
|
iteration = int(latest_checkpoint.split("-")[-1])
|
||||||
|
|
||||||
|
return iteration
|
||||||
|
|
||||||
|
def _save_checkpoint(checkpoint_dir, iteration):
|
||||||
|
"""Save the iteration number of the latest model to be checkpointed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||||
|
iteration (int): the latest iteration number.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
|
||||||
|
# Update the latest checkpoint index.
|
||||||
|
with open(checkpoint_record, "w") as handle:
|
||||||
|
handle.write("model_checkpoint_path: step-{}".format(iteration))
|
||||||
|
|
||||||
|
def load_parameters(model,
|
||||||
|
optimizer=None,
|
||||||
|
checkpoint_dir=None,
|
||||||
|
checkpoint_path=None):
|
||||||
|
"""Load a specific model checkpoint from disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (obj): model to load parameters.
|
||||||
|
optimizer (obj, optional): optimizer to load states if needed.
|
||||||
|
Defaults to None.
|
||||||
|
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
||||||
|
checkpoint_path (str, optional): if specified, load the checkpoint
|
||||||
|
stored in the checkpoint_path and the argument 'checkpoint_dir' will
|
||||||
|
be ignored. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
iteration (int): number of iterations that the loaded checkpoint has
|
||||||
|
been trained.
|
||||||
|
"""
|
||||||
|
if checkpoint_path is not None:
|
||||||
|
iteration = int(os.path.basename(checkpoint_path).split("-")[-1])
|
||||||
|
elif checkpoint_dir is not None:
|
||||||
|
iteration = _load_latest_checkpoint(checkpoint_dir)
|
||||||
|
if iteration == 0:
|
||||||
|
return iteration
|
||||||
|
checkpoint_path = os.path.join(checkpoint_dir,
|
||||||
|
"step-{}".format(iteration))
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
|
||||||
|
)
|
||||||
|
|
||||||
|
local_rank = dist.get_rank()
|
||||||
|
|
||||||
|
params_path = checkpoint_path + ".pdparams"
|
||||||
|
model_dict = paddle.load(params_path)
|
||||||
|
model.set_state_dict(model_dict)
|
||||||
|
print("[checkpoint] Rank {}: loaded model from {}".format(
|
||||||
|
local_rank, params_path))
|
||||||
|
|
||||||
|
optimizer_path = checkpoint_path + ".pdopt"
|
||||||
|
if optimizer and os.path.isfile(optimizer_path):
|
||||||
|
optimizer_dict = paddle.load(optimizer_path)
|
||||||
|
optimizer.set_state_dict(optimizer_dict)
|
||||||
|
print("[checkpoint] Rank {}: loaded optimizer state from {}".
|
||||||
|
format(local_rank, optimizer_path))
|
||||||
|
|
||||||
|
return iteration
|
||||||
|
|
||||||
|
@mp_tools.rank_zero_only
|
||||||
|
def save_parameters(checkpoint_dir, iteration, model, optimizer=None):
|
||||||
|
"""Checkpoint the latest trained model parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||||
|
iteration (int): the latest iteration number.
|
||||||
|
model (obj): model to be checkpointed.
|
||||||
|
optimizer (obj, optional): optimizer to be checkpointed.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
checkpoint_path = os.path.join(checkpoint_dir, "step-{}".format(iteration))
|
||||||
|
|
||||||
|
model_dict = model.state_dict()
|
||||||
|
params_path = checkpoint_path + ".pdparams"
|
||||||
|
paddle.save(model_dict, params_path)
|
||||||
|
print("[checkpoint] Saved model to {}".format(params_path))
|
||||||
|
|
||||||
|
if optimizer:
|
||||||
|
opt_dict = optimizer.state_dict()
|
||||||
|
optimizer_path = checkpoint_path + ".pdopt"
|
||||||
|
paddle.save(opt_dict, optimizer_path)
|
||||||
|
print("[checkpoint] Saved optimzier state to {}".format(
|
||||||
|
optimizer_path))
|
||||||
|
|
||||||
|
_save_checkpoint(checkpoint_dir, iteration)
|
Loading…
Reference in New Issue