diff --git a/parakeet/models/wavenet.py b/parakeet/models/wavenet.py index 135c8e4..fed3dd8 100644 --- a/parakeet/models/wavenet.py +++ b/parakeet/models/wavenet.py @@ -599,14 +599,13 @@ class ConditionalWaveNet(nn.Layer): self.decoder.start_sequence() x_t = paddle.zeros((batch_size, ), dtype=mel.dtype) for i in trange(time_steps): - c_t = condition[:, :, i] - y_t = self.decoder.add_input(x_t, c_t) + c_t = condition[:, :, i] # (B, C) + y_t = self.decoder.add_input(x_t, c_t) #(B, C) y_t = paddle.unsqueeze(y_t, 1) - x_t = self.sample(y_t) - x_t = paddle.squeeze(x_t, 1) + x_t = self.sample(y_t) # (B, 1) + x_t = paddle.squeeze(x_t, 1) #(B,) samples.append(x_t) - - samples = paddle.concat(samples, -1) + samples = paddle.stack(samples, -1) return samples @paddle.no_grad()