This commit is contained in:
tlk-dsg 2021-09-14 20:51:00 +08:00
parent 1f001c7d1a
commit 59985352c5
2 changed files with 4 additions and 4 deletions
src/deepke/ae/regular/module
tutorial-notebooks/ae/regular

View File

@ -27,8 +27,8 @@ class Embedding(nn.Module):
def forward(self, *x):
word, entity, attribute_key = x
word_embedding = self.wordEmbed(word)
entity_embedding = self.entityPosEmbed(head)
attribute_key_embedding = self.attribute_keyPosEmbed(tail)
entity_embedding = self.entityPosEmbed(entity)
attribute_key_embedding = self.attribute_keyPosEmbed(attribute_key)
if self.dim_strategy == 'cat':
return torch.cat((word_embedding, entity_embedding, attribute_key_embedding), -1)

View File

@ -447,8 +447,8 @@
" def forward(self, *x):\n",
" word, entity, attribute_key = x\n",
" word_embedding = self.wordEmbed(word)\n",
" entity_embedding = self.entityPosEmbed(head)\n",
" attribute_key_embedding = self.attribute_keyPosEmbed(tail)\n",
" entity_embedding = self.entityPosEmbed(entity)\n",
" attribute_key_embedding = self.attribute_keyPosEmbed(attribute_key)\n",
"\n",
" if self.dim_strategy == 'cat':\n",
" return torch.cat((word_embedding, entity_embedding, attribute_key_embedding), -1)\n",