Update Embedding.py
This commit is contained in:
parent
b100404be0
commit
8e7c15d914
|
@ -1,6 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
|
||||||
|
|
||||||
|
|
||||||
class Embedding(nn.Module):
|
class Embedding(nn.Module):
|
||||||
|
@ -19,10 +18,9 @@ class Embedding(nn.Module):
|
||||||
self.pos_dim = config.pos_dim if config.dim_strategy == 'cat' else config.word_dim
|
self.pos_dim = config.pos_dim if config.dim_strategy == 'cat' else config.word_dim
|
||||||
self.dim_strategy = config.dim_strategy
|
self.dim_strategy = config.dim_strategy
|
||||||
|
|
||||||
self.wordEmbed = nn.Embedding(self.vocab_size,self.word_dim,padding_idx=0)
|
self.wordEmbed = nn.Embedding(self.vocab_size, self.word_dim, padding_idx=0)
|
||||||
self.headPosEmbed = nn.Embedding(self.pos_size,self.pos_dim,padding_idx=0)
|
self.headPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0)
|
||||||
self.tailPosEmbed = nn.Embedding(self.pos_size,self.pos_dim,padding_idx=0)
|
self.tailPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, *x):
|
def forward(self, *x):
|
||||||
word, head, tail = x
|
word, head, tail = x
|
||||||
|
@ -31,7 +29,7 @@ class Embedding(nn.Module):
|
||||||
tail_embedding = self.tailPosEmbed(tail)
|
tail_embedding = self.tailPosEmbed(tail)
|
||||||
|
|
||||||
if self.dim_strategy == 'cat':
|
if self.dim_strategy == 'cat':
|
||||||
return torch.cat((word_embedding,head_embedding, tail_embedding), -1)
|
return torch.cat((word_embedding, head_embedding, tail_embedding), -1)
|
||||||
elif self.dim_strategy == 'sum':
|
elif self.dim_strategy == 'sum':
|
||||||
# 此时 pos_dim == word_dim
|
# 此时 pos_dim == word_dim
|
||||||
return word_embedding + head_embedding + tail_embedding
|
return word_embedding + head_embedding + tail_embedding
|
||||||
|
|
Loading…
Reference in New Issue