diff --git a/examples/ge2e/dataset_processors.py b/examples/ge2e/dataset_processors.py index d892003..044097f 100644 --- a/examples/ge2e/dataset_processors.py +++ b/examples/ge2e/dataset_processors.py @@ -137,3 +137,26 @@ def process_voxceleb2(processor, speaker_dirs = list((dataset_root / "wav").glob("*")) _process_dataset(processor, datasets_root, speaker_dirs, dataset_name, output_dir, "*.wav", skip_existing) + +def process_aidatatang_200zh(processor, + datasets_root, + output_dir, + skip_existing=False): + dataset_name = "aidatatang_200zh/train" + dataset_root = datasets_root / dataset_name + + speaker_dirs = list((dataset_root).glob("*")) + _process_dataset(processor, datasets_root, speaker_dirs, dataset_name, + output_dir, "*.wav", skip_existing) + + +def process_magicdata(processor, + datasets_root, + output_dir, + skip_existing=False): + dataset_name = "magicdata/train" + dataset_root = datasets_root / dataset_name + + speaker_dirs = list((dataset_root).glob("*")) + _process_dataset(processor, datasets_root, speaker_dirs, dataset_name, + output_dir, "*.wav", skip_existing) diff --git a/examples/ge2e/preprocess.py b/examples/ge2e/preprocess.py index a601715..b5e2b0f 100644 --- a/examples/ge2e/preprocess.py +++ b/examples/ge2e/preprocess.py @@ -2,7 +2,7 @@ import argparse from pathlib import Path from config import get_cfg_defaults from audio_processor import SpeakerVerificationPreprocessor -from dataset_processors import process_librispeech, process_voxceleb1, process_voxceleb2 +from dataset_processors import process_librispeech, process_voxceleb1, process_voxceleb2, process_aidatatang_200zh, process_magicdata if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -23,7 +23,7 @@ if __name__ == "__main__": help= "comma-separated list of names of the datasets you want to preprocess. only " "the train set of these datastes will be used. Possible names: librispeech_other, " - "voxceleb1, voxceleb2.") + "voxceleb1, voxceleb2, aidatatang_200zh, magicdata.") parser.add_argument( "--skip_existing", action="store_true", @@ -79,6 +79,8 @@ if __name__ == "__main__": "librispeech_other": process_librispeech, "voxceleb1": process_voxceleb1, "voxceleb2": process_voxceleb2, + "aidatatang_200zh": process_aidatatang_200zh, + "magicdata": process_magicdata, } for dataset in args.datasets: diff --git a/parakeet/models/lstm_speaker_encoder.py b/parakeet/models/lstm_speaker_encoder.py index 0d3f285..82193b7 100644 --- a/parakeet/models/lstm_speaker_encoder.py +++ b/parakeet/models/lstm_speaker_encoder.py @@ -32,8 +32,9 @@ class LSTMSpeakerEncoder(nn.Layer): normalized_embeds = F.normalize(embeds) if reduce: embed = paddle.mean(normalized_embeds, 0) - embed = F.normalize(embed, axis=0) - return embed + embed = F.normalize(embed, axis=0) + return embed + return normalized_embeds def embed_utterance(self, utterances, initial_states=None): # utterances: [B, T, C] -> embed [C']