replace add_scalar to add_scalars
This commit is contained in:
parent
bf6d9ef06f
commit
2933eb7e57
|
@ -29,41 +29,6 @@ from parakeet.models.transformer_tts import TransformerTTS
|
|||
from parakeet.utils import io
|
||||
|
||||
|
||||
def add_scalars(self, main_tag, tag_scalar_dict, step, walltime=None):
|
||||
"""Add scalars to vdl record file.
|
||||
Args:
|
||||
main_tag (string): The parent name for the tags
|
||||
tag_scalar_dict (dict): Key-value pair storing the tag and corresponding values
|
||||
step (int): Step of scalars
|
||||
walltime (float): Wall time of scalars.
|
||||
Example:
|
||||
for index in range(1, 101):
|
||||
writer.add_scalar(tag="train/loss", value=index*0.2, step=index)
|
||||
writer.add_scalar(tag="train/lr", value=index*0.5, step=index)
|
||||
"""
|
||||
import time
|
||||
from visualdl.writer.record_writer import RecordFileWriter
|
||||
from visualdl.component.base_component import scalar
|
||||
|
||||
fw_logdir = self.logdir
|
||||
walltime = round(time.time()) if walltime is None else walltime
|
||||
for tag, value in tag_scalar_dict.items():
|
||||
tag = os.path.join(fw_logdir, main_tag, tag)
|
||||
if '%' in tag:
|
||||
raise RuntimeError("% can't appear in tag!")
|
||||
if tag in self._all_writers:
|
||||
fw = self._all_writers[tag]
|
||||
else:
|
||||
fw = RecordFileWriter(
|
||||
logdir=tag,
|
||||
max_queue_size=self._max_queue,
|
||||
flush_secs=self._flush_secs,
|
||||
filename_suffix=self._filename_suffix)
|
||||
self._all_writers.update({tag: fw})
|
||||
fw.add_record(
|
||||
scalar(tag=main_tag, value=value, step=step, walltime=walltime))
|
||||
|
||||
|
||||
def add_config_options_to_parser(parser):
|
||||
parser.add_argument("--config", type=str, help="path of the config file")
|
||||
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
||||
|
@ -99,7 +64,6 @@ def main(args):
|
|||
|
||||
writer = LogWriter(os.path.join(args.output,
|
||||
'log')) if local_rank == 0 else None
|
||||
writer.add_scalars = add_scalars
|
||||
|
||||
fluid.enable_dygraph(place)
|
||||
network_cfg = cfg['network']
|
||||
|
@ -167,23 +131,28 @@ def main(args):
|
|||
loss = loss + stop_loss
|
||||
|
||||
if local_rank == 0:
|
||||
writer.add_scalars('training_loss', {
|
||||
'mel_loss': mel_loss.numpy(),
|
||||
'post_mel_loss': post_mel_loss.numpy()
|
||||
}, global_step)
|
||||
|
||||
writer.add_scalar('training_loss/mel_loss',
|
||||
mel_loss.numpy(),
|
||||
global_step)
|
||||
writer.add_scalar('training_loss/post_mel_loss',
|
||||
post_mel_loss.numpy(),
|
||||
global_step)
|
||||
writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
|
||||
|
||||
if parallel:
|
||||
writer.add_scalars('alphas', {
|
||||
'encoder_alpha': model._layers.encoder.alpha.numpy(),
|
||||
'decoder_alpha': model._layers.decoder.alpha.numpy(),
|
||||
}, global_step)
|
||||
writer.add_scalar('alphas/encoder_alpha',
|
||||
model._layers.encoder.alpha.numpy(),
|
||||
global_step)
|
||||
writer.add_scalar('alphas/decoder_alpha',
|
||||
model._layers.decoder.alpha.numpy(),
|
||||
global_step)
|
||||
else:
|
||||
writer.add_scalars('alphas', {
|
||||
'encoder_alpha': model.encoder.alpha.numpy(),
|
||||
'decoder_alpha': model.decoder.alpha.numpy(),
|
||||
}, global_step)
|
||||
writer.add_scalar('alphas/encoder_alpha',
|
||||
model.encoder.alpha.numpy(),
|
||||
global_step)
|
||||
writer.add_scalar('alphas/decoder_alpha',
|
||||
model.decoder.alpha.numpy(),
|
||||
global_step)
|
||||
|
||||
writer.add_scalar('learning_rate',
|
||||
optimizer._learning_rate.step().numpy(),
|
||||
|
|
|
@ -121,7 +121,7 @@ def main(args):
|
|||
model.clear_gradients()
|
||||
|
||||
if local_rank == 0:
|
||||
writer.add_scalars('training_loss', {'loss': loss.numpy(), },
|
||||
writer.add_scalar('training_loss/loss', loss.numpy(),
|
||||
global_step)
|
||||
|
||||
# save checkpoint
|
||||
|
|
Loading…
Reference in New Issue