Dynamic to static
This commit is contained in:
parent
d96e2828b8
commit
ea5cb8e71f
|
@ -58,10 +58,10 @@ For more help on arguments
|
|||
|
||||
## Synthesis
|
||||
|
||||
After training the Tacotron2, spectrogram can be synthesized by running ``synthesis.py``.
|
||||
After training the Tacotron2, spectrogram can be synthesized by running ``synthesize.py``.
|
||||
|
||||
```bash
|
||||
python synthesis.py \
|
||||
python synthesize.py \
|
||||
--config=${CONFIGPATH} \
|
||||
--checkpoint_path=${CHECKPOINTPATH} \
|
||||
--input=${TEXTPATH} \
|
||||
|
|
|
@ -44,7 +44,8 @@ def fold(x, n_group):
|
|||
Tensor : [shape=(\*, time_steps // n_group, group)]
|
||||
Folded tensor.
|
||||
"""
|
||||
*spatial_shape, time_steps = x.shape
|
||||
spatial_shape = list(x.shape[:-1])
|
||||
time_steps = paddle.shape(x)[-1]
|
||||
new_shape = spatial_shape + [time_steps // n_group, n_group]
|
||||
return paddle.reshape(x, new_shape)
|
||||
|
||||
|
@ -232,7 +233,7 @@ class ResidualBlock(nn.Layer):
|
|||
"""
|
||||
if self.training:
|
||||
raise ValueError("Only use start sequence at evaluation mode.")
|
||||
self._conv_buffer = None
|
||||
self._conv_buffer = paddle.zeros([1])
|
||||
|
||||
# NOTE: call self.conv's weight norm hook expliccitly since
|
||||
# its weight will be visited directly in `add_input` without
|
||||
|
@ -263,10 +264,9 @@ class ResidualBlock(nn.Layer):
|
|||
A row of the skip output.
|
||||
"""
|
||||
x_row_in = x_row
|
||||
if self._conv_buffer is None:
|
||||
if len(paddle.shape(self._conv_buffer)) == 1:
|
||||
self._init_buffer(x_row)
|
||||
self._update_buffer(x_row)
|
||||
|
||||
rw = self.rw
|
||||
x_row = F.conv2d(
|
||||
self._conv_buffer,
|
||||
|
@ -275,7 +275,6 @@ class ResidualBlock(nn.Layer):
|
|||
padding=[0, 0, rw // 2, (rw - 1) // 2],
|
||||
dilation=self.dilations)
|
||||
x_row += self.condition_proj(condition_row)
|
||||
|
||||
content, gate = paddle.chunk(x_row, 2, axis=1)
|
||||
x_row = paddle.tanh(content) * F.sigmoid(gate)
|
||||
|
||||
|
@ -329,7 +328,7 @@ class ResidualNet(nn.LayerList):
|
|||
if len(dilations_h) != n_layer:
|
||||
raise ValueError(
|
||||
"number of dilations_h should equals num of layers")
|
||||
super().__init__()
|
||||
super(ResidualNet, self).__init__()
|
||||
for i in range(n_layer):
|
||||
dilation = (dilations_h[i], 2**i)
|
||||
layer = ResidualBlock(residual_channels, condition_channels,
|
||||
|
@ -539,27 +538,21 @@ class Flow(nn.Layer):
|
|||
transformation from x to z.
|
||||
"""
|
||||
z_0 = z[:, :, :1, :]
|
||||
x = []
|
||||
logs_list = []
|
||||
b_list = []
|
||||
x.append(z_0)
|
||||
x = paddle.zeros_like(z)
|
||||
x[:, :, :1, :] = z_0
|
||||
|
||||
self._start_sequence()
|
||||
for i in range(1, self.n_group):
|
||||
x_row = x[-1] # actuallt i-1:i
|
||||
|
||||
num_step = paddle.ones([1], dtype='int32') * (self.n_group)
|
||||
for i in range(1, num_step):
|
||||
x_row = x[:, :, i - 1:i, :]
|
||||
z_row = z[:, :, i:i + 1, :]
|
||||
condition_row = condition[:, :, i:i + 1, :]
|
||||
|
||||
x_next_row, (logs, b) = self._inverse_row(z_row, x_row,
|
||||
condition_row)
|
||||
x.append(x_next_row)
|
||||
logs_list.append(logs)
|
||||
b_list.append(b)
|
||||
|
||||
x = paddle.concat(x, 2)
|
||||
logs = paddle.concat(logs_list, 2)
|
||||
b = paddle.concat(b_list, 2)
|
||||
return x, (logs, b)
|
||||
condition_row)
|
||||
x[:, :, i:i+1, :] = x_next_row
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class WaveFlow(nn.LayerList):
|
||||
|
@ -611,16 +604,18 @@ class WaveFlow(nn.LayerList):
|
|||
perms = []
|
||||
for i in range(n_flows):
|
||||
if i < n_flows // 2:
|
||||
perms.append(indices[::-1])
|
||||
perm = indices[::-1]
|
||||
else:
|
||||
perm = list(reversed(indices[:half])) + list(
|
||||
reversed(indices[half:]))
|
||||
perms.append(perm)
|
||||
perm = paddle.to_tensor(perm)
|
||||
self.register_buffer(perm.name, perm)
|
||||
perms.append(perm)
|
||||
return perms
|
||||
|
||||
def _trim(self, x, condition):
|
||||
assert condition.shape[-1] >= x.shape[-1]
|
||||
pruned_len = int(x.shape[-1] // self.n_group * self.n_group)
|
||||
pruned_len = int(paddle.shape(x)[-1] // self.n_group * self.n_group)
|
||||
|
||||
if x.shape[-1] > pruned_len:
|
||||
x = x[:, :pruned_len]
|
||||
|
@ -707,7 +702,7 @@ class WaveFlow(nn.LayerList):
|
|||
for i in reversed(range(self.n_flows)):
|
||||
z = geo.shuffle_dim(z, 2, perm=self.perms[i])
|
||||
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
||||
z, (logs, b) = self[i].inverse(z, condition)
|
||||
z = self[i].inverse(z, condition)
|
||||
|
||||
x = paddle.squeeze(z, 1) # (B, H, W)
|
||||
batch_size = x.shape[0]
|
||||
|
@ -893,3 +888,21 @@ class WaveFlowLoss(nn.Layer):
|
|||
) - log_det_jacobian
|
||||
loss = loss / np.prod(z.shape)
|
||||
return loss + self.const
|
||||
|
||||
|
||||
class ConditionalWaveFlow2Infer(ConditionalWaveFlow):
|
||||
def forward(self, mel):
|
||||
"""Generate raw audio given mel spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mel : np.ndarray [shape=(C_mel, T_mel)]
|
||||
Mel spectrogram of an utterance(in log-magnitude).
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray [shape=(T,)]
|
||||
The synthesized audio.
|
||||
"""
|
||||
audio = self.predict(mel)
|
||||
return audio
|
||||
|
|
Loading…
Reference in New Issue