122 lines
5.3 KiB
Python
122 lines
5.3 KiB
Python
import numpy as np
|
||
import paddle
|
||
from paddle import nn
|
||
from paddle.fluid.param_attr import ParamAttr
|
||
from paddle.nn import functional as F
|
||
from paddle.nn import initializer as I
|
||
|
||
from scipy.interpolate import interp1d
|
||
from sklearn.metrics import roc_curve
|
||
from scipy.optimize import brentq
|
||
|
||
|
||
class LSTMSpeakerEncoder(nn.Layer):
|
||
def __init__(self, n_mels, num_layers, hidden_size, output_size):
|
||
super().__init__()
|
||
self.lstm = nn.LSTM(n_mels, hidden_size, num_layers)
|
||
self.linear = nn.Linear(hidden_size, output_size)
|
||
self.similarity_weight = self.create_parameter(
|
||
[1], default_initializer=I.Constant(10.))
|
||
self.similarity_bias = self.create_parameter(
|
||
[1], default_initializer=I.Constant(-5.))
|
||
|
||
def forward(self, utterances, num_speakers, initial_states=None):
|
||
normalized_embeds = self.embed_sequences(utterances, initial_states)
|
||
embeds = normalized_embeds.reshape([num_speakers, -1, num_speakers])
|
||
loss, eer = self.loss(embeds)
|
||
return loss, eer
|
||
|
||
def embed_sequences(self, utterances, initial_states=None, reduce=False):
|
||
out, (h, c) = self.lstm(utterances, initial_states)
|
||
embeds = F.relu(self.linear(h[-1]))
|
||
normalized_embeds = F.normalize(embeds)
|
||
if reduce:
|
||
embed = paddle.mean(normalized_embeds, 0)
|
||
embed = F.normalize(embed, axis=0)
|
||
return embed
|
||
return normalized_embeds
|
||
|
||
def embed_utterance(self, utterances, initial_states=None):
|
||
# utterances: [B, T, C] -> embed [C']
|
||
embed = self.embed_sequences(utterances, initial_states, reduce=True)
|
||
return embed
|
||
|
||
def similarity_matrix(self, embeds):
|
||
# (N, M, C)
|
||
speakers_per_batch, utterances_per_speaker, embed_dim = embeds.shape
|
||
|
||
# Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
|
||
centroids_incl = paddle.mean(embeds, axis=1)
|
||
centroids_incl_norm = paddle.norm(centroids_incl, p=2, axis=1, keepdim=True)
|
||
normalized_centroids_incl = centroids_incl / centroids_incl_norm
|
||
|
||
# Exclusive centroids (1 per utterance)
|
||
centroids_excl = paddle.broadcast_to(paddle.sum(embeds, axis=1, keepdim=True), embeds.shape) - embeds
|
||
centroids_excl /= (utterances_per_speaker - 1)
|
||
centroids_excl_norm = paddle.norm(centroids_excl, p=2, axis=2, keepdim=True)
|
||
normalized_centroids_excl = centroids_excl / centroids_excl_norm
|
||
|
||
p1 = paddle.matmul(embeds.reshape([-1, embed_dim]),
|
||
normalized_centroids_incl, transpose_y=True) # (NMN)
|
||
p1 = p1.reshape([-1])
|
||
# print("p1: ", p1.shape)
|
||
p2 = paddle.bmm(embeds.reshape([-1, 1, embed_dim]),
|
||
normalized_centroids_excl.reshape([-1, embed_dim, 1])) # (NM, 1, 1)
|
||
p2 = p2.reshape([-1]) # (NM)
|
||
|
||
# begin: alternative implementation for scatter
|
||
with paddle.no_grad():
|
||
index = paddle.arange(0, speakers_per_batch * utterances_per_speaker, dtype="int64").reshape([speakers_per_batch, utterances_per_speaker])
|
||
index = index * speakers_per_batch + paddle.arange(0, speakers_per_batch, dtype="int64").unsqueeze(-1)
|
||
index = paddle.reshape(index, [-1])
|
||
ones = paddle.ones([speakers_per_batch * utterances_per_speaker * speakers_per_batch])
|
||
zeros = paddle.zeros_like(index, dtype=ones.dtype)
|
||
mask_p1 = paddle.scatter(ones, index, zeros)
|
||
p = p1 * mask_p1 + (1 - mask_p1) * paddle.scatter(ones, index, p2)
|
||
# end: alternative implementation for scatter
|
||
# p = paddle.scatter(p1, index, p2)
|
||
|
||
p = p * self.similarity_weight + self.similarity_bias # neg
|
||
p = p.reshape([speakers_per_batch * utterances_per_speaker, speakers_per_batch])
|
||
return p, p1, p2
|
||
|
||
def do_gradient_ops(self):
|
||
for p in [self.similarity_weight, self.similarity_bias]:
|
||
g = p._grad_ivar()
|
||
g[...] = g * 0.01
|
||
|
||
def loss(self, embeds):
|
||
"""
|
||
Computes the softmax loss according the section 2.1 of GE2E.
|
||
|
||
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
||
utterances_per_speaker, embedding_size)
|
||
:return: the loss and the EER for this batch of embeddings.
|
||
"""
|
||
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
||
|
||
# Loss
|
||
sim_matrix, *_ = self.similarity_matrix(embeds)
|
||
sim_matrix = sim_matrix.reshape(
|
||
[speakers_per_batch * utterances_per_speaker, speakers_per_batch])
|
||
target = paddle.arange(0, speakers_per_batch, dtype="int64").unsqueeze(-1)
|
||
target = paddle.expand(target, [speakers_per_batch, utterances_per_speaker])
|
||
target = paddle.reshape(target, [-1])
|
||
|
||
loss = nn.CrossEntropyLoss()(sim_matrix, target)
|
||
|
||
# EER (not backpropagated)
|
||
with paddle.no_grad():
|
||
ground_truth = target.numpy()
|
||
inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
|
||
labels = np.array([inv_argmax(i) for i in ground_truth])
|
||
preds = sim_matrix.numpy()
|
||
|
||
# Snippet from https://yangcha.github.io/EER-ROC/
|
||
fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
|
||
eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
|
||
|
||
return loss, eer
|
||
|
||
|