Merge pull request #31 from ShenYuhan/add_vdl
add visualdl for parakeet
This commit is contained in:
commit
ce8fad5412
|
@ -21,7 +21,6 @@ import random
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import pickle
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
|
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
|
|
|
@ -21,7 +21,7 @@ import random
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import pickle
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorboardX import SummaryWriter
|
from visualdl import LogWriter
|
||||||
|
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
|
@ -179,7 +179,7 @@ if __name__ == "__main__":
|
||||||
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||||
state_dir = os.path.join(args.output, "states")
|
state_dir = os.path.join(args.output, "states")
|
||||||
log_dir = os.path.join(args.output, "log")
|
log_dir = os.path.join(args.output, "log")
|
||||||
writer = SummaryWriter(log_dir)
|
writer = LogWriter(log_dir)
|
||||||
|
|
||||||
if args.checkpoint is not None:
|
if args.checkpoint is not None:
|
||||||
iteration = io.load_parameters(
|
iteration = io.load_parameters(
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
import os
|
import os
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
|
|
|
@ -11,7 +11,6 @@ from paddle import fluid
|
||||||
from paddle.fluid import layers as F
|
from paddle.fluid import layers as F
|
||||||
from paddle.fluid import dygraph as dg
|
from paddle.fluid import dygraph as dg
|
||||||
from paddle.fluid.io import DataLoader
|
from paddle.fluid.io import DataLoader
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
from parakeet.data import SliceDataset, DataCargo, PartialyRandomizedSimilarTimeLengthSampler, SequentialSampler
|
from parakeet.data import SliceDataset, DataCargo, PartialyRandomizedSimilarTimeLengthSampler, SequentialSampler
|
||||||
|
|
|
@ -10,7 +10,7 @@ from paddle.fluid import layers as F
|
||||||
from paddle.fluid import initializer as I
|
from paddle.fluid import initializer as I
|
||||||
from paddle.fluid import dygraph as dg
|
from paddle.fluid import dygraph as dg
|
||||||
from paddle.fluid.io import DataLoader
|
from paddle.fluid.io import DataLoader
|
||||||
from tensorboardX import SummaryWriter
|
from visualdl import LogWriter
|
||||||
|
|
||||||
from parakeet.models.deepvoice3 import Encoder, Decoder, PostNet, SpectraNet
|
from parakeet.models.deepvoice3 import Encoder, Decoder, PostNet, SpectraNet
|
||||||
from parakeet.data import SliceDataset, DataCargo, SequentialSampler, RandomSampler
|
from parakeet.data import SliceDataset, DataCargo, SequentialSampler, RandomSampler
|
||||||
|
@ -181,7 +181,7 @@ if __name__ == "__main__":
|
||||||
global global_step
|
global global_step
|
||||||
global_step = 1
|
global_step = 1
|
||||||
global writer
|
global writer
|
||||||
writer = SummaryWriter()
|
writer = LogWriter()
|
||||||
print("[Training] tensorboard log and checkpoints are save in {}".format(
|
print("[Training] tensorboard log and checkpoints are save in {}".format(
|
||||||
writer.logdir))
|
writer.logdir))
|
||||||
train(args, config)
|
train(args, config)
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
from tensorboardX import SummaryWriter
|
from visualdl import LogWriter
|
||||||
from scipy.io.wavfile import write
|
from scipy.io.wavfile import write
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -78,7 +78,7 @@ def synthesis(text_input, args):
|
||||||
if not os.path.exists(args.output):
|
if not os.path.exists(args.output):
|
||||||
os.mkdir(args.output)
|
os.mkdir(args.output)
|
||||||
|
|
||||||
writer = SummaryWriter(os.path.join(args.output, 'log'))
|
writer = LogWriter(os.path.join(args.output, 'log'))
|
||||||
|
|
||||||
model = FastSpeech(cfg['network'], num_mels=cfg['audio']['num_mels'])
|
model = FastSpeech(cfg['network'], num_mels=cfg['audio']['num_mels'])
|
||||||
# Load parameters.
|
# Load parameters.
|
||||||
|
|
|
@ -22,7 +22,7 @@ from ruamel import yaml
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from matplotlib import cm
|
from matplotlib import cm
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from tensorboardX import SummaryWriter
|
from visualdl import LogWriter
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
import paddle.fluid.layers as layers
|
import paddle.fluid.layers as layers
|
||||||
import paddle.fluid as fluid
|
import paddle.fluid as fluid
|
||||||
|
@ -69,8 +69,8 @@ def main(args):
|
||||||
if not os.path.exists(args.output):
|
if not os.path.exists(args.output):
|
||||||
os.mkdir(args.output)
|
os.mkdir(args.output)
|
||||||
|
|
||||||
writer = SummaryWriter(os.path.join(args.output,
|
writer = LogWriter(os.path.join(args.output,
|
||||||
'log')) if local_rank == 0 else None
|
'log')) if local_rank == 0 else None
|
||||||
|
|
||||||
model = FastSpeech(cfg['network'], num_mels=cfg['audio']['num_mels'])
|
model = FastSpeech(cfg['network'], num_mels=cfg['audio']['num_mels'])
|
||||||
model.train()
|
model.train()
|
||||||
|
|
|
@ -16,7 +16,7 @@ from scipy.io.wavfile import write
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from matplotlib import cm
|
from matplotlib import cm
|
||||||
from tensorboardX import SummaryWriter
|
from visualdl import LogWriter
|
||||||
from ruamel import yaml
|
from ruamel import yaml
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -81,7 +81,7 @@ def synthesis(text_input, args):
|
||||||
if not os.path.exists(args.output):
|
if not os.path.exists(args.output):
|
||||||
os.mkdir(args.output)
|
os.mkdir(args.output)
|
||||||
|
|
||||||
writer = SummaryWriter(os.path.join(args.output, 'log'))
|
writer = LogWriter(os.path.join(args.output, 'log'))
|
||||||
|
|
||||||
fluid.enable_dygraph(place)
|
fluid.enable_dygraph(place)
|
||||||
with fluid.unique_name.guard():
|
with fluid.unique_name.guard():
|
||||||
|
@ -121,8 +121,7 @@ def synthesis(text_input, args):
|
||||||
writer.add_image(
|
writer.add_image(
|
||||||
'Attention_%d_0' % global_step,
|
'Attention_%d_0' % global_step,
|
||||||
x,
|
x,
|
||||||
i * 4 + j,
|
i * 4 + j)
|
||||||
dataformats="HWC")
|
|
||||||
|
|
||||||
if args.vocoder == 'griffin-lim':
|
if args.vocoder == 'griffin-lim':
|
||||||
#synthesis use griffin-lim
|
#synthesis use griffin-lim
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from tensorboardX import SummaryWriter
|
from visualdl import LogWriter
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import argparse
|
import argparse
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
@ -62,8 +62,8 @@ def main(args):
|
||||||
if not os.path.exists(args.output):
|
if not os.path.exists(args.output):
|
||||||
os.mkdir(args.output)
|
os.mkdir(args.output)
|
||||||
|
|
||||||
writer = SummaryWriter(os.path.join(args.output,
|
writer = LogWriter(os.path.join(args.output,
|
||||||
'log')) if local_rank == 0 else None
|
'log')) if local_rank == 0 else None
|
||||||
|
|
||||||
fluid.enable_dygraph(place)
|
fluid.enable_dygraph(place)
|
||||||
network_cfg = cfg['network']
|
network_cfg = cfg['network']
|
||||||
|
@ -131,23 +131,28 @@ def main(args):
|
||||||
loss = loss + stop_loss
|
loss = loss + stop_loss
|
||||||
|
|
||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
writer.add_scalars('training_loss', {
|
writer.add_scalar('training_loss/mel_loss',
|
||||||
'mel_loss': mel_loss.numpy(),
|
mel_loss.numpy(),
|
||||||
'post_mel_loss': post_mel_loss.numpy()
|
global_step)
|
||||||
}, global_step)
|
writer.add_scalar('training_loss/post_mel_loss',
|
||||||
|
post_mel_loss.numpy(),
|
||||||
|
global_step)
|
||||||
writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
|
writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
|
||||||
|
|
||||||
if parallel:
|
if parallel:
|
||||||
writer.add_scalars('alphas', {
|
writer.add_scalar('alphas/encoder_alpha',
|
||||||
'encoder_alpha': model._layers.encoder.alpha.numpy(),
|
model._layers.encoder.alpha.numpy(),
|
||||||
'decoder_alpha': model._layers.decoder.alpha.numpy(),
|
global_step)
|
||||||
}, global_step)
|
writer.add_scalar('alphas/decoder_alpha',
|
||||||
|
model._layers.decoder.alpha.numpy(),
|
||||||
|
global_step)
|
||||||
else:
|
else:
|
||||||
writer.add_scalars('alphas', {
|
writer.add_scalar('alphas/encoder_alpha',
|
||||||
'encoder_alpha': model.encoder.alpha.numpy(),
|
model.encoder.alpha.numpy(),
|
||||||
'decoder_alpha': model.decoder.alpha.numpy(),
|
global_step)
|
||||||
}, global_step)
|
writer.add_scalar('alphas/decoder_alpha',
|
||||||
|
model.decoder.alpha.numpy(),
|
||||||
|
global_step)
|
||||||
|
|
||||||
writer.add_scalar('learning_rate',
|
writer.add_scalar('learning_rate',
|
||||||
optimizer._learning_rate.step().numpy(),
|
optimizer._learning_rate.step().numpy(),
|
||||||
|
@ -162,8 +167,7 @@ def main(args):
|
||||||
writer.add_image(
|
writer.add_image(
|
||||||
'Attention_%d_0' % global_step,
|
'Attention_%d_0' % global_step,
|
||||||
x,
|
x,
|
||||||
i * 4 + j,
|
i * 4 + j)
|
||||||
dataformats="HWC")
|
|
||||||
|
|
||||||
for i, prob in enumerate(attn_enc):
|
for i, prob in enumerate(attn_enc):
|
||||||
for j in range(cfg['network']['encoder_num_head']):
|
for j in range(cfg['network']['encoder_num_head']):
|
||||||
|
@ -173,8 +177,7 @@ def main(args):
|
||||||
writer.add_image(
|
writer.add_image(
|
||||||
'Attention_enc_%d_0' % global_step,
|
'Attention_enc_%d_0' % global_step,
|
||||||
x,
|
x,
|
||||||
i * 4 + j,
|
i * 4 + j)
|
||||||
dataformats="HWC")
|
|
||||||
|
|
||||||
for i, prob in enumerate(attn_dec):
|
for i, prob in enumerate(attn_dec):
|
||||||
for j in range(cfg['network']['decoder_num_head']):
|
for j in range(cfg['network']['decoder_num_head']):
|
||||||
|
@ -184,8 +187,7 @@ def main(args):
|
||||||
writer.add_image(
|
writer.add_image(
|
||||||
'Attention_dec_%d_0' % global_step,
|
'Attention_dec_%d_0' % global_step,
|
||||||
x,
|
x,
|
||||||
i * 4 + j,
|
i * 4 + j)
|
||||||
dataformats="HWC")
|
|
||||||
|
|
||||||
if parallel:
|
if parallel:
|
||||||
loss = model.scale_loss(loss)
|
loss = model.scale_loss(loss)
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from tensorboardX import SummaryWriter
|
from visualdl import LogWriter
|
||||||
import os
|
import os
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -60,8 +60,8 @@ def main(args):
|
||||||
if not os.path.exists(args.output):
|
if not os.path.exists(args.output):
|
||||||
os.mkdir(args.output)
|
os.mkdir(args.output)
|
||||||
|
|
||||||
writer = SummaryWriter(os.path.join(args.output,
|
writer = LogWriter(os.path.join(args.output,
|
||||||
'log')) if local_rank == 0 else None
|
'log')) if local_rank == 0 else None
|
||||||
|
|
||||||
fluid.enable_dygraph(place)
|
fluid.enable_dygraph(place)
|
||||||
model = Vocoder(cfg['train']['batch_size'], cfg['vocoder']['hidden_size'],
|
model = Vocoder(cfg['train']['batch_size'], cfg['vocoder']['hidden_size'],
|
||||||
|
@ -121,7 +121,7 @@ def main(args):
|
||||||
model.clear_gradients()
|
model.clear_gradients()
|
||||||
|
|
||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
writer.add_scalars('training_loss', {'loss': loss.numpy(), },
|
writer.add_scalar('training_loss/loss', loss.numpy(),
|
||||||
global_step)
|
global_step)
|
||||||
|
|
||||||
# save checkpoint
|
# save checkpoint
|
||||||
|
|
|
@ -22,7 +22,8 @@ import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
from tensorboardX import SummaryWriter
|
from visualdl import LogWriter
|
||||||
|
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
from parakeet.utils import io
|
from parakeet.utils import io
|
||||||
|
@ -78,8 +79,8 @@ def train(config):
|
||||||
os.makedirs(checkpoint_dir)
|
os.makedirs(checkpoint_dir)
|
||||||
|
|
||||||
# Create tensorboard logger.
|
# Create tensorboard logger.
|
||||||
tb = SummaryWriter(os.path.join(run_dir, "logs")) \
|
vdl = LogWriter(os.path.join(run_dir, "logs")) \
|
||||||
if rank == 0 else None
|
if rank == 0 else None
|
||||||
|
|
||||||
# Configurate device
|
# Configurate device
|
||||||
place = fluid.CUDAPlace(rank) if use_gpu else fluid.CPUPlace()
|
place = fluid.CUDAPlace(rank) if use_gpu else fluid.CPUPlace()
|
||||||
|
@ -94,7 +95,7 @@ def train(config):
|
||||||
print("Random Seed: ", seed)
|
print("Random Seed: ", seed)
|
||||||
|
|
||||||
# Build model.
|
# Build model.
|
||||||
model = WaveFlow(config, checkpoint_dir, parallel, rank, nranks, tb)
|
model = WaveFlow(config, checkpoint_dir, parallel, rank, nranks, vdl)
|
||||||
iteration = model.build()
|
iteration = model.build()
|
||||||
|
|
||||||
while iteration < config.max_iterations:
|
while iteration < config.max_iterations:
|
||||||
|
@ -113,7 +114,7 @@ def train(config):
|
||||||
|
|
||||||
# Close TensorBoard.
|
# Close TensorBoard.
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
tb.close()
|
vdl.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -42,7 +42,7 @@ class WaveFlow():
|
||||||
rank (int, optional): the rank of the process in a multi-process
|
rank (int, optional): the rank of the process in a multi-process
|
||||||
scenario. Defaults to 0.
|
scenario. Defaults to 0.
|
||||||
nranks (int, optional): the total number of processes. Defaults to 1.
|
nranks (int, optional): the total number of processes. Defaults to 1.
|
||||||
tb_logger (obj, optional): logger to visualize metrics.
|
vdl_logger (obj, optional): logger to visualize metrics.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -55,13 +55,13 @@ class WaveFlow():
|
||||||
parallel=False,
|
parallel=False,
|
||||||
rank=0,
|
rank=0,
|
||||||
nranks=1,
|
nranks=1,
|
||||||
tb_logger=None):
|
vdl_logger=None):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.checkpoint_dir = checkpoint_dir
|
self.checkpoint_dir = checkpoint_dir
|
||||||
self.parallel = parallel
|
self.parallel = parallel
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.nranks = nranks
|
self.nranks = nranks
|
||||||
self.tb_logger = tb_logger
|
self.vdl_logger = vdl_logger
|
||||||
self.dtype = "float16" if config.use_fp16 else "float32"
|
self.dtype = "float16" if config.use_fp16 else "float32"
|
||||||
|
|
||||||
def build(self, training=True):
|
def build(self, training=True):
|
||||||
|
@ -161,8 +161,8 @@ class WaveFlow():
|
||||||
load_time - start_time, graph_time - load_time)
|
load_time - start_time, graph_time - load_time)
|
||||||
print(log)
|
print(log)
|
||||||
|
|
||||||
tb = self.tb_logger
|
vdl_writer = self.vdl_logger
|
||||||
tb.add_scalar("Train-Loss-Rank-0", loss_val, iteration)
|
vdl_writer.add_scalar("Train-Loss-Rank-0", loss_val, iteration)
|
||||||
|
|
||||||
@dg.no_grad
|
@dg.no_grad
|
||||||
def valid_step(self, iteration):
|
def valid_step(self, iteration):
|
||||||
|
@ -175,7 +175,7 @@ class WaveFlow():
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
self.waveflow.eval()
|
self.waveflow.eval()
|
||||||
tb = self.tb_logger
|
vdl_writer = self.vdl_logger
|
||||||
|
|
||||||
total_loss = []
|
total_loss = []
|
||||||
sample_audios = []
|
sample_audios = []
|
||||||
|
@ -188,10 +188,12 @@ class WaveFlow():
|
||||||
|
|
||||||
# Visualize latent z and scale log_s.
|
# Visualize latent z and scale log_s.
|
||||||
if self.rank == 0 and i == 0:
|
if self.rank == 0 and i == 0:
|
||||||
tb.add_histogram("Valid-Latent_z", valid_z.numpy(), iteration)
|
vdl_writer.add_histogram("Valid-Latent_z", valid_z.numpy(),
|
||||||
|
iteration)
|
||||||
for j, valid_log_s in enumerate(valid_log_s_list):
|
for j, valid_log_s in enumerate(valid_log_s_list):
|
||||||
hist_name = "Valid-{}th-Flow-Log_s".format(j)
|
hist_name = "Valid-{}th-Flow-Log_s".format(j)
|
||||||
tb.add_histogram(hist_name, valid_log_s.numpy(), iteration)
|
vdl_writer.add_histogram(hist_name, valid_log_s.numpy(),
|
||||||
|
iteration)
|
||||||
|
|
||||||
valid_loss = self.criterion(valid_outputs)
|
valid_loss = self.criterion(valid_outputs)
|
||||||
total_loss.append(float(valid_loss.numpy()))
|
total_loss.append(float(valid_loss.numpy()))
|
||||||
|
@ -202,7 +204,7 @@ class WaveFlow():
|
||||||
log = "Test | Rank: {} AvgLoss: {:<8.3f} Time {:<8.3f}".format(
|
log = "Test | Rank: {} AvgLoss: {:<8.3f} Time {:<8.3f}".format(
|
||||||
self.rank, loss_val, total_time)
|
self.rank, loss_val, total_time)
|
||||||
print(log)
|
print(log)
|
||||||
tb.add_scalar("Valid-Avg-Loss", loss_val, iteration)
|
vdl_writer.add_scalar("Valid-Avg-Loss", loss_val, iteration)
|
||||||
|
|
||||||
@dg.no_grad
|
@dg.no_grad
|
||||||
def infer(self, iteration):
|
def infer(self, iteration):
|
||||||
|
|
|
@ -17,7 +17,6 @@ import os
|
||||||
import ruamel.yaml
|
import ruamel.yaml
|
||||||
import argparse
|
import argparse
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
fluid.require_version('1.8.0')
|
fluid.require_version('1.8.0')
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
|
|
|
@ -17,7 +17,7 @@ import os
|
||||||
import ruamel.yaml
|
import ruamel.yaml
|
||||||
import argparse
|
import argparse
|
||||||
import tqdm
|
import tqdm
|
||||||
from tensorboardX import SummaryWriter
|
from visualdl import LogWriter
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
fluid.require_version('1.8.0')
|
fluid.require_version('1.8.0')
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
|
@ -154,7 +154,7 @@ if __name__ == "__main__":
|
||||||
eval_interval = train_config["eval_interval"]
|
eval_interval = train_config["eval_interval"]
|
||||||
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||||
log_dir = os.path.join(args.output, "log")
|
log_dir = os.path.join(args.output, "log")
|
||||||
writer = SummaryWriter(log_dir)
|
writer = LogWriter(log_dir)
|
||||||
|
|
||||||
# load parameters and optimizer, and update iterations done so far
|
# load parameters and optimizer, and update iterations done so far
|
||||||
if args.checkpoint is not None:
|
if args.checkpoint is not None:
|
||||||
|
|
Loading…
Reference in New Issue