ParakeetRebeccaRosario/parakeet/utils/display.py

111 lines
3.3 KiB
Python

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import librosa.display
import matplotlib.pylab as plt
__all__ = [
"plot_alignment",
"plot_spectrogram",
"plot_waveform",
"plot_multihead_alignments",
"plot_multilayer_multihead_alignments",
]
def plot_alignment(alignment, title=None):
# alignment: [encoder_steps, decoder_steps)
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(
alignment, aspect='auto', origin='lower', interpolation='none')
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
if title is not None:
xlabel += '\n\n' + title
plt.xlabel(xlabel)
plt.ylabel('Encoder timestep')
plt.tight_layout()
return fig
def plot_multihead_alignments(alignments, title=None):
# alignments: [N, encoder_steps, decoder_steps)
num_subplots = alignments.shape[0]
fig, axes = plt.subplots(
figsize=(6 * num_subplots, 4),
ncols=num_subplots,
sharey=True,
squeeze=True)
for i, ax in enumerate(axes):
im = ax.imshow(
alignments[i], aspect='auto', origin='lower', interpolation='none')
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
if title is not None:
xlabel += '\n\n' + title
ax.set_xlabel(xlabel)
if i == 0:
ax.set_ylabel('Encoder timestep')
plt.tight_layout()
return fig
def plot_multilayer_multihead_alignments(alignments, title=None):
# alignments: [num_layers, num_heads, encoder_steps, decoder_steps)
num_layers, num_heads, *_ = alignments.shape
fig, axes = plt.subplots(
figsize=(6 * num_heads, 4 * num_layers),
nrows=num_layers,
ncols=num_heads,
sharex=True,
sharey=True,
squeeze=True)
for i, row in enumerate(axes):
for j, ax in enumerate(row):
im = ax.imshow(
alignments[i, j],
aspect='auto',
origin='lower',
interpolation='none')
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
if title is not None:
xlabel += '\n\n' + title
if i == num_layers - 1:
ax.set_xlabel(xlabel)
if j == 0:
ax.set_ylabel('Encoder timestep')
plt.tight_layout()
return fig
def plot_spectrogram(spec):
# spec: [C, T] librosa convention
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(spec, aspect="auto", origin="lower", interpolation='none')
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
return fig
def plot_waveform(wav, sr=22050):
fig, ax = plt.subplots(figsize=(12, 3))
im = librosa.display.waveplot(wav, sr=22050)
plt.colorbar(im, ax=ax)
plt.tight_layout()
return fig