complete theexample for speedyspeech; fix several bugs in training module
This commit is contained in:
parent
6c21d80025
commit
4a7888b8c6
|
@ -17,21 +17,27 @@ from parakeet.data.batch import batch_sequences
|
||||||
|
|
||||||
|
|
||||||
def collate_baker_examples(examples):
|
def collate_baker_examples(examples):
|
||||||
# fields = ["phones", "tones", "num_phones", "num_frames", "feats"]
|
# fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
|
||||||
phones = [np.array(item["phones"], dtype=np.int64) for item in examples]
|
phones = [np.array(item["phones"], dtype=np.int64) for item in examples]
|
||||||
tones = [np.array(item["tones"], dtype=np.int64) for item in examples]
|
tones = [np.array(item["tones"], dtype=np.int64) for item in examples]
|
||||||
feats = [np.array(item["feats"], dtype=np.float32) for item in examples]
|
feats = [np.array(item["feats"], dtype=np.float32) for item in examples]
|
||||||
|
durations = [
|
||||||
|
np.array(
|
||||||
|
item["durations"], dtype=np.int64) for item in examples
|
||||||
|
]
|
||||||
num_phones = np.array([item["num_phones"] for item in examples])
|
num_phones = np.array([item["num_phones"] for item in examples])
|
||||||
num_frames = np.array([item["num_frames"] for item in examples])
|
num_frames = np.array([item["num_frames"] for item in examples])
|
||||||
|
|
||||||
phones = batch_sequences(phones)
|
phones = batch_sequences(phones)
|
||||||
tones = batch_sequences(tones)
|
tones = batch_sequences(tones)
|
||||||
feats = batch_sequences(feats)
|
feats = batch_sequences(feats)
|
||||||
|
durations = batch_sequences(durations)
|
||||||
batch = {
|
batch = {
|
||||||
"phones": phones,
|
"phones": phones,
|
||||||
"tones": tones,
|
"tones": tones,
|
||||||
"num_phones": num_phones,
|
"num_phones": num_phones,
|
||||||
"num_frames": num_frames,
|
"num_frames": num_frames,
|
||||||
"feats": feats,
|
"feats": feats,
|
||||||
|
"durations": durations,
|
||||||
}
|
}
|
||||||
return batch
|
return batch
|
||||||
|
|
|
@ -20,8 +20,8 @@ trim_hop_length: 512 # Hop size in trimming.(in samples)
|
||||||
###########################################################
|
###########################################################
|
||||||
# DATA SETTING #
|
# DATA SETTING #
|
||||||
###########################################################
|
###########################################################
|
||||||
batch_size: 16
|
batch_size: 32
|
||||||
num_workers: 0
|
num_workers: 4
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,8 +30,8 @@ num_workers: 0
|
||||||
# MODEL SETTING #
|
# MODEL SETTING #
|
||||||
###########################################################
|
###########################################################
|
||||||
model:
|
model:
|
||||||
vocab_size: 68
|
vocab_size: 101 # 99 + 2
|
||||||
tone_size: 6
|
tone_size: 8 # 6 + 2
|
||||||
encoder_hidden_size: 128
|
encoder_hidden_size: 128
|
||||||
encoder_kernel_size: 3
|
encoder_kernel_size: 3
|
||||||
encoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
|
encoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
|
||||||
|
@ -39,7 +39,7 @@ model:
|
||||||
decoder_hidden_size: 128
|
decoder_hidden_size: 128
|
||||||
decoder_output_size: 80
|
decoder_output_size: 80
|
||||||
decoder_kernel_size: 3
|
decoder_kernel_size: 3
|
||||||
decoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
|
decoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
|
||||||
|
|
||||||
|
|
||||||
###########################################################
|
###########################################################
|
||||||
|
@ -47,6 +47,12 @@ model:
|
||||||
###########################################################
|
###########################################################
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# TRAINING SETTING #
|
||||||
|
###########################################################
|
||||||
|
max_epoch: 300
|
||||||
|
num_snapshots: 5
|
||||||
|
|
||||||
|
|
||||||
###########################################################
|
###########################################################
|
||||||
# OTHER SETTING #
|
# OTHER SETTING #
|
||||||
|
|
|
@ -17,7 +17,7 @@ from paddle.nn import functional as F
|
||||||
from paddle.fluid.layers import huber_loss
|
from paddle.fluid.layers import huber_loss
|
||||||
|
|
||||||
from parakeet.modules.ssim import ssim
|
from parakeet.modules.ssim import ssim
|
||||||
from parakeet.modules.modules.losses import masked_l1_loss, weighted_mean
|
from parakeet.modules.losses import masked_l1_loss, weighted_mean
|
||||||
from parakeet.training.reporter import report
|
from parakeet.training.reporter import report
|
||||||
from parakeet.training.updaters.standard_updater import StandardUpdater
|
from parakeet.training.updaters.standard_updater import StandardUpdater
|
||||||
from parakeet.training.extensions.evaluator import StandardEvaluator
|
from parakeet.training.extensions.evaluator import StandardEvaluator
|
||||||
|
@ -27,23 +27,25 @@ from parakeet.models.speedyspeech import SpeedySpeech
|
||||||
class SpeedySpeechUpdater(StandardUpdater):
|
class SpeedySpeechUpdater(StandardUpdater):
|
||||||
def update_core(self, batch):
|
def update_core(self, batch):
|
||||||
decoded, predicted_durations = self.model(
|
decoded, predicted_durations = self.model(
|
||||||
text=batch["phonemes"],
|
text=batch["phones"],
|
||||||
tones=batch["tones"],
|
tones=batch["tones"],
|
||||||
plens=batch["phoneme_lenghts"],
|
plens=batch["num_phones"],
|
||||||
durations=batch["phoneme_durations"])
|
durations=batch["durations"])
|
||||||
|
|
||||||
target_mel = batch["mel"]
|
target_mel = batch["feats"]
|
||||||
spec_mask = F.sequence_mask(
|
spec_mask = F.sequence_mask(
|
||||||
batch["num_frames"], dtype=target_mel.dtype).unsqueeze(-1)
|
batch["num_frames"], dtype=target_mel.dtype).unsqueeze(-1)
|
||||||
text_mask = F.sequence_mask(
|
text_mask = F.sequence_mask(
|
||||||
batch["phoneme_lenghts"], dtype=predicted_durations.dtype)
|
batch["num_phones"], dtype=predicted_durations.dtype)
|
||||||
|
|
||||||
# spec loss
|
# spec loss
|
||||||
l1_loss = masked_l1_loss(decoded, target_mel, spec_mask)
|
l1_loss = masked_l1_loss(decoded, target_mel, spec_mask)
|
||||||
|
|
||||||
# duration loss
|
# duration loss
|
||||||
target_durations = batch["phoneme_durations"]
|
target_durations = batch["durations"]
|
||||||
target_durations = paddle.clip(target_durations, min=1.0)
|
target_durations = paddle.maximum(
|
||||||
|
target_durations.astype(predicted_durations.dtype),
|
||||||
|
paddle.to_tensor([1.0]))
|
||||||
duration_loss = weighted_mean(
|
duration_loss = weighted_mean(
|
||||||
huber_loss(
|
huber_loss(
|
||||||
predicted_durations, paddle.log(target_durations), delta=1.0),
|
predicted_durations, paddle.log(target_durations), delta=1.0),
|
||||||
|
@ -53,13 +55,57 @@ class SpeedySpeechUpdater(StandardUpdater):
|
||||||
ssim_loss = 1.0 - ssim((decoded * spec_mask).unsqueeze(1),
|
ssim_loss = 1.0 - ssim((decoded * spec_mask).unsqueeze(1),
|
||||||
(target_mel * spec_mask).unsqueeze(1))
|
(target_mel * spec_mask).unsqueeze(1))
|
||||||
|
|
||||||
loss = l1_loss + duration_loss + ssim_loss
|
loss = l1_loss + ssim_loss + duration_loss
|
||||||
|
|
||||||
optimizer = self.optimizer
|
optimizer = self.optimizer
|
||||||
optimizer.clear_grad()
|
optimizer.clear_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
report("train/loss", float(loss))
|
||||||
report("train/l1_loss", float(l1_loss))
|
report("train/l1_loss", float(l1_loss))
|
||||||
report("train/duration_loss", float(duration_loss))
|
report("train/duration_loss", float(duration_loss))
|
||||||
report("train/ssim_loss", float(ssim_loss))
|
report("train/ssim_loss", float(ssim_loss))
|
||||||
|
|
||||||
|
|
||||||
|
class SpeedySpeechEvaluator(StandardEvaluator):
|
||||||
|
def evaluate_core(self, batch):
|
||||||
|
print("fire")
|
||||||
|
decoded, predicted_durations = self.model(
|
||||||
|
text=batch["phones"],
|
||||||
|
tones=batch["tones"],
|
||||||
|
plens=batch["num_phones"],
|
||||||
|
durations=batch["durations"])
|
||||||
|
|
||||||
|
target_mel = batch["feats"]
|
||||||
|
spec_mask = F.sequence_mask(
|
||||||
|
batch["num_frames"], dtype=target_mel.dtype).unsqueeze(-1)
|
||||||
|
text_mask = F.sequence_mask(
|
||||||
|
batch["num_phones"], dtype=predicted_durations.dtype)
|
||||||
|
|
||||||
|
# spec loss
|
||||||
|
l1_loss = masked_l1_loss(decoded, target_mel, spec_mask)
|
||||||
|
|
||||||
|
# duration loss
|
||||||
|
target_durations = batch["durations"]
|
||||||
|
target_durations = paddle.maximum(
|
||||||
|
target_durations.astype(predicted_durations.dtype),
|
||||||
|
paddle.to_tensor([1.0]))
|
||||||
|
duration_loss = weighted_mean(
|
||||||
|
huber_loss(
|
||||||
|
predicted_durations, paddle.log(target_durations), delta=1.0),
|
||||||
|
text_mask, )
|
||||||
|
|
||||||
|
# ssim loss
|
||||||
|
ssim_loss = 1.0 - ssim((decoded * spec_mask).unsqueeze(1),
|
||||||
|
(target_mel * spec_mask).unsqueeze(1))
|
||||||
|
|
||||||
|
loss = l1_loss + ssim_loss + duration_loss
|
||||||
|
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
|
||||||
|
report("eval/loss", float(loss))
|
||||||
|
report("eval/l1_loss", float(l1_loss))
|
||||||
|
report("eval/duration_loss", float(duration_loss))
|
||||||
|
report("eval/ssim_loss", float(ssim_loss))
|
||||||
|
|
|
@ -0,0 +1,131 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
import jsonlines
|
||||||
|
import paddle
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from paddle.io import DataLoader, DistributedBatchSampler
|
||||||
|
from paddle.optimizer import Adam # No RAdaom
|
||||||
|
from paddle.optimizer.lr import StepDecay
|
||||||
|
from paddle import DataParallel
|
||||||
|
from visualdl import LogWriter
|
||||||
|
|
||||||
|
from parakeet.datasets.data_table import DataTable
|
||||||
|
from parakeet.models.speedyspeech import SpeedySpeech
|
||||||
|
from parakeet.models.parallel_wavegan import PWGGenerator
|
||||||
|
|
||||||
|
from parakeet.training.updater import UpdaterBase
|
||||||
|
from parakeet.training.trainer import Trainer
|
||||||
|
from parakeet.training.reporter import report
|
||||||
|
from parakeet.training import extension
|
||||||
|
from parakeet.training.extensions.snapshot import Snapshot
|
||||||
|
from parakeet.training.extensions.visualizer import VisualDL
|
||||||
|
from parakeet.training.seeding import seed_everything
|
||||||
|
|
||||||
|
from batch_fn import collate_baker_examples
|
||||||
|
from speedyspeech_updater import SpeedySpeechUpdater, SpeedySpeechEvaluator
|
||||||
|
from config import get_cfg_default
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(args, config):
|
||||||
|
# dataloader has been too verbose
|
||||||
|
logging.getLogger("DataLoader").disabled = True
|
||||||
|
|
||||||
|
# construct dataset for evaluation
|
||||||
|
with jsonlines.open(args.test_metadata, 'r') as reader:
|
||||||
|
test_metadata = list(reader)
|
||||||
|
test_dataset = DataTable(
|
||||||
|
data=test_metadata, fields=["utt_id", "phones", "tones"])
|
||||||
|
|
||||||
|
model = SpeedySpeech(**config["model"])
|
||||||
|
model.set_state_dict(paddle.load(args.checkpoint)["main_params"])
|
||||||
|
model.eval()
|
||||||
|
vocoder_config = yaml.safe_load(
|
||||||
|
open("../../parallelwave_gan/baker/conf/default.yaml"))
|
||||||
|
vocoder = PWGGenerator(**vocoder_config["generator_params"])
|
||||||
|
vocoder.set_state_dict(
|
||||||
|
paddle.load("../../parallelwave_gan/baker/converted.pdparams"))
|
||||||
|
vocoder.remove_weight_norm()
|
||||||
|
vocoder.eval()
|
||||||
|
# print(model)
|
||||||
|
print("model done!")
|
||||||
|
|
||||||
|
stat = np.load("../../speedyspeech/baker/dump/train/stats.npy")
|
||||||
|
mu, std = stat
|
||||||
|
mu = paddle.to_tensor(mu)
|
||||||
|
std = paddle.to_tensor(std)
|
||||||
|
|
||||||
|
stat2 = np.load("../../parallelwave_gan/baker/dump/train/stats.npy")
|
||||||
|
mu2, std2 = stat2
|
||||||
|
mu2 = paddle.to_tensor(mu2)
|
||||||
|
std2 = paddle.to_tensor(std2)
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for datum in test_dataset:
|
||||||
|
utt_id = datum["utt_id"]
|
||||||
|
phones = paddle.to_tensor(datum["phones"])
|
||||||
|
tones = paddle.to_tensor(datum["tones"])
|
||||||
|
|
||||||
|
mel, _ = model.inference(phones, tones)
|
||||||
|
mel = mel * std + mu
|
||||||
|
mel = (mel - mu2) / std2
|
||||||
|
|
||||||
|
wav = vocoder.inference(mel)
|
||||||
|
sf.write(
|
||||||
|
output_dir / (utt_id + ".wav"), wav.numpy(), samplerate=config.sr)
|
||||||
|
print(f"{utt_id} done!")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# parse args and config and redirect to train_sp
|
||||||
|
parser = argparse.ArgumentParser(description="Train a ParallelWaveGAN "
|
||||||
|
"model with Baker Mandrin TTS dataset.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", type=str, help="config file to overwrite default config")
|
||||||
|
parser.add_argument("--checkpoint", type=str, help="checkpoint to load.")
|
||||||
|
parser.add_argument("--test-metadata", type=str, help="training data")
|
||||||
|
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", type=str, default="gpu", help="device type to use")
|
||||||
|
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
config = get_cfg_default()
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
|
||||||
|
print("========Args========")
|
||||||
|
print(yaml.safe_dump(vars(args)))
|
||||||
|
print("========Config========")
|
||||||
|
print(config)
|
||||||
|
|
||||||
|
evaluate(args, config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -44,6 +44,7 @@ from parakeet.training.extensions.visualizer import VisualDL
|
||||||
from parakeet.training.seeding import seed_everything
|
from parakeet.training.seeding import seed_everything
|
||||||
|
|
||||||
from batch_fn import collate_baker_examples
|
from batch_fn import collate_baker_examples
|
||||||
|
from speedyspeech_updater import SpeedySpeechUpdater, SpeedySpeechEvaluator
|
||||||
from config import get_cfg_default
|
from config import get_cfg_default
|
||||||
|
|
||||||
|
|
||||||
|
@ -73,26 +74,29 @@ def train_sp(args, config):
|
||||||
train_metadata = list(reader)
|
train_metadata = list(reader)
|
||||||
train_dataset = DataTable(
|
train_dataset = DataTable(
|
||||||
data=train_metadata,
|
data=train_metadata,
|
||||||
fields=["phones", "tones", "num_phones", "num_frames", "feats"],
|
fields=[
|
||||||
|
"phones", "tones", "num_phones", "num_frames", "feats", "durations"
|
||||||
|
],
|
||||||
converters={"feats": np.load, }, )
|
converters={"feats": np.load, }, )
|
||||||
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||||
dev_metadata = list(reader)
|
dev_metadata = list(reader)
|
||||||
dev_dataset = DataTable(
|
dev_dataset = DataTable(
|
||||||
data=dev_metadata,
|
data=dev_metadata,
|
||||||
fields=["phones", "tones", "num_phones", "num_frames", "feats"],
|
fields=[
|
||||||
|
"phones", "tones", "num_phones", "num_frames", "feats", "durations"
|
||||||
|
],
|
||||||
converters={"feats": np.load, }, )
|
converters={"feats": np.load, }, )
|
||||||
|
|
||||||
# collate function and dataloader
|
# collate function and dataloader
|
||||||
train_sampler = DistributedBatchSampler(
|
train_sampler = DistributedBatchSampler(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
shuffle=True,
|
|
||||||
drop_last=True)
|
|
||||||
dev_sampler = DistributedBatchSampler(
|
|
||||||
dev_dataset,
|
|
||||||
batch_size=config.batch_size,
|
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
drop_last=False)
|
drop_last=True)
|
||||||
|
# dev_sampler = DistributedBatchSampler(dev_dataset,
|
||||||
|
# batch_size=config.batch_size,
|
||||||
|
# shuffle=False,
|
||||||
|
# drop_last=False)
|
||||||
print("samplers done!")
|
print("samplers done!")
|
||||||
|
|
||||||
train_dataloader = DataLoader(
|
train_dataloader = DataLoader(
|
||||||
|
@ -102,16 +106,43 @@ def train_sp(args, config):
|
||||||
num_workers=config.num_workers)
|
num_workers=config.num_workers)
|
||||||
dev_dataloader = DataLoader(
|
dev_dataloader = DataLoader(
|
||||||
dev_dataset,
|
dev_dataset,
|
||||||
batch_sampler=dev_sampler,
|
shuffle=False,
|
||||||
|
drop_last=False,
|
||||||
|
batch_size=config.batch_size,
|
||||||
collate_fn=collate_baker_examples,
|
collate_fn=collate_baker_examples,
|
||||||
num_workers=config.num_workers)
|
num_workers=config.num_workers)
|
||||||
print("dataloaders done!")
|
print("dataloaders done!")
|
||||||
|
|
||||||
# batch = collate_baker_examples([train_dataset[i] for i in range(10)])
|
# batch = collate_baker_examples([train_dataset[i] for i in range(10)])
|
||||||
# batch = collate_baker_examples([dev_dataset[i] for i in range(10)])
|
# # batch = collate_baker_examples([dev_dataset[i] for i in range(10)])
|
||||||
# import pdb; pdb.set_trace()
|
# import pdb; pdb.set_trace()
|
||||||
model = SpeedySpeech(**config["model"])
|
model = SpeedySpeech(**config["model"])
|
||||||
print(model)
|
if world_size > 1:
|
||||||
|
model = DataParallel(model) # TODO, do not use vocab size from config
|
||||||
|
# print(model)
|
||||||
|
print("model done!")
|
||||||
|
optimizer = Adam(
|
||||||
|
0.001,
|
||||||
|
parameters=model.parameters(),
|
||||||
|
grad_clip=nn.ClipGradByGlobalNorm(5.0))
|
||||||
|
print("optimizer done!")
|
||||||
|
|
||||||
|
updater = SpeedySpeechUpdater(
|
||||||
|
model=model, optimizer=optimizer, dataloader=train_dataloader)
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
|
||||||
|
|
||||||
|
evaluator = SpeedySpeechEvaluator(model, dev_dataloader)
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
trainer.extend(evaluator, trigger=(1, "epoch"))
|
||||||
|
writer = LogWriter(str(output_dir))
|
||||||
|
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
|
||||||
|
trainer.extend(
|
||||||
|
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
|
||||||
|
print(trainer.extensions)
|
||||||
|
trainer.run()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
@ -189,7 +189,7 @@ class SpeedySpeech(nn.Layer):
|
||||||
|
|
||||||
# decode
|
# decode
|
||||||
# remove positional encoding here
|
# remove positional encoding here
|
||||||
_, t_dec, feature_size = encodings.shpae
|
_, t_dec, feature_size = encodings.shape
|
||||||
encodings += sinusoid_position_encoding(t_dec, feature_size)
|
encodings += sinusoid_position_encoding(t_dec, feature_size)
|
||||||
decoded = self.decoder(encodings)
|
decoded = self.decoder(encodings)
|
||||||
return decoded, pred_durations
|
return decoded, pred_durations
|
||||||
|
@ -211,4 +211,4 @@ class SpeedySpeech(nn.Layer):
|
||||||
t_dec, feature_size = shape[1], shape[2]
|
t_dec, feature_size = shape[1], shape[2]
|
||||||
encodings += sinusoid_position_encoding(t_dec, feature_size)
|
encodings += sinusoid_position_encoding(t_dec, feature_size)
|
||||||
decoded = self.decoder(encodings)
|
decoded = self.decoder(encodings)
|
||||||
return decoded, pred_durations
|
return decoded[0], pred_durations[0]
|
||||||
|
|
|
@ -123,8 +123,6 @@ class Trainer(object):
|
||||||
update = self.updater.update # training step
|
update = self.updater.update # training step
|
||||||
stop_trigger = self.stop_trigger
|
stop_trigger = self.stop_trigger
|
||||||
|
|
||||||
print(self.updater.state)
|
|
||||||
|
|
||||||
# display only one progress bar
|
# display only one progress bar
|
||||||
max_iteration = None
|
max_iteration = None
|
||||||
if isinstance(stop_trigger, LimitTrigger):
|
if isinstance(stop_trigger, LimitTrigger):
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
class IntervalTrigger(object):
|
class IntervalTrigger(object):
|
||||||
"""A Predicate to do something every N cycle."""
|
"""A Predicate to do something every N cycle."""
|
||||||
|
@ -23,9 +25,16 @@ class IntervalTrigger(object):
|
||||||
raise ValueError("period should be a positive integer.")
|
raise ValueError("period should be a positive integer.")
|
||||||
self.period = period
|
self.period = period
|
||||||
self.unit = unit
|
self.unit = unit
|
||||||
|
self.last_index = None
|
||||||
|
|
||||||
def __call__(self, trainer):
|
def __call__(self, trainer):
|
||||||
state = trainer.updater.state
|
if self.last_index is None:
|
||||||
index = getattr(state, self.unit)
|
last_index = getattr(trainer.updater.state, self.unit)
|
||||||
fire = index % self.period == 0
|
self.last_index = last_index
|
||||||
|
|
||||||
|
last_index = self.last_index
|
||||||
|
index = getattr(trainer.updater.state, self.unit)
|
||||||
|
fire = index // self.period != last_index // self.period
|
||||||
|
|
||||||
|
self.last_index = index
|
||||||
return fire
|
return fire
|
||||||
|
|
|
@ -106,8 +106,8 @@ class StandardUpdater(UpdaterBase):
|
||||||
self.update_core(batch)
|
self.update_core(batch)
|
||||||
|
|
||||||
self.state.iteration += 1
|
self.state.iteration += 1
|
||||||
if self.updaters_per_epoch is not None:
|
if self.updates_per_epoch is not None:
|
||||||
if self.state.iteration % self.updaters_per_epoch == 0:
|
if self.state.iteration % self.updates_per_epoch == 0:
|
||||||
self.state.epoch += 1
|
self.state.epoch += 1
|
||||||
|
|
||||||
def update_core(self, batch):
|
def update_core(self, batch):
|
||||||
|
@ -139,7 +139,7 @@ class StandardUpdater(UpdaterBase):
|
||||||
self.optimizer.update()
|
self.optimizer.update()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def updaters_per_epoch(self):
|
def updates_per_epoch(self):
|
||||||
"""Number of updater per epoch, determined by the length of the
|
"""Number of updater per epoch, determined by the length of the
|
||||||
dataloader."""
|
dataloader."""
|
||||||
length_of_dataloader = None
|
length_of_dataloader = None
|
||||||
|
|
Loading…
Reference in New Issue