diff --git a/examples/parallelwave_gan/baker/batch_fn.py b/examples/parallelwave_gan/baker/batch_fn.py index be22dbb..aff647c 100644 --- a/examples/parallelwave_gan/baker/batch_fn.py +++ b/examples/parallelwave_gan/baker/batch_fn.py @@ -50,7 +50,8 @@ class Clip(object): """Convert into batch tensors. Args: - batch (list): list of tuple of the pair of audio and features. + batch (list): list of tuple of the pair of audio and features. Audio shape + (T, ), features shape(T', C). Returns: Tensor: Auxiliary feature batch (B, C, T'), where @@ -60,13 +61,13 @@ class Clip(object): """ # check length examples = [ - self._adjust_length(b['wave_path'], b['feats_path']) - for b in examples if b['feats_path'].shape[1] > self.mel_threshold + self._adjust_length(b['wave'], b['feats']) for b in examples + if b['feats'].shape[0] > self.mel_threshold ] xs, cs = [b[0] for b in examples], [b[1] for b in examples] # make batch with random cut - c_lengths = [c.shape[1] for c in cs] + c_lengths = [c.shape[0] for c in cs] start_frames = np.array([ np.random.randint(self.start_offset, cl + self.end_offset) for cl in c_lengths @@ -79,12 +80,13 @@ class Clip(object): y_batch = np.stack( [x[start:end] for x, start, end in zip(xs, x_starts, x_ends)]) c_batch = np.stack( - [c[:, start:end] for c, start, end in zip(cs, c_starts, c_ends)]) + [c[start:end] for c, start, end in zip(cs, c_starts, c_ends)]) # convert each batch to tensor, asuume that each item in batch has the same length y_batch = paddle.to_tensor( y_batch, dtype=paddle.float32).unsqueeze(1) # (B, 1, T) - c_batch = paddle.to_tensor(c_batch, dtype=paddle.float32) # (B, C, T') + c_batch = paddle.to_tensor( + c_batch, dtype=paddle.float32).transpose([0, 2, 1]) # (B, C, T') return y_batch, c_batch @@ -103,6 +105,6 @@ class Clip(object): # check the legnth is valid assert len(x) == c.shape[ - 1] * self.hop_size, f"wave length: ({len(x)}), mel length: ({c.shape[1]})" + 0] * self.hop_size, f"wave length: ({len(x)}), mel length: ({c.shape[0]})" return x, c diff --git a/examples/parallelwave_gan/baker/train.py b/examples/parallelwave_gan/baker/train.py index d424c5d..087963e 100644 --- a/examples/parallelwave_gan/baker/train.py +++ b/examples/parallelwave_gan/baker/train.py @@ -20,7 +20,7 @@ import dataclasses from pathlib import Path import yaml -import json +import jsonlines import paddle import numpy as np from paddle import nn @@ -61,23 +61,23 @@ def train_sp(args, config): ) # construct dataset for training and validation - with open(args.train_metadata) as f: - train_metadata = json.load(f) + with jsonlines.open(args.train_metadata, 'r') as reader: + train_metadata = list(reader) train_dataset = DataTable( data=train_metadata, - fields=["wave_path", "feats_path"], + fields=["wave", "feats"], converters={ - "wave_path": np.load, - "feats_path": np.load, + "wave": np.load, + "feats": np.load, }, ) - with open(args.dev_metadata) as f: - dev_metadata = json.load(f) + with jsonlines.open(args.dev_metadata, 'r') as reader: + dev_metadata = list(reader) dev_dataset = DataTable( data=dev_metadata, - fields=["wave_path", "feats_path"], + fields=["wave", "feats"], converters={ - "wave_path": np.load, - "feats_path": np.load, + "wave": np.load, + "feats": np.load, }, ) # collate function and dataloader @@ -169,12 +169,13 @@ def train_sp(args, config): trainer = Trainer( updater, - stop_trigger=(10, "iteration"), # PROFILING + stop_trigger=(config.train_max_steps, "iteration"), # PROFILING out=output_dir, ) - with paddle.fluid.profiler.profiler('All', 'total', - str(output_dir / "profiler.log"), - 'Default') as prof: - trainer.run() + + # with paddle.fluid.profiler.profiler('All', 'total', + # str(output_dir / "profiler.log"), + # 'Default') as prof: + trainer.run() def main(): diff --git a/parakeet/modules/stft_loss.py b/parakeet/modules/stft_loss.py index 6531010..f98a3dd 100644 --- a/parakeet/modules/stft_loss.py +++ b/parakeet/modules/stft_loss.py @@ -35,8 +35,9 @@ class SpectralConvergenceLoss(nn.Layer): Tensor: Spectral convergence loss value. """ return paddle.norm( - y_mag - x_mag, p="fro") / paddle.norm( - y_mag, p="fro") + y_mag - x_mag, p="fro") / paddle.clip( + paddle.norm( + y_mag, p="fro"), min=1e-10) class LogSTFTMagnitudeLoss(nn.Layer): @@ -54,7 +55,11 @@ class LogSTFTMagnitudeLoss(nn.Layer): Returns: Tensor: Log STFT magnitude loss value. """ - return F.l1_loss(paddle.log(y_mag), paddle.log(x_mag)) + return F.l1_loss( + paddle.log(paddle.clip( + y_mag, min=1e-10)), + paddle.log(paddle.clip( + x_mag, min=1e-10))) class STFTLoss(nn.Layer): diff --git a/parakeet/training/reporter.py b/parakeet/training/reporter.py index 3f4d77f..c2f171c 100644 --- a/parakeet/training/reporter.py +++ b/parakeet/training/reporter.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import contextlib +from collections import defaultdict OBSERVATIONS = None @@ -45,3 +47,113 @@ def report(name, value): return else: observations[name] = value + + +class Summary(object): + """Online summarization of a sequence of scalars. + Summary computes the statistics of given scalars online. + """ + + def __init__(self): + self._x = 0.0 + self._x2 = 0.0 + self._n = 0 + + def add(self, value, weight=1): + """Adds a scalar value. + + Args: + value: Scalar value to accumulate. It is either a NumPy scalar or + a zero-dimensional array (on CPU or GPU). + weight: An optional weight for the value. It is a NumPy scalar or + a zero-dimensional array (on CPU or GPU). + Default is 1 (integer). + + """ + self._x += weight * value + self._x2 += weight * value * value + self._n += weight + + def compute_mean(self): + """Computes the mean.""" + x, n = self._x, self._n + return x / n + + def make_statistics(self): + """Computes and returns the mean and standard deviation values. + + Returns: + tuple: Mean and standard deviation values. + + """ + x, n = self._x, self._n + mean = x / n + var = self._x2 / n - mean * mean + std = math.sqrt(var) + return mean, std + + +class DictSummary(object): + """Online summarization of a sequence of dictionaries. + + ``DictSummary`` computes the statistics of a given set of scalars online. + It only computes the statistics for scalar values and variables of scalar + values in the dictionaries. + + """ + + def __init__(self): + self._summaries = defaultdict(Summary) + + def add(self, d): + """Adds a dictionary of scalars. + + Args: + d (dict): Dictionary of scalars to accumulate. Only elements of + scalars, zero-dimensional arrays, and variables of + zero-dimensional arrays are accumulated. When the value + is a tuple, the second element is interpreted as a weight. + + """ + summaries = self._summaries + for k, v in d.items(): + w = 1 + if isinstance(v, tuple): + w = v[1] + v = v[0] + summaries[k].add(v, weight=w) + + def compute_mean(self): + """Creates a dictionary of mean values. + + It returns a single dictionary that holds a mean value for each entry + added to the summary. + + Returns: + dict: Dictionary of mean values. + + """ + return { + name: summary.compute_mean() + for name, summary in self._summaries.items() + } + + def make_statistics(self): + """Creates a dictionary of statistics. + + It returns a single dictionary that holds mean and standard deviation + values for every entry added to the summary. For an entry of name + ``'key'``, these values are added to the dictionary by names ``'key'`` + and ``'key.std'``, respectively. + + Returns: + dict: Dictionary of statistics of all entries. + + """ + stats = {} + for name, summary in self._summaries.items(): + mean, std = summary.make_statistics() + stats[name] = mean + stats[name + '.std'] = std + + return stats diff --git a/tests/test_reporter.py b/tests/test_reporter.py new file mode 100644 index 0000000..cd40364 --- /dev/null +++ b/tests/test_reporter.py @@ -0,0 +1,51 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from parakeet.training.reporter import report, scope +from parakeet.training.reporter import Summary, DictSummary + + +def test_reporter_scope(): + first = {} + second = {} + third = {} + + with scope(first): + report("first_begin", 1) + with scope(second): + report("second_begin", 2) + with scope(third): + report("third_begin", 3) + report("third_end", 4) + report("seconf_end", 5) + report("first_end", 6) + + assert first == {'first_begin': 1, 'first_end': 6} + assert second == {'second_begin': 2, 'seconf_end': 5} + assert third == {'third_begin': 3, 'third_end': 4} + print(first) + print(second) + print(third) + + +def test_summary(): + summary = Summary() + summary.add(1) + summary.add(2) + summary.add(3) + state = summary.make_statistics() + print(state) + np.testing.assert_allclose( + np.array(list(state)), np.array([2.0, np.std([1, 2, 3])]))