fix WeightNormWrapper, stop using CacheDataset for deep voice 3, pin numba version to 0.47.0
This commit is contained in:
parent
b7c584e2f7
commit
45af3a43b2
|
@ -230,7 +230,7 @@ def make_data_loader(data_root, config):
|
||||||
ref_level_db=c["ref_level_db"],
|
ref_level_db=c["ref_level_db"],
|
||||||
max_norm=c["max_norm"],
|
max_norm=c["max_norm"],
|
||||||
clip_norm=c["clip_norm"])
|
clip_norm=c["clip_norm"])
|
||||||
ljspeech = CacheDataset(TransformDataset(meta, transform))
|
ljspeech = TransformDataset(meta, transform)
|
||||||
|
|
||||||
# use meta data's text length as a sort key for the sampler
|
# use meta data's text length as a sort key for the sampler
|
||||||
batch_size = config["train"]["batch_size"]
|
batch_size = config["train"]["batch_size"]
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
import paddle.fluid.layers as F
|
import paddle.fluid.layers as F
|
||||||
|
@ -44,10 +43,10 @@ def norm_except(param, dim, power):
|
||||||
if dim is None:
|
if dim is None:
|
||||||
return norm(param, dim, power)
|
return norm(param, dim, power)
|
||||||
elif dim == 0:
|
elif dim == 0:
|
||||||
param_matrix = F.reshape(param, (shape[0], np.prod(shape[1:])))
|
param_matrix = F.reshape(param, (shape[0], -1))
|
||||||
return norm(param_matrix, dim=1, power=power)
|
return norm(param_matrix, dim=1, power=power)
|
||||||
elif dim == -1 or dim == ndim - 1:
|
elif dim == -1 or dim == ndim - 1:
|
||||||
param_matrix = F.reshape(param, (np.prod(shape[:-1]), shape[-1]))
|
param_matrix = F.reshape(param, (-1, shape[-1]))
|
||||||
return norm(param_matrix, dim=0, power=power)
|
return norm(param_matrix, dim=0, power=power)
|
||||||
else:
|
else:
|
||||||
perm = list(range(ndim))
|
perm = list(range(ndim))
|
||||||
|
@ -62,24 +61,26 @@ def compute_l2_normalized_weight(v, g, dim):
|
||||||
ndim = len(shape)
|
ndim = len(shape)
|
||||||
|
|
||||||
if dim is None:
|
if dim is None:
|
||||||
v_normalized = v / (F.reduce_sum(F.square(v)) + 1e-12)
|
v_normalized = v / (F.sqrt(F.reduce_sum(F.square(v))) + 1e-12)
|
||||||
elif dim == 0:
|
elif dim == 0:
|
||||||
param_matrix = F.reshape(v, (shape[0], np.prod(shape[1:])))
|
param_matrix = F.reshape(v, (shape[0], -1))
|
||||||
v_normalized = F.l2_normalize(param_matrix, axis=1)
|
v_normalized = F.l2_normalize(param_matrix, axis=1)
|
||||||
|
v_normalized = F.reshape(v_normalized, shape)
|
||||||
elif dim == -1 or dim == ndim - 1:
|
elif dim == -1 or dim == ndim - 1:
|
||||||
param_matrix = F.reshape(v, (np.prod(shape[:-1]), shape[-1]))
|
param_matrix = F.reshape(v, (-1, shape[-1]))
|
||||||
v_normalized = F.l2_normalize(param_matrix, axis=0)
|
v_normalized = F.l2_normalize(param_matrix, axis=0)
|
||||||
|
v_normalized = F.reshape(v_normalized, shape)
|
||||||
else:
|
else:
|
||||||
perm = list(range(ndim))
|
perm = list(range(ndim))
|
||||||
perm[0] = dim
|
perm[0] = dim
|
||||||
perm[dim] = 0
|
perm[dim] = 0
|
||||||
transposed_param = F.transpose(v, perm)
|
transposed_param = F.transpose(v, perm)
|
||||||
param_matrix = F.reshape(
|
transposed_shape = transposed_param.shape
|
||||||
transposed_param,
|
param_matrix = F.reshape(transposed_param,
|
||||||
(transposed_param.shape[0], np.prod(transposed_param.shape[1:])))
|
(transposed_param.shape[0], -1))
|
||||||
v_normalized = F.l2_normalize(param_matrix, axis=1)
|
v_normalized = F.l2_normalize(param_matrix, axis=1)
|
||||||
|
v_normalized = F.reshape(v_normalized, transposed_shape)
|
||||||
v_normalized = F.transpose(v_normalized, perm)
|
v_normalized = F.transpose(v_normalized, perm)
|
||||||
v_normalized = F.reshape(v_normalized, shape)
|
|
||||||
weight = F.elementwise_mul(v_normalized, g, axis=dim)
|
weight = F.elementwise_mul(v_normalized, g, axis=dim)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue