change wavenet to use on-the-fly prepeocessing
This commit is contained in:
parent
e06c6cdfe1
commit
3741cc49ca
|
@ -92,8 +92,8 @@ class LJSpeechCollector(object):
|
|||
text_lens = np.array(sorted(text_lens, reverse=True), dtype=np.int64)
|
||||
|
||||
# Pad sequence with largest len of the batch
|
||||
texts = batch_text_id(texts, pad_id=self.padding_idx)
|
||||
mels = np.transpose(
|
||||
texts, _ = batch_text_id(texts, pad_id=self.padding_idx)
|
||||
mels, _ = np.transpose(
|
||||
batch_spec(
|
||||
mels, pad_value=self.padding_value), axes=(0, 2, 1))
|
||||
|
||||
|
|
|
@ -44,9 +44,9 @@ def collate_aishell3_examples(examples):
|
|||
spec_lengths = np.array([item.shape[1] for item in mel], dtype=np.int64)
|
||||
T_dec = np.max(spec_lengths)
|
||||
stop_tokens = (np.arange(T_dec) >= np.expand_dims(spec_lengths, -1)).astype(np.float32)
|
||||
phones = batch_text_id(phones)
|
||||
tones = batch_text_id(tones)
|
||||
mel = batch_spec(mel)
|
||||
phones, _ = batch_text_id(phones)
|
||||
tones, _ = batch_text_id(tones)
|
||||
mel, _ = batch_spec(mel)
|
||||
mel = np.transpose(mel, (0, 2, 1))
|
||||
embed = np.stack(embed)
|
||||
# 7 fields
|
||||
|
|
|
@ -40,9 +40,9 @@ def collate_baker_examples(examples):
|
|||
spec_lengths = np.array([item.shape[1] for item in mel], dtype=np.int64)
|
||||
T_dec = np.max(spec_lengths)
|
||||
stop_tokens = (np.arange(T_dec) >= np.expand_dims(spec_lengths, -1)).astype(np.float32)
|
||||
phones = batch_text_id(phones)
|
||||
tones = batch_text_id(tones)
|
||||
mel = batch_spec(mel)
|
||||
phone, _ = batch_text_id(phones)
|
||||
tones, _ = batch_text_id(tones)
|
||||
mel, _ = batch_spec(mel)
|
||||
mel = np.transpose(mel, (0, 2, 1))
|
||||
|
||||
return phones, tones, mel, text_lengths, spec_lengths, stop_tokens
|
||||
|
|
|
@ -53,6 +53,6 @@ def collate_vctk_examples(examples):
|
|||
slens = np.array([item.shape[1] for item in mels], dtype=np.int64)
|
||||
speaker_ids = np.array(speaker_ids, dtype=np.int64)
|
||||
|
||||
phonemes = batch_text_id(phonemes, pad_id=0)
|
||||
mels = np.transpose(batch_spec(mels, pad_value=0.), [0, 2, 1])
|
||||
phonemes, _ = batch_text_id(phonemes, pad_id=0)
|
||||
mels, _ = np.transpose(batch_spec(mels, pad_value=0.), [0, 2, 1])
|
||||
return phonemes, plens, mels, slens, speaker_ids
|
||||
|
|
|
@ -76,9 +76,9 @@ class LJSpeechCollector(object):
|
|||
mels = [example[1] for example in examples]
|
||||
stop_probs = [example[2] for example in examples]
|
||||
|
||||
ids = batch_text_id(ids, pad_id=self.padding_idx)
|
||||
mels = batch_spec(mels, pad_value=self.padding_value)
|
||||
stop_probs = batch_text_id(stop_probs, pad_id=self.padding_idx)
|
||||
ids, _ = batch_text_id(ids, pad_id=self.padding_idx)
|
||||
mels, _ = batch_spec(mels, pad_value=self.padding_value)
|
||||
stop_probs, _ = batch_text_id(stop_probs, pad_id=self.padding_idx)
|
||||
return ids, np.transpose(mels, [0, 2, 1]), stop_probs
|
||||
|
||||
|
||||
|
|
|
@ -61,8 +61,8 @@ class LJSpeechCollector(object):
|
|||
def __call__(self, examples):
|
||||
mels = [example[0] for example in examples]
|
||||
wavs = [example[1] for example in examples]
|
||||
mels = batch_spec(mels, pad_value=self.padding_value)
|
||||
wavs = batch_wav(wavs, pad_value=self.padding_value)
|
||||
mels, _ = batch_spec(mels, pad_value=self.padding_value)
|
||||
wavs, _ = batch_wav(wavs, pad_value=self.padding_value)
|
||||
return mels, wavs
|
||||
|
||||
|
||||
|
|
|
@ -20,10 +20,12 @@ _C.data = CN(
|
|||
batch_size=8, # batch size
|
||||
valid_size=16, # the first N examples are reserved for validation
|
||||
sample_rate=22050, # Hz, sample rate
|
||||
n_fft=2048, # fft frame size
|
||||
n_fft=1024, # fft frame size
|
||||
win_length=1024, # window size
|
||||
hop_length=256, # hop size between ajacent frame
|
||||
# f_max=8000, # Hz, max frequency when converting to mel
|
||||
top_db=60, # db, used to trim silence
|
||||
fmin = 0, # Hz, max frequency when converting to mel
|
||||
fmax=8000, # Hz, max frequency when converting to mel
|
||||
n_mels=80, # mel bands
|
||||
train_clip_seconds=0.5, # audio clip length(in seconds)
|
||||
))
|
||||
|
|
|
@ -16,136 +16,43 @@ import os
|
|||
from pathlib import Path
|
||||
import pickle
|
||||
import numpy as np
|
||||
import librosa
|
||||
import pandas
|
||||
from paddle.io import Dataset, DataLoader
|
||||
|
||||
from parakeet.data.batch import batch_spec, batch_wav
|
||||
from parakeet.data import dataset
|
||||
from parakeet.audio import AudioProcessor
|
||||
|
||||
|
||||
class LJSpeech(Dataset):
|
||||
"""A simple dataset adaptor for the processed ljspeech dataset."""
|
||||
|
||||
def __init__(self, root):
|
||||
def __init__(self, root, sample_rate, length, top_db):
|
||||
self.root = Path(root).expanduser()
|
||||
meta_data = pandas.read_csv(
|
||||
self.metadata = pandas.read_csv(
|
||||
str(self.root / "metadata.csv"),
|
||||
sep="\t",
|
||||
sep="|",
|
||||
header=None,
|
||||
names=["fname", "frames", "samples"])
|
||||
|
||||
records = []
|
||||
for row in meta_data.itertuples():
|
||||
mel_path = str(self.root / "mel" / (row.fname + ".npy"))
|
||||
wav_path = str(self.root / "wav" / (row.fname + ".npy"))
|
||||
records.append((mel_path, wav_path))
|
||||
self.records = records
|
||||
names=["fname", "text", "normalized_text"])
|
||||
self.wav_dir = self.root / "wavs"
|
||||
self.sr = sample_rate
|
||||
self.top_db = top_db
|
||||
self.length = length # samples in the clip
|
||||
|
||||
def __getitem__(self, i):
|
||||
mel_name, wav_name = self.records[i]
|
||||
mel = np.load(mel_name)
|
||||
wav = np.load(wav_name)
|
||||
return mel, wav
|
||||
fname = self.metadata.iloc[0].fname
|
||||
fpath = (self.wav_dir / fname).with_suffix(".wav")
|
||||
y, sr = librosa.load(fpath, self.sr)
|
||||
y, _ = librosa.effects.trim(y, top_db=self.top_db)
|
||||
y = librosa.util.normalize(y)
|
||||
y = y.astype(np.float32)
|
||||
|
||||
# pad or trim
|
||||
if y.size <= self.length:
|
||||
y = np.pad(y, [0, self.length - len(y)], mode='constant')
|
||||
else:
|
||||
start = np.random.randint(0, 1 + len(y) - self.length)
|
||||
y = y[start: start + self.length]
|
||||
return y
|
||||
|
||||
def __len__(self):
|
||||
return len(self.records)
|
||||
return len(self.metadata)
|
||||
|
||||
|
||||
class LJSpeechCollector(object):
|
||||
"""A simple callable to batch LJSpeech examples."""
|
||||
|
||||
def __init__(self, padding_value=0.):
|
||||
self.padding_value = padding_value
|
||||
|
||||
def __call__(self, examples):
|
||||
batch_size = len(examples)
|
||||
mels = [example[0] for example in examples]
|
||||
wavs = [example[1] for example in examples]
|
||||
mels = batch_spec(mels, pad_value=self.padding_value)
|
||||
wavs = batch_wav(wavs, pad_value=self.padding_value)
|
||||
audio_starts = np.zeros((batch_size, ), dtype=np.int64)
|
||||
return mels, wavs, audio_starts
|
||||
|
||||
|
||||
class LJSpeechClipCollector(object):
|
||||
def __init__(self, clip_frames=65, hop_length=256):
|
||||
self.clip_frames = clip_frames
|
||||
self.hop_length = hop_length
|
||||
|
||||
def __call__(self, examples):
|
||||
mels = []
|
||||
wavs = []
|
||||
starts = []
|
||||
for example in examples:
|
||||
mel, wav_clip, start = self.clip(example)
|
||||
mels.append(mel)
|
||||
wavs.append(wav_clip)
|
||||
starts.append(start)
|
||||
mels = batch_spec(mels)
|
||||
wavs = np.stack(wavs)
|
||||
starts = np.array(starts, dtype=np.int64)
|
||||
return mels, wavs, starts
|
||||
|
||||
def clip(self, example):
|
||||
mel, wav = example
|
||||
frames = mel.shape[-1]
|
||||
start = np.random.randint(0, frames - self.clip_frames)
|
||||
wav_clip = wav[start * self.hop_length:(start + self.clip_frames) *
|
||||
self.hop_length]
|
||||
return mel, wav_clip, start
|
||||
|
||||
|
||||
class DataCollector(object):
|
||||
def __init__(self,
|
||||
context_size,
|
||||
sample_rate,
|
||||
hop_length,
|
||||
train_clip_seconds,
|
||||
valid=False):
|
||||
frames_per_second = sample_rate // hop_length
|
||||
train_clip_frames = int(
|
||||
np.ceil(train_clip_seconds * frames_per_second))
|
||||
context_frames = context_size // hop_length
|
||||
self.num_frames = train_clip_frames + context_frames
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.hop_length = hop_length
|
||||
self.valid = valid
|
||||
|
||||
def random_crop(self, sample):
|
||||
audio, mel_spectrogram = sample
|
||||
audio_frames = int(audio.size) // self.hop_length
|
||||
max_start_frame = audio_frames - self.num_frames
|
||||
assert max_start_frame >= 0, "audio is too short to be cropped"
|
||||
|
||||
frame_start = np.random.randint(0, max_start_frame)
|
||||
# frame_start = 0 # norandom
|
||||
frame_end = frame_start + self.num_frames
|
||||
|
||||
audio_start = frame_start * self.hop_length
|
||||
audio_end = frame_end * self.hop_length
|
||||
|
||||
audio = audio[audio_start:audio_end]
|
||||
return audio, mel_spectrogram, audio_start
|
||||
|
||||
def __call__(self, samples):
|
||||
# transform them first
|
||||
if self.valid:
|
||||
samples = [(audio, mel_spectrogram, 0)
|
||||
for audio, mel_spectrogram in samples]
|
||||
else:
|
||||
samples = [self.random_crop(sample) for sample in samples]
|
||||
# batch them
|
||||
audios = [sample[0] for sample in samples]
|
||||
audio_starts = [sample[2] for sample in samples]
|
||||
mels = [sample[1] for sample in samples]
|
||||
|
||||
mels = batch_spec(mels)
|
||||
|
||||
if self.valid:
|
||||
audios = batch_wav(audios, dtype=np.float32)
|
||||
else:
|
||||
audios = np.array(audios, dtype=np.float32)
|
||||
audio_starts = np.array(audio_starts, dtype=np.int64)
|
||||
return audios, mels, audio_starts
|
||||
|
|
|
@ -30,9 +30,13 @@ from parakeet.utils import scheduler, mp_tools
|
|||
from parakeet.training.cli import default_argument_parser
|
||||
from parakeet.training.experiment import ExperimentBase
|
||||
from parakeet.utils.mp_tools import rank_zero_only
|
||||
from parakeet.datasets import AudioDataset, AudioSegmentDataset
|
||||
from parakeet.data import batch_wav
|
||||
|
||||
from parakeet.modules.audio import STFT, MelScale
|
||||
|
||||
from config import get_cfg_defaults
|
||||
from ljspeech import LJSpeech, LJSpeechClipCollector, LJSpeechCollector
|
||||
from ljspeech import LJSpeech
|
||||
|
||||
|
||||
class Experiment(ExperimentBase):
|
||||
|
@ -61,38 +65,47 @@ class Experiment(ExperimentBase):
|
|||
grad_clip=paddle.nn.ClipGradByGlobalNorm(
|
||||
config.training.gradient_max_norm))
|
||||
|
||||
self.stft = STFT(config.data.n_fft, config.data.hop_length, config.data.win_length)
|
||||
self.mel_scale = MelScale(config.data.sample_rate, config.data.n_fft, config.data.n_mels, config.data.fmin, config.data.fmax)
|
||||
|
||||
self.model = model
|
||||
self.model_core = model._layers if self.parallel else model
|
||||
self.optimizer = optimizer
|
||||
|
||||
|
||||
def setup_dataloader(self):
|
||||
config = self.config
|
||||
args = self.args
|
||||
|
||||
ljspeech_dataset = LJSpeech(args.data)
|
||||
valid_set, train_set = dataset.split(ljspeech_dataset,
|
||||
config.data.valid_size)
|
||||
|
||||
# convolutional net's causal padding size
|
||||
context_size = config.model.n_stack \
|
||||
* sum([(config.model.filter_size - 1) * 2**i for i in range(config.model.n_loop)]) \
|
||||
+ 1
|
||||
context_frames = context_size // config.data.hop_length
|
||||
|
||||
# frames used to compute loss
|
||||
frames_per_second = config.data.sample_rate // config.data.hop_length
|
||||
train_clip_frames = math.ceil(config.data.train_clip_seconds *
|
||||
frames_per_second)
|
||||
train_clip_size = int(config.data.train_clip_seconds * config.data.sample_rate)
|
||||
length = context_size + train_clip_size
|
||||
|
||||
root = Path(args.data).expanduser()
|
||||
file_paths = sorted(list((root / "wavs").rglob("*.wav")))
|
||||
train_set = AudioSegmentDataset(
|
||||
file_paths[config.data.valid_size:],
|
||||
config.data.sample_rate,
|
||||
length,
|
||||
top_db=config.data.top_db)
|
||||
valid_set = AudioDataset(
|
||||
file_paths[:config.data.valid_size],
|
||||
config.data.sample_rate,
|
||||
top_db=config.data.top_db)
|
||||
|
||||
num_frames = train_clip_frames + context_frames
|
||||
batch_fn = LJSpeechClipCollector(num_frames, config.data.hop_length)
|
||||
if not self.parallel:
|
||||
train_loader = DataLoader(
|
||||
train_set,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=batch_fn)
|
||||
num_workers=1,
|
||||
)
|
||||
else:
|
||||
sampler = DistributedBatchSampler(
|
||||
train_set,
|
||||
|
@ -100,25 +113,36 @@ class Experiment(ExperimentBase):
|
|||
shuffle=True,
|
||||
drop_last=True)
|
||||
train_loader = DataLoader(
|
||||
train_set, batch_sampler=sampler, collate_fn=batch_fn)
|
||||
train_set, batch_sampler=sampler, num_workers=1)
|
||||
|
||||
valid_batch_fn = LJSpeechCollector()
|
||||
valid_loader = DataLoader(
|
||||
valid_set, batch_size=1, collate_fn=valid_batch_fn)
|
||||
valid_set,
|
||||
batch_size=config.data.batch_size,
|
||||
num_workers=1,
|
||||
collate_fn=batch_wav)
|
||||
|
||||
self.train_loader = train_loader
|
||||
self.valid_loader = valid_loader
|
||||
|
||||
def train_batch(self):
|
||||
# load data
|
||||
start = time.time()
|
||||
batch = self.read_batch()
|
||||
data_loader_time = time.time() - start
|
||||
|
||||
self.model.train()
|
||||
self.optimizer.clear_grad()
|
||||
mel, wav, audio_starts = batch
|
||||
wav = batch
|
||||
|
||||
y = self.model(wav, mel, audio_starts)
|
||||
# data preprocessing
|
||||
S = self.stft.magnitude(wav)
|
||||
mel = self.mel_scale(S)
|
||||
logmel = 20 * paddle.log10(mel, paddle.clip(mel, min=1e-5))
|
||||
logmel = paddle.clip((logmel + 80) / 100, min=0.0, max=1.0)
|
||||
|
||||
# forward & backward
|
||||
|
||||
y = self.model(wav, logmel)
|
||||
loss = self.model_core.loss(y, wav)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
@ -129,24 +153,43 @@ class Experiment(ExperimentBase):
|
|||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time,
|
||||
iteration_time)
|
||||
msg += "loss: {:>.6f}".format(loss_value)
|
||||
msg += "train/loss: {:>.6f}, ".format(loss_value)
|
||||
msg += "lr: {:>.6f}".format(self.optimizer.get_lr())
|
||||
self.logger.info(msg)
|
||||
if dist.get_rank() == 0:
|
||||
self.visualizer.add_scalar(
|
||||
"train/loss", loss_value, global_step=self.iteration)
|
||||
"train/loss", loss_value, self.iteration)
|
||||
self.visualizer.add_scalar(
|
||||
"train/lr", self.optimizer.get_lr(), self.iteration)
|
||||
|
||||
# now we have to call learning rate scheduler.step() mannually
|
||||
self.optimizer._learning_rate.step()
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
@paddle.no_grad()
|
||||
def valid(self):
|
||||
valid_iterator = iter(self.valid_loader)
|
||||
valid_losses = []
|
||||
mel, wav, audio_starts = next(valid_iterator)
|
||||
y = self.model(wav, mel, audio_starts)
|
||||
|
||||
for batch in self.valid_loader:
|
||||
wav, length = batch
|
||||
# data preprocessing
|
||||
S = self.stft.magnitude(wav)
|
||||
mel = self.mel_scale(S)
|
||||
logmel = 20 * paddle.log10(mel, paddle.clip(mel, min=1e-5))
|
||||
logmel = paddle.clip((logmel + 80) / 100, min=0.0, max=1.0)
|
||||
|
||||
y = self.model(wav, logmel)
|
||||
loss = self.model_core.loss(y, wav)
|
||||
valid_losses.append(float(loss))
|
||||
valid_loss = np.mean(valid_losses)
|
||||
|
||||
msg = "Rank: {}, ".format(dist.get_rank())
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += "valid/loss: {:>.6f}".format(valid_loss)
|
||||
self.logger.info(msg)
|
||||
|
||||
self.visualizer.add_scalar(
|
||||
"valid/loss", valid_loss, global_step=self.iteration)
|
||||
"valid/loss", valid_loss, self.iteration)
|
||||
|
||||
|
||||
def main_sp(config, args):
|
||||
|
|
|
@ -65,7 +65,7 @@ def batch_text_id(minibatch, pad_id=0, dtype=np.int64):
|
|||
mode='constant',
|
||||
constant_values=pad_id))
|
||||
|
||||
return np.array(batch, dtype=dtype)
|
||||
return np.array(batch, dtype=dtype), np.array(lengths, dtype=np.int64)
|
||||
|
||||
|
||||
class WavBatcher(object):
|
||||
|
@ -106,7 +106,7 @@ def batch_wav(minibatch, pad_value=0., dtype=np.float32):
|
|||
np.pad(example, [(0, pad_len)],
|
||||
mode='constant',
|
||||
constant_values=pad_value))
|
||||
return np.array(batch, dtype=dtype)
|
||||
return np.array(batch, dtype=dtype), np.array(lengths, dtype=np.int64)
|
||||
|
||||
|
||||
class SpecBatcher(object):
|
||||
|
@ -160,4 +160,4 @@ def batch_spec(minibatch, pad_value=0., time_major=False, dtype=np.float32):
|
|||
np.pad(example, [(0, 0), (0, pad_len)],
|
||||
mode='constant',
|
||||
constant_values=pad_value))
|
||||
return np.array(batch, dtype=dtype)
|
||||
return np.array(batch, dtype=dtype), np.array(lengths, dtype=np.int64)
|
||||
|
|
|
@ -15,24 +15,75 @@
|
|||
from paddle.io import Dataset
|
||||
import os
|
||||
import librosa
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
__all__ = ["AudioFolderDataset"]
|
||||
__all__ = ["AudioSegmentDataset", "AudioDataset", "AudioFolderDataset"]
|
||||
|
||||
|
||||
class AudioFolderDataset(Dataset):
|
||||
def __init__(self, path, sample_rate, extension="wav"):
|
||||
self.root = os.path.expanduser(path)
|
||||
self.sample_rate = sample_rate
|
||||
self.extension = extension
|
||||
self.file_names = [
|
||||
os.path.join(self.root, x) for x in os.listdir(self.root) \
|
||||
if os.path.splitext(x)[-1] == self.extension]
|
||||
self.length = len(self.file_names)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
class AudioSegmentDataset(Dataset):
|
||||
"""A simple dataset adaptor for audio files to train vocoders.
|
||||
Read -> trim silence -> normalize -> extract a segment
|
||||
"""
|
||||
def __init__(self, file_paths: List[Path], sample_rate: int, length: int,
|
||||
top_db: float):
|
||||
self.file_paths = file_paths
|
||||
self.sr = sample_rate
|
||||
self.top_db = top_db
|
||||
self.length = length # samples in the clip
|
||||
|
||||
def __getitem__(self, i):
|
||||
file_name = self.file_names[i]
|
||||
y, _ = librosa.load(file_name, sr=self.sample_rate) # pylint: disable=unused-variable
|
||||
fpath = self.file_paths[i]
|
||||
y, sr = librosa.load(fpath, self.sr)
|
||||
y, _ = librosa.effects.trim(y, top_db=self.top_db)
|
||||
y = librosa.util.normalize(y)
|
||||
y = y.astype(np.float32)
|
||||
|
||||
# pad or trim
|
||||
if y.size <= self.length:
|
||||
y = np.pad(y, [0, self.length - len(y)], mode='constant')
|
||||
else:
|
||||
start = np.random.randint(0, 1 + len(y) - self.length)
|
||||
y = y[start:start + self.length]
|
||||
return y
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_paths)
|
||||
|
||||
|
||||
class AudioDataset(Dataset):
|
||||
"""A simple dataset adaptor for the audio files.
|
||||
Read -> trim silence -> normalize
|
||||
"""
|
||||
def __init__(self,
|
||||
file_paths: List[Path],
|
||||
sample_rate: int,
|
||||
top_db: float = 60):
|
||||
self.file_paths = file_paths
|
||||
self.sr = sample_rate
|
||||
self.top_db = top_db
|
||||
|
||||
def __getitem__(self, i):
|
||||
fpath = self.file_paths[i]
|
||||
y, sr = librosa.load(fpath, self.sr)
|
||||
y, _ = librosa.effects.trim(y, top_db=self.top_db)
|
||||
y = librosa.util.normalize(y)
|
||||
y = y.astype(np.float32)
|
||||
return y
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_paths)
|
||||
|
||||
|
||||
class AudioFolderDataset(AudioDataset):
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
sample_rate,
|
||||
top_db=60,
|
||||
extension=".wav",
|
||||
):
|
||||
root = Path(root).expanduser()
|
||||
file_paths = sorted(list(root.rglob("*{}".format(extension))))
|
||||
super().__init__(file_paths, sample_rate, top_db)
|
||||
|
|
|
@ -101,9 +101,7 @@ class UpsampleNet(nn.LayerList):
|
|||
def __init__(self, upscale_factors=[16, 16]):
|
||||
super(UpsampleNet, self).__init__()
|
||||
self.upscale_factors = list(upscale_factors)
|
||||
self.upscale_factor = 1
|
||||
for item in upscale_factors:
|
||||
self.upscale_factor *= item
|
||||
self.upscale_factor = np.prod(upscale_factors)
|
||||
|
||||
for factor in self.upscale_factors:
|
||||
self.append(
|
||||
|
@ -224,13 +222,15 @@ class ResidualBlock(nn.Layer):
|
|||
other ResidualBlocks.
|
||||
"""
|
||||
h = x
|
||||
length = x.shape[-1]
|
||||
|
||||
# dilated conv
|
||||
h = self.conv(h)
|
||||
|
||||
# condition
|
||||
# NOTE: expanded condition may have a larger timesteps than x
|
||||
if condition is not None:
|
||||
h += self.condition_proj(condition)
|
||||
h += self.condition_proj(condition)[:, :, :length]
|
||||
|
||||
# gated tanh
|
||||
content, gate = paddle.split(h, 2, axis=1)
|
||||
|
@ -822,7 +822,7 @@ class ConditionalWaveNet(nn.Layer):
|
|||
loss_type=loss_type,
|
||||
log_scale_min=log_scale_min)
|
||||
|
||||
def forward(self, audio, mel, audio_start):
|
||||
def forward(self, audio, mel):
|
||||
"""Compute the output distribution given the mel spectrogram and the input(for teacher force training).
|
||||
|
||||
Parameters
|
||||
|
@ -845,13 +845,13 @@ class ConditionalWaveNet(nn.Layer):
|
|||
"""
|
||||
audio_length = audio.shape[1] # audio clip's length
|
||||
condition = self.encoder(mel)
|
||||
condition_slice = crop(condition, audio_start, audio_length)
|
||||
|
||||
|
||||
# shifting 1 step
|
||||
audio = audio[:, :-1]
|
||||
condition_slice = condition_slice[:, :, 1:]
|
||||
condition = condition[:, :, 1:]
|
||||
|
||||
y = self.decoder(audio, condition_slice)
|
||||
y = self.decoder(audio, condition)
|
||||
return y
|
||||
|
||||
def loss(self, y, t):
|
||||
|
|
|
@ -16,6 +16,8 @@ import paddle
|
|||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from scipy import signal
|
||||
import librosa
|
||||
from librosa.util import pad_center
|
||||
import numpy as np
|
||||
|
||||
__all__ = ["quantize", "dequantize", "STFT"]
|
||||
|
@ -88,6 +90,19 @@ class STFT(nn.Layer):
|
|||
Name of window function, see `scipy.signal.get_window` for more
|
||||
details. Defaults to "hanning".
|
||||
|
||||
center : bool
|
||||
If True, the signal y is padded so that frame D[:, t] is centered
|
||||
at y[t * hop_length]. If False, then D[:, t] begins at y[t * hop_length].
|
||||
Defaults to True.
|
||||
|
||||
pad_mode : string or function
|
||||
If center=True, this argument is passed to np.pad for padding the edges
|
||||
of the signal y. By default (pad_mode="reflect"), y is padded on both
|
||||
sides with its own reflection, mirrored around its first and last
|
||||
sample respectively. If center=False, this argument is ignored.
|
||||
|
||||
|
||||
|
||||
Notes
|
||||
-----------
|
||||
It behaves like ``librosa.core.stft``. See ``librosa.core.stft`` for more
|
||||
|
@ -101,29 +116,47 @@ class STFT(nn.Layer):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, n_fft, hop_length, win_length, window="hanning"):
|
||||
super(STFT, self).__init__()
|
||||
def __init__(self, n_fft, hop_length=None, win_length=None, window="hanning", center=True, pad_mode="reflect"):
|
||||
super().__init__()
|
||||
# By default, use the entire frame
|
||||
if win_length is None:
|
||||
win_length = n_fft
|
||||
|
||||
# Set the default hop, if it's not already specified
|
||||
if hop_length is None:
|
||||
hop_length = int(win_length // 4)
|
||||
|
||||
self.hop_length = hop_length
|
||||
self.n_bin = 1 + n_fft // 2
|
||||
self.n_fft = n_fft
|
||||
self.center = center
|
||||
self.pad_mode = pad_mode
|
||||
|
||||
# calculate window
|
||||
window = signal.get_window(window, win_length)
|
||||
window = signal.get_window(window, win_length, fftbins=True)
|
||||
|
||||
# pad window to n_fft size
|
||||
if n_fft != win_length:
|
||||
pad = (n_fft - win_length) // 2
|
||||
window = np.pad(window, ((pad, pad), ), 'constant')
|
||||
window = pad_center(window, n_fft, mode="constant")
|
||||
#lpad = (n_fft - win_length) // 2
|
||||
#rpad = n_fft - win_length - lpad
|
||||
#window = np.pad(window, ((lpad, pad), ), 'constant')
|
||||
|
||||
# calculate weights
|
||||
r = np.arange(0, n_fft)
|
||||
M = np.expand_dims(r, -1) * np.expand_dims(r, 0)
|
||||
w_real = np.reshape(window *
|
||||
np.cos(2 * np.pi * M / n_fft)[:self.n_bin],
|
||||
(self.n_bin, 1, 1, self.n_fft))
|
||||
w_imag = np.reshape(window *
|
||||
np.sin(-2 * np.pi * M / n_fft)[:self.n_bin],
|
||||
(self.n_bin, 1, 1, self.n_fft))
|
||||
|
||||
#r = np.arange(0, n_fft)
|
||||
#M = np.expand_dims(r, -1) * np.expand_dims(r, 0)
|
||||
#w_real = np.reshape(window *
|
||||
#np.cos(2 * np.pi * M / n_fft)[:self.n_bin],
|
||||
#(self.n_bin, 1, self.n_fft))
|
||||
#w_imag = np.reshape(window *
|
||||
#np.sin(-2 * np.pi * M / n_fft)[:self.n_bin],
|
||||
#(self.n_bin, 1, self.n_fft))
|
||||
weight = np.fft.fft(np.eye(n_fft))[:self.n_bin]
|
||||
w_real = weight.real
|
||||
w_imag = weight.imag
|
||||
w = np.concatenate([w_real, w_imag], axis=0)
|
||||
w = w * window
|
||||
w = np.expand_dims(w, 1)
|
||||
self.weight = paddle.cast(
|
||||
paddle.to_tensor(w), paddle.get_default_dtype())
|
||||
|
||||
|
@ -137,23 +170,20 @@ class STFT(nn.Layer):
|
|||
|
||||
Returns
|
||||
------------
|
||||
real : Tensor [shape=(B, C, 1, frames)]
|
||||
real : Tensor [shape=(B, C, frames)]
|
||||
The real part of the spectrogram.
|
||||
|
||||
imag : Tensor [shape=(B, C, 1, frames)]
|
||||
imag : Tensor [shape=(B, C, frames)]
|
||||
The image part of the spectrogram.
|
||||
"""
|
||||
# x(batch_size, time_steps)
|
||||
# pad it first with reflect mode
|
||||
# TODO(chenfeiyu): report an issue on paddle.flip
|
||||
pad_start = paddle.reverse(x[:, 1:1 + self.n_fft // 2], axis=[1])
|
||||
pad_stop = paddle.reverse(x[:, -(1 + self.n_fft // 2):-1], axis=[1])
|
||||
x = paddle.concat([pad_start, x, pad_stop], axis=-1)
|
||||
x = paddle.unsqueeze(x, axis=1)
|
||||
if self.center:
|
||||
x = F.pad(x, [self.n_fft // 2, self.n_fft // 2],
|
||||
data_format='NCL', mode=self.pad_mode)
|
||||
|
||||
# to BC1T, C=1
|
||||
x = paddle.unsqueeze(x, axis=[1, 2])
|
||||
out = F.conv2d(x, self.weight, stride=(1, self.hop_length))
|
||||
real, imag = paddle.chunk(out, 2, axis=1) # BC1T
|
||||
# to BCT, C=1
|
||||
out = F.conv1d(x, self.weight, stride=self.hop_length)
|
||||
real, imag = paddle.chunk(out, 2, axis=1) # BCT
|
||||
return real, imag
|
||||
|
||||
def power(self, x):
|
||||
|
@ -166,7 +196,7 @@ class STFT(nn.Layer):
|
|||
|
||||
Returns
|
||||
------------
|
||||
Tensor [shape=(B, C, 1, T)]
|
||||
Tensor [shape=(B, C, T)]
|
||||
The power spectrum.
|
||||
"""
|
||||
real, imag = self(x)
|
||||
|
@ -183,9 +213,22 @@ class STFT(nn.Layer):
|
|||
|
||||
Returns
|
||||
------------
|
||||
Tensor [shape=(B, C, 1, T)]
|
||||
Tensor [shape=(B, C, T)]
|
||||
The magnitude of the spectrum.
|
||||
"""
|
||||
power = self.power(x)
|
||||
magnitude = paddle.sqrt(power)
|
||||
return magnitude
|
||||
|
||||
|
||||
class MelScale(nn.Layer):
|
||||
def __init__(self, sr, n_fft, n_mels, fmin, fmax):
|
||||
super().__init__()
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax)
|
||||
print(mel_basis.shape)
|
||||
self.weight = paddle.to_tensor(mel_basis)
|
||||
|
||||
def forward(self, spec):
|
||||
# (n_mels, n_freq) * (batch_size, n_freq, n_frames)
|
||||
mel = paddle.matmul(self.weight, spec)
|
||||
return mel
|
Loading…
Reference in New Issue