ParakeetRebeccaRosario/parakeet/models/lstm_speaker_encoder.py

150 lines
6.1 KiB
Python
Raw Normal View History

2021-05-13 16:22:56 +08:00
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2021-03-27 17:39:37 +08:00
import numpy as np
import paddle
from paddle import nn
from paddle.fluid.param_attr import ParamAttr
from paddle.nn import functional as F
from paddle.nn import initializer as I
from scipy.interpolate import interp1d
from sklearn.metrics import roc_curve
from scipy.optimize import brentq
class LSTMSpeakerEncoder(nn.Layer):
def __init__(self, n_mels, num_layers, hidden_size, output_size):
super().__init__()
self.lstm = nn.LSTM(n_mels, hidden_size, num_layers)
self.linear = nn.Linear(hidden_size, output_size)
self.similarity_weight = self.create_parameter(
[1], default_initializer=I.Constant(10.))
self.similarity_bias = self.create_parameter(
[1], default_initializer=I.Constant(-5.))
def forward(self, utterances, num_speakers, initial_states=None):
normalized_embeds = self.embed_sequences(utterances, initial_states)
embeds = normalized_embeds.reshape([num_speakers, -1, num_speakers])
2021-05-13 16:22:56 +08:00
loss, eer = self.loss(embeds)
2021-03-27 17:39:37 +08:00
return loss, eer
2021-05-13 16:22:56 +08:00
2021-03-27 17:39:37 +08:00
def embed_sequences(self, utterances, initial_states=None, reduce=False):
out, (h, c) = self.lstm(utterances, initial_states)
embeds = F.relu(self.linear(h[-1]))
normalized_embeds = F.normalize(embeds)
if reduce:
embed = paddle.mean(normalized_embeds, 0)
embed = F.normalize(embed, axis=0)
return embed
return normalized_embeds
2021-05-13 16:22:56 +08:00
2021-03-27 17:39:37 +08:00
def embed_utterance(self, utterances, initial_states=None):
# utterances: [B, T, C] -> embed [C']
embed = self.embed_sequences(utterances, initial_states, reduce=True)
return embed
def similarity_matrix(self, embeds):
# (N, M, C)
speakers_per_batch, utterances_per_speaker, embed_dim = embeds.shape
# Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
centroids_incl = paddle.mean(embeds, axis=1)
2021-05-13 16:22:56 +08:00
centroids_incl_norm = paddle.norm(
centroids_incl, p=2, axis=1, keepdim=True)
2021-03-27 17:39:37 +08:00
normalized_centroids_incl = centroids_incl / centroids_incl_norm
# Exclusive centroids (1 per utterance)
2021-05-13 16:22:56 +08:00
centroids_excl = paddle.broadcast_to(
paddle.sum(embeds, axis=1, keepdim=True), embeds.shape) - embeds
2021-03-27 17:39:37 +08:00
centroids_excl /= (utterances_per_speaker - 1)
2021-05-13 16:22:56 +08:00
centroids_excl_norm = paddle.norm(
centroids_excl, p=2, axis=2, keepdim=True)
2021-03-27 17:39:37 +08:00
normalized_centroids_excl = centroids_excl / centroids_excl_norm
2021-05-13 16:22:56 +08:00
p1 = paddle.matmul(
embeds.reshape([-1, embed_dim]),
normalized_centroids_incl,
transpose_y=True) # (NMN)
2021-03-27 17:39:37 +08:00
p1 = p1.reshape([-1])
# print("p1: ", p1.shape)
2021-05-13 16:22:56 +08:00
p2 = paddle.bmm(
embeds.reshape([-1, 1, embed_dim]),
normalized_centroids_excl.reshape(
[-1, embed_dim, 1])) # (NM, 1, 1)
p2 = p2.reshape([-1]) # NM)
2021-03-27 17:39:37 +08:00
# begin: alternative implementation for scatter
with paddle.no_grad():
2021-05-13 16:22:56 +08:00
index = paddle.arange(
0, speakers_per_batch * utterances_per_speaker,
dtype="int64").reshape(
[speakers_per_batch, utterances_per_speaker])
index = index * speakers_per_batch + paddle.arange(
0, speakers_per_batch, dtype="int64").unsqueeze(-1)
2021-03-27 17:39:37 +08:00
index = paddle.reshape(index, [-1])
2021-05-13 16:22:56 +08:00
ones = paddle.ones([
speakers_per_batch * utterances_per_speaker * speakers_per_batch
])
2021-03-27 17:39:37 +08:00
zeros = paddle.zeros_like(index, dtype=ones.dtype)
mask_p1 = paddle.scatter(ones, index, zeros)
p = p1 * mask_p1 + (1 - mask_p1) * paddle.scatter(ones, index, p2)
# end: alternative implementation for scatter
# p = paddle.scatter(p1, index, p2)
2021-05-13 16:22:56 +08:00
p = p * self.similarity_weight + self.similarity_bias # neg
p = p.reshape(
[speakers_per_batch * utterances_per_speaker, speakers_per_batch])
2021-03-27 17:39:37 +08:00
return p, p1, p2
def do_gradient_ops(self):
for p in [self.similarity_weight, self.similarity_bias]:
g = p._grad_ivar()
g[...] = g * 0.01
def loss(self, embeds):
"""
Computes the softmax loss according the section 2.1 of GE2E.
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
utterances_per_speaker, embedding_size)
:return: the loss and the EER for this batch of embeddings.
"""
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
# Loss
sim_matrix, *_ = self.similarity_matrix(embeds)
sim_matrix = sim_matrix.reshape(
[speakers_per_batch * utterances_per_speaker, speakers_per_batch])
2021-05-13 16:22:56 +08:00
target = paddle.arange(
0, speakers_per_batch, dtype="int64").unsqueeze(-1)
target = paddle.expand(target,
[speakers_per_batch, utterances_per_speaker])
2021-03-27 17:39:37 +08:00
target = paddle.reshape(target, [-1])
loss = nn.CrossEntropyLoss()(sim_matrix, target)
# EER (not backpropagated)
with paddle.no_grad():
ground_truth = target.numpy()
inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
labels = np.array([inv_argmax(i) for i in ground_truth])
preds = sim_matrix.numpy()
# Snippet from https://yangcha.github.io/EER-ROC/
2021-05-13 16:22:56 +08:00
fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
2021-03-27 17:39:37 +08:00
eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
return loss, eer