Update Embedding.py

This commit is contained in:
leo 2019-12-03 22:43:06 +08:00
parent b100404be0
commit 8e7c15d914
1 changed files with 5 additions and 7 deletions

View File

@ -1,6 +1,5 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
class Embedding(nn.Module): class Embedding(nn.Module):
@ -19,10 +18,9 @@ class Embedding(nn.Module):
self.pos_dim = config.pos_dim if config.dim_strategy == 'cat' else config.word_dim self.pos_dim = config.pos_dim if config.dim_strategy == 'cat' else config.word_dim
self.dim_strategy = config.dim_strategy self.dim_strategy = config.dim_strategy
self.wordEmbed = nn.Embedding(self.vocab_size,self.word_dim,padding_idx=0) self.wordEmbed = nn.Embedding(self.vocab_size, self.word_dim, padding_idx=0)
self.headPosEmbed = nn.Embedding(self.pos_size,self.pos_dim,padding_idx=0) self.headPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0)
self.tailPosEmbed = nn.Embedding(self.pos_size,self.pos_dim,padding_idx=0) self.tailPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0)
def forward(self, *x): def forward(self, *x):
word, head, tail = x word, head, tail = x
@ -31,7 +29,7 @@ class Embedding(nn.Module):
tail_embedding = self.tailPosEmbed(tail) tail_embedding = self.tailPosEmbed(tail)
if self.dim_strategy == 'cat': if self.dim_strategy == 'cat':
return torch.cat((word_embedding,head_embedding, tail_embedding), -1) return torch.cat((word_embedding, head_embedding, tail_embedding), -1)
elif self.dim_strategy == 'sum': elif self.dim_strategy == 'sum':
# 此时 pos_dim == word_dim # 此时 pos_dim == word_dim
return word_embedding + head_embedding + tail_embedding return word_embedding + head_embedding + tail_embedding