From 3741cc49ca7a041b36ccf477cea3aac14c88903f Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Mon, 19 Apr 2021 19:58:36 +0800 Subject: [PATCH] change wavenet to use on-the-fly prepeocessing --- examples/tacotron2/ljspeech.py | 4 +- examples/tacotron2_aishell3/aishell3.py | 6 +- examples/tacotron2_baker/data.py | 6 +- examples/tacotron2_vctk/vctk.py | 4 +- examples/transformer_tts/ljspeech.py | 6 +- examples/waveflow/ljspeech.py | 4 +- examples/wavenet/config.py | 6 +- examples/wavenet/ljspeech.py | 141 ++++-------------------- examples/wavenet/train.py | 97 +++++++++++----- parakeet/data/batch.py | 6 +- parakeet/datasets/common.py | 81 +++++++++++--- parakeet/models/wavenet.py | 18 +-- parakeet/modules/audio.py | 99 ++++++++++++----- 13 files changed, 262 insertions(+), 216 deletions(-) diff --git a/examples/tacotron2/ljspeech.py b/examples/tacotron2/ljspeech.py index ad6cfbc..9acebc4 100644 --- a/examples/tacotron2/ljspeech.py +++ b/examples/tacotron2/ljspeech.py @@ -92,8 +92,8 @@ class LJSpeechCollector(object): text_lens = np.array(sorted(text_lens, reverse=True), dtype=np.int64) # Pad sequence with largest len of the batch - texts = batch_text_id(texts, pad_id=self.padding_idx) - mels = np.transpose( + texts, _ = batch_text_id(texts, pad_id=self.padding_idx) + mels, _ = np.transpose( batch_spec( mels, pad_value=self.padding_value), axes=(0, 2, 1)) diff --git a/examples/tacotron2_aishell3/aishell3.py b/examples/tacotron2_aishell3/aishell3.py index 4a8d6c7..d017bdc 100644 --- a/examples/tacotron2_aishell3/aishell3.py +++ b/examples/tacotron2_aishell3/aishell3.py @@ -44,9 +44,9 @@ def collate_aishell3_examples(examples): spec_lengths = np.array([item.shape[1] for item in mel], dtype=np.int64) T_dec = np.max(spec_lengths) stop_tokens = (np.arange(T_dec) >= np.expand_dims(spec_lengths, -1)).astype(np.float32) - phones = batch_text_id(phones) - tones = batch_text_id(tones) - mel = batch_spec(mel) + phones, _ = batch_text_id(phones) + tones, _ = batch_text_id(tones) + mel, _ = batch_spec(mel) mel = np.transpose(mel, (0, 2, 1)) embed = np.stack(embed) # 7 fields diff --git a/examples/tacotron2_baker/data.py b/examples/tacotron2_baker/data.py index fc2d815..f91ee28 100644 --- a/examples/tacotron2_baker/data.py +++ b/examples/tacotron2_baker/data.py @@ -40,9 +40,9 @@ def collate_baker_examples(examples): spec_lengths = np.array([item.shape[1] for item in mel], dtype=np.int64) T_dec = np.max(spec_lengths) stop_tokens = (np.arange(T_dec) >= np.expand_dims(spec_lengths, -1)).astype(np.float32) - phones = batch_text_id(phones) - tones = batch_text_id(tones) - mel = batch_spec(mel) + phone, _ = batch_text_id(phones) + tones, _ = batch_text_id(tones) + mel, _ = batch_spec(mel) mel = np.transpose(mel, (0, 2, 1)) return phones, tones, mel, text_lengths, spec_lengths, stop_tokens diff --git a/examples/tacotron2_vctk/vctk.py b/examples/tacotron2_vctk/vctk.py index e4d8af9..d4c588d 100644 --- a/examples/tacotron2_vctk/vctk.py +++ b/examples/tacotron2_vctk/vctk.py @@ -53,6 +53,6 @@ def collate_vctk_examples(examples): slens = np.array([item.shape[1] for item in mels], dtype=np.int64) speaker_ids = np.array(speaker_ids, dtype=np.int64) - phonemes = batch_text_id(phonemes, pad_id=0) - mels = np.transpose(batch_spec(mels, pad_value=0.), [0, 2, 1]) + phonemes, _ = batch_text_id(phonemes, pad_id=0) + mels, _ = np.transpose(batch_spec(mels, pad_value=0.), [0, 2, 1]) return phonemes, plens, mels, slens, speaker_ids diff --git a/examples/transformer_tts/ljspeech.py b/examples/transformer_tts/ljspeech.py index 137db96..7f89503 100644 --- a/examples/transformer_tts/ljspeech.py +++ b/examples/transformer_tts/ljspeech.py @@ -76,9 +76,9 @@ class LJSpeechCollector(object): mels = [example[1] for example in examples] stop_probs = [example[2] for example in examples] - ids = batch_text_id(ids, pad_id=self.padding_idx) - mels = batch_spec(mels, pad_value=self.padding_value) - stop_probs = batch_text_id(stop_probs, pad_id=self.padding_idx) + ids, _ = batch_text_id(ids, pad_id=self.padding_idx) + mels, _ = batch_spec(mels, pad_value=self.padding_value) + stop_probs, _ = batch_text_id(stop_probs, pad_id=self.padding_idx) return ids, np.transpose(mels, [0, 2, 1]), stop_probs diff --git a/examples/waveflow/ljspeech.py b/examples/waveflow/ljspeech.py index e07303a..bc288bc 100644 --- a/examples/waveflow/ljspeech.py +++ b/examples/waveflow/ljspeech.py @@ -61,8 +61,8 @@ class LJSpeechCollector(object): def __call__(self, examples): mels = [example[0] for example in examples] wavs = [example[1] for example in examples] - mels = batch_spec(mels, pad_value=self.padding_value) - wavs = batch_wav(wavs, pad_value=self.padding_value) + mels, _ = batch_spec(mels, pad_value=self.padding_value) + wavs, _ = batch_wav(wavs, pad_value=self.padding_value) return mels, wavs diff --git a/examples/wavenet/config.py b/examples/wavenet/config.py index 658d416..8e09d8c 100644 --- a/examples/wavenet/config.py +++ b/examples/wavenet/config.py @@ -20,10 +20,12 @@ _C.data = CN( batch_size=8, # batch size valid_size=16, # the first N examples are reserved for validation sample_rate=22050, # Hz, sample rate - n_fft=2048, # fft frame size + n_fft=1024, # fft frame size win_length=1024, # window size hop_length=256, # hop size between ajacent frame - # f_max=8000, # Hz, max frequency when converting to mel + top_db=60, # db, used to trim silence + fmin = 0, # Hz, max frequency when converting to mel + fmax=8000, # Hz, max frequency when converting to mel n_mels=80, # mel bands train_clip_seconds=0.5, # audio clip length(in seconds) )) diff --git a/examples/wavenet/ljspeech.py b/examples/wavenet/ljspeech.py index d1d3c67..92c53ff 100644 --- a/examples/wavenet/ljspeech.py +++ b/examples/wavenet/ljspeech.py @@ -16,136 +16,43 @@ import os from pathlib import Path import pickle import numpy as np +import librosa import pandas from paddle.io import Dataset, DataLoader -from parakeet.data.batch import batch_spec, batch_wav -from parakeet.data import dataset -from parakeet.audio import AudioProcessor - class LJSpeech(Dataset): """A simple dataset adaptor for the processed ljspeech dataset.""" - def __init__(self, root): + def __init__(self, root, sample_rate, length, top_db): self.root = Path(root).expanduser() - meta_data = pandas.read_csv( + self.metadata = pandas.read_csv( str(self.root / "metadata.csv"), - sep="\t", + sep="|", header=None, - names=["fname", "frames", "samples"]) - - records = [] - for row in meta_data.itertuples(): - mel_path = str(self.root / "mel" / (row.fname + ".npy")) - wav_path = str(self.root / "wav" / (row.fname + ".npy")) - records.append((mel_path, wav_path)) - self.records = records + names=["fname", "text", "normalized_text"]) + self.wav_dir = self.root / "wavs" + self.sr = sample_rate + self.top_db = top_db + self.length = length # samples in the clip def __getitem__(self, i): - mel_name, wav_name = self.records[i] - mel = np.load(mel_name) - wav = np.load(wav_name) - return mel, wav + fname = self.metadata.iloc[0].fname + fpath = (self.wav_dir / fname).with_suffix(".wav") + y, sr = librosa.load(fpath, self.sr) + y, _ = librosa.effects.trim(y, top_db=self.top_db) + y = librosa.util.normalize(y) + y = y.astype(np.float32) + + # pad or trim + if y.size <= self.length: + y = np.pad(y, [0, self.length - len(y)], mode='constant') + else: + start = np.random.randint(0, 1 + len(y) - self.length) + y = y[start: start + self.length] + return y def __len__(self): - return len(self.records) + return len(self.metadata) -class LJSpeechCollector(object): - """A simple callable to batch LJSpeech examples.""" - - def __init__(self, padding_value=0.): - self.padding_value = padding_value - - def __call__(self, examples): - batch_size = len(examples) - mels = [example[0] for example in examples] - wavs = [example[1] for example in examples] - mels = batch_spec(mels, pad_value=self.padding_value) - wavs = batch_wav(wavs, pad_value=self.padding_value) - audio_starts = np.zeros((batch_size, ), dtype=np.int64) - return mels, wavs, audio_starts - - -class LJSpeechClipCollector(object): - def __init__(self, clip_frames=65, hop_length=256): - self.clip_frames = clip_frames - self.hop_length = hop_length - - def __call__(self, examples): - mels = [] - wavs = [] - starts = [] - for example in examples: - mel, wav_clip, start = self.clip(example) - mels.append(mel) - wavs.append(wav_clip) - starts.append(start) - mels = batch_spec(mels) - wavs = np.stack(wavs) - starts = np.array(starts, dtype=np.int64) - return mels, wavs, starts - - def clip(self, example): - mel, wav = example - frames = mel.shape[-1] - start = np.random.randint(0, frames - self.clip_frames) - wav_clip = wav[start * self.hop_length:(start + self.clip_frames) * - self.hop_length] - return mel, wav_clip, start - - -class DataCollector(object): - def __init__(self, - context_size, - sample_rate, - hop_length, - train_clip_seconds, - valid=False): - frames_per_second = sample_rate // hop_length - train_clip_frames = int( - np.ceil(train_clip_seconds * frames_per_second)) - context_frames = context_size // hop_length - self.num_frames = train_clip_frames + context_frames - - self.sample_rate = sample_rate - self.hop_length = hop_length - self.valid = valid - - def random_crop(self, sample): - audio, mel_spectrogram = sample - audio_frames = int(audio.size) // self.hop_length - max_start_frame = audio_frames - self.num_frames - assert max_start_frame >= 0, "audio is too short to be cropped" - - frame_start = np.random.randint(0, max_start_frame) - # frame_start = 0 # norandom - frame_end = frame_start + self.num_frames - - audio_start = frame_start * self.hop_length - audio_end = frame_end * self.hop_length - - audio = audio[audio_start:audio_end] - return audio, mel_spectrogram, audio_start - - def __call__(self, samples): - # transform them first - if self.valid: - samples = [(audio, mel_spectrogram, 0) - for audio, mel_spectrogram in samples] - else: - samples = [self.random_crop(sample) for sample in samples] - # batch them - audios = [sample[0] for sample in samples] - audio_starts = [sample[2] for sample in samples] - mels = [sample[1] for sample in samples] - - mels = batch_spec(mels) - - if self.valid: - audios = batch_wav(audios, dtype=np.float32) - else: - audios = np.array(audios, dtype=np.float32) - audio_starts = np.array(audio_starts, dtype=np.int64) - return audios, mels, audio_starts diff --git a/examples/wavenet/train.py b/examples/wavenet/train.py index 8a42e6f..642d4ff 100644 --- a/examples/wavenet/train.py +++ b/examples/wavenet/train.py @@ -30,9 +30,13 @@ from parakeet.utils import scheduler, mp_tools from parakeet.training.cli import default_argument_parser from parakeet.training.experiment import ExperimentBase from parakeet.utils.mp_tools import rank_zero_only +from parakeet.datasets import AudioDataset, AudioSegmentDataset +from parakeet.data import batch_wav + +from parakeet.modules.audio import STFT, MelScale from config import get_cfg_defaults -from ljspeech import LJSpeech, LJSpeechClipCollector, LJSpeechCollector +from ljspeech import LJSpeech class Experiment(ExperimentBase): @@ -60,39 +64,48 @@ class Experiment(ExperimentBase): parameters=model.parameters(), grad_clip=paddle.nn.ClipGradByGlobalNorm( config.training.gradient_max_norm)) - + + self.stft = STFT(config.data.n_fft, config.data.hop_length, config.data.win_length) + self.mel_scale = MelScale(config.data.sample_rate, config.data.n_fft, config.data.n_mels, config.data.fmin, config.data.fmax) + self.model = model self.model_core = model._layers if self.parallel else model self.optimizer = optimizer + def setup_dataloader(self): config = self.config args = self.args - ljspeech_dataset = LJSpeech(args.data) - valid_set, train_set = dataset.split(ljspeech_dataset, - config.data.valid_size) - # convolutional net's causal padding size context_size = config.model.n_stack \ * sum([(config.model.filter_size - 1) * 2**i for i in range(config.model.n_loop)]) \ + 1 - context_frames = context_size // config.data.hop_length # frames used to compute loss - frames_per_second = config.data.sample_rate // config.data.hop_length - train_clip_frames = math.ceil(config.data.train_clip_seconds * - frames_per_second) + train_clip_size = int(config.data.train_clip_seconds * config.data.sample_rate) + length = context_size + train_clip_size + + root = Path(args.data).expanduser() + file_paths = sorted(list((root / "wavs").rglob("*.wav"))) + train_set = AudioSegmentDataset( + file_paths[config.data.valid_size:], + config.data.sample_rate, + length, + top_db=config.data.top_db) + valid_set = AudioDataset( + file_paths[:config.data.valid_size], + config.data.sample_rate, + top_db=config.data.top_db) - num_frames = train_clip_frames + context_frames - batch_fn = LJSpeechClipCollector(num_frames, config.data.hop_length) if not self.parallel: train_loader = DataLoader( train_set, batch_size=config.data.batch_size, shuffle=True, drop_last=True, - collate_fn=batch_fn) + num_workers=1, + ) else: sampler = DistributedBatchSampler( train_set, @@ -100,25 +113,36 @@ class Experiment(ExperimentBase): shuffle=True, drop_last=True) train_loader = DataLoader( - train_set, batch_sampler=sampler, collate_fn=batch_fn) + train_set, batch_sampler=sampler, num_workers=1) - valid_batch_fn = LJSpeechCollector() valid_loader = DataLoader( - valid_set, batch_size=1, collate_fn=valid_batch_fn) + valid_set, + batch_size=config.data.batch_size, + num_workers=1, + collate_fn=batch_wav) self.train_loader = train_loader self.valid_loader = valid_loader def train_batch(self): + # load data start = time.time() batch = self.read_batch() data_loader_time = time.time() - start self.model.train() self.optimizer.clear_grad() - mel, wav, audio_starts = batch + wav = batch + + # data preprocessing + S = self.stft.magnitude(wav) + mel = self.mel_scale(S) + logmel = 20 * paddle.log10(mel, paddle.clip(mel, min=1e-5)) + logmel = paddle.clip((logmel + 80) / 100, min=0.0, max=1.0) + + # forward & backward - y = self.model(wav, mel, audio_starts) + y = self.model(wav, logmel) loss = self.model_core.loss(y, wav) loss.backward() self.optimizer.step() @@ -129,24 +153,43 @@ class Experiment(ExperimentBase): msg += "step: {}, ".format(self.iteration) msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time, iteration_time) - msg += "loss: {:>.6f}".format(loss_value) + msg += "train/loss: {:>.6f}, ".format(loss_value) + msg += "lr: {:>.6f}".format(self.optimizer.get_lr()) self.logger.info(msg) if dist.get_rank() == 0: self.visualizer.add_scalar( - "train/loss", loss_value, global_step=self.iteration) + "train/loss", loss_value, self.iteration) + self.visualizer.add_scalar( + "train/lr", self.optimizer.get_lr(), self.iteration) + + # now we have to call learning rate scheduler.step() mannually + self.optimizer._learning_rate.step() @mp_tools.rank_zero_only @paddle.no_grad() def valid(self): - valid_iterator = iter(self.valid_loader) valid_losses = [] - mel, wav, audio_starts = next(valid_iterator) - y = self.model(wav, mel, audio_starts) - loss = self.model_core.loss(y, wav) - valid_losses.append(float(loss)) - valid_loss = np.mean(valid_losses) + + for batch in self.valid_loader: + wav, length = batch + # data preprocessing + S = self.stft.magnitude(wav) + mel = self.mel_scale(S) + logmel = 20 * paddle.log10(mel, paddle.clip(mel, min=1e-5)) + logmel = paddle.clip((logmel + 80) / 100, min=0.0, max=1.0) + + y = self.model(wav, logmel) + loss = self.model_core.loss(y, wav) + valid_losses.append(float(loss)) + valid_loss = np.mean(valid_losses) + + msg = "Rank: {}, ".format(dist.get_rank()) + msg += "step: {}, ".format(self.iteration) + msg += "valid/loss: {:>.6f}".format(valid_loss) + self.logger.info(msg) + self.visualizer.add_scalar( - "valid/loss", valid_loss, global_step=self.iteration) + "valid/loss", valid_loss, self.iteration) def main_sp(config, args): diff --git a/parakeet/data/batch.py b/parakeet/data/batch.py index 4c5be61..1397f55 100644 --- a/parakeet/data/batch.py +++ b/parakeet/data/batch.py @@ -65,7 +65,7 @@ def batch_text_id(minibatch, pad_id=0, dtype=np.int64): mode='constant', constant_values=pad_id)) - return np.array(batch, dtype=dtype) + return np.array(batch, dtype=dtype), np.array(lengths, dtype=np.int64) class WavBatcher(object): @@ -106,7 +106,7 @@ def batch_wav(minibatch, pad_value=0., dtype=np.float32): np.pad(example, [(0, pad_len)], mode='constant', constant_values=pad_value)) - return np.array(batch, dtype=dtype) + return np.array(batch, dtype=dtype), np.array(lengths, dtype=np.int64) class SpecBatcher(object): @@ -160,4 +160,4 @@ def batch_spec(minibatch, pad_value=0., time_major=False, dtype=np.float32): np.pad(example, [(0, 0), (0, pad_len)], mode='constant', constant_values=pad_value)) - return np.array(batch, dtype=dtype) + return np.array(batch, dtype=dtype), np.array(lengths, dtype=np.int64) diff --git a/parakeet/datasets/common.py b/parakeet/datasets/common.py index a1d16d6..78bfc2b 100644 --- a/parakeet/datasets/common.py +++ b/parakeet/datasets/common.py @@ -15,24 +15,75 @@ from paddle.io import Dataset import os import librosa +from pathlib import Path +import numpy as np +from typing import List -__all__ = ["AudioFolderDataset"] +__all__ = ["AudioSegmentDataset", "AudioDataset", "AudioFolderDataset"] -class AudioFolderDataset(Dataset): - def __init__(self, path, sample_rate, extension="wav"): - self.root = os.path.expanduser(path) - self.sample_rate = sample_rate - self.extension = extension - self.file_names = [ - os.path.join(self.root, x) for x in os.listdir(self.root) \ - if os.path.splitext(x)[-1] == self.extension] - self.length = len(self.file_names) - - def __len__(self): - return self.length +class AudioSegmentDataset(Dataset): + """A simple dataset adaptor for audio files to train vocoders. + Read -> trim silence -> normalize -> extract a segment + """ + def __init__(self, file_paths: List[Path], sample_rate: int, length: int, + top_db: float): + self.file_paths = file_paths + self.sr = sample_rate + self.top_db = top_db + self.length = length # samples in the clip def __getitem__(self, i): - file_name = self.file_names[i] - y, _ = librosa.load(file_name, sr=self.sample_rate) # pylint: disable=unused-variable + fpath = self.file_paths[i] + y, sr = librosa.load(fpath, self.sr) + y, _ = librosa.effects.trim(y, top_db=self.top_db) + y = librosa.util.normalize(y) + y = y.astype(np.float32) + + # pad or trim + if y.size <= self.length: + y = np.pad(y, [0, self.length - len(y)], mode='constant') + else: + start = np.random.randint(0, 1 + len(y) - self.length) + y = y[start:start + self.length] return y + + def __len__(self): + return len(self.file_paths) + + +class AudioDataset(Dataset): + """A simple dataset adaptor for the audio files. + Read -> trim silence -> normalize + """ + def __init__(self, + file_paths: List[Path], + sample_rate: int, + top_db: float = 60): + self.file_paths = file_paths + self.sr = sample_rate + self.top_db = top_db + + def __getitem__(self, i): + fpath = self.file_paths[i] + y, sr = librosa.load(fpath, self.sr) + y, _ = librosa.effects.trim(y, top_db=self.top_db) + y = librosa.util.normalize(y) + y = y.astype(np.float32) + return y + + def __len__(self): + return len(self.file_paths) + + +class AudioFolderDataset(AudioDataset): + def __init__( + self, + root, + sample_rate, + top_db=60, + extension=".wav", + ): + root = Path(root).expanduser() + file_paths = sorted(list(root.rglob("*{}".format(extension)))) + super().__init__(file_paths, sample_rate, top_db) diff --git a/parakeet/models/wavenet.py b/parakeet/models/wavenet.py index 5ff3435..1914942 100644 --- a/parakeet/models/wavenet.py +++ b/parakeet/models/wavenet.py @@ -101,9 +101,7 @@ class UpsampleNet(nn.LayerList): def __init__(self, upscale_factors=[16, 16]): super(UpsampleNet, self).__init__() self.upscale_factors = list(upscale_factors) - self.upscale_factor = 1 - for item in upscale_factors: - self.upscale_factor *= item + self.upscale_factor = np.prod(upscale_factors) for factor in self.upscale_factors: self.append( @@ -224,13 +222,15 @@ class ResidualBlock(nn.Layer): other ResidualBlocks. """ h = x - + length = x.shape[-1] + # dilated conv h = self.conv(h) # condition + # NOTE: expanded condition may have a larger timesteps than x if condition is not None: - h += self.condition_proj(condition) + h += self.condition_proj(condition)[:, :, :length] # gated tanh content, gate = paddle.split(h, 2, axis=1) @@ -822,7 +822,7 @@ class ConditionalWaveNet(nn.Layer): loss_type=loss_type, log_scale_min=log_scale_min) - def forward(self, audio, mel, audio_start): + def forward(self, audio, mel): """Compute the output distribution given the mel spectrogram and the input(for teacher force training). Parameters @@ -845,13 +845,13 @@ class ConditionalWaveNet(nn.Layer): """ audio_length = audio.shape[1] # audio clip's length condition = self.encoder(mel) - condition_slice = crop(condition, audio_start, audio_length) + # shifting 1 step audio = audio[:, :-1] - condition_slice = condition_slice[:, :, 1:] + condition = condition[:, :, 1:] - y = self.decoder(audio, condition_slice) + y = self.decoder(audio, condition) return y def loss(self, y, t): diff --git a/parakeet/modules/audio.py b/parakeet/modules/audio.py index 03e42b0..cfa215f 100644 --- a/parakeet/modules/audio.py +++ b/parakeet/modules/audio.py @@ -16,6 +16,8 @@ import paddle from paddle import nn from paddle.nn import functional as F from scipy import signal +import librosa +from librosa.util import pad_center import numpy as np __all__ = ["quantize", "dequantize", "STFT"] @@ -88,6 +90,19 @@ class STFT(nn.Layer): Name of window function, see `scipy.signal.get_window` for more details. Defaults to "hanning". + center : bool + If True, the signal y is padded so that frame D[:, t] is centered + at y[t * hop_length]. If False, then D[:, t] begins at y[t * hop_length]. + Defaults to True. + + pad_mode : string or function + If center=True, this argument is passed to np.pad for padding the edges + of the signal y. By default (pad_mode="reflect"), y is padded on both + sides with its own reflection, mirrored around its first and last + sample respectively. If center=False, this argument is ignored. + + + Notes ----------- It behaves like ``librosa.core.stft``. See ``librosa.core.stft`` for more @@ -101,29 +116,47 @@ class STFT(nn.Layer): """ - def __init__(self, n_fft, hop_length, win_length, window="hanning"): - super(STFT, self).__init__() + def __init__(self, n_fft, hop_length=None, win_length=None, window="hanning", center=True, pad_mode="reflect"): + super().__init__() + # By default, use the entire frame + if win_length is None: + win_length = n_fft + + # Set the default hop, if it's not already specified + if hop_length is None: + hop_length = int(win_length // 4) + self.hop_length = hop_length self.n_bin = 1 + n_fft // 2 self.n_fft = n_fft + self.center = center + self.pad_mode = pad_mode # calculate window - window = signal.get_window(window, win_length) + window = signal.get_window(window, win_length, fftbins=True) + + # pad window to n_fft size if n_fft != win_length: - pad = (n_fft - win_length) // 2 - window = np.pad(window, ((pad, pad), ), 'constant') + window = pad_center(window, n_fft, mode="constant") + #lpad = (n_fft - win_length) // 2 + #rpad = n_fft - win_length - lpad + #window = np.pad(window, ((lpad, pad), ), 'constant') # calculate weights - r = np.arange(0, n_fft) - M = np.expand_dims(r, -1) * np.expand_dims(r, 0) - w_real = np.reshape(window * - np.cos(2 * np.pi * M / n_fft)[:self.n_bin], - (self.n_bin, 1, 1, self.n_fft)) - w_imag = np.reshape(window * - np.sin(-2 * np.pi * M / n_fft)[:self.n_bin], - (self.n_bin, 1, 1, self.n_fft)) - + #r = np.arange(0, n_fft) + #M = np.expand_dims(r, -1) * np.expand_dims(r, 0) + #w_real = np.reshape(window * + #np.cos(2 * np.pi * M / n_fft)[:self.n_bin], + #(self.n_bin, 1, self.n_fft)) + #w_imag = np.reshape(window * + #np.sin(-2 * np.pi * M / n_fft)[:self.n_bin], + #(self.n_bin, 1, self.n_fft)) + weight = np.fft.fft(np.eye(n_fft))[:self.n_bin] + w_real = weight.real + w_imag = weight.imag w = np.concatenate([w_real, w_imag], axis=0) + w = w * window + w = np.expand_dims(w, 1) self.weight = paddle.cast( paddle.to_tensor(w), paddle.get_default_dtype()) @@ -137,23 +170,20 @@ class STFT(nn.Layer): Returns ------------ - real : Tensor [shape=(B, C, 1, frames)] + real : Tensor [shape=(B, C, frames)] The real part of the spectrogram. - imag : Tensor [shape=(B, C, 1, frames)] + imag : Tensor [shape=(B, C, frames)] The image part of the spectrogram. """ - # x(batch_size, time_steps) - # pad it first with reflect mode - # TODO(chenfeiyu): report an issue on paddle.flip - pad_start = paddle.reverse(x[:, 1:1 + self.n_fft // 2], axis=[1]) - pad_stop = paddle.reverse(x[:, -(1 + self.n_fft // 2):-1], axis=[1]) - x = paddle.concat([pad_start, x, pad_stop], axis=-1) + x = paddle.unsqueeze(x, axis=1) + if self.center: + x = F.pad(x, [self.n_fft // 2, self.n_fft // 2], + data_format='NCL', mode=self.pad_mode) - # to BC1T, C=1 - x = paddle.unsqueeze(x, axis=[1, 2]) - out = F.conv2d(x, self.weight, stride=(1, self.hop_length)) - real, imag = paddle.chunk(out, 2, axis=1) # BC1T + # to BCT, C=1 + out = F.conv1d(x, self.weight, stride=self.hop_length) + real, imag = paddle.chunk(out, 2, axis=1) # BCT return real, imag def power(self, x): @@ -166,7 +196,7 @@ class STFT(nn.Layer): Returns ------------ - Tensor [shape=(B, C, 1, T)] + Tensor [shape=(B, C, T)] The power spectrum. """ real, imag = self(x) @@ -183,9 +213,22 @@ class STFT(nn.Layer): Returns ------------ - Tensor [shape=(B, C, 1, T)] + Tensor [shape=(B, C, T)] The magnitude of the spectrum. """ power = self.power(x) magnitude = paddle.sqrt(power) return magnitude + + +class MelScale(nn.Layer): + def __init__(self, sr, n_fft, n_mels, fmin, fmax): + super().__init__() + mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax) + print(mel_basis.shape) + self.weight = paddle.to_tensor(mel_basis) + + def forward(self, spec): + # (n_mels, n_freq) * (batch_size, n_freq, n_frames) + mel = paddle.matmul(self.weight, spec) + return mel \ No newline at end of file