This commit is contained in:
tlk-dsg 2021-09-28 18:49:58 +08:00
parent 115f6ca852
commit c07eaa4ada
12 changed files with 16 additions and 12 deletions

View File

@ -11,7 +11,7 @@ from torch.utils.data import DataLoader
from deepke.name_entity_re.few_shot.models.model import PromptBartModel, PromptGeneratorModel
from deepke.name_entity_re.few_shot.module.datasets import ConllNERProcessor, ConllNERDataset
from deepke.name_entity_re.few_shot.module.train import Trainer
from deepke.name_entity_re.few_shot.utils.utils import set_seed
from deepke.name_entity_re.few_shot.utils.util import set_seed
from deepke.name_entity_re.few_shot.module.mapping_type import mit_movie_mapping, mit_restaurant_mapping, atis_mapping
import warnings

View File

@ -13,7 +13,7 @@ from deepke.name_entity_re.few_shot.models.model import PromptBartModel, PromptG
from deepke.name_entity_re.few_shot.module.datasets import ConllNERProcessor, ConllNERDataset
from deepke.name_entity_re.few_shot.module.train import Trainer
from deepke.name_entity_re.few_shot.module.metrics import Seq2SeqSpanMetric
from deepke.name_entity_re.few_shot.utils.utils import get_loss, set_seed
from deepke.name_entity_re.few_shot.utils.util import get_loss, set_seed
from deepke.name_entity_re.few_shot.module.mapping_type import mit_movie_mapping, mit_restaurant_mapping, atis_mapping
import warnings

View File

@ -1,7 +1,7 @@
from setuptools import setup, find_packages
setup(
name='deepke', # 打包后的包文件名
version='0.2.40', #版本号
version='0.2.60', #版本号
keywords=["pip", "RE","NER","AE"], # 关键字
description='DeepKE 是基于 Pytorch 的深度学习中文关系抽取处理套件。', # 说明
long_description="client", #详细说明

View File

@ -1 +1,3 @@
from .few_shot import *
from .models import *
from .module import *
from .utils import *

View File

@ -4,7 +4,7 @@ from torch.nn import functional as F
from transformers.configuration_bart import BartConfig
from .modeling_bart import BartModel, _prepare_bart_decoder_inputs
from ..utils.utils import avg_token_embeddings, seq_to_mask, _get_model_device
from ..utils import avg_token_embeddings, seq_to_mask,get_model_device
from functools import partial
from typing import Union

View File

@ -1,7 +1,7 @@
import torch
from torch import optim
from tqdm import tqdm
from ..utils.utils import convert_preds_to_outputs, write_predictions
from ..utils import convert_preds_to_outputs, write_predictions
import random
class Trainer(object):

View File

@ -1 +1 @@
from utils import *
from .util import *

View File

@ -63,7 +63,7 @@ def get_loss(tgt_tokens, tgt_seq_len, pred):
loss = F.cross_entropy(target=tgt_tokens, input=pred.transpose(1, 2))
return loss
def _get_model_device(model):
def get_model_device(model):
assert isinstance(model, nn.Module)
parameters = list(model.parameters())

View File

@ -0,0 +1,2 @@
from .models import *
from .tools import *

View File

@ -1,2 +1,2 @@
from .BasicNer import TrainNer
from .InferBert import InferNer
from .BasicNer import *
from .InferBert import *

View File

@ -1,3 +1,3 @@
from .dataset import *
from .preprocess import *
from .trainer import train
from .trainer import *

View File

@ -32,7 +32,7 @@ from hydra import utils
from .dataset import *
from .preprocess import *
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
from models.BasicNer import TrainNer
from ..models import TrainNer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',