This commit is contained in:
iclementine 2021-03-27 17:39:37 +08:00
parent 0aa7088d36
commit 2475da3322
11 changed files with 1067 additions and 1 deletions

View File

@ -0,0 +1,226 @@
from scipy.ndimage.morphology import binary_dilation
from config import get_cfg_defaults
from pathlib import Path
from typing import Optional, Union
from warnings import warn
import numpy as np
import librosa
import struct
try:
import webrtcvad
except:
warn("Unable to import 'webrtcvad'."
"This package enables noise removal and is recommended.")
webrtcvad = None
int16_max = (2**15) - 1
def normalize_volume(wav,
target_dBFS,
increase_only=False,
decrease_only=False):
# this function implements Loudness normalization, instead of peak
# normalization, See https://en.wikipedia.org/wiki/Audio_normalization
# dBFS: Decibels relative to full scale
# See https://en.wikipedia.org/wiki/DBFS for more details
# for 16Bit PCM audio, minimal level is -96dB
# compute the mean dBFS and adjust to target dBFS, with by increasing
# or decreasing
if increase_only and decrease_only:
raise ValueError("Both increase only and decrease only are set")
dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav**2))
if ((dBFS_change < 0 and increase_only)
or (dBFS_change > 0 and decrease_only)):
return wav
gain = 10**(dBFS_change / 20)
return wav * gain
def trim_long_silences(wav, vad_window_length: int,
vad_moving_average_width: int,
vad_max_silence_length: int, sampling_rate: int):
"""
Ensures that segments without voice in the waveform remain no longer than a
threshold determined by the VAD parameters in params.py.
:param wav: the raw waveform as a numpy array of floats
:return: the same waveform with silences trimmed away (length <= original wav length)
"""
# Compute the voice detection window size
samples_per_window = (vad_window_length * sampling_rate) // 1000
# Trim the end of the audio to have a multiple of the window size
wav = wav[:len(wav) - (len(wav) % samples_per_window)]
# Convert the float waveform to 16-bit mono PCM
pcm_wave = struct.pack("%dh" % len(wav),
*(np.round(wav * int16_max)).astype(np.int16))
# Perform voice activation detection
voice_flags = []
vad = webrtcvad.Vad(mode=3)
for window_start in range(0, len(wav), samples_per_window):
window_end = window_start + samples_per_window
voice_flags.append(
vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
sample_rate=sampling_rate))
voice_flags = np.array(voice_flags)
# Smooth the voice detection with a moving average
def moving_average(array, width):
array_padded = np.concatenate((np.zeros(
(width - 1) // 2), array, np.zeros(width // 2)))
ret = np.cumsum(array_padded, dtype=float)
ret[width:] = ret[width:] - ret[:-width]
return ret[width - 1:] / width
audio_mask = moving_average(voice_flags, vad_moving_average_width)
audio_mask = np.round(audio_mask).astype(np.bool)
# Dilate the voiced regions
audio_mask = binary_dilation(audio_mask,
np.ones(vad_max_silence_length + 1))
audio_mask = np.repeat(audio_mask, samples_per_window)
return wav[audio_mask == True]
def compute_partial_slices(n_samples: int,
partial_utterance_n_frames: int,
hop_length: int,
min_pad_coverage: float = 0.75,
overlap: float = 0.5):
"""
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
spectrogram slices are returned, so as to make each partial utterance waveform correspond to
its spectrogram. This function assumes that the mel spectrogram parameters used are those
defined in params_data.py.
The returned ranges may be indexing further than the length of the waveform. It is
recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
:param n_samples: the number of samples in the waveform
:param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
utterance
:param min_pad_coverage: when reaching the last partial utterance, it may or may not have
enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
then the last partial utterance will be considered, as if we padded the audio. Otherwise,
it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
utterance, this parameter is ignored so that the function always returns at least 1 slice.
:param overlap: by how much the partial utterance should overlap. If set to 0, the partial
utterances are entirely disjoint.
:return: the waveform slices and mel spectrogram slices as lists of array slices. Index
respectively the waveform and the mel spectrogram with these slices to obtain the partial
utterances.
"""
assert 0 <= overlap < 1
assert 0 < min_pad_coverage <= 1
# librosa's function to compute num_frames from num_samples
n_frames = int(np.ceil((n_samples + 1) / hop_length))
# frame shift between ajacent partials
frame_step = max(1,
int(np.round(partial_utterance_n_frames * (1 - overlap))))
# Compute the slices
wav_slices, mel_slices = [], []
steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
for i in range(0, steps, frame_step):
mel_range = np.array([i, i + partial_utterance_n_frames])
wav_range = mel_range * hop_length
mel_slices.append(slice(*mel_range))
wav_slices.append(slice(*wav_range))
# Evaluate whether extra padding is warranted or not
last_wav_range = wav_slices[-1]
coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop -
last_wav_range.start)
if coverage < min_pad_coverage and len(mel_slices) > 1:
mel_slices = mel_slices[:-1]
wav_slices = wav_slices[:-1]
return wav_slices, mel_slices
class SpeakerVerificationPreprocessor(object):
def __init__(self,
sampling_rate: int,
audio_norm_target_dBFS:float,
vad_window_length,
vad_moving_average_width,
vad_max_silence_length,
mel_window_length,
mel_window_step,
n_mels,
partial_n_frames: int,
min_pad_coverage: float = 0.75,
partial_overlap_ratio: float = 0.5):
self.sampling_rate = sampling_rate
self.audio_norm_target_dBFS = audio_norm_target_dBFS
self.vad_window_length = vad_window_length
self.vad_moving_average_width = vad_moving_average_width
self.vad_max_silence_length = vad_max_silence_length
self.n_fft = int(mel_window_length * sampling_rate / 1000)
self.hop_length = int(mel_window_step * sampling_rate / 1000)
self.n_mels = n_mels
self.partial_n_frames = partial_n_frames
self.min_pad_coverage = min_pad_coverage
self.partial_overlap_ratio = partial_overlap_ratio
def preprocess_wav(self, fpath_or_wav, source_sr=None):
# Load the wav from disk if needed
if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
else:
wav = fpath_or_wav
# Resample if numpy.array is passed and sr does not match
if source_sr is not None and source_sr != self.sampling_rate:
wav = librosa.resample(wav, source_sr, self.sampling_rate)
# loudness normalization
wav = normalize_volume(wav,
self.audio_norm_target_dBFS,
increase_only=True)
# trim long silence
if webrtcvad:
wav = trim_long_silences(wav, self.vad_window_length,
self.vad_moving_average_width,
self.vad_max_silence_length,
self.sampling_rate)
return wav
def melspectrogram(self, wav):
mel = librosa.feature.melspectrogram(wav,
sr=self.sampling_rate,
n_fft=self.n_fft,
hop_length=self.hop_length,
n_mels=self.n_mels)
mel = mel.astype(np.float32).T
return mel
def extract_mel_partials(self, wav):
wav_slices, mel_slices = compute_partial_slices(
len(wav),
self.partial_n_frames,
self.hop_length,
self.min_pad_coverage,
self.partial_overlap_ratio)
# pad audio if needed
max_wave_length = wav_slices[-1].stop
if max_wave_length >= len(wav):
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
# Split the utterance into partials
frames = self.melspectrogram(wav)
frames_batch = np.array([frames[s] for s in mel_slices])
return frames_batch # [B, T, C]

48
examples/ge2e/config.py Normal file
View File

@ -0,0 +1,48 @@
from yacs.config import CfgNode
_C = CfgNode()
data_config = _C.data = CfgNode()
## Audio volume normalization
data_config.audio_norm_target_dBFS = -30
## Audio sample rate
data_config.sampling_rate = 16000 # Hz
## Voice Activation Detection
# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
# This sets the granularity of the VAD. Should not need to be changed.
data_config.vad_window_length = 30 # In milliseconds
# Number of frames to average together when performing the moving average smoothing.
# The larger this value, the larger the VAD variations must be to not get smoothed out.
data_config.vad_moving_average_width = 8
# Maximum number of consecutive silent frames a segment can have.
data_config.vad_max_silence_length = 6
## Mel-filterbank
data_config.mel_window_length = 25 # In milliseconds
data_config.mel_window_step = 10 # In milliseconds
data_config.n_mels = 40 # mel bands
# Number of spectrogram frames in a partial utterance
data_config.partial_n_frames = 160 # 1600 ms
data_config.min_pad_coverage = 0.75 # at least 75% of the audio is valid in a partial
data_config.partial_overlap_ratio = 0.5 # overlap ratio between ajancent partials
model_config = _C.model = CfgNode()
model_config.num_layers = 3
model_config.hidden_size = 256
model_config.embedding_size = 256 # output size
training_config = _C.training = CfgNode()
training_config.learning_rate_init = 1e-4
training_config.speakers_per_batch = 64
training_config.utterances_per_speaker = 10
training_config.max_iteration = 1560000
training_config.save_interval = 10000
training_config.valid_interval = 10000
def get_cfg_defaults():
return _C.clone()

View File

@ -0,0 +1,139 @@
import numpy as np
from pathlib import Path
import multiprocessing as mp
from audio_processor import SpeakerVerificationPreprocessor
from tqdm import tqdm
from functools import partial
from typing import List
def _process_utterance(path_pair, processor: SpeakerVerificationPreprocessor):
# Load and preprocess the waveform
input_path, output_path = path_pair
wav = processor.preprocess_wav(input_path)
if len(wav) == 0:
return
# Create the mel spectrogram, discard those that are too short
frames = processor.melspectrogram(wav)
if len(frames) < processor.partial_n_frames:
return
np.save(output_path, frames)
def _process_speaker(speaker_dir: Path,
processor,
datasets_root: Path,
output_dir: Path,
pattern: str,
skip_existing:bool=False):
# datastes root: a reference path to compute speaker_name
# we prepand dataset name to speaker_id becase we are mixing serveal
# multispeaker datasets together
speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
speaker_output_dir = output_dir / speaker_name
speaker_output_dir.mkdir(parents=True, exist_ok=True)
# load exsiting file set
sources_fpath = speaker_output_dir / "_sources.txt"
if sources_fpath.exists():
try:
with sources_fpath.open("rt") as sources_file:
existing_names = {line.split(",")[0] for line in sources_file}
except:
existing_names = {}
else:
existing_names = {}
sources_file = sources_fpath.open("at" if skip_existing else "wt")
for in_fpath in speaker_dir.rglob(pattern):
out_name = "_".join(in_fpath.relative_to(speaker_dir).with_suffix(".npy").parts)
if skip_existing and out_name in existing_names:
continue
out_fpath = speaker_output_dir / out_name
_process_utterance((in_fpath, out_fpath), processor)
sources_file.write(f"{out_name},{in_fpath}\n")
sources_file.close()
def _process_dataset(processor,
datasets_root: Path,
speaker_dirs: List[Path],
dataset_name: str,
output_dir: Path,
pattern: str,
skip_existing: bool =False):
print(f"{dataset_name}: Preprocessing data for {len(speaker_dirs)} speakers.")
_func = partial(_process_speaker,
processor=processor,
datasets_root=datasets_root,
output_dir=output_dir,
pattern=pattern,
skip_existing=skip_existing)
with mp.Pool(16) as pool:
list(
tqdm(pool.imap(_func, speaker_dirs),
dataset_name,
len(speaker_dirs),
unit="speakers"))
print(f"Done preprocessing {dataset_name}.")
def process_librispeech(processor,
datasets_root,
output_dir,
skip_existing=False):
dataset_name = "LibriSpeech/train-other-500"
dataset_root = datasets_root / dataset_name
speaker_dirs = list(dataset_root.glob("*"))
_process_dataset(processor, datasets_root, speaker_dirs, dataset_name,
output_dir, "*.flac", skip_existing)
def process_voxceleb1(processor,
datasets_root,
output_dir,
skip_existing=False):
dataset_name = "VoxCeleb1"
dataset_root = datasets_root / dataset_name
anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
with dataset_root.joinpath("vox1_meta.csv").open("rt") as metafile:
metadata = [line.strip().split("\t") for line in metafile][1:]
# speaker id -> nationality
nationalities = {
line[0]: line[3]
for line in metadata if line[-1] == "dev"
}
keep_speaker_ids = [
speaker_id for speaker_id, nationality in nationalities.items()
if nationality.lower() in anglophone_nationalites
]
print(
"VoxCeleb1: using samples from {} (presumed anglophone) speakers out of {}.".format(
len(keep_speaker_ids), len(nationalities)))
speaker_dirs = list((dataset_root / "wav").glob("*"))
speaker_dirs = [
speaker_dir for speaker_dir in speaker_dirs
if speaker_dir.name in keep_speaker_ids
]
# TODO: filter ansa
_process_dataset(processor, datasets_root, speaker_dirs, dataset_name,
output_dir, "*.wav", skip_existing)
def process_voxceleb2(processor,
datasets_root,
output_dir,
skip_existing=False):
dataset_name = "VoxCeleb2"
dataset_root = datasets_root / dataset_name
# There is no nationality in meta data for VoxCeleb2
speaker_dirs = list((dataset_root / "wav").glob("*"))
_process_dataset(processor, datasets_root, speaker_dirs, dataset_name,
output_dir, "*.wav", skip_existing)

126
examples/ge2e/inference.py Normal file
View File

@ -0,0 +1,126 @@
from parakeet.models.lstm_speaker_encoder import LSTMSpeakerEncoder
from audio_processor import SpeakerVerificationPreprocessor
from matplotlib import cm
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from config import get_cfg_defaults
import argparse
import paddle
import tqdm
from functools import partial
from multiprocessing import Pool
def embed_utterance(processor, model, fpath_or_wav):
# audio processor
wav = processor.preprocess_wav(fpath_or_wav)
mel_partials = processor.extract_mel_partials(wav)
model.eval()
# speaker encoder
with paddle.no_grad():
mel_partials = paddle.to_tensor(mel_partials)
with paddle.no_grad():
embed = model.embed_utterance(mel_partials)
embed = embed.numpy()
return embed
def _process_utterance(ifpath: Path, input_dir: Path, output_dir: Path,
processor: SpeakerVerificationPreprocessor,
model: LSTMSpeakerEncoder):
rel_path = ifpath.relative_to(input_dir)
ofpath = (output_dir / rel_path).with_suffix(".npy")
ofpath.parent.mkdir(parents=True, exist_ok=True)
embed = embed_utterance(processor, model, ifpath)
np.save(ofpath, embed)
def main(config, args):
paddle.set_device(args.device)
# load model
model = LSTMSpeakerEncoder(config.data.n_mels, config.model.num_layers,
config.model.hidden_size,
config.model.embedding_size)
weights_fpath = str(Path(args.checkpoint_path).expanduser())
model_state_dict = paddle.load(weights_fpath + ".pdparams")
model.set_state_dict(model_state_dict)
model.eval()
print(f"Loaded encoder {weights_fpath}")
# create audio processor
c = config.data
processor = SpeakerVerificationPreprocessor(
sampling_rate=c.sampling_rate,
audio_norm_target_dBFS=c.audio_norm_target_dBFS,
vad_window_length=c.vad_window_length,
vad_moving_average_width=c.vad_moving_average_width,
vad_max_silence_length=c.vad_max_silence_length,
mel_window_length=c.mel_window_length,
mel_window_step=c.mel_window_step,
n_mels=c.n_mels,
partial_n_frames=c.partial_n_frames,
min_pad_coverage=c.min_pad_coverage,
partial_overlap_ratio=c.min_pad_coverage,
)
# input output preparation
input_dir = Path(args.input).expanduser()
ifpaths = list(input_dir.rglob(args.pattern))
print(f"{len(ifpaths)} utterances in total")
output_dir = Path(args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
for ifpath in tqdm.tqdm(ifpaths, unit="utterance"):
_process_utterance(ifpath, input_dir, output_dir, processor, model)
if __name__ == "__main__":
config = get_cfg_defaults()
parser = argparse.ArgumentParser(description="compute utterance embed.")
parser.add_argument(
"--config",
metavar="FILE",
help="path of the config file to overwrite to default config with.")
parser.add_argument("--input",
type=str,
help="path of the audio_file folder.")
parser.add_argument("--pattern",
type=str,
default="*.wav",
help="pattern to filter audio files.")
parser.add_argument("--output",
metavar="OUTPUT_DIR",
help="path to save checkpoint and logs.")
# load from saved checkpoint
parser.add_argument("--checkpoint_path",
type=str,
help="path of the checkpoint to load")
# running
parser.add_argument("--device",
type=str,
choices=["cpu", "gpu"],
help="device type to use, cpu and gpu are supported.")
# overwrite extra config and default config
parser.add_argument(
"--opts",
nargs=argparse.REMAINDER,
help=
"options to overwrite --config file and the default config, passing in KEY VALUE pairs"
)
args = parser.parse_args()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
print(args)
main(config, args)

View File

@ -0,0 +1,88 @@
import argparse
from pathlib import Path
from config import get_cfg_defaults
from audio_processor import SpeakerVerificationPreprocessor
from dataset_processors import process_librispeech, process_voxceleb1, process_voxceleb2
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="preprocess dataset for speaker verification task")
parser.add_argument(
"--datasets_root",
type=Path,
help=
"Path to the directory containing your LibriSpeech, LibriTTS and VoxCeleb datasets."
)
parser.add_argument("--output_dir",
type=Path,
help="Path to save processed dataset.")
parser.add_argument(
"--dataset_names",
type=str,
default="librispeech_other,voxceleb1,voxceleb2",
help=
"comma-separated list of names of the datasets you want to preprocess. only "
"the train set of these datastes will be used. Possible names: librispeech_other, "
"voxceleb1, voxceleb2.")
parser.add_argument(
"--skip_existing",
action="store_true",
help=
"Whether to skip ouput files with the same name. Useful if this script was interrupted."
)
parser.add_argument(
"--no_trim",
action="store_true",
help="Preprocess audio without trimming silences (not recommended).")
args = parser.parse_args()
if not args.no_trim:
try:
import webrtcvad
except:
raise ModuleNotFoundError(
"Package 'webrtcvad' not found. This package enables "
"noise removal and is recommended. Please install and try again. If installation fails, "
"use --no_trim to disable this error message.")
del args.no_trim
args.datasets = [item.strip() for item in args.dataset_names.split(",")]
if not hasattr(args, "output_dir"):
args.output_dir = args.dataset_root / "SV2TTS" / "encoder"
args.output_dir = args.output_dir.expanduser()
args.datasets_root = args.datasets_root.expanduser()
assert args.datasets_root.exists()
args.output_dir.mkdir(exist_ok=True, parents=True)
config = get_cfg_defaults()
# TODO: nice print
print(args)
c = config.data
processor = SpeakerVerificationPreprocessor(
sampling_rate=c.sampling_rate,
audio_norm_target_dBFS=c.audio_norm_target_dBFS,
vad_window_length=c.vad_window_length,
vad_moving_average_width=c.vad_moving_average_width,
vad_max_silence_length=c.vad_max_silence_length,
mel_window_length=c.mel_window_length,
mel_window_step=c.mel_window_step,
n_mels=c.n_mels,
partial_n_frames=c.partial_n_frames,
min_pad_coverage=c.min_pad_coverage,
partial_overlap_ratio=c.min_pad_coverage,
)
preprocess_func = {
"librispeech_other": process_librispeech,
"voxceleb1": process_voxceleb1,
"voxceleb2": process_voxceleb2,
}
for dataset in args.datasets:
print("Preprocessing %s" % dataset)
preprocess_func[dataset](processor, args.datasets_root, args.output_dir, args.skip_existing)

View File

@ -0,0 +1,23 @@
import random
def cycle(iterable):
# cycle('ABCD') --> A B C D A B C D A B C D ...
saved = []
for element in iterable:
yield element
saved.append(element)
while saved:
for element in saved:
yield element
def random_cycle(iterable):
# cycle('ABCD') --> A B C D B C D A A D B C ...
saved = []
for element in iterable:
yield element
saved.append(element)
random.shuffle(saved)
while saved:
for element in saved:
yield element
random.shuffle(saved)

View File

@ -0,0 +1,114 @@
import random
import numpy as np
import paddle
from paddle.io import Dataset, BatchSampler
from pathlib import Path
from random_cycle import random_cycle
class MultiSpeakerMelDataset(Dataset):
"""A 2 layer directory thatn contains mel spectrograms in *.npy format.
An Example file structure tree is shown below. We prefer to preprocess
raw datasets and organized them like this.
dataset_root/
speaker1/
utterance1.npy
utterance2.npy
utterance3.npy
speaker2/
utterance1.npy
utterance2.npy
utterance3.npy
"""
def __init__(self, dataset_root: Path):
self.root = Path(dataset_root).expanduser()
speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
speaker_utterances = {
speaker_dir: [f for f in speaker_dir.glob("*.npy")]
for speaker_dir in speaker_dirs
}
self.speaker_dirs = speaker_dirs
self.speaker_to_utterances = speaker_utterances
# meta data
self.num_speakers = len(self.speaker_dirs)
self.num_utterances = np.sum(
len(utterances)
for speaker, utterances in self.speaker_to_utterances.items())
def get_example_by_index(self, speaker_index, utterance_index):
speaker_dir = self.speaker_dirs[speaker_index]
fpath = self.speaker_to_utterances[speaker_dir][utterance_index]
return self[fpath]
def __getitem__(self, fpath):
return np.load(fpath)
def __len__(self):
return int(self.num_utterances())
class MultiSpeakerSampler(BatchSampler):
"""A multi-stratal sampler designed for speaker verification task.
First, N speakers from all speakers are sampled randomly. Then, for each
speaker, randomly sample M utterances from their corresponding utterances.
"""
def __init__(self,
dataset: MultiSpeakerMelDataset,
speakers_per_batch: int,
utterances_per_speaker: int):
self._speakers = list(dataset.speaker_dirs)
self._speaker_to_utterances = dataset.speaker_to_utterances
self.speakers_per_batch = speakers_per_batch
self.utterances_per_speaker = utterances_per_speaker
def __iter__(self):
# yield list of Paths
speaker_generator = iter(random_cycle(self._speakers))
speaker_utterances_generator = {
s: iter(random_cycle(us))
for s, us in self._speaker_to_utterances.items()
}
while True:
speakers = []
for _ in range(self.speakers_per_batch):
speakers.append(next(speaker_generator))
utterances = []
for s in speakers:
us = speaker_utterances_generator[s]
for _ in range(self.utterances_per_speaker):
utterances.append(next(us))
yield utterances
class RandomClip(object):
def __init__(self, frames):
self.frames = frames
def __call__(self, spec):
# spec [T, C]
T = spec.shape[0]
start = random.randint(0, T - self.frames)
return spec[start:start + self.frames, :]
class Collate(object):
def __init__(self, num_frames):
self.random_crop = RandomClip(num_frames)
def __call__(self, examples):
frame_clips = [self.random_crop(mel) for mel in examples]
batced_clips = np.stack(frame_clips)
return batced_clips
if __name__ == "__main__":
mydataset = MultiSpeakerMelDataset(
Path("/home/chenfeiyu/datasets/SV2TTS/encoder"))
print(mydataset.get_example_by_index(0, 10))

105
examples/ge2e/train.py Normal file
View File

@ -0,0 +1,105 @@
from parakeet.training import ExperimentBase, default_argument_parser
from config import get_cfg_defaults
from paddle import distributed as dist
import time
from paddle.nn.clip import ClipGradByGlobalNorm
from parakeet.models.lstm_speaker_encoder import LSTMSpeakerEncoder
from speaker_verification_dataset import MultiSpeakerMelDataset, MultiSpeakerSampler, Collate
from paddle.optimizer import Adam
from paddle import DataParallel
from paddle.io import DataLoader
class Ge2eExperiment(ExperimentBase):
def setup_model(self):
config = self.config
model = LSTMSpeakerEncoder(config.data.n_mels, config.model.num_layers,
config.model.hidden_size,
config.model.embedding_size)
optimizer = Adam(config.training.learning_rate_init,
parameters=model.parameters(),
grad_clip=ClipGradByGlobalNorm(3))
self.model = DataParallel(model) if self.parallel else model
self.model_core = model
self.optimizer = optimizer
def setup_dataloader(self):
config = self.config
train_dataset = MultiSpeakerMelDataset(self.args.data)
sampler = MultiSpeakerSampler(train_dataset,
config.training.speakers_per_batch,
config.training.utterances_per_speaker)
train_loader = DataLoader(train_dataset,
batch_sampler=sampler,
collate_fn=Collate(
config.data.partial_n_frames),
num_workers=16)
self.train_dataset = train_dataset
self.train_loader = train_loader
def train_batch(self):
start = time.time()
batch = self.read_batch()
data_loader_time = time.time() - start
self.optimizer.clear_grad()
self.model.train()
specs = batch
loss, eer = self.model(specs, self.config.training.speakers_per_batch)
loss.backward()
self.model_core.do_gradient_ops()
self.optimizer.step()
iteration_time = time.time() - start
# logging
loss_value = float(loss)
msg = "Rank: {}, ".format(dist.get_rank())
msg += "step: {}, ".format(self.iteration)
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time,
iteration_time)
msg += 'loss: {:>.6f} err: {:>.6f}'.format(loss_value, eer)
self.logger.info(msg)
if dist.get_rank() == 0:
self.visualizer.add_scalar("train/loss", loss_value,
self.iteration)
self.visualizer.add_scalar("train/eer", eer, self.iteration)
self.visualizer.add_scalar("param/w",
float(self.model_core.similarity_weight),
self.iteration)
self.visualizer.add_scalar("param/b",
float(self.model_core.similarity_bias),
self.iteration)
def valid(self):
pass
def main_sp(config, args):
exp = Ge2eExperiment(config, args)
exp.setup()
exp.run()
def main(config, args):
if args.nprocs > 1 and args.device == "gpu":
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
else:
main_sp(config, args)
if __name__ == "__main__":
config = get_cfg_defaults()
parser = default_argument_parser()
args = parser.parse_args()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
print(args)
main(config, args)

View File

@ -0,0 +1,120 @@
import numpy as np
import paddle
from paddle import nn
from paddle.fluid.param_attr import ParamAttr
from paddle.nn import functional as F
from paddle.nn import initializer as I
from scipy.interpolate import interp1d
from sklearn.metrics import roc_curve
from scipy.optimize import brentq
class LSTMSpeakerEncoder(nn.Layer):
def __init__(self, n_mels, num_layers, hidden_size, output_size):
super().__init__()
self.lstm = nn.LSTM(n_mels, hidden_size, num_layers)
self.linear = nn.Linear(hidden_size, output_size)
self.similarity_weight = self.create_parameter(
[1], default_initializer=I.Constant(10.))
self.similarity_bias = self.create_parameter(
[1], default_initializer=I.Constant(-5.))
def forward(self, utterances, num_speakers, initial_states=None):
normalized_embeds = self.embed_sequences(utterances, initial_states)
embeds = normalized_embeds.reshape([num_speakers, -1, num_speakers])
loss, eer = self.loss(embeds)
return loss, eer
def embed_sequences(self, utterances, initial_states=None, reduce=False):
out, (h, c) = self.lstm(utterances, initial_states)
embeds = F.relu(self.linear(h[-1]))
normalized_embeds = F.normalize(embeds)
if reduce:
embed = paddle.mean(normalized_embeds, 0)
embed = F.normalize(embed, axis=0)
return normalized_embeds
def embed_utterance(self, utterances, initial_states=None):
# utterances: [B, T, C] -> embed [C']
embed = self.embed_sequences(utterances, initial_states, reduce=True)
return embed
def similarity_matrix(self, embeds):
# (N, M, C)
speakers_per_batch, utterances_per_speaker, embed_dim = embeds.shape
# Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
centroids_incl = paddle.mean(embeds, axis=1)
centroids_incl_norm = paddle.norm(centroids_incl, p=2, axis=1, keepdim=True)
normalized_centroids_incl = centroids_incl / centroids_incl_norm
# Exclusive centroids (1 per utterance)
centroids_excl = paddle.broadcast_to(paddle.sum(embeds, axis=1, keepdim=True), embeds.shape) - embeds
centroids_excl /= (utterances_per_speaker - 1)
centroids_excl_norm = paddle.norm(centroids_excl, p=2, axis=2, keepdim=True)
normalized_centroids_excl = centroids_excl / centroids_excl_norm
p1 = paddle.matmul(embeds.reshape([-1, embed_dim]),
normalized_centroids_incl, transpose_y=True) # (NMN)
p1 = p1.reshape([-1])
# print("p1: ", p1.shape)
p2 = paddle.bmm(embeds.reshape([-1, 1, embed_dim]),
normalized_centroids_excl.reshape([-1, embed_dim, 1])) # (NM, 1, 1)
p2 = p2.reshape([-1]) # NM)
# begin: alternative implementation for scatter
with paddle.no_grad():
index = paddle.arange(0, speakers_per_batch * utterances_per_speaker, dtype="int64").reshape([speakers_per_batch, utterances_per_speaker])
index = index * speakers_per_batch + paddle.arange(0, speakers_per_batch, dtype="int64").unsqueeze(-1)
index = paddle.reshape(index, [-1])
ones = paddle.ones([speakers_per_batch * utterances_per_speaker * speakers_per_batch])
zeros = paddle.zeros_like(index, dtype=ones.dtype)
mask_p1 = paddle.scatter(ones, index, zeros)
p = p1 * mask_p1 + (1 - mask_p1) * paddle.scatter(ones, index, p2)
# end: alternative implementation for scatter
# p = paddle.scatter(p1, index, p2)
p = p * self.similarity_weight + self.similarity_bias # neg
p = p.reshape([speakers_per_batch * utterances_per_speaker, speakers_per_batch])
return p, p1, p2
def do_gradient_ops(self):
for p in [self.similarity_weight, self.similarity_bias]:
g = p._grad_ivar()
g[...] = g * 0.01
def loss(self, embeds):
"""
Computes the softmax loss according the section 2.1 of GE2E.
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
utterances_per_speaker, embedding_size)
:return: the loss and the EER for this batch of embeddings.
"""
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
# Loss
sim_matrix, *_ = self.similarity_matrix(embeds)
sim_matrix = sim_matrix.reshape(
[speakers_per_batch * utterances_per_speaker, speakers_per_batch])
target = paddle.arange(0, speakers_per_batch, dtype="int64").unsqueeze(-1)
target = paddle.expand(target, [speakers_per_batch, utterances_per_speaker])
target = paddle.reshape(target, [-1])
loss = nn.CrossEntropyLoss()(sim_matrix, target)
# EER (not backpropagated)
with paddle.no_grad():
ground_truth = target.numpy()
inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
labels = np.array([inv_argmax(i) for i in ground_truth])
preds = sim_matrix.numpy()
# Snippet from https://yangcha.github.io/EER-ROC/
fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
return loss, eer

View File

@ -0,0 +1,77 @@
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
class GE2ELoss(nn.Layer):
def __init__(self, init_w=10., init_b=-5.):
super().__init__()
self.w = self.create_parameter([1], attr=I.Constant(init_w))
self.b = self.create_parameter([1], attr=I.Constant(init_b))
def forward(self, embeds):
# embeds [N, M, C]
# N - speakers_per_batch,
# M - utterances_per_speaker
# C - embed_dim
sim_matrix = self._build_sim_matrix(embeds)
_, M, N = sim_matrix.shape
target = paddle.arange(0, N, dtype="int64").unsqueeze(-1)
target = paddle.expand(target, [N, M])
target = paddle.reshape(target, [-1]) # [NM]
criterion = nn.CrossEntropyLoss()
loss = criterion(sim_matrix.reshape([-1, N]), target)
return loss
def _build_sim_matrix(self, embeds):
N, M, C = embeds.shape
# Inclusive centroids (1 per speaker). [N, C]
centroids_incl = paddle.mean(embeds, axis=1)
centroids_incl_norm = paddle.norm(centroids_incl,
p=2,
axis=1,
keepdim=True)
normalized_centroids_incl = centroids_incl / centroids_incl_norm
# Exclusive centroids (1 per utterance) [N, M, C]
centroids_excl = paddle.broadcast_to(
paddle.sum(embeds, axis=1, keepdim=True), embeds.shape) - embeds
centroids_excl /= (M - 1)
centroids_excl_norm = paddle.norm(centroids_excl,
p=2,
axis=2,
keepdim=True)
normalized_centroids_excl = centroids_excl / centroids_excl_norm
# inter-speaker similarity, NM embeds ~ N centroids
# [NM, N]
p1 = paddle.matmul(embeds.reshape([-1, C]),
normalized_centroids_incl,
transpose_y=True)
p1 = p1.reshape([-1]) # [NMN]
# intra-similarity, NM embeds, 1 centroid per embed
p2 = paddle.bmm(embeds.reshape([-1, 1, C]),
normalized_centroids_excl.reshape([-1, C,
1])) # (NM, 1, 1)
p2 = p2.reshape([-1]) # [NM]
with paddle.no_grad():
index = paddle.arange(0, N * M, dtype="int64").reshape([N, M])
index = index * N + paddle.arange(0, N,
dtype="int64").unsqueeze(-1)
index = paddle.reshape(index, [-1])
# begin: alternative implementation for scatter
ones = paddle.ones([N * M * N])
zeros = paddle.zeros_like(index, dtype=ones.dtype)
mask_p1 = paddle.scatter(ones, index, zeros)
p = p1 * mask_p1 + (1 - mask_p1) * paddle.scatter(ones, index, p2)
# end: alternative implementation for scatter
# p = paddle.scatter(p1, index, p2) there is a backward bug in scatter
p = p * self.w + self.b
p = p.reshape([N, M, N])
return p

View File

@ -165,7 +165,7 @@ class ExperimentBase(object):
"""Reset the train loader and increment ``epoch``. """Reset the train loader and increment ``epoch``.
""" """
self.epoch += 1 self.epoch += 1
if self.parallel: if self.parallel and isinstance(self.train_loader.batch_sampler, DistributedBatchSampler):
self.train_loader.batch_sampler.set_epoch(self.epoch) self.train_loader.batch_sampler.set_epoch(self.epoch)
self.iterator = iter(self.train_loader) self.iterator = iter(self.train_loader)