fix wavenet inference shape

This commit is contained in:
chenfeiyu 2020-12-16 00:22:43 +08:00
parent ab56eac676
commit bdf60bec39
1 changed files with 5 additions and 6 deletions

View File

@ -599,14 +599,13 @@ class ConditionalWaveNet(nn.Layer):
self.decoder.start_sequence()
x_t = paddle.zeros((batch_size, ), dtype=mel.dtype)
for i in trange(time_steps):
c_t = condition[:, :, i]
y_t = self.decoder.add_input(x_t, c_t)
c_t = condition[:, :, i] # (B, C)
y_t = self.decoder.add_input(x_t, c_t) #(B, C)
y_t = paddle.unsqueeze(y_t, 1)
x_t = self.sample(y_t)
x_t = paddle.squeeze(x_t, 1)
x_t = self.sample(y_t) # (B, 1)
x_t = paddle.squeeze(x_t, 1) #(B,)
samples.append(x_t)
samples = paddle.concat(samples, -1)
samples = paddle.stack(samples, -1)
return samples
@paddle.no_grad()