From 95f64c4f02cf05516f729cb229745e5c38a931a0 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Mon, 14 Jun 2021 17:21:45 +0800 Subject: [PATCH] WIP: add some trainig info --- examples/parallelwave_gan/baker/conf/default.yaml | 2 +- parakeet/training/trainer.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/parallelwave_gan/baker/conf/default.yaml b/examples/parallelwave_gan/baker/conf/default.yaml index 1cb1487..29d6c0f 100644 --- a/examples/parallelwave_gan/baker/conf/default.yaml +++ b/examples/parallelwave_gan/baker/conf/default.yaml @@ -86,7 +86,7 @@ lambda_adv: 4.0 # Loss balancing coefficient. batch_size: 6 # Batch size. batch_max_steps: 25500 # Length of each audio in batch. Make sure dividable by hop_size. pin_memory: true # Whether to pin memory in Pytorch DataLoader. -num_workers: 0 # Number of workers in Pytorch DataLoader. +num_workers: 4 # Number of workers in Pytorch DataLoader. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. diff --git a/parakeet/training/trainer.py b/parakeet/training/trainer.py index 548df64..ad3661f 100644 --- a/parakeet/training/trainer.py +++ b/parakeet/training/trainer.py @@ -76,6 +76,8 @@ class Trainer(object): else: max_iteration = self.stop_trigger.period + p = tqdm.tqdm() + while True: self.observation = {} # set observation as the report target @@ -84,12 +86,13 @@ class Trainer(object): # updating parameters and state with scope(self.observation): update() - print(self.observation) + p.update() + print(self.observation) - # execute extension when necessary - for name, entry in extensions: - if entry.trigger(self): - entry.extension(self) + # execute extension when necessary + for name, entry in extensions: + if entry.trigger(self): + entry.extension(self) if stop_trigger(self): print("Training Done!")