diff --git a/parakeet/data/batch.py b/parakeet/data/batch.py index 355e570..a5be9f7 100644 --- a/parakeet/data/batch.py +++ b/parakeet/data/batch.py @@ -75,19 +75,16 @@ def batch_wav(minibatch, pad_value=0., dtype=np.float32): """pad audios to the largest length and batch them. Args: - minibatch (List[np.ndarray]): list of rank-1 float arrays(mono-channel audio, shape(T,)) or list of rank-2 float arrays(multi-channel audio, shape(C, T), C stands for numer of channels, T stands for length), dtype float. + minibatch (List[np.ndarray]): list of rank-1 float arrays(mono-channel audio, shape(T,)), dtype float. pad_value (float, optional): the pad value. Defaults to 0.. dtype (np.dtype, optional): the data type of the output. Defaults to np.float32. Returns: - np.ndarray: the output batch. It is a rank-2 float array of shape(B, T) if the minibatch is a list of mono-channel audios, or a rank-3 float array of shape(B, C, T) if the minibatch is a list of multi-channel audios. + np.ndarray: shape(B, T), the output batch. """ peek_example = minibatch[0] - if len(peek_example.shape) == 1: - mono_channel = True - elif len(peek_example.shape) == 2: - mono_channel = False + assert len(peek_example.shape) == 1, "we only handles mono-channel wav" # assume (channel, n_samples) or (n_samples, ) lengths = [example.shape[-1] for example in minibatch] @@ -96,33 +93,27 @@ def batch_wav(minibatch, pad_value=0., dtype=np.float32): batch = [] for example in minibatch: pad_len = max_len - example.shape[-1] - if mono_channel: - batch.append( - np.pad(example, [(0, pad_len)], - mode='constant', - constant_values=pad_value)) - else: - batch.append( - np.pad(example, [(0, 0), (0, pad_len)], - mode='constant', - constant_values=pad_value)) - + batch.append( + np.pad(example, [(0, pad_len)], + mode='constant', + constant_values=pad_value)) return np.array(batch, dtype=dtype) class SpecBatcher(object): """A wrapper class for `batch_spec`""" - def __init__(self, pad_value=0., dtype=np.float32): + def __init__(self, pad_value=0., time_major=False, dtype=np.float32): self.pad_value = pad_value self.dtype = dtype + self.time_major = time_major def __call__(self, minibatch): - out = batch_spec(minibatch, pad_value=self.pad_value, dtype=self.dtype) + out = batch_spec(minibatch, pad_value=self.pad_value, time_major=self.time_major, dtype=self.dtype) return out -def batch_spec(minibatch, pad_value=0., dtype=np.float32): +def batch_spec(minibatch, pad_value=0., time_major=False, dtype=np.float32): """Pad spectra to the largest length and batch them. Args: @@ -131,31 +122,28 @@ def batch_spec(minibatch, pad_value=0., dtype=np.float32): dtype (np.dtype, optional): data type of the output. Defaults to np.float32. Returns: - np.ndarray: a rank-3 array of shape(B, F, T) when the minibatch is a list of mono-channel spectrograms, or a rank-4 array of shape(B, C, F, T) when the minibatch is a list of multi-channel spectorgrams. + np.ndarray: a rank-3 array of shape(B, F, T) or (B, T, F). """ - # assume (F, T) or (C, F, T) + # assume (F, T) or (T, F) peek_example = minibatch[0] - if len(peek_example.shape) == 2: - mono_channel = True - elif len(peek_example.shape) == 3: - mono_channel = False + assert len(peek_example.shape) == 2, "we only handles mono channel spectrogram" - # assume (channel, F, n_frame) or (F, n_frame) - lengths = [example.shape[-1] for example in minibatch] + # assume (F, n_frame) or (n_frame, F) + time_idx = 0 if time_major else -1 + lengths = [example.shape[time_idx] for example in minibatch] max_len = np.max(lengths) batch = [] for example in minibatch: - pad_len = max_len - example.shape[-1] - if mono_channel: + pad_len = max_len - example.shape[time_idx] + if time_major: batch.append( - np.pad(example, [(0, 0), (0, pad_len)], - mode='constant', - constant_values=pad_value)) + np.pad(example, [(0, pad_len), (0, 0)], + mode='constant', + constant_values=pad_value)) else: batch.append( - np.pad(example, [(0, 0), (0, 0), (0, pad_len)], - mode='constant', - constant_values=pad_value)) - + np.pad(example, [(0, 0), (0, pad_len)], + mode='constant', + constant_values=pad_value)) return np.array(batch, dtype=dtype) diff --git a/parakeet/models/transformer_tts.py b/parakeet/models/transformer_tts.py index d1ddca6..c05e197 100644 --- a/parakeet/models/transformer_tts.py +++ b/parakeet/models/transformer_tts.py @@ -288,20 +288,47 @@ class TransformerDecoder(nn.LayerList): class MLPPreNet(nn.Layer): def __init__(self, d_input, d_hidden, d_output): + # (lin + relu + dropout) * n + last projection super(MLPPreNet, self).__init__() self.lin1 = nn.Linear(d_input, d_hidden) - self.lin2 = nn.Linear(d_hidden, d_output) + self.lin2 = nn.Linear(d_hidden, d_hidden) + self.lin3 = nn.Linear(d_output, d_output) def forward(self, x, dropout): # the original code said also use dropout in inference l1 = F.dropout(F.relu(self.lin1(x)), dropout, training=self.training) l2 = F.dropout(F.relu(self.lin2(l1)), dropout, training=self.training) - return l2 + l3 = self.lin3(l2) + return l3 + + +class CNNPreNet(nn.Layer): + def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers, + dropout=0.): + # (conv + bn + relu + dropout) * n + last projection + super(CNNPreNet, self).__init__() + self.convs = nn.LayerList() + c_in = d_input + for _ in range(n_layers): + self.convs.append( + Conv1dBatchNorm(c_in, d_hidden, kernel_size, + weight_attr=I.XavierUniform(), + padding="same", data_format="NLC")) + c_in = d_hidden + self.affine_out = nn.Linear(d_hidden, d_output) + self.dropout = dropout + + def forward(self, x): + for layer in self.convs: + x = F.dropout(F.relu(layer(x)), self.dropout, training=self.training) + x = self.affine_out(x) + return x class CNNPostNet(nn.Layer): def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers): super(CNNPostNet, self).__init__() + self.first_norm = nn.BatchNorm1D(d_output) self.convs = nn.LayerList() kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, ) padding = (kernel_size[0] - 1, 0) @@ -309,14 +336,23 @@ class CNNPostNet(nn.Layer): c_in = d_input if i == 0 else d_hidden c_out = d_output if i == n_layers - 1 else d_hidden self.convs.append( - Conv1dBatchNorm(c_in, c_out, kernel_size, padding=padding)) - self.last_norm = nn.BatchNorm1D(d_output) + Conv1dBatchNorm(c_in, c_out, kernel_size, + weight_attr=I.XavierUniform(), + padding=padding)) + # for a layer that ends with a normalization layer that is targeted to + # output a non zero-central output, it may take a long time to + # train the scale and bias + # NOTE: it can also be a non-causal conv def forward(self, x): + # why not use pre norms x_in = x - for layer in self.convs: - x = paddle.tanh(layer(x)) - x = self.last_norm(x + x_in) + x = self.first_norm(x) + for i, layer in enumerate(self.convs): + x = layer(x) + if i != (len(self.convs) - 1): + x = F.tanh(x) + x = x_in + x return x @@ -326,7 +362,8 @@ class TransformerTTS(nn.Layer): postnet_kernel_size, max_reduction_factor, dropout): super(TransformerTTS, self).__init__() # encoder - self.encoder_prenet = nn.Embedding(vocab_size, d_encoder, padding_idx) + self.embedding = nn.Embedding(vocab_size, d_encoder, padding_idx) + self.encoder_prenet = CNNPreNet(d_encoder, d_encoder, d_encoder, 5, 3, dropout) self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) # it may be extended later self.encoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.)) self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn, encoder_layers, dropout) @@ -366,7 +403,7 @@ class TransformerTTS(nn.Layer): def encode(self, text): T_enc = text.shape[-1] - embed = self.encoder_prenet(text) + embed = self.encoder_prenet(self.embedding(text)) if embed.shape[1] > self.encoder_pe.shape[0]: new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2) self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder) @@ -466,7 +503,7 @@ class TransformerTTSLoss(nn.Layer): mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(mask.dtype) stop_loss = L.masked_softmax_with_cross_entropy(stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1)) - loss = mel_loss1 + mel_loss2 + stop_loss + loss = mel_loss1 + mel_loss2 + stop_loss details = dict( mel_loss1=mel_loss1, # ouput mel loss mel_loss2=mel_loss2, # intermediate mel loss diff --git a/parakeet/modules/conv.py b/parakeet/modules/conv.py index e50d95a..c8f854c 100644 --- a/parakeet/modules/conv.py +++ b/parakeet/modules/conv.py @@ -83,15 +83,16 @@ class Conv1dCell(nn.Conv1D): class Conv1dBatchNorm(nn.Layer): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, - weight_attr=None, bias_attr=None): + weight_attr=None, bias_attr=None, data_format="NCL"): super(Conv1dBatchNorm, self).__init__() # TODO(chenfeiyu): carefully initialize Conv1d's weight self.conv = nn.Conv1D(in_channels, out_channels, kernel_size, stride, padding=padding, weight_attr=weight_attr, - bias_attr=bias_attr) + bias_attr=bias_attr, + data_format=data_format) # TODO: channel last, but BatchNorm1d does not support channel last layout - self.bn = nn.BatchNorm1D(out_channels) + self.bn = nn.BatchNorm1D(out_channels, data_format=data_format) def forward(self, x): return self.bn(self.conv(x)) diff --git a/parakeet/utils/layer_tools.py b/parakeet/utils/layer_tools.py index 82ec20a..cefcfaa 100644 --- a/parakeet/utils/layer_tools.py +++ b/parakeet/utils/layer_tools.py @@ -26,6 +26,14 @@ def summary(layer: nn.Layer): print("layer has {} parameters, {} elements.".format(num_params, num_elements)) +def gradien_norm(layer: nn.Layer): + grad_norm_dict = {} + for name, param in layer.state_dict().items(): + if param.trainable: + grad = param.gradient() + grad_norm_dict[name] = np.linalg.norm(grad) / grad.size + return grad_norm_dict + def freeze(layer: nn.Layer): for param in layer.parameters(): param.trainable = False diff --git a/tests/test_conv.py b/tests/test_conv.py index 72731a7..b76e719 100644 --- a/tests/test_conv.py +++ b/tests/test_conv.py @@ -32,20 +32,25 @@ class TestConv1dCell(unittest.TestCase): class TestConv1dBatchNorm(unittest.TestCase): - def __init__(self, methodName="runTest", causal=False): + def __init__(self, methodName="runTest", causal=False, channel_last=False): super(TestConv1dBatchNorm, self).__init__(methodName) self.causal = causal + self.channel_last = channel_last def setUp(self): k = 5 paddding = (k - 1, 0) if self.causal else ((k-1) // 2, k //2) - self.net = conv.Conv1dBatchNorm(4, 6, (k,), 1, padding=paddding) + self.net = conv.Conv1dBatchNorm(4, 6, (k,), 1, padding=paddding, + data_format="NLC" if self.channel_last else "NCL") def test_input_output(self): - x = paddle.randn([4, 4, 16]) + x = paddle.randn([4, 16, 4]) if self.channel_last else paddle.randn([4, 4, 16]) out = self.net(x) out_np = out.numpy() - self.assertTupleEqual(out_np.shape, (4, 6, 16)) + if self.channel_last: + self.assertTupleEqual(out_np.shape, (4, 16, 6)) + else: + self.assertTupleEqual(out_np.shape, (4, 6, 16)) def runTest(self): self.test_input_output() @@ -53,9 +58,10 @@ class TestConv1dBatchNorm(unittest.TestCase): def load_tests(loader, standard_tests, pattern): suite = unittest.TestSuite() - suite.addTest(TestConv1dBatchNorm("runTest", True)) - suite.addTest(TestConv1dBatchNorm("runTest", False)) - + suite.addTest(TestConv1dBatchNorm("runTest", True, True)) + suite.addTest(TestConv1dBatchNorm("runTest", False, False)) + suite.addTest(TestConv1dBatchNorm("runTest", True, False)) + suite.addTest(TestConv1dBatchNorm("runTest", False, True)) suite.addTest(TestConv1dCell("test_equality")) return suite \ No newline at end of file