rnn支持导出
This commit is contained in:
parent
2f9f258ff4
commit
65d3dfc729
|
@ -28,8 +28,9 @@ class Im2Seq(nn.Layer):
|
|||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x = x.reshape((B, -1, W))
|
||||
x = x.transpose((0, 2, 1)) # (NTC)(batch, width, channels)
|
||||
assert H == 1
|
||||
x = x.squeeze(axis=2)
|
||||
x = x.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -76,7 +77,8 @@ class SequenceEncoder(nn.Layer):
|
|||
'fc': EncoderWithFC,
|
||||
'rnn': EncoderWithRNN
|
||||
}
|
||||
assert encoder_type in support_encoder_dict, '{} must in {}'.format(encoder_type, support_encoder_dict.keys())
|
||||
assert encoder_type in support_encoder_dict, '{} must in {}'.format(
|
||||
encoder_type, support_encoder_dict.keys())
|
||||
|
||||
self.encoder = support_encoder_dict[encoder_type](
|
||||
self.encoder_reshape.out_channels, hidden_size)
|
||||
|
|
Loading…
Reference in New Issue