From 759999c738978a2870079b87e4acaac83128e6e5 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Thu, 10 Jun 2021 04:06:06 +0800 Subject: [PATCH] STFT and MelScale: register filters as buffer. --- parakeet/modules/audio.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/parakeet/modules/audio.py b/parakeet/modules/audio.py index 16c64a4..c44aa66 100644 --- a/parakeet/modules/audio.py +++ b/parakeet/modules/audio.py @@ -20,7 +20,7 @@ import librosa from librosa.util import pad_center import numpy as np -__all__ = ["quantize", "dequantize", "STFT"] +__all__ = ["quantize", "dequantize", "STFT", "MelScale"] def quantize(values, n_bands): @@ -96,10 +96,10 @@ class STFT(nn.Layer): Defaults to True. pad_mode : string or function - If center=True, this argument is passed to np.pad for padding the edges - of the signal y. By default (pad_mode="reflect"), y is padded on both - sides with its own reflection, mirrored around its first and last - sample respectively. If center=False, this argument is ignored. + If center=True, this argument is passed to np.pad for padding the edges + of the signal y. By default (pad_mode="reflect"), y is padded on both + sides with its own reflection, mirrored around its first and last + sample respectively. If center=False, this argument is ignored. @@ -163,17 +163,15 @@ class STFT(nn.Layer): w = np.concatenate([w_real, w_imag], axis=0) w = w * window w = np.expand_dims(w, 1) - self.weight = paddle.cast( - paddle.to_tensor(w), paddle.get_default_dtype()) + weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype()) + self.register_buffer("weight", weight) def forward(self, x): """Compute the stft transform. - Parameters ------------ x : Tensor [shape=(B, T)] The input waveform. - Returns ------------ real : Tensor [shape=(B, C, frames)] @@ -195,36 +193,32 @@ class STFT(nn.Layer): def power(self, x): """Compute the power spectrum. - Parameters ------------ x : Tensor [shape=(B, T)] The input waveform. - Returns ------------ Tensor [shape=(B, C, T)] The power spectrum. """ - real, imag = self(x) + real, imag = self.forward(x) power = real**2 + imag**2 return power def magnitude(self, x): """Compute the magnitude of the spectrum. - Parameters ------------ x : Tensor [shape=(B, T)] The input waveform. - Returns ------------ Tensor [shape=(B, C, T)] The magnitude of the spectrum. """ power = self.power(x) - magnitude = paddle.sqrt(power) + magnitude = paddle.sqrt(power) # TODO(chenfeiyu): maybe clipping return magnitude @@ -232,7 +226,9 @@ class MelScale(nn.Layer): def __init__(self, sr, n_fft, n_mels, fmin, fmax): super().__init__() mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax) - self.weight = paddle.to_tensor(mel_basis) + # self.weight = paddle.to_tensor(mel_basis) + weight = paddle.to_tensor(mel_basis, dtype=paddle.get_default_dtype()) + self.register_buffer("weight", weight) def forward(self, spec): # (n_mels, n_freq) * (batch_size, n_freq, n_frames)