2021-05-13 16:22:56 +08:00
|
|
|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
2021-03-27 17:39:37 +08:00
|
|
|
|
import numpy as np
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle import nn
|
|
|
|
|
from paddle.fluid.param_attr import ParamAttr
|
|
|
|
|
from paddle.nn import functional as F
|
|
|
|
|
from paddle.nn import initializer as I
|
|
|
|
|
|
|
|
|
|
from scipy.interpolate import interp1d
|
|
|
|
|
from sklearn.metrics import roc_curve
|
|
|
|
|
from scipy.optimize import brentq
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LSTMSpeakerEncoder(nn.Layer):
|
|
|
|
|
def __init__(self, n_mels, num_layers, hidden_size, output_size):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.lstm = nn.LSTM(n_mels, hidden_size, num_layers)
|
|
|
|
|
self.linear = nn.Linear(hidden_size, output_size)
|
|
|
|
|
self.similarity_weight = self.create_parameter(
|
|
|
|
|
[1], default_initializer=I.Constant(10.))
|
|
|
|
|
self.similarity_bias = self.create_parameter(
|
|
|
|
|
[1], default_initializer=I.Constant(-5.))
|
|
|
|
|
|
|
|
|
|
def forward(self, utterances, num_speakers, initial_states=None):
|
|
|
|
|
normalized_embeds = self.embed_sequences(utterances, initial_states)
|
|
|
|
|
embeds = normalized_embeds.reshape([num_speakers, -1, num_speakers])
|
2021-05-13 16:22:56 +08:00
|
|
|
|
loss, eer = self.loss(embeds)
|
2021-03-27 17:39:37 +08:00
|
|
|
|
return loss, eer
|
2021-05-13 16:22:56 +08:00
|
|
|
|
|
2021-03-27 17:39:37 +08:00
|
|
|
|
def embed_sequences(self, utterances, initial_states=None, reduce=False):
|
|
|
|
|
out, (h, c) = self.lstm(utterances, initial_states)
|
|
|
|
|
embeds = F.relu(self.linear(h[-1]))
|
|
|
|
|
normalized_embeds = F.normalize(embeds)
|
|
|
|
|
if reduce:
|
|
|
|
|
embed = paddle.mean(normalized_embeds, 0)
|
2021-03-30 14:38:44 +08:00
|
|
|
|
embed = F.normalize(embed, axis=0)
|
|
|
|
|
return embed
|
|
|
|
|
return normalized_embeds
|
2021-05-13 16:22:56 +08:00
|
|
|
|
|
2021-03-27 17:39:37 +08:00
|
|
|
|
def embed_utterance(self, utterances, initial_states=None):
|
|
|
|
|
# utterances: [B, T, C] -> embed [C']
|
|
|
|
|
embed = self.embed_sequences(utterances, initial_states, reduce=True)
|
|
|
|
|
return embed
|
|
|
|
|
|
|
|
|
|
def similarity_matrix(self, embeds):
|
|
|
|
|
# (N, M, C)
|
|
|
|
|
speakers_per_batch, utterances_per_speaker, embed_dim = embeds.shape
|
|
|
|
|
|
|
|
|
|
# Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
|
|
|
|
|
centroids_incl = paddle.mean(embeds, axis=1)
|
2021-05-13 16:22:56 +08:00
|
|
|
|
centroids_incl_norm = paddle.norm(
|
|
|
|
|
centroids_incl, p=2, axis=1, keepdim=True)
|
2021-03-27 17:39:37 +08:00
|
|
|
|
normalized_centroids_incl = centroids_incl / centroids_incl_norm
|
|
|
|
|
|
|
|
|
|
# Exclusive centroids (1 per utterance)
|
2021-05-13 16:22:56 +08:00
|
|
|
|
centroids_excl = paddle.broadcast_to(
|
|
|
|
|
paddle.sum(embeds, axis=1, keepdim=True), embeds.shape) - embeds
|
2021-03-27 17:39:37 +08:00
|
|
|
|
centroids_excl /= (utterances_per_speaker - 1)
|
2021-05-13 16:22:56 +08:00
|
|
|
|
centroids_excl_norm = paddle.norm(
|
|
|
|
|
centroids_excl, p=2, axis=2, keepdim=True)
|
2021-03-27 17:39:37 +08:00
|
|
|
|
normalized_centroids_excl = centroids_excl / centroids_excl_norm
|
|
|
|
|
|
2021-05-13 16:22:56 +08:00
|
|
|
|
p1 = paddle.matmul(
|
|
|
|
|
embeds.reshape([-1, embed_dim]),
|
|
|
|
|
normalized_centroids_incl,
|
|
|
|
|
transpose_y=True) # (NMN)
|
2021-03-27 17:39:37 +08:00
|
|
|
|
p1 = p1.reshape([-1])
|
|
|
|
|
# print("p1: ", p1.shape)
|
2021-05-13 16:22:56 +08:00
|
|
|
|
p2 = paddle.bmm(
|
|
|
|
|
embeds.reshape([-1, 1, embed_dim]),
|
|
|
|
|
normalized_centroids_excl.reshape(
|
|
|
|
|
[-1, embed_dim, 1])) # (NM, 1, 1)
|
|
|
|
|
p2 = p2.reshape([-1]) # (NM)
|
2021-03-27 17:39:37 +08:00
|
|
|
|
|
|
|
|
|
# begin: alternative implementation for scatter
|
|
|
|
|
with paddle.no_grad():
|
2021-05-13 16:22:56 +08:00
|
|
|
|
index = paddle.arange(
|
|
|
|
|
0, speakers_per_batch * utterances_per_speaker,
|
|
|
|
|
dtype="int64").reshape(
|
|
|
|
|
[speakers_per_batch, utterances_per_speaker])
|
|
|
|
|
index = index * speakers_per_batch + paddle.arange(
|
|
|
|
|
0, speakers_per_batch, dtype="int64").unsqueeze(-1)
|
2021-03-27 17:39:37 +08:00
|
|
|
|
index = paddle.reshape(index, [-1])
|
2021-05-13 16:22:56 +08:00
|
|
|
|
ones = paddle.ones([
|
|
|
|
|
speakers_per_batch * utterances_per_speaker * speakers_per_batch
|
|
|
|
|
])
|
2021-03-27 17:39:37 +08:00
|
|
|
|
zeros = paddle.zeros_like(index, dtype=ones.dtype)
|
|
|
|
|
mask_p1 = paddle.scatter(ones, index, zeros)
|
|
|
|
|
p = p1 * mask_p1 + (1 - mask_p1) * paddle.scatter(ones, index, p2)
|
|
|
|
|
# end: alternative implementation for scatter
|
|
|
|
|
# p = paddle.scatter(p1, index, p2)
|
|
|
|
|
|
2021-05-13 16:22:56 +08:00
|
|
|
|
p = p * self.similarity_weight + self.similarity_bias # neg
|
|
|
|
|
p = p.reshape(
|
|
|
|
|
[speakers_per_batch * utterances_per_speaker, speakers_per_batch])
|
2021-03-27 17:39:37 +08:00
|
|
|
|
return p, p1, p2
|
|
|
|
|
|
|
|
|
|
def do_gradient_ops(self):
|
|
|
|
|
for p in [self.similarity_weight, self.similarity_bias]:
|
|
|
|
|
g = p._grad_ivar()
|
|
|
|
|
g[...] = g * 0.01
|
|
|
|
|
|
|
|
|
|
def loss(self, embeds):
|
|
|
|
|
"""
|
|
|
|
|
Computes the softmax loss according the section 2.1 of GE2E.
|
|
|
|
|
|
|
|
|
|
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
|
|
|
|
utterances_per_speaker, embedding_size)
|
|
|
|
|
:return: the loss and the EER for this batch of embeddings.
|
|
|
|
|
"""
|
|
|
|
|
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
|
|
|
|
|
|
|
|
|
# Loss
|
|
|
|
|
sim_matrix, *_ = self.similarity_matrix(embeds)
|
|
|
|
|
sim_matrix = sim_matrix.reshape(
|
|
|
|
|
[speakers_per_batch * utterances_per_speaker, speakers_per_batch])
|
2021-05-13 16:22:56 +08:00
|
|
|
|
target = paddle.arange(
|
|
|
|
|
0, speakers_per_batch, dtype="int64").unsqueeze(-1)
|
|
|
|
|
target = paddle.expand(target,
|
|
|
|
|
[speakers_per_batch, utterances_per_speaker])
|
2021-03-27 17:39:37 +08:00
|
|
|
|
target = paddle.reshape(target, [-1])
|
|
|
|
|
|
|
|
|
|
loss = nn.CrossEntropyLoss()(sim_matrix, target)
|
|
|
|
|
|
|
|
|
|
# EER (not backpropagated)
|
|
|
|
|
with paddle.no_grad():
|
|
|
|
|
ground_truth = target.numpy()
|
|
|
|
|
inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
|
|
|
|
|
labels = np.array([inv_argmax(i) for i in ground_truth])
|
|
|
|
|
preds = sim_matrix.numpy()
|
|
|
|
|
|
|
|
|
|
# Snippet from https://yangcha.github.io/EER-ROC/
|
2021-05-13 16:22:56 +08:00
|
|
|
|
fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
|
2021-03-27 17:39:37 +08:00
|
|
|
|
eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
|
|
|
|
|
|
|
|
|
|
return loss, eer
|