Merge pull request #55 from lfchener/reborn
fix EnglishCharacter frontend and add spectrogram plots
This commit is contained in:
commit
b82217f50f
|
@ -29,4 +29,4 @@ def normalize(sentence):
|
||||||
sentence = re.sub(r"[^ a-z'.,?!\-]", "", sentence)
|
sentence = re.sub(r"[^ a-z'.,?!\-]", "", sentence)
|
||||||
sentence = sentence.replace("i.e.", "that is")
|
sentence = sentence.replace("i.e.", "that is")
|
||||||
sentence = sentence.replace("e.g.", "for example")
|
sentence = sentence.replace("e.g.", "for example")
|
||||||
return sentence.split()
|
return sentence
|
||||||
|
|
|
@ -79,23 +79,14 @@ class EnglishCharacter(Phonetics):
|
||||||
self.vocab = Vocab(self.graphemes + self.punctuations)
|
self.vocab = Vocab(self.graphemes + self.punctuations)
|
||||||
|
|
||||||
def phoneticize(self, sentence):
|
def phoneticize(self, sentence):
|
||||||
start = self.vocab.start_symbol
|
words = normalize(sentence)
|
||||||
end = self.vocab.end_symbol
|
|
||||||
|
|
||||||
words = ([] if start is None else [start]) \
|
|
||||||
+ normalize(sentence) \
|
|
||||||
+ ([] if end is None else [end])
|
|
||||||
return words
|
return words
|
||||||
|
|
||||||
def numericalize(self, words):
|
def numericalize(self, sentence):
|
||||||
ids = []
|
ids = [
|
||||||
for word in words:
|
self.vocab.lookup(item) for item in sentence
|
||||||
if word in self.vocab.stoi:
|
if item in self.vocab.stoi
|
||||||
ids.append(self.vocab.lookup(word))
|
]
|
||||||
continue
|
|
||||||
for char in word:
|
|
||||||
if char in self.vocab.stoi:
|
|
||||||
ids.append(self.vocab.lookup(char))
|
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def reverse(self, ids):
|
def reverse(self, ids):
|
||||||
|
|
|
@ -238,10 +238,7 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
querys = paddle.concat(
|
querys = paddle.concat(
|
||||||
[
|
[
|
||||||
paddle.zeros(
|
paddle.zeros(
|
||||||
shape=[
|
shape=[querys.shape[0], 1, querys.shape[-1]],
|
||||||
querys.shape[0], 1,
|
|
||||||
querys.shape[-1] * self.reduction_factor
|
|
||||||
],
|
|
||||||
dtype=querys.dtype), querys
|
dtype=querys.dtype), querys
|
||||||
],
|
],
|
||||||
axis=1)
|
axis=1)
|
||||||
|
@ -266,7 +263,7 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
return mel_outputs, stop_logits, alignments
|
return mel_outputs, stop_logits, alignments
|
||||||
|
|
||||||
def infer(self, key, stop_threshold=0.5, max_decoder_steps=1000):
|
def infer(self, key, stop_threshold=0.5, max_decoder_steps=1000):
|
||||||
decoder_input = paddle.zeros(
|
query = paddle.zeros(
|
||||||
shape=[key.shape[0], self.d_mels * self.reduction_factor],
|
shape=[key.shape[0], self.d_mels * self.reduction_factor],
|
||||||
dtype=key.dtype) #[B, C]
|
dtype=key.dtype) #[B, C]
|
||||||
|
|
||||||
|
@ -275,8 +272,8 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
|
|
||||||
mel_outputs, stop_logits, alignments = [], [], []
|
mel_outputs, stop_logits, alignments = [], [], []
|
||||||
while True:
|
while True:
|
||||||
decoder_input = self.prenet(decoder_input)
|
query = self.prenet(query)
|
||||||
mel_output, stop_logit, alignment = self._decode(decoder_input)
|
mel_output, stop_logit, alignment = self._decode(query)
|
||||||
|
|
||||||
mel_outputs += [mel_output]
|
mel_outputs += [mel_output]
|
||||||
stop_logits += [stop_logit]
|
stop_logits += [stop_logit]
|
||||||
|
@ -288,7 +285,7 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
print("Warning! Reached max decoder steps!!!")
|
print("Warning! Reached max decoder steps!!!")
|
||||||
break
|
break
|
||||||
|
|
||||||
decoder_input = mel_output
|
query = mel_output
|
||||||
|
|
||||||
alignments = paddle.stack(alignments, axis=1)
|
alignments = paddle.stack(alignments, axis=1)
|
||||||
stop_logits = paddle.concat(stop_logits, axis=1)
|
stop_logits = paddle.concat(stop_logits, axis=1)
|
||||||
|
@ -350,7 +347,7 @@ class Tacotron2(nn.Layer):
|
||||||
attention_kernel_size, p_prenet_dropout, p_attention_dropout,
|
attention_kernel_size, p_prenet_dropout, p_attention_dropout,
|
||||||
p_decoder_dropout)
|
p_decoder_dropout)
|
||||||
self.postnet = DecoderPostNet(
|
self.postnet = DecoderPostNet(
|
||||||
d_mels=d_mels,
|
d_mels=d_mels * reduction_factor,
|
||||||
d_hidden=d_postnet,
|
d_hidden=d_postnet,
|
||||||
kernel_size=postnet_kernel_size,
|
kernel_size=postnet_kernel_size,
|
||||||
padding=int((postnet_kernel_size - 1) / 2),
|
padding=int((postnet_kernel_size - 1) / 2),
|
||||||
|
|
|
@ -19,8 +19,11 @@ import matplotlib.pylab as plt
|
||||||
from matplotlib import cm, pyplot
|
from matplotlib import cm, pyplot
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"pack_attention_images", "add_attention_plots", "plot_alignment",
|
"pack_attention_images",
|
||||||
"min_max_normalize"
|
"add_attention_plots",
|
||||||
|
"plot_alignment",
|
||||||
|
"min_max_normalize",
|
||||||
|
"add_spectrogram_plots",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,6 +51,13 @@ def pack_attention_images(attention_weights, rotate=False):
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def save_figure_to_numpy(fig):
|
||||||
|
# save it to a numpy array.
|
||||||
|
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
||||||
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def plot_alignment(alignment, title=None):
|
def plot_alignment(alignment, title=None):
|
||||||
fig, ax = plt.subplots(figsize=(6, 4))
|
fig, ax = plt.subplots(figsize=(6, 4))
|
||||||
im = ax.imshow(
|
im = ax.imshow(
|
||||||
|
@ -61,8 +71,7 @@ def plot_alignment(alignment, title=None):
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
||||||
fig.canvas.draw()
|
fig.canvas.draw()
|
||||||
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
data = save_figure_to_numpy(fig)
|
||||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
|
|
||||||
plt.close()
|
plt.close()
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -83,5 +92,20 @@ def add_multi_attention_plots(writer, tag, attention_weights, global_step):
|
||||||
dataformats="HWC")
|
dataformats="HWC")
|
||||||
|
|
||||||
|
|
||||||
|
def add_spectrogram_plots(writer, tag, spec, global_step):
|
||||||
|
spec = spec.numpy().T
|
||||||
|
fig, ax = plt.subplots(figsize=(12, 3))
|
||||||
|
im = ax.imshow(spec, aspect="auto", origin="lower", interpolation='none')
|
||||||
|
plt.colorbar(im, ax=ax)
|
||||||
|
plt.xlabel("Frames")
|
||||||
|
plt.ylabel("Channels")
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
fig.canvas.draw()
|
||||||
|
data = save_figure_to_numpy(fig)
|
||||||
|
plt.close()
|
||||||
|
writer.add_image(tag, data, global_step, dataformats="HWC")
|
||||||
|
|
||||||
|
|
||||||
def min_max_normalize(v):
|
def min_max_normalize(v):
|
||||||
return (v - v.min()) / (v.max() - v.min())
|
return (v - v.min()) / (v.max() - v.min())
|
||||||
|
|
Loading…
Reference in New Issue