add support for channel last in batch_spec, and Conv1dBatchNorm

This commit is contained in:
chenfeiyu 2020-10-30 15:13:57 +08:00
parent 36cc543348
commit 57d820f055
5 changed files with 97 additions and 57 deletions

View File

@ -75,19 +75,16 @@ def batch_wav(minibatch, pad_value=0., dtype=np.float32):
"""pad audios to the largest length and batch them.
Args:
minibatch (List[np.ndarray]): list of rank-1 float arrays(mono-channel audio, shape(T,)) or list of rank-2 float arrays(multi-channel audio, shape(C, T), C stands for numer of channels, T stands for length), dtype float.
minibatch (List[np.ndarray]): list of rank-1 float arrays(mono-channel audio, shape(T,)), dtype float.
pad_value (float, optional): the pad value. Defaults to 0..
dtype (np.dtype, optional): the data type of the output. Defaults to np.float32.
Returns:
np.ndarray: the output batch. It is a rank-2 float array of shape(B, T) if the minibatch is a list of mono-channel audios, or a rank-3 float array of shape(B, C, T) if the minibatch is a list of multi-channel audios.
np.ndarray: shape(B, T), the output batch.
"""
peek_example = minibatch[0]
if len(peek_example.shape) == 1:
mono_channel = True
elif len(peek_example.shape) == 2:
mono_channel = False
assert len(peek_example.shape) == 1, "we only handles mono-channel wav"
# assume (channel, n_samples) or (n_samples, )
lengths = [example.shape[-1] for example in minibatch]
@ -96,33 +93,27 @@ def batch_wav(minibatch, pad_value=0., dtype=np.float32):
batch = []
for example in minibatch:
pad_len = max_len - example.shape[-1]
if mono_channel:
batch.append(
np.pad(example, [(0, pad_len)],
mode='constant',
constant_values=pad_value))
else:
batch.append(
np.pad(example, [(0, 0), (0, pad_len)],
mode='constant',
constant_values=pad_value))
batch.append(
np.pad(example, [(0, pad_len)],
mode='constant',
constant_values=pad_value))
return np.array(batch, dtype=dtype)
class SpecBatcher(object):
"""A wrapper class for `batch_spec`"""
def __init__(self, pad_value=0., dtype=np.float32):
def __init__(self, pad_value=0., time_major=False, dtype=np.float32):
self.pad_value = pad_value
self.dtype = dtype
self.time_major = time_major
def __call__(self, minibatch):
out = batch_spec(minibatch, pad_value=self.pad_value, dtype=self.dtype)
out = batch_spec(minibatch, pad_value=self.pad_value, time_major=self.time_major, dtype=self.dtype)
return out
def batch_spec(minibatch, pad_value=0., dtype=np.float32):
def batch_spec(minibatch, pad_value=0., time_major=False, dtype=np.float32):
"""Pad spectra to the largest length and batch them.
Args:
@ -131,31 +122,28 @@ def batch_spec(minibatch, pad_value=0., dtype=np.float32):
dtype (np.dtype, optional): data type of the output. Defaults to np.float32.
Returns:
np.ndarray: a rank-3 array of shape(B, F, T) when the minibatch is a list of mono-channel spectrograms, or a rank-4 array of shape(B, C, F, T) when the minibatch is a list of multi-channel spectorgrams.
np.ndarray: a rank-3 array of shape(B, F, T) or (B, T, F).
"""
# assume (F, T) or (C, F, T)
# assume (F, T) or (T, F)
peek_example = minibatch[0]
if len(peek_example.shape) == 2:
mono_channel = True
elif len(peek_example.shape) == 3:
mono_channel = False
assert len(peek_example.shape) == 2, "we only handles mono channel spectrogram"
# assume (channel, F, n_frame) or (F, n_frame)
lengths = [example.shape[-1] for example in minibatch]
# assume (F, n_frame) or (n_frame, F)
time_idx = 0 if time_major else -1
lengths = [example.shape[time_idx] for example in minibatch]
max_len = np.max(lengths)
batch = []
for example in minibatch:
pad_len = max_len - example.shape[-1]
if mono_channel:
pad_len = max_len - example.shape[time_idx]
if time_major:
batch.append(
np.pad(example, [(0, 0), (0, pad_len)],
mode='constant',
constant_values=pad_value))
np.pad(example, [(0, pad_len), (0, 0)],
mode='constant',
constant_values=pad_value))
else:
batch.append(
np.pad(example, [(0, 0), (0, 0), (0, pad_len)],
mode='constant',
constant_values=pad_value))
np.pad(example, [(0, 0), (0, pad_len)],
mode='constant',
constant_values=pad_value))
return np.array(batch, dtype=dtype)

View File

@ -288,20 +288,47 @@ class TransformerDecoder(nn.LayerList):
class MLPPreNet(nn.Layer):
def __init__(self, d_input, d_hidden, d_output):
# (lin + relu + dropout) * n + last projection
super(MLPPreNet, self).__init__()
self.lin1 = nn.Linear(d_input, d_hidden)
self.lin2 = nn.Linear(d_hidden, d_output)
self.lin2 = nn.Linear(d_hidden, d_hidden)
self.lin3 = nn.Linear(d_output, d_output)
def forward(self, x, dropout):
# the original code said also use dropout in inference
l1 = F.dropout(F.relu(self.lin1(x)), dropout, training=self.training)
l2 = F.dropout(F.relu(self.lin2(l1)), dropout, training=self.training)
return l2
l3 = self.lin3(l2)
return l3
class CNNPreNet(nn.Layer):
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers,
dropout=0.):
# (conv + bn + relu + dropout) * n + last projection
super(CNNPreNet, self).__init__()
self.convs = nn.LayerList()
c_in = d_input
for _ in range(n_layers):
self.convs.append(
Conv1dBatchNorm(c_in, d_hidden, kernel_size,
weight_attr=I.XavierUniform(),
padding="same", data_format="NLC"))
c_in = d_hidden
self.affine_out = nn.Linear(d_hidden, d_output)
self.dropout = dropout
def forward(self, x):
for layer in self.convs:
x = F.dropout(F.relu(layer(x)), self.dropout, training=self.training)
x = self.affine_out(x)
return x
class CNNPostNet(nn.Layer):
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers):
super(CNNPostNet, self).__init__()
self.first_norm = nn.BatchNorm1D(d_output)
self.convs = nn.LayerList()
kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, )
padding = (kernel_size[0] - 1, 0)
@ -309,14 +336,23 @@ class CNNPostNet(nn.Layer):
c_in = d_input if i == 0 else d_hidden
c_out = d_output if i == n_layers - 1 else d_hidden
self.convs.append(
Conv1dBatchNorm(c_in, c_out, kernel_size, padding=padding))
self.last_norm = nn.BatchNorm1D(d_output)
Conv1dBatchNorm(c_in, c_out, kernel_size,
weight_attr=I.XavierUniform(),
padding=padding))
# for a layer that ends with a normalization layer that is targeted to
# output a non zero-central output, it may take a long time to
# train the scale and bias
# NOTE: it can also be a non-causal conv
def forward(self, x):
# why not use pre norms
x_in = x
for layer in self.convs:
x = paddle.tanh(layer(x))
x = self.last_norm(x + x_in)
x = self.first_norm(x)
for i, layer in enumerate(self.convs):
x = layer(x)
if i != (len(self.convs) - 1):
x = F.tanh(x)
x = x_in + x
return x
@ -326,7 +362,8 @@ class TransformerTTS(nn.Layer):
postnet_kernel_size, max_reduction_factor, dropout):
super(TransformerTTS, self).__init__()
# encoder
self.encoder_prenet = nn.Embedding(vocab_size, d_encoder, padding_idx)
self.embedding = nn.Embedding(vocab_size, d_encoder, padding_idx)
self.encoder_prenet = CNNPreNet(d_encoder, d_encoder, d_encoder, 5, 3, dropout)
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) # it may be extended later
self.encoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.))
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn, encoder_layers, dropout)
@ -366,7 +403,7 @@ class TransformerTTS(nn.Layer):
def encode(self, text):
T_enc = text.shape[-1]
embed = self.encoder_prenet(text)
embed = self.encoder_prenet(self.embedding(text))
if embed.shape[1] > self.encoder_pe.shape[0]:
new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2)
self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder)
@ -466,7 +503,7 @@ class TransformerTTSLoss(nn.Layer):
mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(mask.dtype)
stop_loss = L.masked_softmax_with_cross_entropy(stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
loss = mel_loss1 + mel_loss2 + stop_loss
loss = mel_loss1 + mel_loss2 + stop_loss
details = dict(
mel_loss1=mel_loss1, # ouput mel loss
mel_loss2=mel_loss2, # intermediate mel loss

View File

@ -83,15 +83,16 @@ class Conv1dCell(nn.Conv1D):
class Conv1dBatchNorm(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
weight_attr=None, bias_attr=None):
weight_attr=None, bias_attr=None, data_format="NCL"):
super(Conv1dBatchNorm, self).__init__()
# TODO(chenfeiyu): carefully initialize Conv1d's weight
self.conv = nn.Conv1D(in_channels, out_channels, kernel_size, stride,
padding=padding,
weight_attr=weight_attr,
bias_attr=bias_attr)
bias_attr=bias_attr,
data_format=data_format)
# TODO: channel last, but BatchNorm1d does not support channel last layout
self.bn = nn.BatchNorm1D(out_channels)
self.bn = nn.BatchNorm1D(out_channels, data_format=data_format)
def forward(self, x):
return self.bn(self.conv(x))

View File

@ -26,6 +26,14 @@ def summary(layer: nn.Layer):
print("layer has {} parameters, {} elements.".format(num_params,
num_elements))
def gradien_norm(layer: nn.Layer):
grad_norm_dict = {}
for name, param in layer.state_dict().items():
if param.trainable:
grad = param.gradient()
grad_norm_dict[name] = np.linalg.norm(grad) / grad.size
return grad_norm_dict
def freeze(layer: nn.Layer):
for param in layer.parameters():
param.trainable = False

View File

@ -32,20 +32,25 @@ class TestConv1dCell(unittest.TestCase):
class TestConv1dBatchNorm(unittest.TestCase):
def __init__(self, methodName="runTest", causal=False):
def __init__(self, methodName="runTest", causal=False, channel_last=False):
super(TestConv1dBatchNorm, self).__init__(methodName)
self.causal = causal
self.channel_last = channel_last
def setUp(self):
k = 5
paddding = (k - 1, 0) if self.causal else ((k-1) // 2, k //2)
self.net = conv.Conv1dBatchNorm(4, 6, (k,), 1, padding=paddding)
self.net = conv.Conv1dBatchNorm(4, 6, (k,), 1, padding=paddding,
data_format="NLC" if self.channel_last else "NCL")
def test_input_output(self):
x = paddle.randn([4, 4, 16])
x = paddle.randn([4, 16, 4]) if self.channel_last else paddle.randn([4, 4, 16])
out = self.net(x)
out_np = out.numpy()
self.assertTupleEqual(out_np.shape, (4, 6, 16))
if self.channel_last:
self.assertTupleEqual(out_np.shape, (4, 16, 6))
else:
self.assertTupleEqual(out_np.shape, (4, 6, 16))
def runTest(self):
self.test_input_output()
@ -53,9 +58,10 @@ class TestConv1dBatchNorm(unittest.TestCase):
def load_tests(loader, standard_tests, pattern):
suite = unittest.TestSuite()
suite.addTest(TestConv1dBatchNorm("runTest", True))
suite.addTest(TestConv1dBatchNorm("runTest", False))
suite.addTest(TestConv1dBatchNorm("runTest", True, True))
suite.addTest(TestConv1dBatchNorm("runTest", False, False))
suite.addTest(TestConv1dBatchNorm("runTest", True, False))
suite.addTest(TestConv1dBatchNorm("runTest", False, True))
suite.addTest(TestConv1dCell("test_equality"))
return suite