test
This commit is contained in:
parent
1f001c7d1a
commit
59985352c5
|
@ -27,8 +27,8 @@ class Embedding(nn.Module):
|
|||
def forward(self, *x):
|
||||
word, entity, attribute_key = x
|
||||
word_embedding = self.wordEmbed(word)
|
||||
entity_embedding = self.entityPosEmbed(head)
|
||||
attribute_key_embedding = self.attribute_keyPosEmbed(tail)
|
||||
entity_embedding = self.entityPosEmbed(entity)
|
||||
attribute_key_embedding = self.attribute_keyPosEmbed(attribute_key)
|
||||
|
||||
if self.dim_strategy == 'cat':
|
||||
return torch.cat((word_embedding, entity_embedding, attribute_key_embedding), -1)
|
||||
|
|
|
@ -447,8 +447,8 @@
|
|||
" def forward(self, *x):\n",
|
||||
" word, entity, attribute_key = x\n",
|
||||
" word_embedding = self.wordEmbed(word)\n",
|
||||
" entity_embedding = self.entityPosEmbed(head)\n",
|
||||
" attribute_key_embedding = self.attribute_keyPosEmbed(tail)\n",
|
||||
" entity_embedding = self.entityPosEmbed(entity)\n",
|
||||
" attribute_key_embedding = self.attribute_keyPosEmbed(attribute_key)\n",
|
||||
"\n",
|
||||
" if self.dim_strategy == 'cat':\n",
|
||||
" return torch.cat((word_embedding, entity_embedding, attribute_key_embedding), -1)\n",
|
||||
|
|
Loading…
Reference in New Issue