1. change default data layout to channel last in preprocessing;
2. add Summary and DictSummary for aggrelation of evaluation losses; 3. add unittest for report ans scope.
This commit is contained in:
parent
3977632b07
commit
a738954001
|
@ -50,7 +50,8 @@ class Clip(object):
|
|||
"""Convert into batch tensors.
|
||||
|
||||
Args:
|
||||
batch (list): list of tuple of the pair of audio and features.
|
||||
batch (list): list of tuple of the pair of audio and features. Audio shape
|
||||
(T, ), features shape(T', C).
|
||||
|
||||
Returns:
|
||||
Tensor: Auxiliary feature batch (B, C, T'), where
|
||||
|
@ -60,13 +61,13 @@ class Clip(object):
|
|||
"""
|
||||
# check length
|
||||
examples = [
|
||||
self._adjust_length(b['wave_path'], b['feats_path'])
|
||||
for b in examples if b['feats_path'].shape[1] > self.mel_threshold
|
||||
self._adjust_length(b['wave'], b['feats']) for b in examples
|
||||
if b['feats'].shape[0] > self.mel_threshold
|
||||
]
|
||||
xs, cs = [b[0] for b in examples], [b[1] for b in examples]
|
||||
|
||||
# make batch with random cut
|
||||
c_lengths = [c.shape[1] for c in cs]
|
||||
c_lengths = [c.shape[0] for c in cs]
|
||||
start_frames = np.array([
|
||||
np.random.randint(self.start_offset, cl + self.end_offset)
|
||||
for cl in c_lengths
|
||||
|
@ -79,12 +80,13 @@ class Clip(object):
|
|||
y_batch = np.stack(
|
||||
[x[start:end] for x, start, end in zip(xs, x_starts, x_ends)])
|
||||
c_batch = np.stack(
|
||||
[c[:, start:end] for c, start, end in zip(cs, c_starts, c_ends)])
|
||||
[c[start:end] for c, start, end in zip(cs, c_starts, c_ends)])
|
||||
|
||||
# convert each batch to tensor, asuume that each item in batch has the same length
|
||||
y_batch = paddle.to_tensor(
|
||||
y_batch, dtype=paddle.float32).unsqueeze(1) # (B, 1, T)
|
||||
c_batch = paddle.to_tensor(c_batch, dtype=paddle.float32) # (B, C, T')
|
||||
c_batch = paddle.to_tensor(
|
||||
c_batch, dtype=paddle.float32).transpose([0, 2, 1]) # (B, C, T')
|
||||
|
||||
return y_batch, c_batch
|
||||
|
||||
|
@ -103,6 +105,6 @@ class Clip(object):
|
|||
|
||||
# check the legnth is valid
|
||||
assert len(x) == c.shape[
|
||||
1] * self.hop_size, f"wave length: ({len(x)}), mel length: ({c.shape[1]})"
|
||||
0] * self.hop_size, f"wave length: ({len(x)}), mel length: ({c.shape[0]})"
|
||||
|
||||
return x, c
|
||||
|
|
|
@ -20,7 +20,7 @@ import dataclasses
|
|||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
import json
|
||||
import jsonlines
|
||||
import paddle
|
||||
import numpy as np
|
||||
from paddle import nn
|
||||
|
@ -61,23 +61,23 @@ def train_sp(args, config):
|
|||
)
|
||||
|
||||
# construct dataset for training and validation
|
||||
with open(args.train_metadata) as f:
|
||||
train_metadata = json.load(f)
|
||||
with jsonlines.open(args.train_metadata, 'r') as reader:
|
||||
train_metadata = list(reader)
|
||||
train_dataset = DataTable(
|
||||
data=train_metadata,
|
||||
fields=["wave_path", "feats_path"],
|
||||
fields=["wave", "feats"],
|
||||
converters={
|
||||
"wave_path": np.load,
|
||||
"feats_path": np.load,
|
||||
"wave": np.load,
|
||||
"feats": np.load,
|
||||
}, )
|
||||
with open(args.dev_metadata) as f:
|
||||
dev_metadata = json.load(f)
|
||||
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||
dev_metadata = list(reader)
|
||||
dev_dataset = DataTable(
|
||||
data=dev_metadata,
|
||||
fields=["wave_path", "feats_path"],
|
||||
fields=["wave", "feats"],
|
||||
converters={
|
||||
"wave_path": np.load,
|
||||
"feats_path": np.load,
|
||||
"wave": np.load,
|
||||
"feats": np.load,
|
||||
}, )
|
||||
|
||||
# collate function and dataloader
|
||||
|
@ -169,12 +169,13 @@ def train_sp(args, config):
|
|||
|
||||
trainer = Trainer(
|
||||
updater,
|
||||
stop_trigger=(10, "iteration"), # PROFILING
|
||||
stop_trigger=(config.train_max_steps, "iteration"), # PROFILING
|
||||
out=output_dir, )
|
||||
with paddle.fluid.profiler.profiler('All', 'total',
|
||||
str(output_dir / "profiler.log"),
|
||||
'Default') as prof:
|
||||
trainer.run()
|
||||
|
||||
# with paddle.fluid.profiler.profiler('All', 'total',
|
||||
# str(output_dir / "profiler.log"),
|
||||
# 'Default') as prof:
|
||||
trainer.run()
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
@ -35,8 +35,9 @@ class SpectralConvergenceLoss(nn.Layer):
|
|||
Tensor: Spectral convergence loss value.
|
||||
"""
|
||||
return paddle.norm(
|
||||
y_mag - x_mag, p="fro") / paddle.norm(
|
||||
y_mag, p="fro")
|
||||
y_mag - x_mag, p="fro") / paddle.clip(
|
||||
paddle.norm(
|
||||
y_mag, p="fro"), min=1e-10)
|
||||
|
||||
|
||||
class LogSTFTMagnitudeLoss(nn.Layer):
|
||||
|
@ -54,7 +55,11 @@ class LogSTFTMagnitudeLoss(nn.Layer):
|
|||
Returns:
|
||||
Tensor: Log STFT magnitude loss value.
|
||||
"""
|
||||
return F.l1_loss(paddle.log(y_mag), paddle.log(x_mag))
|
||||
return F.l1_loss(
|
||||
paddle.log(paddle.clip(
|
||||
y_mag, min=1e-10)),
|
||||
paddle.log(paddle.clip(
|
||||
x_mag, min=1e-10)))
|
||||
|
||||
|
||||
class STFTLoss(nn.Layer):
|
||||
|
|
|
@ -12,7 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import contextlib
|
||||
from collections import defaultdict
|
||||
|
||||
OBSERVATIONS = None
|
||||
|
||||
|
@ -45,3 +47,113 @@ def report(name, value):
|
|||
return
|
||||
else:
|
||||
observations[name] = value
|
||||
|
||||
|
||||
class Summary(object):
|
||||
"""Online summarization of a sequence of scalars.
|
||||
Summary computes the statistics of given scalars online.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._x = 0.0
|
||||
self._x2 = 0.0
|
||||
self._n = 0
|
||||
|
||||
def add(self, value, weight=1):
|
||||
"""Adds a scalar value.
|
||||
|
||||
Args:
|
||||
value: Scalar value to accumulate. It is either a NumPy scalar or
|
||||
a zero-dimensional array (on CPU or GPU).
|
||||
weight: An optional weight for the value. It is a NumPy scalar or
|
||||
a zero-dimensional array (on CPU or GPU).
|
||||
Default is 1 (integer).
|
||||
|
||||
"""
|
||||
self._x += weight * value
|
||||
self._x2 += weight * value * value
|
||||
self._n += weight
|
||||
|
||||
def compute_mean(self):
|
||||
"""Computes the mean."""
|
||||
x, n = self._x, self._n
|
||||
return x / n
|
||||
|
||||
def make_statistics(self):
|
||||
"""Computes and returns the mean and standard deviation values.
|
||||
|
||||
Returns:
|
||||
tuple: Mean and standard deviation values.
|
||||
|
||||
"""
|
||||
x, n = self._x, self._n
|
||||
mean = x / n
|
||||
var = self._x2 / n - mean * mean
|
||||
std = math.sqrt(var)
|
||||
return mean, std
|
||||
|
||||
|
||||
class DictSummary(object):
|
||||
"""Online summarization of a sequence of dictionaries.
|
||||
|
||||
``DictSummary`` computes the statistics of a given set of scalars online.
|
||||
It only computes the statistics for scalar values and variables of scalar
|
||||
values in the dictionaries.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._summaries = defaultdict(Summary)
|
||||
|
||||
def add(self, d):
|
||||
"""Adds a dictionary of scalars.
|
||||
|
||||
Args:
|
||||
d (dict): Dictionary of scalars to accumulate. Only elements of
|
||||
scalars, zero-dimensional arrays, and variables of
|
||||
zero-dimensional arrays are accumulated. When the value
|
||||
is a tuple, the second element is interpreted as a weight.
|
||||
|
||||
"""
|
||||
summaries = self._summaries
|
||||
for k, v in d.items():
|
||||
w = 1
|
||||
if isinstance(v, tuple):
|
||||
w = v[1]
|
||||
v = v[0]
|
||||
summaries[k].add(v, weight=w)
|
||||
|
||||
def compute_mean(self):
|
||||
"""Creates a dictionary of mean values.
|
||||
|
||||
It returns a single dictionary that holds a mean value for each entry
|
||||
added to the summary.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary of mean values.
|
||||
|
||||
"""
|
||||
return {
|
||||
name: summary.compute_mean()
|
||||
for name, summary in self._summaries.items()
|
||||
}
|
||||
|
||||
def make_statistics(self):
|
||||
"""Creates a dictionary of statistics.
|
||||
|
||||
It returns a single dictionary that holds mean and standard deviation
|
||||
values for every entry added to the summary. For an entry of name
|
||||
``'key'``, these values are added to the dictionary by names ``'key'``
|
||||
and ``'key.std'``, respectively.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary of statistics of all entries.
|
||||
|
||||
"""
|
||||
stats = {}
|
||||
for name, summary in self._summaries.items():
|
||||
mean, std = summary.make_statistics()
|
||||
stats[name] = mean
|
||||
stats[name + '.std'] = std
|
||||
|
||||
return stats
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from parakeet.training.reporter import report, scope
|
||||
from parakeet.training.reporter import Summary, DictSummary
|
||||
|
||||
|
||||
def test_reporter_scope():
|
||||
first = {}
|
||||
second = {}
|
||||
third = {}
|
||||
|
||||
with scope(first):
|
||||
report("first_begin", 1)
|
||||
with scope(second):
|
||||
report("second_begin", 2)
|
||||
with scope(third):
|
||||
report("third_begin", 3)
|
||||
report("third_end", 4)
|
||||
report("seconf_end", 5)
|
||||
report("first_end", 6)
|
||||
|
||||
assert first == {'first_begin': 1, 'first_end': 6}
|
||||
assert second == {'second_begin': 2, 'seconf_end': 5}
|
||||
assert third == {'third_begin': 3, 'third_end': 4}
|
||||
print(first)
|
||||
print(second)
|
||||
print(third)
|
||||
|
||||
|
||||
def test_summary():
|
||||
summary = Summary()
|
||||
summary.add(1)
|
||||
summary.add(2)
|
||||
summary.add(3)
|
||||
state = summary.make_statistics()
|
||||
print(state)
|
||||
np.testing.assert_allclose(
|
||||
np.array(list(state)), np.array([2.0, np.std([1, 2, 3])]))
|
Loading…
Reference in New Issue