diff --git a/module/Embedding.py b/module/Embedding.py index 8568821..25b14be 100644 --- a/module/Embedding.py +++ b/module/Embedding.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence class Embedding(nn.Module): @@ -19,10 +18,9 @@ class Embedding(nn.Module): self.pos_dim = config.pos_dim if config.dim_strategy == 'cat' else config.word_dim self.dim_strategy = config.dim_strategy - self.wordEmbed = nn.Embedding(self.vocab_size,self.word_dim,padding_idx=0) - self.headPosEmbed = nn.Embedding(self.pos_size,self.pos_dim,padding_idx=0) - self.tailPosEmbed = nn.Embedding(self.pos_size,self.pos_dim,padding_idx=0) - + self.wordEmbed = nn.Embedding(self.vocab_size, self.word_dim, padding_idx=0) + self.headPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0) + self.tailPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0) def forward(self, *x): word, head, tail = x @@ -31,9 +29,9 @@ class Embedding(nn.Module): tail_embedding = self.tailPosEmbed(tail) if self.dim_strategy == 'cat': - return torch.cat((word_embedding,head_embedding, tail_embedding), -1) + return torch.cat((word_embedding, head_embedding, tail_embedding), -1) elif self.dim_strategy == 'sum': # 此时 pos_dim == word_dim return word_embedding + head_embedding + tail_embedding else: - raise Exception('dim_strategy must choose from [sum, cat]') \ No newline at end of file + raise Exception('dim_strategy must choose from [sum, cat]')