add visualdl for parakeet
This commit is contained in:
parent
e58e927c5e
commit
bf6d9ef06f
|
@ -21,7 +21,6 @@ import random
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import pickle
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
|
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
|
|
|
@ -21,7 +21,7 @@ import random
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import pickle
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorboardX import SummaryWriter
|
from visualdl import LogWriter
|
||||||
|
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
|
@ -179,7 +179,7 @@ if __name__ == "__main__":
|
||||||
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||||
state_dir = os.path.join(args.output, "states")
|
state_dir = os.path.join(args.output, "states")
|
||||||
log_dir = os.path.join(args.output, "log")
|
log_dir = os.path.join(args.output, "log")
|
||||||
writer = SummaryWriter(log_dir)
|
writer = LogWriter(log_dir)
|
||||||
|
|
||||||
if args.checkpoint is not None:
|
if args.checkpoint is not None:
|
||||||
iteration = io.load_parameters(
|
iteration = io.load_parameters(
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
import os
|
import os
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
|
|
|
@ -11,7 +11,6 @@ from paddle import fluid
|
||||||
from paddle.fluid import layers as F
|
from paddle.fluid import layers as F
|
||||||
from paddle.fluid import dygraph as dg
|
from paddle.fluid import dygraph as dg
|
||||||
from paddle.fluid.io import DataLoader
|
from paddle.fluid.io import DataLoader
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
from parakeet.data import SliceDataset, DataCargo, PartialyRandomizedSimilarTimeLengthSampler, SequentialSampler
|
from parakeet.data import SliceDataset, DataCargo, PartialyRandomizedSimilarTimeLengthSampler, SequentialSampler
|
||||||
|
|
|
@ -9,7 +9,7 @@ from paddle import fluid
|
||||||
from paddle.fluid import layers as F
|
from paddle.fluid import layers as F
|
||||||
from paddle.fluid import dygraph as dg
|
from paddle.fluid import dygraph as dg
|
||||||
from paddle.fluid.io import DataLoader
|
from paddle.fluid.io import DataLoader
|
||||||
from tensorboardX import SummaryWriter
|
from visualdl import LogWriter
|
||||||
|
|
||||||
from parakeet.models.deepvoice3 import Encoder, Decoder, PostNet, SpectraNet
|
from parakeet.models.deepvoice3 import Encoder, Decoder, PostNet, SpectraNet
|
||||||
from parakeet.data import SliceDataset, DataCargo, PartialyRandomizedSimilarTimeLengthSampler, SequentialSampler
|
from parakeet.data import SliceDataset, DataCargo, PartialyRandomizedSimilarTimeLengthSampler, SequentialSampler
|
||||||
|
@ -181,7 +181,7 @@ if __name__ == "__main__":
|
||||||
global global_step
|
global global_step
|
||||||
global_step = 1
|
global_step = 1
|
||||||
global writer
|
global writer
|
||||||
writer = SummaryWriter()
|
writer = LogWriter()
|
||||||
print("[Training] tensorboard log and checkpoints are save in {}".format(
|
print("[Training] tensorboard log and checkpoints are save in {}".format(
|
||||||
writer.logdir))
|
writer.logdir))
|
||||||
train(args, config)
|
train(args, config)
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
from tensorboardX import SummaryWriter
|
from visualdl import LogWriter
|
||||||
from scipy.io.wavfile import write
|
from scipy.io.wavfile import write
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -78,7 +78,7 @@ def synthesis(text_input, args):
|
||||||
if not os.path.exists(args.output):
|
if not os.path.exists(args.output):
|
||||||
os.mkdir(args.output)
|
os.mkdir(args.output)
|
||||||
|
|
||||||
writer = SummaryWriter(os.path.join(args.output, 'log'))
|
writer = LogWriter(os.path.join(args.output, 'log'))
|
||||||
|
|
||||||
model = FastSpeech(cfg['network'], num_mels=cfg['audio']['num_mels'])
|
model = FastSpeech(cfg['network'], num_mels=cfg['audio']['num_mels'])
|
||||||
# Load parameters.
|
# Load parameters.
|
||||||
|
|
|
@ -22,7 +22,7 @@ from ruamel import yaml
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from matplotlib import cm
|
from matplotlib import cm
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from tensorboardX import SummaryWriter
|
from visualdl import LogWriter
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
import paddle.fluid.layers as layers
|
import paddle.fluid.layers as layers
|
||||||
import paddle.fluid as fluid
|
import paddle.fluid as fluid
|
||||||
|
@ -69,8 +69,8 @@ def main(args):
|
||||||
if not os.path.exists(args.output):
|
if not os.path.exists(args.output):
|
||||||
os.mkdir(args.output)
|
os.mkdir(args.output)
|
||||||
|
|
||||||
writer = SummaryWriter(os.path.join(args.output,
|
writer = LogWriter(os.path.join(args.output,
|
||||||
'log')) if local_rank == 0 else None
|
'log')) if local_rank == 0 else None
|
||||||
|
|
||||||
model = FastSpeech(cfg['network'], num_mels=cfg['audio']['num_mels'])
|
model = FastSpeech(cfg['network'], num_mels=cfg['audio']['num_mels'])
|
||||||
model.train()
|
model.train()
|
||||||
|
|
|
@ -16,7 +16,7 @@ from scipy.io.wavfile import write
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from matplotlib import cm
|
from matplotlib import cm
|
||||||
from tensorboardX import SummaryWriter
|
from visualdl import LogWriter
|
||||||
from ruamel import yaml
|
from ruamel import yaml
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -81,7 +81,7 @@ def synthesis(text_input, args):
|
||||||
if not os.path.exists(args.output):
|
if not os.path.exists(args.output):
|
||||||
os.mkdir(args.output)
|
os.mkdir(args.output)
|
||||||
|
|
||||||
writer = SummaryWriter(os.path.join(args.output, 'log'))
|
writer = LogWriter(os.path.join(args.output, 'log'))
|
||||||
|
|
||||||
fluid.enable_dygraph(place)
|
fluid.enable_dygraph(place)
|
||||||
with fluid.unique_name.guard():
|
with fluid.unique_name.guard():
|
||||||
|
@ -121,8 +121,7 @@ def synthesis(text_input, args):
|
||||||
writer.add_image(
|
writer.add_image(
|
||||||
'Attention_%d_0' % global_step,
|
'Attention_%d_0' % global_step,
|
||||||
x,
|
x,
|
||||||
i * 4 + j,
|
i * 4 + j)
|
||||||
dataformats="HWC")
|
|
||||||
|
|
||||||
if args.vocoder == 'griffin-lim':
|
if args.vocoder == 'griffin-lim':
|
||||||
#synthesis use griffin-lim
|
#synthesis use griffin-lim
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from tensorboardX import SummaryWriter
|
from visualdl import LogWriter
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import argparse
|
import argparse
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
@ -29,6 +29,41 @@ from parakeet.models.transformer_tts import TransformerTTS
|
||||||
from parakeet.utils import io
|
from parakeet.utils import io
|
||||||
|
|
||||||
|
|
||||||
|
def add_scalars(self, main_tag, tag_scalar_dict, step, walltime=None):
|
||||||
|
"""Add scalars to vdl record file.
|
||||||
|
Args:
|
||||||
|
main_tag (string): The parent name for the tags
|
||||||
|
tag_scalar_dict (dict): Key-value pair storing the tag and corresponding values
|
||||||
|
step (int): Step of scalars
|
||||||
|
walltime (float): Wall time of scalars.
|
||||||
|
Example:
|
||||||
|
for index in range(1, 101):
|
||||||
|
writer.add_scalar(tag="train/loss", value=index*0.2, step=index)
|
||||||
|
writer.add_scalar(tag="train/lr", value=index*0.5, step=index)
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
from visualdl.writer.record_writer import RecordFileWriter
|
||||||
|
from visualdl.component.base_component import scalar
|
||||||
|
|
||||||
|
fw_logdir = self.logdir
|
||||||
|
walltime = round(time.time()) if walltime is None else walltime
|
||||||
|
for tag, value in tag_scalar_dict.items():
|
||||||
|
tag = os.path.join(fw_logdir, main_tag, tag)
|
||||||
|
if '%' in tag:
|
||||||
|
raise RuntimeError("% can't appear in tag!")
|
||||||
|
if tag in self._all_writers:
|
||||||
|
fw = self._all_writers[tag]
|
||||||
|
else:
|
||||||
|
fw = RecordFileWriter(
|
||||||
|
logdir=tag,
|
||||||
|
max_queue_size=self._max_queue,
|
||||||
|
flush_secs=self._flush_secs,
|
||||||
|
filename_suffix=self._filename_suffix)
|
||||||
|
self._all_writers.update({tag: fw})
|
||||||
|
fw.add_record(
|
||||||
|
scalar(tag=main_tag, value=value, step=step, walltime=walltime))
|
||||||
|
|
||||||
|
|
||||||
def add_config_options_to_parser(parser):
|
def add_config_options_to_parser(parser):
|
||||||
parser.add_argument("--config", type=str, help="path of the config file")
|
parser.add_argument("--config", type=str, help="path of the config file")
|
||||||
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
||||||
|
@ -62,8 +97,9 @@ def main(args):
|
||||||
if not os.path.exists(args.output):
|
if not os.path.exists(args.output):
|
||||||
os.mkdir(args.output)
|
os.mkdir(args.output)
|
||||||
|
|
||||||
writer = SummaryWriter(os.path.join(args.output,
|
writer = LogWriter(os.path.join(args.output,
|
||||||
'log')) if local_rank == 0 else None
|
'log')) if local_rank == 0 else None
|
||||||
|
writer.add_scalars = add_scalars
|
||||||
|
|
||||||
fluid.enable_dygraph(place)
|
fluid.enable_dygraph(place)
|
||||||
network_cfg = cfg['network']
|
network_cfg = cfg['network']
|
||||||
|
@ -162,8 +198,7 @@ def main(args):
|
||||||
writer.add_image(
|
writer.add_image(
|
||||||
'Attention_%d_0' % global_step,
|
'Attention_%d_0' % global_step,
|
||||||
x,
|
x,
|
||||||
i * 4 + j,
|
i * 4 + j)
|
||||||
dataformats="HWC")
|
|
||||||
|
|
||||||
for i, prob in enumerate(attn_enc):
|
for i, prob in enumerate(attn_enc):
|
||||||
for j in range(cfg['network']['encoder_num_head']):
|
for j in range(cfg['network']['encoder_num_head']):
|
||||||
|
@ -173,8 +208,7 @@ def main(args):
|
||||||
writer.add_image(
|
writer.add_image(
|
||||||
'Attention_enc_%d_0' % global_step,
|
'Attention_enc_%d_0' % global_step,
|
||||||
x,
|
x,
|
||||||
i * 4 + j,
|
i * 4 + j)
|
||||||
dataformats="HWC")
|
|
||||||
|
|
||||||
for i, prob in enumerate(attn_dec):
|
for i, prob in enumerate(attn_dec):
|
||||||
for j in range(cfg['network']['decoder_num_head']):
|
for j in range(cfg['network']['decoder_num_head']):
|
||||||
|
@ -184,8 +218,7 @@ def main(args):
|
||||||
writer.add_image(
|
writer.add_image(
|
||||||
'Attention_dec_%d_0' % global_step,
|
'Attention_dec_%d_0' % global_step,
|
||||||
x,
|
x,
|
||||||
i * 4 + j,
|
i * 4 + j)
|
||||||
dataformats="HWC")
|
|
||||||
|
|
||||||
if parallel:
|
if parallel:
|
||||||
loss = model.scale_loss(loss)
|
loss = model.scale_loss(loss)
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from tensorboardX import SummaryWriter
|
from visualdl import LogWriter
|
||||||
import os
|
import os
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -60,8 +60,8 @@ def main(args):
|
||||||
if not os.path.exists(args.output):
|
if not os.path.exists(args.output):
|
||||||
os.mkdir(args.output)
|
os.mkdir(args.output)
|
||||||
|
|
||||||
writer = SummaryWriter(os.path.join(args.output,
|
writer = LogWriter(os.path.join(args.output,
|
||||||
'log')) if local_rank == 0 else None
|
'log')) if local_rank == 0 else None
|
||||||
|
|
||||||
fluid.enable_dygraph(place)
|
fluid.enable_dygraph(place)
|
||||||
model = Vocoder(cfg['train']['batch_size'], cfg['vocoder']['hidden_size'],
|
model = Vocoder(cfg['train']['batch_size'], cfg['vocoder']['hidden_size'],
|
||||||
|
|
|
@ -22,7 +22,8 @@ import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
from tensorboardX import SummaryWriter
|
from visualdl import LogWriter
|
||||||
|
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
from parakeet.utils import io
|
from parakeet.utils import io
|
||||||
|
@ -78,8 +79,8 @@ def train(config):
|
||||||
os.makedirs(checkpoint_dir)
|
os.makedirs(checkpoint_dir)
|
||||||
|
|
||||||
# Create tensorboard logger.
|
# Create tensorboard logger.
|
||||||
tb = SummaryWriter(os.path.join(run_dir, "logs")) \
|
vdl = LogWriter(os.path.join(run_dir, "logs")) \
|
||||||
if rank == 0 else None
|
if rank == 0 else None
|
||||||
|
|
||||||
# Configurate device
|
# Configurate device
|
||||||
place = fluid.CUDAPlace(rank) if use_gpu else fluid.CPUPlace()
|
place = fluid.CUDAPlace(rank) if use_gpu else fluid.CPUPlace()
|
||||||
|
@ -94,7 +95,7 @@ def train(config):
|
||||||
print("Random Seed: ", seed)
|
print("Random Seed: ", seed)
|
||||||
|
|
||||||
# Build model.
|
# Build model.
|
||||||
model = WaveFlow(config, checkpoint_dir, parallel, rank, nranks, tb)
|
model = WaveFlow(config, checkpoint_dir, parallel, rank, nranks, vdl)
|
||||||
iteration = model.build()
|
iteration = model.build()
|
||||||
|
|
||||||
while iteration < config.max_iterations:
|
while iteration < config.max_iterations:
|
||||||
|
@ -113,7 +114,7 @@ def train(config):
|
||||||
|
|
||||||
# Close TensorBoard.
|
# Close TensorBoard.
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
tb.close()
|
vdl.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -42,7 +42,7 @@ class WaveFlow():
|
||||||
rank (int, optional): the rank of the process in a multi-process
|
rank (int, optional): the rank of the process in a multi-process
|
||||||
scenario. Defaults to 0.
|
scenario. Defaults to 0.
|
||||||
nranks (int, optional): the total number of processes. Defaults to 1.
|
nranks (int, optional): the total number of processes. Defaults to 1.
|
||||||
tb_logger (obj, optional): logger to visualize metrics.
|
vdl_logger (obj, optional): logger to visualize metrics.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -55,13 +55,13 @@ class WaveFlow():
|
||||||
parallel=False,
|
parallel=False,
|
||||||
rank=0,
|
rank=0,
|
||||||
nranks=1,
|
nranks=1,
|
||||||
tb_logger=None):
|
vdl_logger=None):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.checkpoint_dir = checkpoint_dir
|
self.checkpoint_dir = checkpoint_dir
|
||||||
self.parallel = parallel
|
self.parallel = parallel
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.nranks = nranks
|
self.nranks = nranks
|
||||||
self.tb_logger = tb_logger
|
self.vdl_logger = vdl_logger
|
||||||
self.dtype = "float16" if config.use_fp16 else "float32"
|
self.dtype = "float16" if config.use_fp16 else "float32"
|
||||||
|
|
||||||
def build(self, training=True):
|
def build(self, training=True):
|
||||||
|
@ -161,8 +161,8 @@ class WaveFlow():
|
||||||
load_time - start_time, graph_time - load_time)
|
load_time - start_time, graph_time - load_time)
|
||||||
print(log)
|
print(log)
|
||||||
|
|
||||||
tb = self.tb_logger
|
vdl_writer = self.vdl_logger
|
||||||
tb.add_scalar("Train-Loss-Rank-0", loss_val, iteration)
|
vdl_writer.add_scalar("Train-Loss-Rank-0", loss_val, iteration)
|
||||||
|
|
||||||
@dg.no_grad
|
@dg.no_grad
|
||||||
def valid_step(self, iteration):
|
def valid_step(self, iteration):
|
||||||
|
@ -175,7 +175,7 @@ class WaveFlow():
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
self.waveflow.eval()
|
self.waveflow.eval()
|
||||||
tb = self.tb_logger
|
vdl_writer = self.vdl_logger
|
||||||
|
|
||||||
total_loss = []
|
total_loss = []
|
||||||
sample_audios = []
|
sample_audios = []
|
||||||
|
@ -188,10 +188,12 @@ class WaveFlow():
|
||||||
|
|
||||||
# Visualize latent z and scale log_s.
|
# Visualize latent z and scale log_s.
|
||||||
if self.rank == 0 and i == 0:
|
if self.rank == 0 and i == 0:
|
||||||
tb.add_histogram("Valid-Latent_z", valid_z.numpy(), iteration)
|
vdl_writer.add_histogram("Valid-Latent_z", valid_z.numpy(),
|
||||||
|
iteration)
|
||||||
for j, valid_log_s in enumerate(valid_log_s_list):
|
for j, valid_log_s in enumerate(valid_log_s_list):
|
||||||
hist_name = "Valid-{}th-Flow-Log_s".format(j)
|
hist_name = "Valid-{}th-Flow-Log_s".format(j)
|
||||||
tb.add_histogram(hist_name, valid_log_s.numpy(), iteration)
|
vdl_writer.add_histogram(hist_name, valid_log_s.numpy(),
|
||||||
|
iteration)
|
||||||
|
|
||||||
valid_loss = self.criterion(valid_outputs)
|
valid_loss = self.criterion(valid_outputs)
|
||||||
total_loss.append(float(valid_loss.numpy()))
|
total_loss.append(float(valid_loss.numpy()))
|
||||||
|
@ -202,7 +204,7 @@ class WaveFlow():
|
||||||
log = "Test | Rank: {} AvgLoss: {:<8.3f} Time {:<8.3f}".format(
|
log = "Test | Rank: {} AvgLoss: {:<8.3f} Time {:<8.3f}".format(
|
||||||
self.rank, loss_val, total_time)
|
self.rank, loss_val, total_time)
|
||||||
print(log)
|
print(log)
|
||||||
tb.add_scalar("Valid-Avg-Loss", loss_val, iteration)
|
vdl_writer.add_scalar("Valid-Avg-Loss", loss_val, iteration)
|
||||||
|
|
||||||
@dg.no_grad
|
@dg.no_grad
|
||||||
def infer(self, iteration):
|
def infer(self, iteration):
|
||||||
|
|
|
@ -17,7 +17,6 @@ import os
|
||||||
import ruamel.yaml
|
import ruamel.yaml
|
||||||
import argparse
|
import argparse
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
fluid.require_version('1.8.0')
|
fluid.require_version('1.8.0')
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
|
|
|
@ -17,7 +17,7 @@ import os
|
||||||
import ruamel.yaml
|
import ruamel.yaml
|
||||||
import argparse
|
import argparse
|
||||||
import tqdm
|
import tqdm
|
||||||
from tensorboardX import SummaryWriter
|
from visualdl import LogWriter
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
fluid.require_version('1.8.0')
|
fluid.require_version('1.8.0')
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
|
@ -154,7 +154,7 @@ if __name__ == "__main__":
|
||||||
eval_interval = train_config["eval_interval"]
|
eval_interval = train_config["eval_interval"]
|
||||||
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||||
log_dir = os.path.join(args.output, "log")
|
log_dir = os.path.join(args.output, "log")
|
||||||
writer = SummaryWriter(log_dir)
|
writer = LogWriter(log_dir)
|
||||||
|
|
||||||
# load parameters and optimizer, and update iterations done so far
|
# load parameters and optimizer, and update iterations done so far
|
||||||
if args.checkpoint is not None:
|
if args.checkpoint is not None:
|
||||||
|
|
Loading…
Reference in New Issue