STFT and MelScale: register filters as buffer.
This commit is contained in:
parent
c306f5c2b3
commit
759999c738
|
@ -20,7 +20,7 @@ import librosa
|
|||
from librosa.util import pad_center
|
||||
import numpy as np
|
||||
|
||||
__all__ = ["quantize", "dequantize", "STFT"]
|
||||
__all__ = ["quantize", "dequantize", "STFT", "MelScale"]
|
||||
|
||||
|
||||
def quantize(values, n_bands):
|
||||
|
@ -96,10 +96,10 @@ class STFT(nn.Layer):
|
|||
Defaults to True.
|
||||
|
||||
pad_mode : string or function
|
||||
If center=True, this argument is passed to np.pad for padding the edges
|
||||
of the signal y. By default (pad_mode="reflect"), y is padded on both
|
||||
sides with its own reflection, mirrored around its first and last
|
||||
sample respectively. If center=False, this argument is ignored.
|
||||
If center=True, this argument is passed to np.pad for padding the edges
|
||||
of the signal y. By default (pad_mode="reflect"), y is padded on both
|
||||
sides with its own reflection, mirrored around its first and last
|
||||
sample respectively. If center=False, this argument is ignored.
|
||||
|
||||
|
||||
|
||||
|
@ -163,17 +163,15 @@ class STFT(nn.Layer):
|
|||
w = np.concatenate([w_real, w_imag], axis=0)
|
||||
w = w * window
|
||||
w = np.expand_dims(w, 1)
|
||||
self.weight = paddle.cast(
|
||||
paddle.to_tensor(w), paddle.get_default_dtype())
|
||||
weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype())
|
||||
self.register_buffer("weight", weight)
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute the stft transform.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
x : Tensor [shape=(B, T)]
|
||||
The input waveform.
|
||||
|
||||
Returns
|
||||
------------
|
||||
real : Tensor [shape=(B, C, frames)]
|
||||
|
@ -195,36 +193,32 @@ class STFT(nn.Layer):
|
|||
|
||||
def power(self, x):
|
||||
"""Compute the power spectrum.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
x : Tensor [shape=(B, T)]
|
||||
The input waveform.
|
||||
|
||||
Returns
|
||||
------------
|
||||
Tensor [shape=(B, C, T)]
|
||||
The power spectrum.
|
||||
"""
|
||||
real, imag = self(x)
|
||||
real, imag = self.forward(x)
|
||||
power = real**2 + imag**2
|
||||
return power
|
||||
|
||||
def magnitude(self, x):
|
||||
"""Compute the magnitude of the spectrum.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
x : Tensor [shape=(B, T)]
|
||||
The input waveform.
|
||||
|
||||
Returns
|
||||
------------
|
||||
Tensor [shape=(B, C, T)]
|
||||
The magnitude of the spectrum.
|
||||
"""
|
||||
power = self.power(x)
|
||||
magnitude = paddle.sqrt(power)
|
||||
magnitude = paddle.sqrt(power) # TODO(chenfeiyu): maybe clipping
|
||||
return magnitude
|
||||
|
||||
|
||||
|
@ -232,7 +226,9 @@ class MelScale(nn.Layer):
|
|||
def __init__(self, sr, n_fft, n_mels, fmin, fmax):
|
||||
super().__init__()
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax)
|
||||
self.weight = paddle.to_tensor(mel_basis)
|
||||
# self.weight = paddle.to_tensor(mel_basis)
|
||||
weight = paddle.to_tensor(mel_basis, dtype=paddle.get_default_dtype())
|
||||
self.register_buffer("weight", weight)
|
||||
|
||||
def forward(self, spec):
|
||||
# (n_mels, n_freq) * (batch_size, n_freq, n_frames)
|
||||
|
|
Loading…
Reference in New Issue