fix wavenet inference shape
This commit is contained in:
parent
ab56eac676
commit
bdf60bec39
|
@ -599,14 +599,13 @@ class ConditionalWaveNet(nn.Layer):
|
||||||
self.decoder.start_sequence()
|
self.decoder.start_sequence()
|
||||||
x_t = paddle.zeros((batch_size, ), dtype=mel.dtype)
|
x_t = paddle.zeros((batch_size, ), dtype=mel.dtype)
|
||||||
for i in trange(time_steps):
|
for i in trange(time_steps):
|
||||||
c_t = condition[:, :, i]
|
c_t = condition[:, :, i] # (B, C)
|
||||||
y_t = self.decoder.add_input(x_t, c_t)
|
y_t = self.decoder.add_input(x_t, c_t) #(B, C)
|
||||||
y_t = paddle.unsqueeze(y_t, 1)
|
y_t = paddle.unsqueeze(y_t, 1)
|
||||||
x_t = self.sample(y_t)
|
x_t = self.sample(y_t) # (B, 1)
|
||||||
x_t = paddle.squeeze(x_t, 1)
|
x_t = paddle.squeeze(x_t, 1) #(B,)
|
||||||
samples.append(x_t)
|
samples.append(x_t)
|
||||||
|
samples = paddle.stack(samples, -1)
|
||||||
samples = paddle.concat(samples, -1)
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
@paddle.no_grad()
|
@paddle.no_grad()
|
||||||
|
|
Loading…
Reference in New Issue