Update Embedding.py
This commit is contained in:
parent
b100404be0
commit
8e7c15d914
|
@ -1,6 +1,5 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
|
@ -19,10 +18,9 @@ class Embedding(nn.Module):
|
|||
self.pos_dim = config.pos_dim if config.dim_strategy == 'cat' else config.word_dim
|
||||
self.dim_strategy = config.dim_strategy
|
||||
|
||||
self.wordEmbed = nn.Embedding(self.vocab_size,self.word_dim,padding_idx=0)
|
||||
self.headPosEmbed = nn.Embedding(self.pos_size,self.pos_dim,padding_idx=0)
|
||||
self.tailPosEmbed = nn.Embedding(self.pos_size,self.pos_dim,padding_idx=0)
|
||||
|
||||
self.wordEmbed = nn.Embedding(self.vocab_size, self.word_dim, padding_idx=0)
|
||||
self.headPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0)
|
||||
self.tailPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0)
|
||||
|
||||
def forward(self, *x):
|
||||
word, head, tail = x
|
||||
|
@ -31,9 +29,9 @@ class Embedding(nn.Module):
|
|||
tail_embedding = self.tailPosEmbed(tail)
|
||||
|
||||
if self.dim_strategy == 'cat':
|
||||
return torch.cat((word_embedding,head_embedding, tail_embedding), -1)
|
||||
return torch.cat((word_embedding, head_embedding, tail_embedding), -1)
|
||||
elif self.dim_strategy == 'sum':
|
||||
# 此时 pos_dim == word_dim
|
||||
return word_embedding + head_embedding + tail_embedding
|
||||
else:
|
||||
raise Exception('dim_strategy must choose from [sum, cat]')
|
||||
raise Exception('dim_strategy must choose from [sum, cat]')
|
||||
|
|
Loading…
Reference in New Issue