Upgrade waveflow to 1.8.0

This commit is contained in:
Yibing Liu 2020-05-22 07:16:45 +00:00
parent 8716a1843c
commit 9b8fd9f93d
2 changed files with 9 additions and 2 deletions

View File

@ -348,7 +348,7 @@ class WaveFlowModule(dg.Layer):
mel = self.conditioner(mel) mel = self.conditioner(mel)
assert mel.shape[2] >= audio.shape[1] assert mel.shape[2] >= audio.shape[1]
# Prune out the tail of audio/mel so that time/n_group == 0. # Prune out the tail of audio/mel so that time/n_group == 0.
pruned_len = audio.shape[1] // self.n_group * self.n_group pruned_len = int(audio.shape[1] // self.n_group * self.n_group)
if audio.shape[1] > pruned_len: if audio.shape[1] > pruned_len:
audio = audio[:, :pruned_len] audio = audio[:, :pruned_len]

View File

@ -87,7 +87,14 @@ def compute_l2_normalized_weight(v, g, dim):
def compute_weight(v, g, dim, power): def compute_weight(v, g, dim, power):
assert len(g.shape) == 1, "magnitude should be a vector" assert len(g.shape) == 1, "magnitude should be a vector"
if power == 2: if power == 2:
return compute_l2_normalized_weight(v, g, dim) in_dtype = v.dtype
if in_dtype == fluid.core.VarDesc.VarType.FP16:
v = F.cast(v, "float32")
g = F.cast(g, "float32")
weight = compute_l2_normalized_weight(v, g, dim)
if in_dtype == fluid.core.VarDesc.VarType.FP16:
weight = F.cast(weight, "float16")
return weight
else: else:
v_normalized = F.elementwise_div( v_normalized = F.elementwise_div(
v, (norm_except(v, dim, power) + 1e-12), axis=dim) v, (norm_except(v, dim, power) + 1e-12), axis=dim)