# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import paddle from paddle import nn from paddle.nn import functional as F from paddle.nn import initializer as I from scipy.interpolate import interp1d from scipy.optimize import brentq from sklearn.metrics import roc_curve class LSTMSpeakerEncoder(nn.Layer): def __init__(self, n_mels, num_layers, hidden_size, output_size): super().__init__() self.lstm = nn.LSTM(n_mels, hidden_size, num_layers) self.linear = nn.Linear(hidden_size, output_size) self.similarity_weight = self.create_parameter( [1], default_initializer=I.Constant(10.)) self.similarity_bias = self.create_parameter( [1], default_initializer=I.Constant(-5.)) def forward(self, utterances, num_speakers, initial_states=None): normalized_embeds = self.embed_sequences(utterances, initial_states) embeds = normalized_embeds.reshape([num_speakers, -1, num_speakers]) loss, eer = self.loss(embeds) return loss, eer def embed_sequences(self, utterances, initial_states=None, reduce=False): out, (h, c) = self.lstm(utterances, initial_states) embeds = F.relu(self.linear(h[-1])) normalized_embeds = F.normalize(embeds) if reduce: embed = paddle.mean(normalized_embeds, 0) embed = F.normalize(embed, axis=0) return embed return normalized_embeds def embed_utterance(self, utterances, initial_states=None): # utterances: [B, T, C] -> embed [C'] embed = self.embed_sequences(utterances, initial_states, reduce=True) return embed def similarity_matrix(self, embeds): # (N, M, C) speakers_per_batch, utterances_per_speaker, embed_dim = embeds.shape # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation centroids_incl = paddle.mean(embeds, axis=1) centroids_incl_norm = paddle.norm( centroids_incl, p=2, axis=1, keepdim=True) normalized_centroids_incl = centroids_incl / centroids_incl_norm # Exclusive centroids (1 per utterance) centroids_excl = paddle.broadcast_to( paddle.sum(embeds, axis=1, keepdim=True), embeds.shape) - embeds centroids_excl /= (utterances_per_speaker - 1) centroids_excl_norm = paddle.norm( centroids_excl, p=2, axis=2, keepdim=True) normalized_centroids_excl = centroids_excl / centroids_excl_norm p1 = paddle.matmul( embeds.reshape([-1, embed_dim]), normalized_centroids_incl, transpose_y=True) # (NMN) p1 = p1.reshape([-1]) # print("p1: ", p1.shape) p2 = paddle.bmm( embeds.reshape([-1, 1, embed_dim]), normalized_centroids_excl.reshape([-1, embed_dim, 1])) # (NM, 1, 1) p2 = p2.reshape([-1]) # (NM) # begin: alternative implementation for scatter with paddle.no_grad(): index = paddle.arange( 0, speakers_per_batch * utterances_per_speaker, dtype="int64").reshape( [speakers_per_batch, utterances_per_speaker]) index = index * speakers_per_batch + paddle.arange( 0, speakers_per_batch, dtype="int64").unsqueeze(-1) index = paddle.reshape(index, [-1]) ones = paddle.ones( [speakers_per_batch * utterances_per_speaker * speakers_per_batch]) zeros = paddle.zeros_like(index, dtype=ones.dtype) mask_p1 = paddle.scatter(ones, index, zeros) p = p1 * mask_p1 + (1 - mask_p1) * paddle.scatter(ones, index, p2) # end: alternative implementation for scatter # p = paddle.scatter(p1, index, p2) p = p * self.similarity_weight + self.similarity_bias # neg p = p.reshape( [speakers_per_batch * utterances_per_speaker, speakers_per_batch]) return p, p1, p2 def do_gradient_ops(self): for p in [self.similarity_weight, self.similarity_bias]: g = p._grad_ivar() g[...] = g * 0.01 def inv_argmax(self, i, num): return np.eye(1, num, i, dtype=np.int)[0] def loss(self, embeds): """ Computes the softmax loss according the section 2.1 of GE2E. :param embeds: the embeddings as a tensor of shape (speakers_per_batch, utterances_per_speaker, embedding_size) :return: the loss and the EER for this batch of embeddings. """ speakers_per_batch, utterances_per_speaker = embeds.shape[:2] # Loss sim_matrix, *_ = self.similarity_matrix(embeds) sim_matrix = sim_matrix.reshape( [speakers_per_batch * utterances_per_speaker, speakers_per_batch]) target = paddle.arange( 0, speakers_per_batch, dtype="int64").unsqueeze(-1) target = paddle.expand(target, [speakers_per_batch, utterances_per_speaker]) target = paddle.reshape(target, [-1]) loss = nn.CrossEntropyLoss()(sim_matrix, target) # EER (not backpropagated) with paddle.no_grad(): ground_truth = target.numpy() labels = np.array( [self.inv_argmax(i, speakers_per_batch) for i in ground_truth]) preds = sim_matrix.numpy() # Snippet from https://yangcha.github.io/EER-ROC/ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten()) eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.) return loss, eer