Upgrade waveflow to 1.8.0
This commit is contained in:
parent
8716a1843c
commit
9b8fd9f93d
|
@ -348,7 +348,7 @@ class WaveFlowModule(dg.Layer):
|
||||||
mel = self.conditioner(mel)
|
mel = self.conditioner(mel)
|
||||||
assert mel.shape[2] >= audio.shape[1]
|
assert mel.shape[2] >= audio.shape[1]
|
||||||
# Prune out the tail of audio/mel so that time/n_group == 0.
|
# Prune out the tail of audio/mel so that time/n_group == 0.
|
||||||
pruned_len = audio.shape[1] // self.n_group * self.n_group
|
pruned_len = int(audio.shape[1] // self.n_group * self.n_group)
|
||||||
|
|
||||||
if audio.shape[1] > pruned_len:
|
if audio.shape[1] > pruned_len:
|
||||||
audio = audio[:, :pruned_len]
|
audio = audio[:, :pruned_len]
|
||||||
|
|
|
@ -87,7 +87,14 @@ def compute_l2_normalized_weight(v, g, dim):
|
||||||
def compute_weight(v, g, dim, power):
|
def compute_weight(v, g, dim, power):
|
||||||
assert len(g.shape) == 1, "magnitude should be a vector"
|
assert len(g.shape) == 1, "magnitude should be a vector"
|
||||||
if power == 2:
|
if power == 2:
|
||||||
return compute_l2_normalized_weight(v, g, dim)
|
in_dtype = v.dtype
|
||||||
|
if in_dtype == fluid.core.VarDesc.VarType.FP16:
|
||||||
|
v = F.cast(v, "float32")
|
||||||
|
g = F.cast(g, "float32")
|
||||||
|
weight = compute_l2_normalized_weight(v, g, dim)
|
||||||
|
if in_dtype == fluid.core.VarDesc.VarType.FP16:
|
||||||
|
weight = F.cast(weight, "float16")
|
||||||
|
return weight
|
||||||
else:
|
else:
|
||||||
v_normalized = F.elementwise_div(
|
v_normalized = F.elementwise_div(
|
||||||
v, (norm_except(v, dim, power) + 1e-12), axis=dim)
|
v, (norm_except(v, dim, power) + 1e-12), axis=dim)
|
||||||
|
|
Loading…
Reference in New Issue