This commit is contained in:
tlk-dsg 2021-10-29 18:34:51 +08:00
parent 1405fb744d
commit 3002fe112f
1 changed files with 4 additions and 1 deletions

View File

@ -4,7 +4,7 @@ import random
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import BartModel, BartTokenizer from transformers import BartModel, BartTokenizer
import os
def avg_token_embeddings(tokenizer: BartTokenizer, bart_model: BartModel, bart_name, num_tokens): def avg_token_embeddings(tokenizer: BartTokenizer, bart_model: BartModel, bart_name, num_tokens):
"""when initial added tokens, use their averge token emebddings """when initial added tokens, use their averge token emebddings
@ -149,6 +149,9 @@ def write_predictions(path, texts, labels):
""" """
print(len(texts), len(labels)) print(len(texts), len(labels))
assert len(texts) == len(labels) assert len(texts) == len(labels)
if not os.path.exists(path):
os.system(r"touch {}".format(path))
with open(path, "w", encoding="utf-8") as f: with open(path, "w", encoding="utf-8") as f:
f.writelines("-DOCSTART- O\n\n") f.writelines("-DOCSTART- O\n\n")
for i in range(len(texts)): for i in range(len(texts)):