change wavenet to use on-the-fly prepeocessing

This commit is contained in:
chenfeiyu 2021-04-19 19:58:36 +08:00
parent e06c6cdfe1
commit 3741cc49ca
13 changed files with 262 additions and 216 deletions

View File

@ -92,8 +92,8 @@ class LJSpeechCollector(object):
text_lens = np.array(sorted(text_lens, reverse=True), dtype=np.int64)
# Pad sequence with largest len of the batch
texts = batch_text_id(texts, pad_id=self.padding_idx)
mels = np.transpose(
texts, _ = batch_text_id(texts, pad_id=self.padding_idx)
mels, _ = np.transpose(
batch_spec(
mels, pad_value=self.padding_value), axes=(0, 2, 1))

View File

@ -44,9 +44,9 @@ def collate_aishell3_examples(examples):
spec_lengths = np.array([item.shape[1] for item in mel], dtype=np.int64)
T_dec = np.max(spec_lengths)
stop_tokens = (np.arange(T_dec) >= np.expand_dims(spec_lengths, -1)).astype(np.float32)
phones = batch_text_id(phones)
tones = batch_text_id(tones)
mel = batch_spec(mel)
phones, _ = batch_text_id(phones)
tones, _ = batch_text_id(tones)
mel, _ = batch_spec(mel)
mel = np.transpose(mel, (0, 2, 1))
embed = np.stack(embed)
# 7 fields

View File

@ -40,9 +40,9 @@ def collate_baker_examples(examples):
spec_lengths = np.array([item.shape[1] for item in mel], dtype=np.int64)
T_dec = np.max(spec_lengths)
stop_tokens = (np.arange(T_dec) >= np.expand_dims(spec_lengths, -1)).astype(np.float32)
phones = batch_text_id(phones)
tones = batch_text_id(tones)
mel = batch_spec(mel)
phone, _ = batch_text_id(phones)
tones, _ = batch_text_id(tones)
mel, _ = batch_spec(mel)
mel = np.transpose(mel, (0, 2, 1))
return phones, tones, mel, text_lengths, spec_lengths, stop_tokens

View File

@ -53,6 +53,6 @@ def collate_vctk_examples(examples):
slens = np.array([item.shape[1] for item in mels], dtype=np.int64)
speaker_ids = np.array(speaker_ids, dtype=np.int64)
phonemes = batch_text_id(phonemes, pad_id=0)
mels = np.transpose(batch_spec(mels, pad_value=0.), [0, 2, 1])
phonemes, _ = batch_text_id(phonemes, pad_id=0)
mels, _ = np.transpose(batch_spec(mels, pad_value=0.), [0, 2, 1])
return phonemes, plens, mels, slens, speaker_ids

View File

@ -76,9 +76,9 @@ class LJSpeechCollector(object):
mels = [example[1] for example in examples]
stop_probs = [example[2] for example in examples]
ids = batch_text_id(ids, pad_id=self.padding_idx)
mels = batch_spec(mels, pad_value=self.padding_value)
stop_probs = batch_text_id(stop_probs, pad_id=self.padding_idx)
ids, _ = batch_text_id(ids, pad_id=self.padding_idx)
mels, _ = batch_spec(mels, pad_value=self.padding_value)
stop_probs, _ = batch_text_id(stop_probs, pad_id=self.padding_idx)
return ids, np.transpose(mels, [0, 2, 1]), stop_probs

View File

@ -61,8 +61,8 @@ class LJSpeechCollector(object):
def __call__(self, examples):
mels = [example[0] for example in examples]
wavs = [example[1] for example in examples]
mels = batch_spec(mels, pad_value=self.padding_value)
wavs = batch_wav(wavs, pad_value=self.padding_value)
mels, _ = batch_spec(mels, pad_value=self.padding_value)
wavs, _ = batch_wav(wavs, pad_value=self.padding_value)
return mels, wavs

View File

@ -20,10 +20,12 @@ _C.data = CN(
batch_size=8, # batch size
valid_size=16, # the first N examples are reserved for validation
sample_rate=22050, # Hz, sample rate
n_fft=2048, # fft frame size
n_fft=1024, # fft frame size
win_length=1024, # window size
hop_length=256, # hop size between ajacent frame
# f_max=8000, # Hz, max frequency when converting to mel
top_db=60, # db, used to trim silence
fmin = 0, # Hz, max frequency when converting to mel
fmax=8000, # Hz, max frequency when converting to mel
n_mels=80, # mel bands
train_clip_seconds=0.5, # audio clip length(in seconds)
))

View File

@ -16,136 +16,43 @@ import os
from pathlib import Path
import pickle
import numpy as np
import librosa
import pandas
from paddle.io import Dataset, DataLoader
from parakeet.data.batch import batch_spec, batch_wav
from parakeet.data import dataset
from parakeet.audio import AudioProcessor
class LJSpeech(Dataset):
"""A simple dataset adaptor for the processed ljspeech dataset."""
def __init__(self, root):
def __init__(self, root, sample_rate, length, top_db):
self.root = Path(root).expanduser()
meta_data = pandas.read_csv(
self.metadata = pandas.read_csv(
str(self.root / "metadata.csv"),
sep="\t",
sep="|",
header=None,
names=["fname", "frames", "samples"])
records = []
for row in meta_data.itertuples():
mel_path = str(self.root / "mel" / (row.fname + ".npy"))
wav_path = str(self.root / "wav" / (row.fname + ".npy"))
records.append((mel_path, wav_path))
self.records = records
names=["fname", "text", "normalized_text"])
self.wav_dir = self.root / "wavs"
self.sr = sample_rate
self.top_db = top_db
self.length = length # samples in the clip
def __getitem__(self, i):
mel_name, wav_name = self.records[i]
mel = np.load(mel_name)
wav = np.load(wav_name)
return mel, wav
fname = self.metadata.iloc[0].fname
fpath = (self.wav_dir / fname).with_suffix(".wav")
y, sr = librosa.load(fpath, self.sr)
y, _ = librosa.effects.trim(y, top_db=self.top_db)
y = librosa.util.normalize(y)
y = y.astype(np.float32)
# pad or trim
if y.size <= self.length:
y = np.pad(y, [0, self.length - len(y)], mode='constant')
else:
start = np.random.randint(0, 1 + len(y) - self.length)
y = y[start: start + self.length]
return y
def __len__(self):
return len(self.records)
return len(self.metadata)
class LJSpeechCollector(object):
"""A simple callable to batch LJSpeech examples."""
def __init__(self, padding_value=0.):
self.padding_value = padding_value
def __call__(self, examples):
batch_size = len(examples)
mels = [example[0] for example in examples]
wavs = [example[1] for example in examples]
mels = batch_spec(mels, pad_value=self.padding_value)
wavs = batch_wav(wavs, pad_value=self.padding_value)
audio_starts = np.zeros((batch_size, ), dtype=np.int64)
return mels, wavs, audio_starts
class LJSpeechClipCollector(object):
def __init__(self, clip_frames=65, hop_length=256):
self.clip_frames = clip_frames
self.hop_length = hop_length
def __call__(self, examples):
mels = []
wavs = []
starts = []
for example in examples:
mel, wav_clip, start = self.clip(example)
mels.append(mel)
wavs.append(wav_clip)
starts.append(start)
mels = batch_spec(mels)
wavs = np.stack(wavs)
starts = np.array(starts, dtype=np.int64)
return mels, wavs, starts
def clip(self, example):
mel, wav = example
frames = mel.shape[-1]
start = np.random.randint(0, frames - self.clip_frames)
wav_clip = wav[start * self.hop_length:(start + self.clip_frames) *
self.hop_length]
return mel, wav_clip, start
class DataCollector(object):
def __init__(self,
context_size,
sample_rate,
hop_length,
train_clip_seconds,
valid=False):
frames_per_second = sample_rate // hop_length
train_clip_frames = int(
np.ceil(train_clip_seconds * frames_per_second))
context_frames = context_size // hop_length
self.num_frames = train_clip_frames + context_frames
self.sample_rate = sample_rate
self.hop_length = hop_length
self.valid = valid
def random_crop(self, sample):
audio, mel_spectrogram = sample
audio_frames = int(audio.size) // self.hop_length
max_start_frame = audio_frames - self.num_frames
assert max_start_frame >= 0, "audio is too short to be cropped"
frame_start = np.random.randint(0, max_start_frame)
# frame_start = 0 # norandom
frame_end = frame_start + self.num_frames
audio_start = frame_start * self.hop_length
audio_end = frame_end * self.hop_length
audio = audio[audio_start:audio_end]
return audio, mel_spectrogram, audio_start
def __call__(self, samples):
# transform them first
if self.valid:
samples = [(audio, mel_spectrogram, 0)
for audio, mel_spectrogram in samples]
else:
samples = [self.random_crop(sample) for sample in samples]
# batch them
audios = [sample[0] for sample in samples]
audio_starts = [sample[2] for sample in samples]
mels = [sample[1] for sample in samples]
mels = batch_spec(mels)
if self.valid:
audios = batch_wav(audios, dtype=np.float32)
else:
audios = np.array(audios, dtype=np.float32)
audio_starts = np.array(audio_starts, dtype=np.int64)
return audios, mels, audio_starts

View File

@ -30,9 +30,13 @@ from parakeet.utils import scheduler, mp_tools
from parakeet.training.cli import default_argument_parser
from parakeet.training.experiment import ExperimentBase
from parakeet.utils.mp_tools import rank_zero_only
from parakeet.datasets import AudioDataset, AudioSegmentDataset
from parakeet.data import batch_wav
from parakeet.modules.audio import STFT, MelScale
from config import get_cfg_defaults
from ljspeech import LJSpeech, LJSpeechClipCollector, LJSpeechCollector
from ljspeech import LJSpeech
class Experiment(ExperimentBase):
@ -60,39 +64,48 @@ class Experiment(ExperimentBase):
parameters=model.parameters(),
grad_clip=paddle.nn.ClipGradByGlobalNorm(
config.training.gradient_max_norm))
self.stft = STFT(config.data.n_fft, config.data.hop_length, config.data.win_length)
self.mel_scale = MelScale(config.data.sample_rate, config.data.n_fft, config.data.n_mels, config.data.fmin, config.data.fmax)
self.model = model
self.model_core = model._layers if self.parallel else model
self.optimizer = optimizer
def setup_dataloader(self):
config = self.config
args = self.args
ljspeech_dataset = LJSpeech(args.data)
valid_set, train_set = dataset.split(ljspeech_dataset,
config.data.valid_size)
# convolutional net's causal padding size
context_size = config.model.n_stack \
* sum([(config.model.filter_size - 1) * 2**i for i in range(config.model.n_loop)]) \
+ 1
context_frames = context_size // config.data.hop_length
# frames used to compute loss
frames_per_second = config.data.sample_rate // config.data.hop_length
train_clip_frames = math.ceil(config.data.train_clip_seconds *
frames_per_second)
train_clip_size = int(config.data.train_clip_seconds * config.data.sample_rate)
length = context_size + train_clip_size
root = Path(args.data).expanduser()
file_paths = sorted(list((root / "wavs").rglob("*.wav")))
train_set = AudioSegmentDataset(
file_paths[config.data.valid_size:],
config.data.sample_rate,
length,
top_db=config.data.top_db)
valid_set = AudioDataset(
file_paths[:config.data.valid_size],
config.data.sample_rate,
top_db=config.data.top_db)
num_frames = train_clip_frames + context_frames
batch_fn = LJSpeechClipCollector(num_frames, config.data.hop_length)
if not self.parallel:
train_loader = DataLoader(
train_set,
batch_size=config.data.batch_size,
shuffle=True,
drop_last=True,
collate_fn=batch_fn)
num_workers=1,
)
else:
sampler = DistributedBatchSampler(
train_set,
@ -100,25 +113,36 @@ class Experiment(ExperimentBase):
shuffle=True,
drop_last=True)
train_loader = DataLoader(
train_set, batch_sampler=sampler, collate_fn=batch_fn)
train_set, batch_sampler=sampler, num_workers=1)
valid_batch_fn = LJSpeechCollector()
valid_loader = DataLoader(
valid_set, batch_size=1, collate_fn=valid_batch_fn)
valid_set,
batch_size=config.data.batch_size,
num_workers=1,
collate_fn=batch_wav)
self.train_loader = train_loader
self.valid_loader = valid_loader
def train_batch(self):
# load data
start = time.time()
batch = self.read_batch()
data_loader_time = time.time() - start
self.model.train()
self.optimizer.clear_grad()
mel, wav, audio_starts = batch
wav = batch
# data preprocessing
S = self.stft.magnitude(wav)
mel = self.mel_scale(S)
logmel = 20 * paddle.log10(mel, paddle.clip(mel, min=1e-5))
logmel = paddle.clip((logmel + 80) / 100, min=0.0, max=1.0)
# forward & backward
y = self.model(wav, mel, audio_starts)
y = self.model(wav, logmel)
loss = self.model_core.loss(y, wav)
loss.backward()
self.optimizer.step()
@ -129,24 +153,43 @@ class Experiment(ExperimentBase):
msg += "step: {}, ".format(self.iteration)
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time,
iteration_time)
msg += "loss: {:>.6f}".format(loss_value)
msg += "train/loss: {:>.6f}, ".format(loss_value)
msg += "lr: {:>.6f}".format(self.optimizer.get_lr())
self.logger.info(msg)
if dist.get_rank() == 0:
self.visualizer.add_scalar(
"train/loss", loss_value, global_step=self.iteration)
"train/loss", loss_value, self.iteration)
self.visualizer.add_scalar(
"train/lr", self.optimizer.get_lr(), self.iteration)
# now we have to call learning rate scheduler.step() mannually
self.optimizer._learning_rate.step()
@mp_tools.rank_zero_only
@paddle.no_grad()
def valid(self):
valid_iterator = iter(self.valid_loader)
valid_losses = []
mel, wav, audio_starts = next(valid_iterator)
y = self.model(wav, mel, audio_starts)
loss = self.model_core.loss(y, wav)
valid_losses.append(float(loss))
valid_loss = np.mean(valid_losses)
for batch in self.valid_loader:
wav, length = batch
# data preprocessing
S = self.stft.magnitude(wav)
mel = self.mel_scale(S)
logmel = 20 * paddle.log10(mel, paddle.clip(mel, min=1e-5))
logmel = paddle.clip((logmel + 80) / 100, min=0.0, max=1.0)
y = self.model(wav, logmel)
loss = self.model_core.loss(y, wav)
valid_losses.append(float(loss))
valid_loss = np.mean(valid_losses)
msg = "Rank: {}, ".format(dist.get_rank())
msg += "step: {}, ".format(self.iteration)
msg += "valid/loss: {:>.6f}".format(valid_loss)
self.logger.info(msg)
self.visualizer.add_scalar(
"valid/loss", valid_loss, global_step=self.iteration)
"valid/loss", valid_loss, self.iteration)
def main_sp(config, args):

View File

@ -65,7 +65,7 @@ def batch_text_id(minibatch, pad_id=0, dtype=np.int64):
mode='constant',
constant_values=pad_id))
return np.array(batch, dtype=dtype)
return np.array(batch, dtype=dtype), np.array(lengths, dtype=np.int64)
class WavBatcher(object):
@ -106,7 +106,7 @@ def batch_wav(minibatch, pad_value=0., dtype=np.float32):
np.pad(example, [(0, pad_len)],
mode='constant',
constant_values=pad_value))
return np.array(batch, dtype=dtype)
return np.array(batch, dtype=dtype), np.array(lengths, dtype=np.int64)
class SpecBatcher(object):
@ -160,4 +160,4 @@ def batch_spec(minibatch, pad_value=0., time_major=False, dtype=np.float32):
np.pad(example, [(0, 0), (0, pad_len)],
mode='constant',
constant_values=pad_value))
return np.array(batch, dtype=dtype)
return np.array(batch, dtype=dtype), np.array(lengths, dtype=np.int64)

View File

@ -15,24 +15,75 @@
from paddle.io import Dataset
import os
import librosa
from pathlib import Path
import numpy as np
from typing import List
__all__ = ["AudioFolderDataset"]
__all__ = ["AudioSegmentDataset", "AudioDataset", "AudioFolderDataset"]
class AudioFolderDataset(Dataset):
def __init__(self, path, sample_rate, extension="wav"):
self.root = os.path.expanduser(path)
self.sample_rate = sample_rate
self.extension = extension
self.file_names = [
os.path.join(self.root, x) for x in os.listdir(self.root) \
if os.path.splitext(x)[-1] == self.extension]
self.length = len(self.file_names)
def __len__(self):
return self.length
class AudioSegmentDataset(Dataset):
"""A simple dataset adaptor for audio files to train vocoders.
Read -> trim silence -> normalize -> extract a segment
"""
def __init__(self, file_paths: List[Path], sample_rate: int, length: int,
top_db: float):
self.file_paths = file_paths
self.sr = sample_rate
self.top_db = top_db
self.length = length # samples in the clip
def __getitem__(self, i):
file_name = self.file_names[i]
y, _ = librosa.load(file_name, sr=self.sample_rate) # pylint: disable=unused-variable
fpath = self.file_paths[i]
y, sr = librosa.load(fpath, self.sr)
y, _ = librosa.effects.trim(y, top_db=self.top_db)
y = librosa.util.normalize(y)
y = y.astype(np.float32)
# pad or trim
if y.size <= self.length:
y = np.pad(y, [0, self.length - len(y)], mode='constant')
else:
start = np.random.randint(0, 1 + len(y) - self.length)
y = y[start:start + self.length]
return y
def __len__(self):
return len(self.file_paths)
class AudioDataset(Dataset):
"""A simple dataset adaptor for the audio files.
Read -> trim silence -> normalize
"""
def __init__(self,
file_paths: List[Path],
sample_rate: int,
top_db: float = 60):
self.file_paths = file_paths
self.sr = sample_rate
self.top_db = top_db
def __getitem__(self, i):
fpath = self.file_paths[i]
y, sr = librosa.load(fpath, self.sr)
y, _ = librosa.effects.trim(y, top_db=self.top_db)
y = librosa.util.normalize(y)
y = y.astype(np.float32)
return y
def __len__(self):
return len(self.file_paths)
class AudioFolderDataset(AudioDataset):
def __init__(
self,
root,
sample_rate,
top_db=60,
extension=".wav",
):
root = Path(root).expanduser()
file_paths = sorted(list(root.rglob("*{}".format(extension))))
super().__init__(file_paths, sample_rate, top_db)

View File

@ -101,9 +101,7 @@ class UpsampleNet(nn.LayerList):
def __init__(self, upscale_factors=[16, 16]):
super(UpsampleNet, self).__init__()
self.upscale_factors = list(upscale_factors)
self.upscale_factor = 1
for item in upscale_factors:
self.upscale_factor *= item
self.upscale_factor = np.prod(upscale_factors)
for factor in self.upscale_factors:
self.append(
@ -224,13 +222,15 @@ class ResidualBlock(nn.Layer):
other ResidualBlocks.
"""
h = x
length = x.shape[-1]
# dilated conv
h = self.conv(h)
# condition
# NOTE: expanded condition may have a larger timesteps than x
if condition is not None:
h += self.condition_proj(condition)
h += self.condition_proj(condition)[:, :, :length]
# gated tanh
content, gate = paddle.split(h, 2, axis=1)
@ -822,7 +822,7 @@ class ConditionalWaveNet(nn.Layer):
loss_type=loss_type,
log_scale_min=log_scale_min)
def forward(self, audio, mel, audio_start):
def forward(self, audio, mel):
"""Compute the output distribution given the mel spectrogram and the input(for teacher force training).
Parameters
@ -845,13 +845,13 @@ class ConditionalWaveNet(nn.Layer):
"""
audio_length = audio.shape[1] # audio clip's length
condition = self.encoder(mel)
condition_slice = crop(condition, audio_start, audio_length)
# shifting 1 step
audio = audio[:, :-1]
condition_slice = condition_slice[:, :, 1:]
condition = condition[:, :, 1:]
y = self.decoder(audio, condition_slice)
y = self.decoder(audio, condition)
return y
def loss(self, y, t):

View File

@ -16,6 +16,8 @@ import paddle
from paddle import nn
from paddle.nn import functional as F
from scipy import signal
import librosa
from librosa.util import pad_center
import numpy as np
__all__ = ["quantize", "dequantize", "STFT"]
@ -88,6 +90,19 @@ class STFT(nn.Layer):
Name of window function, see `scipy.signal.get_window` for more
details. Defaults to "hanning".
center : bool
If True, the signal y is padded so that frame D[:, t] is centered
at y[t * hop_length]. If False, then D[:, t] begins at y[t * hop_length].
Defaults to True.
pad_mode : string or function
If center=True, this argument is passed to np.pad for padding the edges
of the signal y. By default (pad_mode="reflect"), y is padded on both
sides with its own reflection, mirrored around its first and last
sample respectively. If center=False, this argument is ignored.
Notes
-----------
It behaves like ``librosa.core.stft``. See ``librosa.core.stft`` for more
@ -101,29 +116,47 @@ class STFT(nn.Layer):
"""
def __init__(self, n_fft, hop_length, win_length, window="hanning"):
super(STFT, self).__init__()
def __init__(self, n_fft, hop_length=None, win_length=None, window="hanning", center=True, pad_mode="reflect"):
super().__init__()
# By default, use the entire frame
if win_length is None:
win_length = n_fft
# Set the default hop, if it's not already specified
if hop_length is None:
hop_length = int(win_length // 4)
self.hop_length = hop_length
self.n_bin = 1 + n_fft // 2
self.n_fft = n_fft
self.center = center
self.pad_mode = pad_mode
# calculate window
window = signal.get_window(window, win_length)
window = signal.get_window(window, win_length, fftbins=True)
# pad window to n_fft size
if n_fft != win_length:
pad = (n_fft - win_length) // 2
window = np.pad(window, ((pad, pad), ), 'constant')
window = pad_center(window, n_fft, mode="constant")
#lpad = (n_fft - win_length) // 2
#rpad = n_fft - win_length - lpad
#window = np.pad(window, ((lpad, pad), ), 'constant')
# calculate weights
r = np.arange(0, n_fft)
M = np.expand_dims(r, -1) * np.expand_dims(r, 0)
w_real = np.reshape(window *
np.cos(2 * np.pi * M / n_fft)[:self.n_bin],
(self.n_bin, 1, 1, self.n_fft))
w_imag = np.reshape(window *
np.sin(-2 * np.pi * M / n_fft)[:self.n_bin],
(self.n_bin, 1, 1, self.n_fft))
#r = np.arange(0, n_fft)
#M = np.expand_dims(r, -1) * np.expand_dims(r, 0)
#w_real = np.reshape(window *
#np.cos(2 * np.pi * M / n_fft)[:self.n_bin],
#(self.n_bin, 1, self.n_fft))
#w_imag = np.reshape(window *
#np.sin(-2 * np.pi * M / n_fft)[:self.n_bin],
#(self.n_bin, 1, self.n_fft))
weight = np.fft.fft(np.eye(n_fft))[:self.n_bin]
w_real = weight.real
w_imag = weight.imag
w = np.concatenate([w_real, w_imag], axis=0)
w = w * window
w = np.expand_dims(w, 1)
self.weight = paddle.cast(
paddle.to_tensor(w), paddle.get_default_dtype())
@ -137,23 +170,20 @@ class STFT(nn.Layer):
Returns
------------
real : Tensor [shape=(B, C, 1, frames)]
real : Tensor [shape=(B, C, frames)]
The real part of the spectrogram.
imag : Tensor [shape=(B, C, 1, frames)]
imag : Tensor [shape=(B, C, frames)]
The image part of the spectrogram.
"""
# x(batch_size, time_steps)
# pad it first with reflect mode
# TODO(chenfeiyu): report an issue on paddle.flip
pad_start = paddle.reverse(x[:, 1:1 + self.n_fft // 2], axis=[1])
pad_stop = paddle.reverse(x[:, -(1 + self.n_fft // 2):-1], axis=[1])
x = paddle.concat([pad_start, x, pad_stop], axis=-1)
x = paddle.unsqueeze(x, axis=1)
if self.center:
x = F.pad(x, [self.n_fft // 2, self.n_fft // 2],
data_format='NCL', mode=self.pad_mode)
# to BC1T, C=1
x = paddle.unsqueeze(x, axis=[1, 2])
out = F.conv2d(x, self.weight, stride=(1, self.hop_length))
real, imag = paddle.chunk(out, 2, axis=1) # BC1T
# to BCT, C=1
out = F.conv1d(x, self.weight, stride=self.hop_length)
real, imag = paddle.chunk(out, 2, axis=1) # BCT
return real, imag
def power(self, x):
@ -166,7 +196,7 @@ class STFT(nn.Layer):
Returns
------------
Tensor [shape=(B, C, 1, T)]
Tensor [shape=(B, C, T)]
The power spectrum.
"""
real, imag = self(x)
@ -183,9 +213,22 @@ class STFT(nn.Layer):
Returns
------------
Tensor [shape=(B, C, 1, T)]
Tensor [shape=(B, C, T)]
The magnitude of the spectrum.
"""
power = self.power(x)
magnitude = paddle.sqrt(power)
return magnitude
class MelScale(nn.Layer):
def __init__(self, sr, n_fft, n_mels, fmin, fmax):
super().__init__()
mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax)
print(mel_basis.shape)
self.weight = paddle.to_tensor(mel_basis)
def forward(self, spec):
# (n_mels, n_freq) * (batch_size, n_freq, n_frames)
mel = paddle.matmul(self.weight, spec)
return mel