fix wavenet inference shape

This commit is contained in:
chenfeiyu 2020-12-16 00:22:43 +08:00
parent ab56eac676
commit bdf60bec39
1 changed files with 5 additions and 6 deletions

View File

@ -599,14 +599,13 @@ class ConditionalWaveNet(nn.Layer):
self.decoder.start_sequence() self.decoder.start_sequence()
x_t = paddle.zeros((batch_size, ), dtype=mel.dtype) x_t = paddle.zeros((batch_size, ), dtype=mel.dtype)
for i in trange(time_steps): for i in trange(time_steps):
c_t = condition[:, :, i] c_t = condition[:, :, i] # (B, C)
y_t = self.decoder.add_input(x_t, c_t) y_t = self.decoder.add_input(x_t, c_t) #(B, C)
y_t = paddle.unsqueeze(y_t, 1) y_t = paddle.unsqueeze(y_t, 1)
x_t = self.sample(y_t) x_t = self.sample(y_t) # (B, 1)
x_t = paddle.squeeze(x_t, 1) x_t = paddle.squeeze(x_t, 1) #(B,)
samples.append(x_t) samples.append(x_t)
samples = paddle.stack(samples, -1)
samples = paddle.concat(samples, -1)
return samples return samples
@paddle.no_grad() @paddle.no_grad()