fix weight init bug

This commit is contained in:
lilei 2021-11-25 20:19:22 +08:00 committed by GitHub
parent 397b247976
commit 7adfd02931
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -33,7 +33,7 @@ class PromptBartDecoder(nn.Module):
self.averge_weights = nn.ParameterList(parameters=None)
for id in label_ids:
if len(id) > 1:
self.averge_weights.append(nn.Parameter(torch.FloatTensor(len(id))))
self.averge_weights.append(nn.Parameter(torch.FloatTensor(len(id)).uniform_(1.0, 2.5)))
print(self.averge_weights)
mapping = [0, 2]
for id in label_ids:
@ -700,4 +700,4 @@ class BeamHypotheses(object):
elif self.early_stopping:
return True
else:
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty