diff --git a/src/deepke/name_entity_re/few_shot/utils/util.py b/src/deepke/name_entity_re/few_shot/utils/util.py index 9cdbdfa..d0d9bcf 100644 --- a/src/deepke/name_entity_re/few_shot/utils/util.py +++ b/src/deepke/name_entity_re/few_shot/utils/util.py @@ -4,7 +4,7 @@ import random from torch import nn import torch.nn.functional as F from transformers import BartModel, BartTokenizer - +import os def avg_token_embeddings(tokenizer: BartTokenizer, bart_model: BartModel, bart_name, num_tokens): """when initial added tokens, use their averge token emebddings @@ -149,6 +149,9 @@ def write_predictions(path, texts, labels): """ print(len(texts), len(labels)) assert len(texts) == len(labels) + if not os.path.exists(path): + os.system(r"touch {}".format(path)) + with open(path, "w", encoding="utf-8") as f: f.writelines("-DOCSTART- O\n\n") for i in range(len(texts)):