Parakeet/parakeet/modules/stft_loss.py

139 lines
4.6 KiB
Python

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
from paddle.nn import functional as F
from parakeet.modules.audio import STFT
class SpectralConvergenceLoss(nn.Layer):
"""Spectral convergence loss module."""
def __init__(self):
"""Initilize spectral convergence loss module."""
super().__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, C, T).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, C, T).
Returns:
Tensor: Spectral convergence loss value.
"""
return paddle.norm(
y_mag - x_mag, p="fro") / paddle.norm(
y_mag, p="fro")
class LogSTFTMagnitudeLoss(nn.Layer):
"""Log STFT magnitude loss module."""
def __init__(self):
"""Initilize los STFT magnitude loss module."""
super().__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Log STFT magnitude loss value.
"""
return F.l1_loss(paddle.log(y_mag), paddle.log(x_mag))
class STFTLoss(nn.Layer):
"""STFT loss module."""
def __init__(self,
fft_size=1024,
shift_size=120,
win_length=600,
window="hann"):
"""Initialize STFT loss module."""
super().__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.stft = STFT(
n_fft=fft_size,
hop_length=shift_size,
win_length=win_length,
window=window)
self.spectral_convergence_loss = SpectralConvergenceLoss()
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Spectral convergence loss value.
Tensor: Log STFT magnitude loss value.
"""
x_mag = self.stft.magnitude(x)
y_mag = self.stft.magnitude(y)
sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
return sc_loss, mag_loss
class MultiResolutionSTFTLoss(nn.Layer):
"""Multi resolution STFT loss module."""
def __init__(
self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
window="hann", ):
"""Initialize Multi resolution STFT loss module.
Args:
fft_sizes (list): List of FFT sizes.
hop_sizes (list): List of hop sizes.
win_lengths (list): List of window lengths.
window (str): Window function type.
"""
super().__init__()
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
self.stft_losses = nn.LayerList()
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
self.stft_losses.append(STFTLoss(fs, ss, wl, window))
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Multi resolution spectral convergence loss value.
Tensor: Multi resolution log STFT magnitude loss value.
"""
sc_loss = 0.0
mag_loss = 0.0
for f in self.stft_losses:
sc_l, mag_l = f(x, y)
sc_loss += sc_l
mag_loss += mag_l
sc_loss /= len(self.stft_losses)
mag_loss /= len(self.stft_losses)
return sc_loss, mag_loss