STFT and MelScale: register filters as buffer.

This commit is contained in:
chenfeiyu 2021-06-10 04:06:06 +08:00
parent c306f5c2b3
commit 759999c738
1 changed files with 12 additions and 16 deletions

View File

@ -20,7 +20,7 @@ import librosa
from librosa.util import pad_center
import numpy as np
__all__ = ["quantize", "dequantize", "STFT"]
__all__ = ["quantize", "dequantize", "STFT", "MelScale"]
def quantize(values, n_bands):
@ -96,10 +96,10 @@ class STFT(nn.Layer):
Defaults to True.
pad_mode : string or function
If center=True, this argument is passed to np.pad for padding the edges
of the signal y. By default (pad_mode="reflect"), y is padded on both
sides with its own reflection, mirrored around its first and last
sample respectively. If center=False, this argument is ignored.
If center=True, this argument is passed to np.pad for padding the edges
of the signal y. By default (pad_mode="reflect"), y is padded on both
sides with its own reflection, mirrored around its first and last
sample respectively. If center=False, this argument is ignored.
@ -163,17 +163,15 @@ class STFT(nn.Layer):
w = np.concatenate([w_real, w_imag], axis=0)
w = w * window
w = np.expand_dims(w, 1)
self.weight = paddle.cast(
paddle.to_tensor(w), paddle.get_default_dtype())
weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype())
self.register_buffer("weight", weight)
def forward(self, x):
"""Compute the stft transform.
Parameters
------------
x : Tensor [shape=(B, T)]
The input waveform.
Returns
------------
real : Tensor [shape=(B, C, frames)]
@ -195,36 +193,32 @@ class STFT(nn.Layer):
def power(self, x):
"""Compute the power spectrum.
Parameters
------------
x : Tensor [shape=(B, T)]
The input waveform.
Returns
------------
Tensor [shape=(B, C, T)]
The power spectrum.
"""
real, imag = self(x)
real, imag = self.forward(x)
power = real**2 + imag**2
return power
def magnitude(self, x):
"""Compute the magnitude of the spectrum.
Parameters
------------
x : Tensor [shape=(B, T)]
The input waveform.
Returns
------------
Tensor [shape=(B, C, T)]
The magnitude of the spectrum.
"""
power = self.power(x)
magnitude = paddle.sqrt(power)
magnitude = paddle.sqrt(power) # TODO(chenfeiyu): maybe clipping
return magnitude
@ -232,7 +226,9 @@ class MelScale(nn.Layer):
def __init__(self, sr, n_fft, n_mels, fmin, fmax):
super().__init__()
mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax)
self.weight = paddle.to_tensor(mel_basis)
# self.weight = paddle.to_tensor(mel_basis)
weight = paddle.to_tensor(mel_basis, dtype=paddle.get_default_dtype())
self.register_buffer("weight", weight)
def forward(self, spec):
# (n_mels, n_freq) * (batch_size, n_freq, n_frames)