diff --git a/parakeet/models/waveflow/waveflow_modules.py b/parakeet/models/waveflow/waveflow_modules.py index 51f9108..03f873b 100644 --- a/parakeet/models/waveflow/waveflow_modules.py +++ b/parakeet/models/waveflow/waveflow_modules.py @@ -348,7 +348,7 @@ class WaveFlowModule(dg.Layer): mel = self.conditioner(mel) assert mel.shape[2] >= audio.shape[1] # Prune out the tail of audio/mel so that time/n_group == 0. - pruned_len = audio.shape[1] // self.n_group * self.n_group + pruned_len = int(audio.shape[1] // self.n_group * self.n_group) if audio.shape[1] > pruned_len: audio = audio[:, :pruned_len] diff --git a/parakeet/modules/weight_norm.py b/parakeet/modules/weight_norm.py index 7f68cd9..82203d6 100644 --- a/parakeet/modules/weight_norm.py +++ b/parakeet/modules/weight_norm.py @@ -87,7 +87,14 @@ def compute_l2_normalized_weight(v, g, dim): def compute_weight(v, g, dim, power): assert len(g.shape) == 1, "magnitude should be a vector" if power == 2: - return compute_l2_normalized_weight(v, g, dim) + in_dtype = v.dtype + if in_dtype == fluid.core.VarDesc.VarType.FP16: + v = F.cast(v, "float32") + g = F.cast(g, "float32") + weight = compute_l2_normalized_weight(v, g, dim) + if in_dtype == fluid.core.VarDesc.VarType.FP16: + weight = F.cast(weight, "float16") + return weight else: v_normalized = F.elementwise_div( v, (norm_except(v, dim, power) + 1e-12), axis=dim)