Merge branch 'master' of upstream
This commit is contained in:
commit
cd59c98637
|
@ -0,0 +1,103 @@
|
|||
# Clarinet
|
||||
|
||||
Paddle implementation of clarinet in dynamic graph, a convolutional network based vocoder. The implementation is based on the paper [ClariNet: Parallel Wave Generation in End-to-End Text-to-Speech](arxiv.org/abs/1807.07281).
|
||||
|
||||
|
||||
## Dataset
|
||||
|
||||
We experiment with the LJSpeech dataset. Download and unzip [LJSpeech](https://keithito.com/LJ-Speech-Dataset/).
|
||||
|
||||
```bash
|
||||
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
|
||||
tar xjvf LJSpeech-1.1.tar.bz2
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```text
|
||||
├── data.py data_processing
|
||||
├── configs/ (example) configuration file
|
||||
├── synthesis.py script to synthesize waveform from mel_spectrogram
|
||||
├── train.py script to train a model
|
||||
└── utils.py utility functions
|
||||
```
|
||||
|
||||
## Train
|
||||
|
||||
Train the model using train.py, follow the usage displayed by `python train.py --help`.
|
||||
|
||||
```text
|
||||
usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--output OUTPUT]
|
||||
[--data DATA] [--resume RESUME] [--wavenet WAVENET]
|
||||
|
||||
train a clarinet model with LJspeech and a trained wavenet model.
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--config CONFIG path of the config file.
|
||||
--device DEVICE device to use.
|
||||
--output OUTPUT path to save student.
|
||||
--data DATA path of LJspeech dataset.
|
||||
--resume RESUME checkpoint to load from.
|
||||
--wavenet WAVENET wavenet checkpoint to use.
|
||||
```
|
||||
|
||||
1. `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
|
||||
2. `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
|
||||
3. `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
|
||||
4. `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below.
|
||||
|
||||
```text
|
||||
├── checkpoints # checkpoint
|
||||
├── states # audio files generated at validation
|
||||
└── log # tensorboard log
|
||||
```
|
||||
|
||||
5. `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
||||
6. `--wavenet` is the path of the wavenet checkpoint to load. if you do not specify `--resume`, then this must be provided.
|
||||
|
||||
|
||||
Before you start training a clarinet model, you should have trained a wavenet model with single gaussian as output distribution. Make sure the config for teacher matches that for the trained model.
|
||||
|
||||
example script:
|
||||
|
||||
```bash
|
||||
python train.py --config=./configs/clarinet_ljspeech.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0 --conditioner=wavenet_checkpoint/conditioner --conditioner=wavenet_checkpoint/teacher
|
||||
```
|
||||
|
||||
You can monitor training log via tensorboard, using the script below.
|
||||
|
||||
```bash
|
||||
cd experiment/log
|
||||
tensorboard --logdir=.
|
||||
```
|
||||
|
||||
## Synthesis
|
||||
```text
|
||||
usage: synthesis.py [-h] [--config CONFIG] [--device DEVICE] [--data DATA]
|
||||
checkpoint output
|
||||
|
||||
train a clarinet model with LJspeech and a trained wavenet model.
|
||||
|
||||
positional arguments:
|
||||
checkpoint checkpoint to load from.
|
||||
output path to save student.
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--config CONFIG path of the config file.
|
||||
--device DEVICE device to use.
|
||||
--data DATA path of LJspeech dataset.
|
||||
```
|
||||
|
||||
1. `--config` is the configuration file to use. You should use the same configuration with which you train you model.
|
||||
2. `--data` is the path of the LJspeech dataset. A dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files.
|
||||
3. `checkpoint` is the checkpoint to load.
|
||||
4. `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`).
|
||||
5. `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
||||
|
||||
example script:
|
||||
|
||||
```bash
|
||||
python synthesis.py --config=./configs/wavenet_single_gaussian.yaml --data=./LJSpeech-1.1/ --device=0 experiment/checkpoints/step_500000 generated
|
||||
```
|
|
@ -0,0 +1,52 @@
|
|||
data:
|
||||
batch_size: 4
|
||||
train_clip_seconds: 0.5
|
||||
sample_rate: 22050
|
||||
hop_length: 256
|
||||
win_length: 1024
|
||||
n_fft: 2048
|
||||
|
||||
n_mels: 80
|
||||
valid_size: 16
|
||||
|
||||
|
||||
conditioner:
|
||||
upsampling_factors: [16, 16]
|
||||
|
||||
teacher:
|
||||
n_loop: 10
|
||||
n_layer: 3
|
||||
filter_size: 2
|
||||
residual_channels: 128
|
||||
loss_type: "mog"
|
||||
output_dim: 3
|
||||
log_scale_min: -9
|
||||
|
||||
student:
|
||||
n_loops: [10, 10, 10, 10, 10, 10]
|
||||
n_layers: [1, 1, 1, 1, 1, 1]
|
||||
filter_size: 3
|
||||
residual_channels: 64
|
||||
log_scale_min: -7
|
||||
|
||||
stft:
|
||||
n_fft: 2048
|
||||
win_length: 1024
|
||||
hop_length: 256
|
||||
|
||||
loss:
|
||||
lmd: 4
|
||||
|
||||
train:
|
||||
learning_rate: 0.0005
|
||||
anneal_rate: 0.5
|
||||
anneal_interval: 200000
|
||||
gradient_max_norm: 100.0
|
||||
|
||||
checkpoint_interval: 1000
|
||||
eval_interval: 1000
|
||||
|
||||
max_iterations: 2000000
|
||||
|
||||
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
data:
|
||||
root: "/workspace/datasets/LJSpeech-1.1/"
|
||||
batch_size: 4
|
||||
train_clip_seconds: 0.5
|
||||
sample_rate: 22050
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
data:
|
||||
root: "/workspace/datasets/LJSpeech-1.1/"
|
||||
batch_size: 4
|
||||
train_clip_seconds: 0.5
|
||||
sample_rate: 22050
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
data:
|
||||
root: "/workspace/datasets/LJSpeech-1.1/"
|
||||
batch_size: 4
|
||||
train_clip_seconds: 0.5
|
||||
sample_rate: 22050
|
||||
|
|
|
@ -56,7 +56,7 @@ def eval_model(model, valid_loader, output_dir, sample_rate):
|
|||
audio_clips, mel_specs, audio_starts = batch
|
||||
wav_var = model.synthesis(mel_specs)
|
||||
wav_np = wav_var.numpy()[0]
|
||||
sf.write(wav_np, path, samplerate=sample_rate)
|
||||
sf.write(path, wav_np, samplerate=sample_rate)
|
||||
print("generated {}".format(path))
|
||||
|
||||
|
||||
|
|
|
@ -134,7 +134,7 @@ class SliceDataset(DatasetMixin):
|
|||
format(len(order), len(dataset)))
|
||||
self._order = order
|
||||
|
||||
def len(self):
|
||||
def __len__(self):
|
||||
return self._size
|
||||
|
||||
def get_example(self, i):
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .net import *
|
||||
from .parallel_wavenet import *
|
|
@ -0,0 +1,169 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
from tqdm import trange
|
||||
|
||||
import paddle.fluid.layers as F
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid.initializer as I
|
||||
import paddle.fluid.layers.distributions as D
|
||||
|
||||
from parakeet.modules.weight_norm import Conv2DTranspose
|
||||
from parakeet.models.wavenet import crop, WaveNet, UpsampleNet
|
||||
from parakeet.models.clarinet.parallel_wavenet import ParallelWaveNet
|
||||
from parakeet.models.clarinet.utils import conv2d
|
||||
|
||||
|
||||
# Gaussian IAF model
|
||||
class Clarinet(dg.Layer):
|
||||
def __init__(self,
|
||||
encoder,
|
||||
teacher,
|
||||
student,
|
||||
stft,
|
||||
min_log_scale=-6.0,
|
||||
lmd=4.0):
|
||||
super(Clarinet, self).__init__()
|
||||
self.lmd = lmd
|
||||
self.encoder = encoder
|
||||
self.teacher = teacher
|
||||
self.student = student
|
||||
|
||||
self.min_log_scale = min_log_scale
|
||||
self.stft = stft
|
||||
|
||||
def forward(self, audio, mel, audio_start, clip_kl=True):
|
||||
"""Compute loss for a distill model
|
||||
|
||||
Arguments:
|
||||
audio {Variable} -- shape(batch_size, time_steps), target waveform.
|
||||
mel {Variable} -- shape(batch_size, condition_dim, time_steps // hop_length), original mel spectrogram, not upsampled yet.
|
||||
audio_starts {Variable} -- shape(batch_size, ), the index of the start sample.
|
||||
clip_kl (bool) -- whether to clip kl divergence if it is greater than 10.0.
|
||||
|
||||
Returns:
|
||||
Variable -- shape(1,), loss
|
||||
"""
|
||||
|
||||
batch_size, audio_length = audio.shape # audio clip's length
|
||||
|
||||
z = F.gaussian_random(audio.shape)
|
||||
condition = self.encoder(mel) # (B, C, T)
|
||||
condition_slice = crop(condition, audio_start, audio_length)
|
||||
|
||||
x, s_means, s_scales = self.student(z, condition_slice) # all [0: T]
|
||||
s_means = s_means[:, 1:] # (B, T-1), time steps [1: T]
|
||||
s_scales = s_scales[:, 1:] # (B, T-1), time steps [1: T]
|
||||
s_clipped_scales = F.clip(s_scales, self.min_log_scale, 100.)
|
||||
|
||||
# teacher outputs single gaussian
|
||||
y = self.teacher(x[:, :-1], condition_slice[:, :, 1:])
|
||||
_, t_means, t_scales = F.split(y, 3, -1) # time steps [1: T]
|
||||
t_means = F.squeeze(t_means, [-1]) # (B, T-1), time steps [1: T]
|
||||
t_scales = F.squeeze(t_scales, [-1]) # (B, T-1), time steps [1: T]
|
||||
t_clipped_scales = F.clip(t_scales, self.min_log_scale, 100.)
|
||||
|
||||
s_distribution = D.Normal(s_means, F.exp(s_clipped_scales))
|
||||
t_distribution = D.Normal(t_means, F.exp(t_clipped_scales))
|
||||
|
||||
# kl divergence loss, so we only need to sample once? no MC
|
||||
kl = s_distribution.kl_divergence(t_distribution)
|
||||
if clip_kl:
|
||||
kl = F.clip(kl, -100., 10.)
|
||||
# context size dropped
|
||||
kl = F.reduce_mean(kl[:, self.teacher.context_size:])
|
||||
# major diff here
|
||||
regularization = F.mse_loss(t_scales[:, self.teacher.context_size:],
|
||||
s_scales[:, self.teacher.context_size:])
|
||||
|
||||
# introduce information from real target
|
||||
spectrogram_frame_loss = F.mse_loss(
|
||||
self.stft.magnitude(audio), self.stft.magnitude(x))
|
||||
loss = kl + self.lmd * regularization + spectrogram_frame_loss
|
||||
loss_dict = {
|
||||
"loss": loss,
|
||||
"kl_divergence": kl,
|
||||
"regularization": regularization,
|
||||
"stft_loss": spectrogram_frame_loss
|
||||
}
|
||||
return loss_dict
|
||||
|
||||
@dg.no_grad
|
||||
def synthesis(self, mel):
|
||||
"""Synthesize waveform conditioned on the mel spectrogram.
|
||||
|
||||
Arguments:
|
||||
mel {Variable} -- shape(batch_size, frequqncy_bands, frames)
|
||||
|
||||
Returns:
|
||||
Variable -- shape(batch_size, frames * upsample_factor)
|
||||
"""
|
||||
condition = self.encoder(mel)
|
||||
samples_shape = (condition.shape[0], condition.shape[-1])
|
||||
z = F.gaussian_random(samples_shape)
|
||||
x, s_means, s_scales = self.student(z, condition)
|
||||
return x
|
||||
|
||||
|
||||
class STFT(dg.Layer):
|
||||
def __init__(self, n_fft, hop_length, win_length, window="hanning"):
|
||||
super(STFT, self).__init__()
|
||||
self.hop_length = hop_length
|
||||
self.n_bin = 1 + n_fft // 2
|
||||
self.n_fft = n_fft
|
||||
|
||||
# calculate window
|
||||
window = signal.get_window(window, win_length)
|
||||
if n_fft != win_length:
|
||||
pad = (n_fft - win_length) // 2
|
||||
window = np.pad(window, ((pad, pad), ), 'constant')
|
||||
|
||||
# calculate weights
|
||||
r = np.arange(0, n_fft)
|
||||
M = np.expand_dims(r, -1) * np.expand_dims(r, 0)
|
||||
w_real = np.reshape(window *
|
||||
np.cos(2 * np.pi * M / n_fft)[:self.n_bin],
|
||||
(self.n_bin, 1, 1, self.n_fft)).astype("float32")
|
||||
w_imag = np.reshape(window *
|
||||
np.sin(-2 * np.pi * M / n_fft)[:self.n_bin],
|
||||
(self.n_bin, 1, 1, self.n_fft)).astype("float32")
|
||||
|
||||
w = np.concatenate([w_real, w_imag], axis=0)
|
||||
self.weight = dg.to_variable(w)
|
||||
|
||||
def forward(self, x):
|
||||
# x(batch_size, time_steps)
|
||||
# pad it first with reflect mode
|
||||
pad_start = F.reverse(x[:, 1:1 + self.n_fft // 2], axis=1)
|
||||
pad_stop = F.reverse(x[:, -(1 + self.n_fft // 2):-1], axis=1)
|
||||
x = F.concat([pad_start, x, pad_stop], axis=-1)
|
||||
|
||||
# to BC1T, C=1
|
||||
x = F.unsqueeze(x, axes=[1, 2])
|
||||
out = conv2d(x, self.weight, stride=(1, self.hop_length))
|
||||
real, imag = F.split(out, 2, dim=1) # BC1T
|
||||
return real, imag
|
||||
|
||||
def power(self, x):
|
||||
real, imag = self(x)
|
||||
power = real**2 + imag**2
|
||||
return power
|
||||
|
||||
def magnitude(self, x):
|
||||
power = self.power(x)
|
||||
magnitude = F.sqrt(power)
|
||||
return magnitude
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import time
|
||||
import itertools
|
||||
import numpy as np
|
||||
|
||||
import paddle.fluid.layers as F
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid.initializer as I
|
||||
import paddle.fluid.layers.distributions as D
|
||||
|
||||
from parakeet.modules.weight_norm import Linear, Conv1D, Conv1DCell, Conv2DTranspose
|
||||
from parakeet.models.wavenet import WaveNet
|
||||
|
||||
|
||||
class ParallelWaveNet(dg.Layer):
|
||||
def __init__(self, n_loops, n_layers, residual_channels, condition_dim,
|
||||
filter_size):
|
||||
super(ParallelWaveNet, self).__init__()
|
||||
self.flows = dg.LayerList()
|
||||
for n_loop, n_layer in zip(n_loops, n_layers):
|
||||
# teacher's log_scale_min does not matter herem, -100 is a dummy value
|
||||
self.flows.append(
|
||||
WaveNet(n_loop, n_layer, residual_channels, 3, condition_dim,
|
||||
filter_size, "mog", -100.0))
|
||||
|
||||
def forward(self, z, condition=None):
|
||||
"""Inverse Autoregressive Flow. Several wavenets.
|
||||
|
||||
Arguments:
|
||||
z {Variable} -- shape(batch_size, time_steps), hidden variable, sampled from a standard normal distribution.
|
||||
|
||||
Keyword Arguments:
|
||||
condition {Variable} -- shape(batch_size, condition_dim, time_steps), condition, basically upsampled mel spectrogram. (default: {None})
|
||||
|
||||
Returns:
|
||||
Variable -- shape(batch_size, time_steps), transformed z.
|
||||
Variable -- shape(batch_size, time_steps), output distribution's mu.
|
||||
Variable -- shape(batch_size, time_steps), output distribution's log_std.
|
||||
"""
|
||||
|
||||
for i, flow in enumerate(self.flows):
|
||||
theta = flow(z, condition) # w, mu, log_std [0: T]
|
||||
w, mu, log_std = F.split(theta, 3, dim=-1) # (B, T, 1) for each
|
||||
mu = F.squeeze(mu, [-1]) #[0: T]
|
||||
log_std = F.squeeze(log_std, [-1]) #[0: T]
|
||||
z = z * F.exp(log_std) + mu #[0: T]
|
||||
|
||||
if i == 0:
|
||||
out_mu = mu
|
||||
out_log_std = log_std
|
||||
else:
|
||||
out_mu = out_mu * F.exp(log_std) + mu
|
||||
out_log_std += log_std
|
||||
|
||||
return z, out_mu, out_log_std
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle import fluid
|
||||
from paddle.fluid.core import ops
|
||||
|
||||
|
||||
@fluid.framework.dygraph_only
|
||||
def conv2d(input,
|
||||
weight,
|
||||
stride=(1, 1),
|
||||
padding=((0, 0), (0, 0)),
|
||||
dilation=(1, 1),
|
||||
groups=1,
|
||||
use_cudnn=True,
|
||||
data_format="NCHW"):
|
||||
padding = tuple(pad for pad_dim in padding for pad in pad_dim)
|
||||
|
||||
inputs = {
|
||||
'Input': [input],
|
||||
'Filter': [weight],
|
||||
}
|
||||
attrs = {
|
||||
'strides': stride,
|
||||
'paddings': padding,
|
||||
'dilations': dilation,
|
||||
'groups': groups,
|
||||
'use_cudnn': use_cudnn,
|
||||
'use_mkldnn': False,
|
||||
'fuse_relu_before_depthwise_conv': False,
|
||||
"padding_algorithm": "EXPLICIT",
|
||||
"data_format": data_format,
|
||||
}
|
||||
|
||||
outputs = ops.conv2d(inputs, attrs)
|
||||
out = outputs["Output"][0]
|
||||
return out
|
|
@ -57,7 +57,7 @@ class UpsampleNet(dg.Layer):
|
|||
"""
|
||||
|
||||
def __init__(self, upscale_factors=[16, 16]):
|
||||
super().__init__()
|
||||
super(UpsampleNet, self).__init__()
|
||||
self.upscale_factors = list(upscale_factors)
|
||||
self.upsample_convs = dg.LayerList()
|
||||
for i, factor in enumerate(upscale_factors):
|
||||
|
@ -92,7 +92,7 @@ class UpsampleNet(dg.Layer):
|
|||
# AutoRegressive Model
|
||||
class ConditionalWavenet(dg.Layer):
|
||||
def __init__(self, encoder: UpsampleNet, decoder: WaveNet):
|
||||
super().__init__()
|
||||
super(ConditionalWavenet, self).__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ def dequantize(quantized, n_bands):
|
|||
class ResidualBlock(dg.Layer):
|
||||
def __init__(self, residual_channels, condition_dim, filter_size,
|
||||
dilation):
|
||||
super().__init__()
|
||||
super(ResidualBlock, self).__init__()
|
||||
dilated_channels = 2 * residual_channels
|
||||
# following clarinet's implementation, we do not have parametric residual
|
||||
# & skip connection.
|
||||
|
@ -135,7 +135,7 @@ class ResidualBlock(dg.Layer):
|
|||
class ResidualNet(dg.Layer):
|
||||
def __init__(self, n_loop, n_layer, residual_channels, condition_dim,
|
||||
filter_size):
|
||||
super().__init__()
|
||||
super(ResidualNet, self).__init__()
|
||||
# double the dilation at each layer in a loop(n_loop layers)
|
||||
dilations = [2**i for i in range(n_loop)] * n_layer
|
||||
self.context_size = 1 + sum(dilations)
|
||||
|
@ -198,7 +198,7 @@ class ResidualNet(dg.Layer):
|
|||
class WaveNet(dg.Layer):
|
||||
def __init__(self, n_loop, n_layer, residual_channels, output_dim,
|
||||
condition_dim, filter_size, loss_type, log_scale_min):
|
||||
super().__init__()
|
||||
super(WaveNet, self).__init__()
|
||||
if loss_type not in ["softmax", "mog"]:
|
||||
raise ValueError("loss_type {} is not supported".format(loss_type))
|
||||
if loss_type == "softmax":
|
||||
|
|
Loading…
Reference in New Issue