change wavenet to use on-the-fly prepeocessing

This commit is contained in:
chenfeiyu 2021-04-19 19:58:36 +08:00
parent e06c6cdfe1
commit 3741cc49ca
13 changed files with 262 additions and 216 deletions

View File

@ -92,8 +92,8 @@ class LJSpeechCollector(object):
text_lens = np.array(sorted(text_lens, reverse=True), dtype=np.int64) text_lens = np.array(sorted(text_lens, reverse=True), dtype=np.int64)
# Pad sequence with largest len of the batch # Pad sequence with largest len of the batch
texts = batch_text_id(texts, pad_id=self.padding_idx) texts, _ = batch_text_id(texts, pad_id=self.padding_idx)
mels = np.transpose( mels, _ = np.transpose(
batch_spec( batch_spec(
mels, pad_value=self.padding_value), axes=(0, 2, 1)) mels, pad_value=self.padding_value), axes=(0, 2, 1))

View File

@ -44,9 +44,9 @@ def collate_aishell3_examples(examples):
spec_lengths = np.array([item.shape[1] for item in mel], dtype=np.int64) spec_lengths = np.array([item.shape[1] for item in mel], dtype=np.int64)
T_dec = np.max(spec_lengths) T_dec = np.max(spec_lengths)
stop_tokens = (np.arange(T_dec) >= np.expand_dims(spec_lengths, -1)).astype(np.float32) stop_tokens = (np.arange(T_dec) >= np.expand_dims(spec_lengths, -1)).astype(np.float32)
phones = batch_text_id(phones) phones, _ = batch_text_id(phones)
tones = batch_text_id(tones) tones, _ = batch_text_id(tones)
mel = batch_spec(mel) mel, _ = batch_spec(mel)
mel = np.transpose(mel, (0, 2, 1)) mel = np.transpose(mel, (0, 2, 1))
embed = np.stack(embed) embed = np.stack(embed)
# 7 fields # 7 fields

View File

@ -40,9 +40,9 @@ def collate_baker_examples(examples):
spec_lengths = np.array([item.shape[1] for item in mel], dtype=np.int64) spec_lengths = np.array([item.shape[1] for item in mel], dtype=np.int64)
T_dec = np.max(spec_lengths) T_dec = np.max(spec_lengths)
stop_tokens = (np.arange(T_dec) >= np.expand_dims(spec_lengths, -1)).astype(np.float32) stop_tokens = (np.arange(T_dec) >= np.expand_dims(spec_lengths, -1)).astype(np.float32)
phones = batch_text_id(phones) phone, _ = batch_text_id(phones)
tones = batch_text_id(tones) tones, _ = batch_text_id(tones)
mel = batch_spec(mel) mel, _ = batch_spec(mel)
mel = np.transpose(mel, (0, 2, 1)) mel = np.transpose(mel, (0, 2, 1))
return phones, tones, mel, text_lengths, spec_lengths, stop_tokens return phones, tones, mel, text_lengths, spec_lengths, stop_tokens

View File

@ -53,6 +53,6 @@ def collate_vctk_examples(examples):
slens = np.array([item.shape[1] for item in mels], dtype=np.int64) slens = np.array([item.shape[1] for item in mels], dtype=np.int64)
speaker_ids = np.array(speaker_ids, dtype=np.int64) speaker_ids = np.array(speaker_ids, dtype=np.int64)
phonemes = batch_text_id(phonemes, pad_id=0) phonemes, _ = batch_text_id(phonemes, pad_id=0)
mels = np.transpose(batch_spec(mels, pad_value=0.), [0, 2, 1]) mels, _ = np.transpose(batch_spec(mels, pad_value=0.), [0, 2, 1])
return phonemes, plens, mels, slens, speaker_ids return phonemes, plens, mels, slens, speaker_ids

View File

@ -76,9 +76,9 @@ class LJSpeechCollector(object):
mels = [example[1] for example in examples] mels = [example[1] for example in examples]
stop_probs = [example[2] for example in examples] stop_probs = [example[2] for example in examples]
ids = batch_text_id(ids, pad_id=self.padding_idx) ids, _ = batch_text_id(ids, pad_id=self.padding_idx)
mels = batch_spec(mels, pad_value=self.padding_value) mels, _ = batch_spec(mels, pad_value=self.padding_value)
stop_probs = batch_text_id(stop_probs, pad_id=self.padding_idx) stop_probs, _ = batch_text_id(stop_probs, pad_id=self.padding_idx)
return ids, np.transpose(mels, [0, 2, 1]), stop_probs return ids, np.transpose(mels, [0, 2, 1]), stop_probs

View File

@ -61,8 +61,8 @@ class LJSpeechCollector(object):
def __call__(self, examples): def __call__(self, examples):
mels = [example[0] for example in examples] mels = [example[0] for example in examples]
wavs = [example[1] for example in examples] wavs = [example[1] for example in examples]
mels = batch_spec(mels, pad_value=self.padding_value) mels, _ = batch_spec(mels, pad_value=self.padding_value)
wavs = batch_wav(wavs, pad_value=self.padding_value) wavs, _ = batch_wav(wavs, pad_value=self.padding_value)
return mels, wavs return mels, wavs

View File

@ -20,10 +20,12 @@ _C.data = CN(
batch_size=8, # batch size batch_size=8, # batch size
valid_size=16, # the first N examples are reserved for validation valid_size=16, # the first N examples are reserved for validation
sample_rate=22050, # Hz, sample rate sample_rate=22050, # Hz, sample rate
n_fft=2048, # fft frame size n_fft=1024, # fft frame size
win_length=1024, # window size win_length=1024, # window size
hop_length=256, # hop size between ajacent frame hop_length=256, # hop size between ajacent frame
# f_max=8000, # Hz, max frequency when converting to mel top_db=60, # db, used to trim silence
fmin = 0, # Hz, max frequency when converting to mel
fmax=8000, # Hz, max frequency when converting to mel
n_mels=80, # mel bands n_mels=80, # mel bands
train_clip_seconds=0.5, # audio clip length(in seconds) train_clip_seconds=0.5, # audio clip length(in seconds)
)) ))

View File

@ -16,136 +16,43 @@ import os
from pathlib import Path from pathlib import Path
import pickle import pickle
import numpy as np import numpy as np
import librosa
import pandas import pandas
from paddle.io import Dataset, DataLoader from paddle.io import Dataset, DataLoader
from parakeet.data.batch import batch_spec, batch_wav
from parakeet.data import dataset
from parakeet.audio import AudioProcessor
class LJSpeech(Dataset): class LJSpeech(Dataset):
"""A simple dataset adaptor for the processed ljspeech dataset.""" """A simple dataset adaptor for the processed ljspeech dataset."""
def __init__(self, root): def __init__(self, root, sample_rate, length, top_db):
self.root = Path(root).expanduser() self.root = Path(root).expanduser()
meta_data = pandas.read_csv( self.metadata = pandas.read_csv(
str(self.root / "metadata.csv"), str(self.root / "metadata.csv"),
sep="\t", sep="|",
header=None, header=None,
names=["fname", "frames", "samples"]) names=["fname", "text", "normalized_text"])
self.wav_dir = self.root / "wavs"
records = [] self.sr = sample_rate
for row in meta_data.itertuples(): self.top_db = top_db
mel_path = str(self.root / "mel" / (row.fname + ".npy")) self.length = length # samples in the clip
wav_path = str(self.root / "wav" / (row.fname + ".npy"))
records.append((mel_path, wav_path))
self.records = records
def __getitem__(self, i): def __getitem__(self, i):
mel_name, wav_name = self.records[i] fname = self.metadata.iloc[0].fname
mel = np.load(mel_name) fpath = (self.wav_dir / fname).with_suffix(".wav")
wav = np.load(wav_name) y, sr = librosa.load(fpath, self.sr)
return mel, wav y, _ = librosa.effects.trim(y, top_db=self.top_db)
y = librosa.util.normalize(y)
y = y.astype(np.float32)
# pad or trim
if y.size <= self.length:
y = np.pad(y, [0, self.length - len(y)], mode='constant')
else:
start = np.random.randint(0, 1 + len(y) - self.length)
y = y[start: start + self.length]
return y
def __len__(self): def __len__(self):
return len(self.records) return len(self.metadata)
class LJSpeechCollector(object):
"""A simple callable to batch LJSpeech examples."""
def __init__(self, padding_value=0.):
self.padding_value = padding_value
def __call__(self, examples):
batch_size = len(examples)
mels = [example[0] for example in examples]
wavs = [example[1] for example in examples]
mels = batch_spec(mels, pad_value=self.padding_value)
wavs = batch_wav(wavs, pad_value=self.padding_value)
audio_starts = np.zeros((batch_size, ), dtype=np.int64)
return mels, wavs, audio_starts
class LJSpeechClipCollector(object):
def __init__(self, clip_frames=65, hop_length=256):
self.clip_frames = clip_frames
self.hop_length = hop_length
def __call__(self, examples):
mels = []
wavs = []
starts = []
for example in examples:
mel, wav_clip, start = self.clip(example)
mels.append(mel)
wavs.append(wav_clip)
starts.append(start)
mels = batch_spec(mels)
wavs = np.stack(wavs)
starts = np.array(starts, dtype=np.int64)
return mels, wavs, starts
def clip(self, example):
mel, wav = example
frames = mel.shape[-1]
start = np.random.randint(0, frames - self.clip_frames)
wav_clip = wav[start * self.hop_length:(start + self.clip_frames) *
self.hop_length]
return mel, wav_clip, start
class DataCollector(object):
def __init__(self,
context_size,
sample_rate,
hop_length,
train_clip_seconds,
valid=False):
frames_per_second = sample_rate // hop_length
train_clip_frames = int(
np.ceil(train_clip_seconds * frames_per_second))
context_frames = context_size // hop_length
self.num_frames = train_clip_frames + context_frames
self.sample_rate = sample_rate
self.hop_length = hop_length
self.valid = valid
def random_crop(self, sample):
audio, mel_spectrogram = sample
audio_frames = int(audio.size) // self.hop_length
max_start_frame = audio_frames - self.num_frames
assert max_start_frame >= 0, "audio is too short to be cropped"
frame_start = np.random.randint(0, max_start_frame)
# frame_start = 0 # norandom
frame_end = frame_start + self.num_frames
audio_start = frame_start * self.hop_length
audio_end = frame_end * self.hop_length
audio = audio[audio_start:audio_end]
return audio, mel_spectrogram, audio_start
def __call__(self, samples):
# transform them first
if self.valid:
samples = [(audio, mel_spectrogram, 0)
for audio, mel_spectrogram in samples]
else:
samples = [self.random_crop(sample) for sample in samples]
# batch them
audios = [sample[0] for sample in samples]
audio_starts = [sample[2] for sample in samples]
mels = [sample[1] for sample in samples]
mels = batch_spec(mels)
if self.valid:
audios = batch_wav(audios, dtype=np.float32)
else:
audios = np.array(audios, dtype=np.float32)
audio_starts = np.array(audio_starts, dtype=np.int64)
return audios, mels, audio_starts

View File

@ -30,9 +30,13 @@ from parakeet.utils import scheduler, mp_tools
from parakeet.training.cli import default_argument_parser from parakeet.training.cli import default_argument_parser
from parakeet.training.experiment import ExperimentBase from parakeet.training.experiment import ExperimentBase
from parakeet.utils.mp_tools import rank_zero_only from parakeet.utils.mp_tools import rank_zero_only
from parakeet.datasets import AudioDataset, AudioSegmentDataset
from parakeet.data import batch_wav
from parakeet.modules.audio import STFT, MelScale
from config import get_cfg_defaults from config import get_cfg_defaults
from ljspeech import LJSpeech, LJSpeechClipCollector, LJSpeechCollector from ljspeech import LJSpeech
class Experiment(ExperimentBase): class Experiment(ExperimentBase):
@ -61,38 +65,47 @@ class Experiment(ExperimentBase):
grad_clip=paddle.nn.ClipGradByGlobalNorm( grad_clip=paddle.nn.ClipGradByGlobalNorm(
config.training.gradient_max_norm)) config.training.gradient_max_norm))
self.stft = STFT(config.data.n_fft, config.data.hop_length, config.data.win_length)
self.mel_scale = MelScale(config.data.sample_rate, config.data.n_fft, config.data.n_mels, config.data.fmin, config.data.fmax)
self.model = model self.model = model
self.model_core = model._layers if self.parallel else model self.model_core = model._layers if self.parallel else model
self.optimizer = optimizer self.optimizer = optimizer
def setup_dataloader(self): def setup_dataloader(self):
config = self.config config = self.config
args = self.args args = self.args
ljspeech_dataset = LJSpeech(args.data)
valid_set, train_set = dataset.split(ljspeech_dataset,
config.data.valid_size)
# convolutional net's causal padding size # convolutional net's causal padding size
context_size = config.model.n_stack \ context_size = config.model.n_stack \
* sum([(config.model.filter_size - 1) * 2**i for i in range(config.model.n_loop)]) \ * sum([(config.model.filter_size - 1) * 2**i for i in range(config.model.n_loop)]) \
+ 1 + 1
context_frames = context_size // config.data.hop_length
# frames used to compute loss # frames used to compute loss
frames_per_second = config.data.sample_rate // config.data.hop_length train_clip_size = int(config.data.train_clip_seconds * config.data.sample_rate)
train_clip_frames = math.ceil(config.data.train_clip_seconds * length = context_size + train_clip_size
frames_per_second)
root = Path(args.data).expanduser()
file_paths = sorted(list((root / "wavs").rglob("*.wav")))
train_set = AudioSegmentDataset(
file_paths[config.data.valid_size:],
config.data.sample_rate,
length,
top_db=config.data.top_db)
valid_set = AudioDataset(
file_paths[:config.data.valid_size],
config.data.sample_rate,
top_db=config.data.top_db)
num_frames = train_clip_frames + context_frames
batch_fn = LJSpeechClipCollector(num_frames, config.data.hop_length)
if not self.parallel: if not self.parallel:
train_loader = DataLoader( train_loader = DataLoader(
train_set, train_set,
batch_size=config.data.batch_size, batch_size=config.data.batch_size,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
collate_fn=batch_fn) num_workers=1,
)
else: else:
sampler = DistributedBatchSampler( sampler = DistributedBatchSampler(
train_set, train_set,
@ -100,25 +113,36 @@ class Experiment(ExperimentBase):
shuffle=True, shuffle=True,
drop_last=True) drop_last=True)
train_loader = DataLoader( train_loader = DataLoader(
train_set, batch_sampler=sampler, collate_fn=batch_fn) train_set, batch_sampler=sampler, num_workers=1)
valid_batch_fn = LJSpeechCollector()
valid_loader = DataLoader( valid_loader = DataLoader(
valid_set, batch_size=1, collate_fn=valid_batch_fn) valid_set,
batch_size=config.data.batch_size,
num_workers=1,
collate_fn=batch_wav)
self.train_loader = train_loader self.train_loader = train_loader
self.valid_loader = valid_loader self.valid_loader = valid_loader
def train_batch(self): def train_batch(self):
# load data
start = time.time() start = time.time()
batch = self.read_batch() batch = self.read_batch()
data_loader_time = time.time() - start data_loader_time = time.time() - start
self.model.train() self.model.train()
self.optimizer.clear_grad() self.optimizer.clear_grad()
mel, wav, audio_starts = batch wav = batch
y = self.model(wav, mel, audio_starts) # data preprocessing
S = self.stft.magnitude(wav)
mel = self.mel_scale(S)
logmel = 20 * paddle.log10(mel, paddle.clip(mel, min=1e-5))
logmel = paddle.clip((logmel + 80) / 100, min=0.0, max=1.0)
# forward & backward
y = self.model(wav, logmel)
loss = self.model_core.loss(y, wav) loss = self.model_core.loss(y, wav)
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
@ -129,24 +153,43 @@ class Experiment(ExperimentBase):
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time, msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time,
iteration_time) iteration_time)
msg += "loss: {:>.6f}".format(loss_value) msg += "train/loss: {:>.6f}, ".format(loss_value)
msg += "lr: {:>.6f}".format(self.optimizer.get_lr())
self.logger.info(msg) self.logger.info(msg)
if dist.get_rank() == 0: if dist.get_rank() == 0:
self.visualizer.add_scalar( self.visualizer.add_scalar(
"train/loss", loss_value, global_step=self.iteration) "train/loss", loss_value, self.iteration)
self.visualizer.add_scalar(
"train/lr", self.optimizer.get_lr(), self.iteration)
# now we have to call learning rate scheduler.step() mannually
self.optimizer._learning_rate.step()
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
valid_iterator = iter(self.valid_loader)
valid_losses = [] valid_losses = []
mel, wav, audio_starts = next(valid_iterator)
y = self.model(wav, mel, audio_starts) for batch in self.valid_loader:
loss = self.model_core.loss(y, wav) wav, length = batch
valid_losses.append(float(loss)) # data preprocessing
valid_loss = np.mean(valid_losses) S = self.stft.magnitude(wav)
mel = self.mel_scale(S)
logmel = 20 * paddle.log10(mel, paddle.clip(mel, min=1e-5))
logmel = paddle.clip((logmel + 80) / 100, min=0.0, max=1.0)
y = self.model(wav, logmel)
loss = self.model_core.loss(y, wav)
valid_losses.append(float(loss))
valid_loss = np.mean(valid_losses)
msg = "Rank: {}, ".format(dist.get_rank())
msg += "step: {}, ".format(self.iteration)
msg += "valid/loss: {:>.6f}".format(valid_loss)
self.logger.info(msg)
self.visualizer.add_scalar( self.visualizer.add_scalar(
"valid/loss", valid_loss, global_step=self.iteration) "valid/loss", valid_loss, self.iteration)
def main_sp(config, args): def main_sp(config, args):

View File

@ -65,7 +65,7 @@ def batch_text_id(minibatch, pad_id=0, dtype=np.int64):
mode='constant', mode='constant',
constant_values=pad_id)) constant_values=pad_id))
return np.array(batch, dtype=dtype) return np.array(batch, dtype=dtype), np.array(lengths, dtype=np.int64)
class WavBatcher(object): class WavBatcher(object):
@ -106,7 +106,7 @@ def batch_wav(minibatch, pad_value=0., dtype=np.float32):
np.pad(example, [(0, pad_len)], np.pad(example, [(0, pad_len)],
mode='constant', mode='constant',
constant_values=pad_value)) constant_values=pad_value))
return np.array(batch, dtype=dtype) return np.array(batch, dtype=dtype), np.array(lengths, dtype=np.int64)
class SpecBatcher(object): class SpecBatcher(object):
@ -160,4 +160,4 @@ def batch_spec(minibatch, pad_value=0., time_major=False, dtype=np.float32):
np.pad(example, [(0, 0), (0, pad_len)], np.pad(example, [(0, 0), (0, pad_len)],
mode='constant', mode='constant',
constant_values=pad_value)) constant_values=pad_value))
return np.array(batch, dtype=dtype) return np.array(batch, dtype=dtype), np.array(lengths, dtype=np.int64)

View File

@ -15,24 +15,75 @@
from paddle.io import Dataset from paddle.io import Dataset
import os import os
import librosa import librosa
from pathlib import Path
import numpy as np
from typing import List
__all__ = ["AudioFolderDataset"] __all__ = ["AudioSegmentDataset", "AudioDataset", "AudioFolderDataset"]
class AudioFolderDataset(Dataset): class AudioSegmentDataset(Dataset):
def __init__(self, path, sample_rate, extension="wav"): """A simple dataset adaptor for audio files to train vocoders.
self.root = os.path.expanduser(path) Read -> trim silence -> normalize -> extract a segment
self.sample_rate = sample_rate """
self.extension = extension def __init__(self, file_paths: List[Path], sample_rate: int, length: int,
self.file_names = [ top_db: float):
os.path.join(self.root, x) for x in os.listdir(self.root) \ self.file_paths = file_paths
if os.path.splitext(x)[-1] == self.extension] self.sr = sample_rate
self.length = len(self.file_names) self.top_db = top_db
self.length = length # samples in the clip
def __len__(self):
return self.length
def __getitem__(self, i): def __getitem__(self, i):
file_name = self.file_names[i] fpath = self.file_paths[i]
y, _ = librosa.load(file_name, sr=self.sample_rate) # pylint: disable=unused-variable y, sr = librosa.load(fpath, self.sr)
y, _ = librosa.effects.trim(y, top_db=self.top_db)
y = librosa.util.normalize(y)
y = y.astype(np.float32)
# pad or trim
if y.size <= self.length:
y = np.pad(y, [0, self.length - len(y)], mode='constant')
else:
start = np.random.randint(0, 1 + len(y) - self.length)
y = y[start:start + self.length]
return y return y
def __len__(self):
return len(self.file_paths)
class AudioDataset(Dataset):
"""A simple dataset adaptor for the audio files.
Read -> trim silence -> normalize
"""
def __init__(self,
file_paths: List[Path],
sample_rate: int,
top_db: float = 60):
self.file_paths = file_paths
self.sr = sample_rate
self.top_db = top_db
def __getitem__(self, i):
fpath = self.file_paths[i]
y, sr = librosa.load(fpath, self.sr)
y, _ = librosa.effects.trim(y, top_db=self.top_db)
y = librosa.util.normalize(y)
y = y.astype(np.float32)
return y
def __len__(self):
return len(self.file_paths)
class AudioFolderDataset(AudioDataset):
def __init__(
self,
root,
sample_rate,
top_db=60,
extension=".wav",
):
root = Path(root).expanduser()
file_paths = sorted(list(root.rglob("*{}".format(extension))))
super().__init__(file_paths, sample_rate, top_db)

View File

@ -101,9 +101,7 @@ class UpsampleNet(nn.LayerList):
def __init__(self, upscale_factors=[16, 16]): def __init__(self, upscale_factors=[16, 16]):
super(UpsampleNet, self).__init__() super(UpsampleNet, self).__init__()
self.upscale_factors = list(upscale_factors) self.upscale_factors = list(upscale_factors)
self.upscale_factor = 1 self.upscale_factor = np.prod(upscale_factors)
for item in upscale_factors:
self.upscale_factor *= item
for factor in self.upscale_factors: for factor in self.upscale_factors:
self.append( self.append(
@ -224,13 +222,15 @@ class ResidualBlock(nn.Layer):
other ResidualBlocks. other ResidualBlocks.
""" """
h = x h = x
length = x.shape[-1]
# dilated conv # dilated conv
h = self.conv(h) h = self.conv(h)
# condition # condition
# NOTE: expanded condition may have a larger timesteps than x
if condition is not None: if condition is not None:
h += self.condition_proj(condition) h += self.condition_proj(condition)[:, :, :length]
# gated tanh # gated tanh
content, gate = paddle.split(h, 2, axis=1) content, gate = paddle.split(h, 2, axis=1)
@ -822,7 +822,7 @@ class ConditionalWaveNet(nn.Layer):
loss_type=loss_type, loss_type=loss_type,
log_scale_min=log_scale_min) log_scale_min=log_scale_min)
def forward(self, audio, mel, audio_start): def forward(self, audio, mel):
"""Compute the output distribution given the mel spectrogram and the input(for teacher force training). """Compute the output distribution given the mel spectrogram and the input(for teacher force training).
Parameters Parameters
@ -845,13 +845,13 @@ class ConditionalWaveNet(nn.Layer):
""" """
audio_length = audio.shape[1] # audio clip's length audio_length = audio.shape[1] # audio clip's length
condition = self.encoder(mel) condition = self.encoder(mel)
condition_slice = crop(condition, audio_start, audio_length)
# shifting 1 step # shifting 1 step
audio = audio[:, :-1] audio = audio[:, :-1]
condition_slice = condition_slice[:, :, 1:] condition = condition[:, :, 1:]
y = self.decoder(audio, condition_slice) y = self.decoder(audio, condition)
return y return y
def loss(self, y, t): def loss(self, y, t):

View File

@ -16,6 +16,8 @@ import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
from scipy import signal from scipy import signal
import librosa
from librosa.util import pad_center
import numpy as np import numpy as np
__all__ = ["quantize", "dequantize", "STFT"] __all__ = ["quantize", "dequantize", "STFT"]
@ -88,6 +90,19 @@ class STFT(nn.Layer):
Name of window function, see `scipy.signal.get_window` for more Name of window function, see `scipy.signal.get_window` for more
details. Defaults to "hanning". details. Defaults to "hanning".
center : bool
If True, the signal y is padded so that frame D[:, t] is centered
at y[t * hop_length]. If False, then D[:, t] begins at y[t * hop_length].
Defaults to True.
pad_mode : string or function
If center=True, this argument is passed to np.pad for padding the edges
of the signal y. By default (pad_mode="reflect"), y is padded on both
sides with its own reflection, mirrored around its first and last
sample respectively. If center=False, this argument is ignored.
Notes Notes
----------- -----------
It behaves like ``librosa.core.stft``. See ``librosa.core.stft`` for more It behaves like ``librosa.core.stft``. See ``librosa.core.stft`` for more
@ -101,29 +116,47 @@ class STFT(nn.Layer):
""" """
def __init__(self, n_fft, hop_length, win_length, window="hanning"): def __init__(self, n_fft, hop_length=None, win_length=None, window="hanning", center=True, pad_mode="reflect"):
super(STFT, self).__init__() super().__init__()
# By default, use the entire frame
if win_length is None:
win_length = n_fft
# Set the default hop, if it's not already specified
if hop_length is None:
hop_length = int(win_length // 4)
self.hop_length = hop_length self.hop_length = hop_length
self.n_bin = 1 + n_fft // 2 self.n_bin = 1 + n_fft // 2
self.n_fft = n_fft self.n_fft = n_fft
self.center = center
self.pad_mode = pad_mode
# calculate window # calculate window
window = signal.get_window(window, win_length) window = signal.get_window(window, win_length, fftbins=True)
# pad window to n_fft size
if n_fft != win_length: if n_fft != win_length:
pad = (n_fft - win_length) // 2 window = pad_center(window, n_fft, mode="constant")
window = np.pad(window, ((pad, pad), ), 'constant') #lpad = (n_fft - win_length) // 2
#rpad = n_fft - win_length - lpad
#window = np.pad(window, ((lpad, pad), ), 'constant')
# calculate weights # calculate weights
r = np.arange(0, n_fft) #r = np.arange(0, n_fft)
M = np.expand_dims(r, -1) * np.expand_dims(r, 0) #M = np.expand_dims(r, -1) * np.expand_dims(r, 0)
w_real = np.reshape(window * #w_real = np.reshape(window *
np.cos(2 * np.pi * M / n_fft)[:self.n_bin], #np.cos(2 * np.pi * M / n_fft)[:self.n_bin],
(self.n_bin, 1, 1, self.n_fft)) #(self.n_bin, 1, self.n_fft))
w_imag = np.reshape(window * #w_imag = np.reshape(window *
np.sin(-2 * np.pi * M / n_fft)[:self.n_bin], #np.sin(-2 * np.pi * M / n_fft)[:self.n_bin],
(self.n_bin, 1, 1, self.n_fft)) #(self.n_bin, 1, self.n_fft))
weight = np.fft.fft(np.eye(n_fft))[:self.n_bin]
w_real = weight.real
w_imag = weight.imag
w = np.concatenate([w_real, w_imag], axis=0) w = np.concatenate([w_real, w_imag], axis=0)
w = w * window
w = np.expand_dims(w, 1)
self.weight = paddle.cast( self.weight = paddle.cast(
paddle.to_tensor(w), paddle.get_default_dtype()) paddle.to_tensor(w), paddle.get_default_dtype())
@ -137,23 +170,20 @@ class STFT(nn.Layer):
Returns Returns
------------ ------------
real : Tensor [shape=(B, C, 1, frames)] real : Tensor [shape=(B, C, frames)]
The real part of the spectrogram. The real part of the spectrogram.
imag : Tensor [shape=(B, C, 1, frames)] imag : Tensor [shape=(B, C, frames)]
The image part of the spectrogram. The image part of the spectrogram.
""" """
# x(batch_size, time_steps) x = paddle.unsqueeze(x, axis=1)
# pad it first with reflect mode if self.center:
# TODO(chenfeiyu): report an issue on paddle.flip x = F.pad(x, [self.n_fft // 2, self.n_fft // 2],
pad_start = paddle.reverse(x[:, 1:1 + self.n_fft // 2], axis=[1]) data_format='NCL', mode=self.pad_mode)
pad_stop = paddle.reverse(x[:, -(1 + self.n_fft // 2):-1], axis=[1])
x = paddle.concat([pad_start, x, pad_stop], axis=-1)
# to BC1T, C=1 # to BCT, C=1
x = paddle.unsqueeze(x, axis=[1, 2]) out = F.conv1d(x, self.weight, stride=self.hop_length)
out = F.conv2d(x, self.weight, stride=(1, self.hop_length)) real, imag = paddle.chunk(out, 2, axis=1) # BCT
real, imag = paddle.chunk(out, 2, axis=1) # BC1T
return real, imag return real, imag
def power(self, x): def power(self, x):
@ -166,7 +196,7 @@ class STFT(nn.Layer):
Returns Returns
------------ ------------
Tensor [shape=(B, C, 1, T)] Tensor [shape=(B, C, T)]
The power spectrum. The power spectrum.
""" """
real, imag = self(x) real, imag = self(x)
@ -183,9 +213,22 @@ class STFT(nn.Layer):
Returns Returns
------------ ------------
Tensor [shape=(B, C, 1, T)] Tensor [shape=(B, C, T)]
The magnitude of the spectrum. The magnitude of the spectrum.
""" """
power = self.power(x) power = self.power(x)
magnitude = paddle.sqrt(power) magnitude = paddle.sqrt(power)
return magnitude return magnitude
class MelScale(nn.Layer):
def __init__(self, sr, n_fft, n_mels, fmin, fmax):
super().__init__()
mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax)
print(mel_basis.shape)
self.weight = paddle.to_tensor(mel_basis)
def forward(self, spec):
# (n_mels, n_freq) * (batch_size, n_freq, n_frames)
mel = paddle.matmul(self.weight, spec)
return mel