fix bug
This commit is contained in:
parent
1405fb744d
commit
3002fe112f
|
@ -4,7 +4,7 @@ import random
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import BartModel, BartTokenizer
|
from transformers import BartModel, BartTokenizer
|
||||||
|
import os
|
||||||
|
|
||||||
def avg_token_embeddings(tokenizer: BartTokenizer, bart_model: BartModel, bart_name, num_tokens):
|
def avg_token_embeddings(tokenizer: BartTokenizer, bart_model: BartModel, bart_name, num_tokens):
|
||||||
"""when initial added tokens, use their averge token emebddings
|
"""when initial added tokens, use their averge token emebddings
|
||||||
|
@ -149,6 +149,9 @@ def write_predictions(path, texts, labels):
|
||||||
"""
|
"""
|
||||||
print(len(texts), len(labels))
|
print(len(texts), len(labels))
|
||||||
assert len(texts) == len(labels)
|
assert len(texts) == len(labels)
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.system(r"touch {}".format(path))
|
||||||
|
|
||||||
with open(path, "w", encoding="utf-8") as f:
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
f.writelines("-DOCSTART- O\n\n")
|
f.writelines("-DOCSTART- O\n\n")
|
||||||
for i in range(len(texts)):
|
for i in range(len(texts)):
|
||||||
|
|
Loading…
Reference in New Issue