fix weight init bug
This commit is contained in:
parent
397b247976
commit
7adfd02931
|
@ -33,7 +33,7 @@ class PromptBartDecoder(nn.Module):
|
|||
self.averge_weights = nn.ParameterList(parameters=None)
|
||||
for id in label_ids:
|
||||
if len(id) > 1:
|
||||
self.averge_weights.append(nn.Parameter(torch.FloatTensor(len(id))))
|
||||
self.averge_weights.append(nn.Parameter(torch.FloatTensor(len(id)).uniform_(1.0, 2.5)))
|
||||
print(self.averge_weights)
|
||||
mapping = [0, 2]
|
||||
for id in label_ids:
|
||||
|
|
Loading…
Reference in New Issue