fix bug
This commit is contained in:
parent
1405fb744d
commit
3002fe112f
|
@ -4,7 +4,7 @@ import random
|
|||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import BartModel, BartTokenizer
|
||||
|
||||
import os
|
||||
|
||||
def avg_token_embeddings(tokenizer: BartTokenizer, bart_model: BartModel, bart_name, num_tokens):
|
||||
"""when initial added tokens, use their averge token emebddings
|
||||
|
@ -149,6 +149,9 @@ def write_predictions(path, texts, labels):
|
|||
"""
|
||||
print(len(texts), len(labels))
|
||||
assert len(texts) == len(labels)
|
||||
if not os.path.exists(path):
|
||||
os.system(r"touch {}".format(path))
|
||||
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.writelines("-DOCSTART- O\n\n")
|
||||
for i in range(len(texts)):
|
||||
|
|
Loading…
Reference in New Issue