From 2475da3322c60413369705fd151ecd88c4dbf47b Mon Sep 17 00:00:00 2001 From: iclementine Date: Sat, 27 Mar 2021 17:39:37 +0800 Subject: [PATCH] add ge2e --- examples/ge2e/audio_processor.py | 226 ++++++++++++++++++ examples/ge2e/config.py | 48 ++++ examples/ge2e/dataset_processors.py | 139 +++++++++++ examples/ge2e/inference.py | 126 ++++++++++ examples/ge2e/preprocess.py | 88 +++++++ examples/ge2e/random_cycle.py | 23 ++ examples/ge2e/speaker_verification_dataset.py | 114 +++++++++ examples/ge2e/train.py | 105 ++++++++ parakeet/models/lstm_speaker_encoder.py | 120 ++++++++++ parakeet/modules/g2e2_loss.py | 77 ++++++ parakeet/training/experiment.py | 2 +- 11 files changed, 1067 insertions(+), 1 deletion(-) create mode 100644 examples/ge2e/audio_processor.py create mode 100644 examples/ge2e/config.py create mode 100644 examples/ge2e/dataset_processors.py create mode 100644 examples/ge2e/inference.py create mode 100644 examples/ge2e/preprocess.py create mode 100644 examples/ge2e/random_cycle.py create mode 100644 examples/ge2e/speaker_verification_dataset.py create mode 100644 examples/ge2e/train.py create mode 100644 parakeet/models/lstm_speaker_encoder.py create mode 100644 parakeet/modules/g2e2_loss.py diff --git a/examples/ge2e/audio_processor.py b/examples/ge2e/audio_processor.py new file mode 100644 index 0000000..b10db34 --- /dev/null +++ b/examples/ge2e/audio_processor.py @@ -0,0 +1,226 @@ +from scipy.ndimage.morphology import binary_dilation +from config import get_cfg_defaults +from pathlib import Path +from typing import Optional, Union +from warnings import warn +import numpy as np +import librosa +import struct + +try: + import webrtcvad +except: + warn("Unable to import 'webrtcvad'." + "This package enables noise removal and is recommended.") + webrtcvad = None + +int16_max = (2**15) - 1 + + +def normalize_volume(wav, + target_dBFS, + increase_only=False, + decrease_only=False): + # this function implements Loudness normalization, instead of peak + # normalization, See https://en.wikipedia.org/wiki/Audio_normalization + # dBFS: Decibels relative to full scale + # See https://en.wikipedia.org/wiki/DBFS for more details + # for 16Bit PCM audio, minimal level is -96dB + # compute the mean dBFS and adjust to target dBFS, with by increasing + # or decreasing + if increase_only and decrease_only: + raise ValueError("Both increase only and decrease only are set") + dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav**2)) + if ((dBFS_change < 0 and increase_only) + or (dBFS_change > 0 and decrease_only)): + return wav + gain = 10**(dBFS_change / 20) + return wav * gain + + +def trim_long_silences(wav, vad_window_length: int, + vad_moving_average_width: int, + vad_max_silence_length: int, sampling_rate: int): + """ + Ensures that segments without voice in the waveform remain no longer than a + threshold determined by the VAD parameters in params.py. + + :param wav: the raw waveform as a numpy array of floats + :return: the same waveform with silences trimmed away (length <= original wav length) + """ + # Compute the voice detection window size + samples_per_window = (vad_window_length * sampling_rate) // 1000 + + # Trim the end of the audio to have a multiple of the window size + wav = wav[:len(wav) - (len(wav) % samples_per_window)] + + # Convert the float waveform to 16-bit mono PCM + pcm_wave = struct.pack("%dh" % len(wav), + *(np.round(wav * int16_max)).astype(np.int16)) + + # Perform voice activation detection + voice_flags = [] + vad = webrtcvad.Vad(mode=3) + for window_start in range(0, len(wav), samples_per_window): + window_end = window_start + samples_per_window + voice_flags.append( + vad.is_speech(pcm_wave[window_start * 2:window_end * 2], + sample_rate=sampling_rate)) + voice_flags = np.array(voice_flags) + + # Smooth the voice detection with a moving average + def moving_average(array, width): + array_padded = np.concatenate((np.zeros( + (width - 1) // 2), array, np.zeros(width // 2))) + ret = np.cumsum(array_padded, dtype=float) + ret[width:] = ret[width:] - ret[:-width] + return ret[width - 1:] / width + + audio_mask = moving_average(voice_flags, vad_moving_average_width) + audio_mask = np.round(audio_mask).astype(np.bool) + + # Dilate the voiced regions + audio_mask = binary_dilation(audio_mask, + np.ones(vad_max_silence_length + 1)) + audio_mask = np.repeat(audio_mask, samples_per_window) + + return wav[audio_mask == True] + + +def compute_partial_slices(n_samples: int, + partial_utterance_n_frames: int, + hop_length: int, + min_pad_coverage: float = 0.75, + overlap: float = 0.5): + """ + Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain + partial utterances of each. Both the waveform and the mel + spectrogram slices are returned, so as to make each partial utterance waveform correspond to + its spectrogram. This function assumes that the mel spectrogram parameters used are those + defined in params_data.py. + + The returned ranges may be indexing further than the length of the waveform. It is + recommended that you pad the waveform with zeros up to wave_slices[-1].stop. + + :param n_samples: the number of samples in the waveform + :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial + utterance + :param min_pad_coverage: when reaching the last partial utterance, it may or may not have + enough frames. If at least of are present, + then the last partial utterance will be considered, as if we padded the audio. Otherwise, + it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial + utterance, this parameter is ignored so that the function always returns at least 1 slice. + :param overlap: by how much the partial utterance should overlap. If set to 0, the partial + utterances are entirely disjoint. + :return: the waveform slices and mel spectrogram slices as lists of array slices. Index + respectively the waveform and the mel spectrogram with these slices to obtain the partial + utterances. + """ + assert 0 <= overlap < 1 + assert 0 < min_pad_coverage <= 1 + + # librosa's function to compute num_frames from num_samples + n_frames = int(np.ceil((n_samples + 1) / hop_length)) + # frame shift between ajacent partials + frame_step = max(1, + int(np.round(partial_utterance_n_frames * (1 - overlap)))) + + # Compute the slices + wav_slices, mel_slices = [], [] + steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1) + for i in range(0, steps, frame_step): + mel_range = np.array([i, i + partial_utterance_n_frames]) + wav_range = mel_range * hop_length + mel_slices.append(slice(*mel_range)) + wav_slices.append(slice(*wav_range)) + + # Evaluate whether extra padding is warranted or not + last_wav_range = wav_slices[-1] + coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - + last_wav_range.start) + if coverage < min_pad_coverage and len(mel_slices) > 1: + mel_slices = mel_slices[:-1] + wav_slices = wav_slices[:-1] + + return wav_slices, mel_slices + + +class SpeakerVerificationPreprocessor(object): + def __init__(self, + sampling_rate: int, + audio_norm_target_dBFS:float, + vad_window_length, + vad_moving_average_width, + vad_max_silence_length, + mel_window_length, + mel_window_step, + n_mels, + partial_n_frames: int, + min_pad_coverage: float = 0.75, + partial_overlap_ratio: float = 0.5): + self.sampling_rate = sampling_rate + self.audio_norm_target_dBFS = audio_norm_target_dBFS + + self.vad_window_length = vad_window_length + self.vad_moving_average_width = vad_moving_average_width + self.vad_max_silence_length = vad_max_silence_length + + self.n_fft = int(mel_window_length * sampling_rate / 1000) + self.hop_length = int(mel_window_step * sampling_rate / 1000) + self.n_mels = n_mels + + self.partial_n_frames = partial_n_frames + self.min_pad_coverage = min_pad_coverage + self.partial_overlap_ratio = partial_overlap_ratio + + def preprocess_wav(self, fpath_or_wav, source_sr=None): + # Load the wav from disk if needed + if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path): + wav, source_sr = librosa.load(str(fpath_or_wav), sr=None) + else: + wav = fpath_or_wav + + # Resample if numpy.array is passed and sr does not match + if source_sr is not None and source_sr != self.sampling_rate: + wav = librosa.resample(wav, source_sr, self.sampling_rate) + + # loudness normalization + wav = normalize_volume(wav, + self.audio_norm_target_dBFS, + increase_only=True) + + # trim long silence + if webrtcvad: + wav = trim_long_silences(wav, self.vad_window_length, + self.vad_moving_average_width, + self.vad_max_silence_length, + self.sampling_rate) + return wav + + def melspectrogram(self, wav): + mel = librosa.feature.melspectrogram(wav, + sr=self.sampling_rate, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels) + mel = mel.astype(np.float32).T + return mel + + def extract_mel_partials(self, wav): + wav_slices, mel_slices = compute_partial_slices( + len(wav), + self.partial_n_frames, + self.hop_length, + self.min_pad_coverage, + self.partial_overlap_ratio) + + # pad audio if needed + max_wave_length = wav_slices[-1].stop + if max_wave_length >= len(wav): + wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") + + # Split the utterance into partials + frames = self.melspectrogram(wav) + frames_batch = np.array([frames[s] for s in mel_slices]) + return frames_batch # [B, T, C] + diff --git a/examples/ge2e/config.py b/examples/ge2e/config.py new file mode 100644 index 0000000..662f1b4 --- /dev/null +++ b/examples/ge2e/config.py @@ -0,0 +1,48 @@ +from yacs.config import CfgNode + +_C = CfgNode() + +data_config = _C.data = CfgNode() + +## Audio volume normalization +data_config.audio_norm_target_dBFS = -30 + +## Audio sample rate +data_config.sampling_rate = 16000 # Hz + +## Voice Activation Detection +# Window size of the VAD. Must be either 10, 20 or 30 milliseconds. +# This sets the granularity of the VAD. Should not need to be changed. +data_config.vad_window_length = 30 # In milliseconds +# Number of frames to average together when performing the moving average smoothing. +# The larger this value, the larger the VAD variations must be to not get smoothed out. +data_config.vad_moving_average_width = 8 +# Maximum number of consecutive silent frames a segment can have. +data_config.vad_max_silence_length = 6 + +## Mel-filterbank +data_config.mel_window_length = 25 # In milliseconds +data_config.mel_window_step = 10 # In milliseconds +data_config.n_mels = 40 # mel bands + +# Number of spectrogram frames in a partial utterance +data_config.partial_n_frames = 160 # 1600 ms +data_config.min_pad_coverage = 0.75 # at least 75% of the audio is valid in a partial +data_config.partial_overlap_ratio = 0.5 # overlap ratio between ajancent partials + +model_config = _C.model = CfgNode() +model_config.num_layers = 3 +model_config.hidden_size = 256 +model_config.embedding_size = 256 # output size + +training_config = _C.training = CfgNode() +training_config.learning_rate_init = 1e-4 +training_config.speakers_per_batch = 64 +training_config.utterances_per_speaker = 10 +training_config.max_iteration = 1560000 +training_config.save_interval = 10000 +training_config.valid_interval = 10000 + + +def get_cfg_defaults(): + return _C.clone() diff --git a/examples/ge2e/dataset_processors.py b/examples/ge2e/dataset_processors.py new file mode 100644 index 0000000..d892003 --- /dev/null +++ b/examples/ge2e/dataset_processors.py @@ -0,0 +1,139 @@ +import numpy as np +from pathlib import Path +import multiprocessing as mp +from audio_processor import SpeakerVerificationPreprocessor +from tqdm import tqdm +from functools import partial +from typing import List + +def _process_utterance(path_pair, processor: SpeakerVerificationPreprocessor): + # Load and preprocess the waveform + input_path, output_path = path_pair + wav = processor.preprocess_wav(input_path) + if len(wav) == 0: + return + + # Create the mel spectrogram, discard those that are too short + frames = processor.melspectrogram(wav) + if len(frames) < processor.partial_n_frames: + return + + np.save(output_path, frames) + + +def _process_speaker(speaker_dir: Path, + processor, + datasets_root: Path, + output_dir: Path, + pattern: str, + skip_existing:bool=False): + # datastes root: a reference path to compute speaker_name + # we prepand dataset name to speaker_id becase we are mixing serveal + # multispeaker datasets together + speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts) + speaker_output_dir = output_dir / speaker_name + speaker_output_dir.mkdir(parents=True, exist_ok=True) + + # load exsiting file set + sources_fpath = speaker_output_dir / "_sources.txt" + if sources_fpath.exists(): + try: + with sources_fpath.open("rt") as sources_file: + existing_names = {line.split(",")[0] for line in sources_file} + except: + existing_names = {} + else: + existing_names = {} + + sources_file = sources_fpath.open("at" if skip_existing else "wt") + for in_fpath in speaker_dir.rglob(pattern): + out_name = "_".join(in_fpath.relative_to(speaker_dir).with_suffix(".npy").parts) + if skip_existing and out_name in existing_names: + continue + out_fpath = speaker_output_dir / out_name + _process_utterance((in_fpath, out_fpath), processor) + sources_file.write(f"{out_name},{in_fpath}\n") + + sources_file.close() + + +def _process_dataset(processor, + datasets_root: Path, + speaker_dirs: List[Path], + dataset_name: str, + output_dir: Path, + pattern: str, + skip_existing: bool =False): + print(f"{dataset_name}: Preprocessing data for {len(speaker_dirs)} speakers.") + + _func = partial(_process_speaker, + processor=processor, + datasets_root=datasets_root, + output_dir=output_dir, + pattern=pattern, + skip_existing=skip_existing) + + with mp.Pool(16) as pool: + list( + tqdm(pool.imap(_func, speaker_dirs), + dataset_name, + len(speaker_dirs), + unit="speakers")) + print(f"Done preprocessing {dataset_name}.") + + +def process_librispeech(processor, + datasets_root, + output_dir, + skip_existing=False): + dataset_name = "LibriSpeech/train-other-500" + dataset_root = datasets_root / dataset_name + speaker_dirs = list(dataset_root.glob("*")) + _process_dataset(processor, datasets_root, speaker_dirs, dataset_name, + output_dir, "*.flac", skip_existing) + + +def process_voxceleb1(processor, + datasets_root, + output_dir, + skip_existing=False): + dataset_name = "VoxCeleb1" + dataset_root = datasets_root / dataset_name + + anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"] + with dataset_root.joinpath("vox1_meta.csv").open("rt") as metafile: + metadata = [line.strip().split("\t") for line in metafile][1:] + + # speaker id -> nationality + nationalities = { + line[0]: line[3] + for line in metadata if line[-1] == "dev" + } + keep_speaker_ids = [ + speaker_id for speaker_id, nationality in nationalities.items() + if nationality.lower() in anglophone_nationalites + ] + print( + "VoxCeleb1: using samples from {} (presumed anglophone) speakers out of {}.".format( + len(keep_speaker_ids), len(nationalities))) + + speaker_dirs = list((dataset_root / "wav").glob("*")) + speaker_dirs = [ + speaker_dir for speaker_dir in speaker_dirs + if speaker_dir.name in keep_speaker_ids + ] + # TODO: filter ansa + _process_dataset(processor, datasets_root, speaker_dirs, dataset_name, + output_dir, "*.wav", skip_existing) + + +def process_voxceleb2(processor, + datasets_root, + output_dir, + skip_existing=False): + dataset_name = "VoxCeleb2" + dataset_root = datasets_root / dataset_name + # There is no nationality in meta data for VoxCeleb2 + speaker_dirs = list((dataset_root / "wav").glob("*")) + _process_dataset(processor, datasets_root, speaker_dirs, dataset_name, + output_dir, "*.wav", skip_existing) diff --git a/examples/ge2e/inference.py b/examples/ge2e/inference.py new file mode 100644 index 0000000..76ee9a5 --- /dev/null +++ b/examples/ge2e/inference.py @@ -0,0 +1,126 @@ +from parakeet.models.lstm_speaker_encoder import LSTMSpeakerEncoder +from audio_processor import SpeakerVerificationPreprocessor +from matplotlib import cm +from pathlib import Path +import matplotlib.pyplot as plt +import numpy as np +from config import get_cfg_defaults +import argparse +import paddle +import tqdm +from functools import partial +from multiprocessing import Pool + + +def embed_utterance(processor, model, fpath_or_wav): + # audio processor + wav = processor.preprocess_wav(fpath_or_wav) + mel_partials = processor.extract_mel_partials(wav) + + model.eval() + # speaker encoder + with paddle.no_grad(): + mel_partials = paddle.to_tensor(mel_partials) + with paddle.no_grad(): + embed = model.embed_utterance(mel_partials) + embed = embed.numpy() + return embed + + +def _process_utterance(ifpath: Path, input_dir: Path, output_dir: Path, + processor: SpeakerVerificationPreprocessor, + model: LSTMSpeakerEncoder): + rel_path = ifpath.relative_to(input_dir) + ofpath = (output_dir / rel_path).with_suffix(".npy") + ofpath.parent.mkdir(parents=True, exist_ok=True) + embed = embed_utterance(processor, model, ifpath) + np.save(ofpath, embed) + + +def main(config, args): + paddle.set_device(args.device) + + # load model + model = LSTMSpeakerEncoder(config.data.n_mels, config.model.num_layers, + config.model.hidden_size, + config.model.embedding_size) + weights_fpath = str(Path(args.checkpoint_path).expanduser()) + model_state_dict = paddle.load(weights_fpath + ".pdparams") + model.set_state_dict(model_state_dict) + model.eval() + print(f"Loaded encoder {weights_fpath}") + + # create audio processor + c = config.data + processor = SpeakerVerificationPreprocessor( + sampling_rate=c.sampling_rate, + audio_norm_target_dBFS=c.audio_norm_target_dBFS, + vad_window_length=c.vad_window_length, + vad_moving_average_width=c.vad_moving_average_width, + vad_max_silence_length=c.vad_max_silence_length, + mel_window_length=c.mel_window_length, + mel_window_step=c.mel_window_step, + n_mels=c.n_mels, + partial_n_frames=c.partial_n_frames, + min_pad_coverage=c.min_pad_coverage, + partial_overlap_ratio=c.min_pad_coverage, + ) + + # input output preparation + input_dir = Path(args.input).expanduser() + ifpaths = list(input_dir.rglob(args.pattern)) + print(f"{len(ifpaths)} utterances in total") + output_dir = Path(args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + + for ifpath in tqdm.tqdm(ifpaths, unit="utterance"): + _process_utterance(ifpath, input_dir, output_dir, processor, model) + + +if __name__ == "__main__": + config = get_cfg_defaults() + parser = argparse.ArgumentParser(description="compute utterance embed.") + parser.add_argument( + "--config", + metavar="FILE", + help="path of the config file to overwrite to default config with.") + parser.add_argument("--input", + type=str, + help="path of the audio_file folder.") + parser.add_argument("--pattern", + type=str, + default="*.wav", + help="pattern to filter audio files.") + parser.add_argument("--output", + metavar="OUTPUT_DIR", + help="path to save checkpoint and logs.") + + # load from saved checkpoint + parser.add_argument("--checkpoint_path", + type=str, + help="path of the checkpoint to load") + + # running + parser.add_argument("--device", + type=str, + choices=["cpu", "gpu"], + help="device type to use, cpu and gpu are supported.") + + # overwrite extra config and default config + parser.add_argument( + "--opts", + nargs=argparse.REMAINDER, + help= + "options to overwrite --config file and the default config, passing in KEY VALUE pairs" + ) + + args = parser.parse_args() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + print(args) + + main(config, args) diff --git a/examples/ge2e/preprocess.py b/examples/ge2e/preprocess.py new file mode 100644 index 0000000..a601715 --- /dev/null +++ b/examples/ge2e/preprocess.py @@ -0,0 +1,88 @@ +import argparse +from pathlib import Path +from config import get_cfg_defaults +from audio_processor import SpeakerVerificationPreprocessor +from dataset_processors import process_librispeech, process_voxceleb1, process_voxceleb2 + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="preprocess dataset for speaker verification task") + parser.add_argument( + "--datasets_root", + type=Path, + help= + "Path to the directory containing your LibriSpeech, LibriTTS and VoxCeleb datasets." + ) + parser.add_argument("--output_dir", + type=Path, + help="Path to save processed dataset.") + parser.add_argument( + "--dataset_names", + type=str, + default="librispeech_other,voxceleb1,voxceleb2", + help= + "comma-separated list of names of the datasets you want to preprocess. only " + "the train set of these datastes will be used. Possible names: librispeech_other, " + "voxceleb1, voxceleb2.") + parser.add_argument( + "--skip_existing", + action="store_true", + help= + "Whether to skip ouput files with the same name. Useful if this script was interrupted." + ) + parser.add_argument( + "--no_trim", + action="store_true", + help="Preprocess audio without trimming silences (not recommended).") + + args = parser.parse_args() + + if not args.no_trim: + try: + import webrtcvad + except: + raise ModuleNotFoundError( + "Package 'webrtcvad' not found. This package enables " + "noise removal and is recommended. Please install and try again. If installation fails, " + "use --no_trim to disable this error message.") + del args.no_trim + + args.datasets = [item.strip() for item in args.dataset_names.split(",")] + if not hasattr(args, "output_dir"): + args.output_dir = args.dataset_root / "SV2TTS" / "encoder" + + args.output_dir = args.output_dir.expanduser() + args.datasets_root = args.datasets_root.expanduser() + assert args.datasets_root.exists() + args.output_dir.mkdir(exist_ok=True, parents=True) + + config = get_cfg_defaults() + # TODO: nice print + print(args) + + c = config.data + processor = SpeakerVerificationPreprocessor( + sampling_rate=c.sampling_rate, + audio_norm_target_dBFS=c.audio_norm_target_dBFS, + vad_window_length=c.vad_window_length, + vad_moving_average_width=c.vad_moving_average_width, + vad_max_silence_length=c.vad_max_silence_length, + mel_window_length=c.mel_window_length, + mel_window_step=c.mel_window_step, + n_mels=c.n_mels, + partial_n_frames=c.partial_n_frames, + min_pad_coverage=c.min_pad_coverage, + partial_overlap_ratio=c.min_pad_coverage, + ) + + preprocess_func = { + "librispeech_other": process_librispeech, + "voxceleb1": process_voxceleb1, + "voxceleb2": process_voxceleb2, + } + + for dataset in args.datasets: + print("Preprocessing %s" % dataset) + preprocess_func[dataset](processor, args.datasets_root, args.output_dir, args.skip_existing) + + diff --git a/examples/ge2e/random_cycle.py b/examples/ge2e/random_cycle.py new file mode 100644 index 0000000..42620ae --- /dev/null +++ b/examples/ge2e/random_cycle.py @@ -0,0 +1,23 @@ +import random + +def cycle(iterable): + # cycle('ABCD') --> A B C D A B C D A B C D ... + saved = [] + for element in iterable: + yield element + saved.append(element) + while saved: + for element in saved: + yield element + +def random_cycle(iterable): + # cycle('ABCD') --> A B C D B C D A A D B C ... + saved = [] + for element in iterable: + yield element + saved.append(element) + random.shuffle(saved) + while saved: + for element in saved: + yield element + random.shuffle(saved) \ No newline at end of file diff --git a/examples/ge2e/speaker_verification_dataset.py b/examples/ge2e/speaker_verification_dataset.py new file mode 100644 index 0000000..bfa3601 --- /dev/null +++ b/examples/ge2e/speaker_verification_dataset.py @@ -0,0 +1,114 @@ +import random +import numpy as np +import paddle +from paddle.io import Dataset, BatchSampler +from pathlib import Path +from random_cycle import random_cycle + + +class MultiSpeakerMelDataset(Dataset): + """A 2 layer directory thatn contains mel spectrograms in *.npy format. + An Example file structure tree is shown below. We prefer to preprocess + raw datasets and organized them like this. + + dataset_root/ + speaker1/ + utterance1.npy + utterance2.npy + utterance3.npy + speaker2/ + utterance1.npy + utterance2.npy + utterance3.npy + """ + def __init__(self, dataset_root: Path): + self.root = Path(dataset_root).expanduser() + speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()] + + speaker_utterances = { + speaker_dir: [f for f in speaker_dir.glob("*.npy")] + for speaker_dir in speaker_dirs + } + + self.speaker_dirs = speaker_dirs + self.speaker_to_utterances = speaker_utterances + + # meta data + self.num_speakers = len(self.speaker_dirs) + self.num_utterances = np.sum( + len(utterances) + for speaker, utterances in self.speaker_to_utterances.items()) + + def get_example_by_index(self, speaker_index, utterance_index): + speaker_dir = self.speaker_dirs[speaker_index] + fpath = self.speaker_to_utterances[speaker_dir][utterance_index] + return self[fpath] + + def __getitem__(self, fpath): + return np.load(fpath) + + def __len__(self): + return int(self.num_utterances()) + + +class MultiSpeakerSampler(BatchSampler): + """A multi-stratal sampler designed for speaker verification task. + First, N speakers from all speakers are sampled randomly. Then, for each + speaker, randomly sample M utterances from their corresponding utterances. + """ + def __init__(self, + dataset: MultiSpeakerMelDataset, + speakers_per_batch: int, + utterances_per_speaker: int): + self._speakers = list(dataset.speaker_dirs) + self._speaker_to_utterances = dataset.speaker_to_utterances + + self.speakers_per_batch = speakers_per_batch + self.utterances_per_speaker = utterances_per_speaker + + def __iter__(self): + # yield list of Paths + speaker_generator = iter(random_cycle(self._speakers)) + speaker_utterances_generator = { + s: iter(random_cycle(us)) + for s, us in self._speaker_to_utterances.items() + } + + while True: + speakers = [] + for _ in range(self.speakers_per_batch): + speakers.append(next(speaker_generator)) + + utterances = [] + for s in speakers: + us = speaker_utterances_generator[s] + for _ in range(self.utterances_per_speaker): + utterances.append(next(us)) + yield utterances + + +class RandomClip(object): + def __init__(self, frames): + self.frames = frames + + def __call__(self, spec): + # spec [T, C] + T = spec.shape[0] + start = random.randint(0, T - self.frames) + return spec[start:start + self.frames, :] + + +class Collate(object): + def __init__(self, num_frames): + self.random_crop = RandomClip(num_frames) + + def __call__(self, examples): + frame_clips = [self.random_crop(mel) for mel in examples] + batced_clips = np.stack(frame_clips) + return batced_clips + + +if __name__ == "__main__": + mydataset = MultiSpeakerMelDataset( + Path("/home/chenfeiyu/datasets/SV2TTS/encoder")) + print(mydataset.get_example_by_index(0, 10)) diff --git a/examples/ge2e/train.py b/examples/ge2e/train.py new file mode 100644 index 0000000..332aec7 --- /dev/null +++ b/examples/ge2e/train.py @@ -0,0 +1,105 @@ +from parakeet.training import ExperimentBase, default_argument_parser +from config import get_cfg_defaults +from paddle import distributed as dist +import time +from paddle.nn.clip import ClipGradByGlobalNorm + +from parakeet.models.lstm_speaker_encoder import LSTMSpeakerEncoder +from speaker_verification_dataset import MultiSpeakerMelDataset, MultiSpeakerSampler, Collate +from paddle.optimizer import Adam +from paddle import DataParallel +from paddle.io import DataLoader + + +class Ge2eExperiment(ExperimentBase): + def setup_model(self): + config = self.config + model = LSTMSpeakerEncoder(config.data.n_mels, config.model.num_layers, + config.model.hidden_size, + config.model.embedding_size) + optimizer = Adam(config.training.learning_rate_init, + parameters=model.parameters(), + grad_clip=ClipGradByGlobalNorm(3)) + self.model = DataParallel(model) if self.parallel else model + self.model_core = model + self.optimizer = optimizer + + def setup_dataloader(self): + config = self.config + train_dataset = MultiSpeakerMelDataset(self.args.data) + sampler = MultiSpeakerSampler(train_dataset, + config.training.speakers_per_batch, + config.training.utterances_per_speaker) + train_loader = DataLoader(train_dataset, + batch_sampler=sampler, + collate_fn=Collate( + config.data.partial_n_frames), + num_workers=16) + + self.train_dataset = train_dataset + self.train_loader = train_loader + + def train_batch(self): + start = time.time() + batch = self.read_batch() + data_loader_time = time.time() - start + + self.optimizer.clear_grad() + self.model.train() + specs = batch + loss, eer = self.model(specs, self.config.training.speakers_per_batch) + loss.backward() + self.model_core.do_gradient_ops() + self.optimizer.step() + iteration_time = time.time() - start + + # logging + loss_value = float(loss) + msg = "Rank: {}, ".format(dist.get_rank()) + msg += "step: {}, ".format(self.iteration) + msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time, + iteration_time) + msg += 'loss: {:>.6f} err: {:>.6f}'.format(loss_value, eer) + self.logger.info(msg) + + if dist.get_rank() == 0: + self.visualizer.add_scalar("train/loss", loss_value, + self.iteration) + self.visualizer.add_scalar("train/eer", eer, self.iteration) + self.visualizer.add_scalar("param/w", + float(self.model_core.similarity_weight), + self.iteration) + self.visualizer.add_scalar("param/b", + float(self.model_core.similarity_bias), + self.iteration) + + def valid(self): + pass + + +def main_sp(config, args): + exp = Ge2eExperiment(config, args) + exp.setup() + exp.run() + + +def main(config, args): + if args.nprocs > 1 and args.device == "gpu": + dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) + else: + main_sp(config, args) + + +if __name__ == "__main__": + config = get_cfg_defaults() + parser = default_argument_parser() + args = parser.parse_args() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + print(args) + + main(config, args) diff --git a/parakeet/models/lstm_speaker_encoder.py b/parakeet/models/lstm_speaker_encoder.py new file mode 100644 index 0000000..96a3667 --- /dev/null +++ b/parakeet/models/lstm_speaker_encoder.py @@ -0,0 +1,120 @@ +import numpy as np +import paddle +from paddle import nn +from paddle.fluid.param_attr import ParamAttr +from paddle.nn import functional as F +from paddle.nn import initializer as I + +from scipy.interpolate import interp1d +from sklearn.metrics import roc_curve +from scipy.optimize import brentq + + +class LSTMSpeakerEncoder(nn.Layer): + def __init__(self, n_mels, num_layers, hidden_size, output_size): + super().__init__() + self.lstm = nn.LSTM(n_mels, hidden_size, num_layers) + self.linear = nn.Linear(hidden_size, output_size) + self.similarity_weight = self.create_parameter( + [1], default_initializer=I.Constant(10.)) + self.similarity_bias = self.create_parameter( + [1], default_initializer=I.Constant(-5.)) + + def forward(self, utterances, num_speakers, initial_states=None): + normalized_embeds = self.embed_sequences(utterances, initial_states) + embeds = normalized_embeds.reshape([num_speakers, -1, num_speakers]) + loss, eer = self.loss(embeds) + return loss, eer + + def embed_sequences(self, utterances, initial_states=None, reduce=False): + out, (h, c) = self.lstm(utterances, initial_states) + embeds = F.relu(self.linear(h[-1])) + normalized_embeds = F.normalize(embeds) + if reduce: + embed = paddle.mean(normalized_embeds, 0) + embed = F.normalize(embed, axis=0) + return normalized_embeds + + def embed_utterance(self, utterances, initial_states=None): + # utterances: [B, T, C] -> embed [C'] + embed = self.embed_sequences(utterances, initial_states, reduce=True) + return embed + + def similarity_matrix(self, embeds): + # (N, M, C) + speakers_per_batch, utterances_per_speaker, embed_dim = embeds.shape + + # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation + centroids_incl = paddle.mean(embeds, axis=1) + centroids_incl_norm = paddle.norm(centroids_incl, p=2, axis=1, keepdim=True) + normalized_centroids_incl = centroids_incl / centroids_incl_norm + + # Exclusive centroids (1 per utterance) + centroids_excl = paddle.broadcast_to(paddle.sum(embeds, axis=1, keepdim=True), embeds.shape) - embeds + centroids_excl /= (utterances_per_speaker - 1) + centroids_excl_norm = paddle.norm(centroids_excl, p=2, axis=2, keepdim=True) + normalized_centroids_excl = centroids_excl / centroids_excl_norm + + p1 = paddle.matmul(embeds.reshape([-1, embed_dim]), + normalized_centroids_incl, transpose_y=True) # (NMN) + p1 = p1.reshape([-1]) + # print("p1: ", p1.shape) + p2 = paddle.bmm(embeds.reshape([-1, 1, embed_dim]), + normalized_centroids_excl.reshape([-1, embed_dim, 1])) # (NM, 1, 1) + p2 = p2.reshape([-1]) # (NM) + + # begin: alternative implementation for scatter + with paddle.no_grad(): + index = paddle.arange(0, speakers_per_batch * utterances_per_speaker, dtype="int64").reshape([speakers_per_batch, utterances_per_speaker]) + index = index * speakers_per_batch + paddle.arange(0, speakers_per_batch, dtype="int64").unsqueeze(-1) + index = paddle.reshape(index, [-1]) + ones = paddle.ones([speakers_per_batch * utterances_per_speaker * speakers_per_batch]) + zeros = paddle.zeros_like(index, dtype=ones.dtype) + mask_p1 = paddle.scatter(ones, index, zeros) + p = p1 * mask_p1 + (1 - mask_p1) * paddle.scatter(ones, index, p2) + # end: alternative implementation for scatter + # p = paddle.scatter(p1, index, p2) + + p = p * self.similarity_weight + self.similarity_bias # neg + p = p.reshape([speakers_per_batch * utterances_per_speaker, speakers_per_batch]) + return p, p1, p2 + + def do_gradient_ops(self): + for p in [self.similarity_weight, self.similarity_bias]: + g = p._grad_ivar() + g[...] = g * 0.01 + + def loss(self, embeds): + """ + Computes the softmax loss according the section 2.1 of GE2E. + + :param embeds: the embeddings as a tensor of shape (speakers_per_batch, + utterances_per_speaker, embedding_size) + :return: the loss and the EER for this batch of embeddings. + """ + speakers_per_batch, utterances_per_speaker = embeds.shape[:2] + + # Loss + sim_matrix, *_ = self.similarity_matrix(embeds) + sim_matrix = sim_matrix.reshape( + [speakers_per_batch * utterances_per_speaker, speakers_per_batch]) + target = paddle.arange(0, speakers_per_batch, dtype="int64").unsqueeze(-1) + target = paddle.expand(target, [speakers_per_batch, utterances_per_speaker]) + target = paddle.reshape(target, [-1]) + + loss = nn.CrossEntropyLoss()(sim_matrix, target) + + # EER (not backpropagated) + with paddle.no_grad(): + ground_truth = target.numpy() + inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0] + labels = np.array([inv_argmax(i) for i in ground_truth]) + preds = sim_matrix.numpy() + + # Snippet from https://yangcha.github.io/EER-ROC/ + fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten()) + eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.) + + return loss, eer + + diff --git a/parakeet/modules/g2e2_loss.py b/parakeet/modules/g2e2_loss.py new file mode 100644 index 0000000..b694dec --- /dev/null +++ b/parakeet/modules/g2e2_loss.py @@ -0,0 +1,77 @@ +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + + +class GE2ELoss(nn.Layer): + def __init__(self, init_w=10., init_b=-5.): + super().__init__() + self.w = self.create_parameter([1], attr=I.Constant(init_w)) + self.b = self.create_parameter([1], attr=I.Constant(init_b)) + + def forward(self, embeds): + # embeds [N, M, C] + # N - speakers_per_batch, + # M - utterances_per_speaker + # C - embed_dim + sim_matrix = self._build_sim_matrix(embeds) + _, M, N = sim_matrix.shape + target = paddle.arange(0, N, dtype="int64").unsqueeze(-1) + target = paddle.expand(target, [N, M]) + target = paddle.reshape(target, [-1]) # [NM] + + criterion = nn.CrossEntropyLoss() + loss = criterion(sim_matrix.reshape([-1, N]), target) + return loss + + def _build_sim_matrix(self, embeds): + N, M, C = embeds.shape + + # Inclusive centroids (1 per speaker). [N, C] + centroids_incl = paddle.mean(embeds, axis=1) + centroids_incl_norm = paddle.norm(centroids_incl, + p=2, + axis=1, + keepdim=True) + normalized_centroids_incl = centroids_incl / centroids_incl_norm + + # Exclusive centroids (1 per utterance) [N, M, C] + centroids_excl = paddle.broadcast_to( + paddle.sum(embeds, axis=1, keepdim=True), embeds.shape) - embeds + centroids_excl /= (M - 1) + centroids_excl_norm = paddle.norm(centroids_excl, + p=2, + axis=2, + keepdim=True) + normalized_centroids_excl = centroids_excl / centroids_excl_norm + + # inter-speaker similarity, NM embeds ~ N centroids + # [NM, N] + p1 = paddle.matmul(embeds.reshape([-1, C]), + normalized_centroids_incl, + transpose_y=True) + p1 = p1.reshape([-1]) # [NMN] + + # intra-similarity, NM embeds, 1 centroid per embed + p2 = paddle.bmm(embeds.reshape([-1, 1, C]), + normalized_centroids_excl.reshape([-1, C, + 1])) # (NM, 1, 1) + p2 = p2.reshape([-1]) # [NM] + + with paddle.no_grad(): + index = paddle.arange(0, N * M, dtype="int64").reshape([N, M]) + index = index * N + paddle.arange(0, N, + dtype="int64").unsqueeze(-1) + index = paddle.reshape(index, [-1]) + # begin: alternative implementation for scatter + ones = paddle.ones([N * M * N]) + zeros = paddle.zeros_like(index, dtype=ones.dtype) + mask_p1 = paddle.scatter(ones, index, zeros) + p = p1 * mask_p1 + (1 - mask_p1) * paddle.scatter(ones, index, p2) + # end: alternative implementation for scatter + # p = paddle.scatter(p1, index, p2) there is a backward bug in scatter + + p = p * self.w + self.b + p = p.reshape([N, M, N]) + return p diff --git a/parakeet/training/experiment.py b/parakeet/training/experiment.py index bc08f4e..420df46 100644 --- a/parakeet/training/experiment.py +++ b/parakeet/training/experiment.py @@ -165,7 +165,7 @@ class ExperimentBase(object): """Reset the train loader and increment ``epoch``. """ self.epoch += 1 - if self.parallel: + if self.parallel and isinstance(self.train_loader.batch_sampler, DistributedBatchSampler): self.train_loader.batch_sampler.set_epoch(self.epoch) self.iterator = iter(self.train_loader)