Upgrade waveflow to 1.8.0
This commit is contained in:
parent
8716a1843c
commit
9b8fd9f93d
|
@ -348,7 +348,7 @@ class WaveFlowModule(dg.Layer):
|
|||
mel = self.conditioner(mel)
|
||||
assert mel.shape[2] >= audio.shape[1]
|
||||
# Prune out the tail of audio/mel so that time/n_group == 0.
|
||||
pruned_len = audio.shape[1] // self.n_group * self.n_group
|
||||
pruned_len = int(audio.shape[1] // self.n_group * self.n_group)
|
||||
|
||||
if audio.shape[1] > pruned_len:
|
||||
audio = audio[:, :pruned_len]
|
||||
|
|
|
@ -87,7 +87,14 @@ def compute_l2_normalized_weight(v, g, dim):
|
|||
def compute_weight(v, g, dim, power):
|
||||
assert len(g.shape) == 1, "magnitude should be a vector"
|
||||
if power == 2:
|
||||
return compute_l2_normalized_weight(v, g, dim)
|
||||
in_dtype = v.dtype
|
||||
if in_dtype == fluid.core.VarDesc.VarType.FP16:
|
||||
v = F.cast(v, "float32")
|
||||
g = F.cast(g, "float32")
|
||||
weight = compute_l2_normalized_weight(v, g, dim)
|
||||
if in_dtype == fluid.core.VarDesc.VarType.FP16:
|
||||
weight = F.cast(weight, "float16")
|
||||
return weight
|
||||
else:
|
||||
v_normalized = F.elementwise_div(
|
||||
v, (norm_except(v, dim, power) + 1e-12), axis=dim)
|
||||
|
|
Loading…
Reference in New Issue