fix weight init bug
This commit is contained in:
parent
397b247976
commit
7adfd02931
|
@ -33,7 +33,7 @@ class PromptBartDecoder(nn.Module):
|
||||||
self.averge_weights = nn.ParameterList(parameters=None)
|
self.averge_weights = nn.ParameterList(parameters=None)
|
||||||
for id in label_ids:
|
for id in label_ids:
|
||||||
if len(id) > 1:
|
if len(id) > 1:
|
||||||
self.averge_weights.append(nn.Parameter(torch.FloatTensor(len(id))))
|
self.averge_weights.append(nn.Parameter(torch.FloatTensor(len(id)).uniform_(1.0, 2.5)))
|
||||||
print(self.averge_weights)
|
print(self.averge_weights)
|
||||||
mapping = [0, 2]
|
mapping = [0, 2]
|
||||||
for id in label_ids:
|
for id in label_ids:
|
||||||
|
@ -700,4 +700,4 @@ class BeamHypotheses(object):
|
||||||
elif self.early_stopping:
|
elif self.early_stopping:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
|
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
|
||||||
|
|
Loading…
Reference in New Issue