test
This commit is contained in:
parent
115f6ca852
commit
c07eaa4ada
|
@ -11,7 +11,7 @@ from torch.utils.data import DataLoader
|
|||
from deepke.name_entity_re.few_shot.models.model import PromptBartModel, PromptGeneratorModel
|
||||
from deepke.name_entity_re.few_shot.module.datasets import ConllNERProcessor, ConllNERDataset
|
||||
from deepke.name_entity_re.few_shot.module.train import Trainer
|
||||
from deepke.name_entity_re.few_shot.utils.utils import set_seed
|
||||
from deepke.name_entity_re.few_shot.utils.util import set_seed
|
||||
from deepke.name_entity_re.few_shot.module.mapping_type import mit_movie_mapping, mit_restaurant_mapping, atis_mapping
|
||||
|
||||
import warnings
|
||||
|
|
|
@ -13,7 +13,7 @@ from deepke.name_entity_re.few_shot.models.model import PromptBartModel, PromptG
|
|||
from deepke.name_entity_re.few_shot.module.datasets import ConllNERProcessor, ConllNERDataset
|
||||
from deepke.name_entity_re.few_shot.module.train import Trainer
|
||||
from deepke.name_entity_re.few_shot.module.metrics import Seq2SeqSpanMetric
|
||||
from deepke.name_entity_re.few_shot.utils.utils import get_loss, set_seed
|
||||
from deepke.name_entity_re.few_shot.utils.util import get_loss, set_seed
|
||||
from deepke.name_entity_re.few_shot.module.mapping_type import mit_movie_mapping, mit_restaurant_mapping, atis_mapping
|
||||
|
||||
import warnings
|
||||
|
|
2
setup.py
2
setup.py
|
@ -1,7 +1,7 @@
|
|||
from setuptools import setup, find_packages
|
||||
setup(
|
||||
name='deepke', # 打包后的包文件名
|
||||
version='0.2.40', #版本号
|
||||
version='0.2.60', #版本号
|
||||
keywords=["pip", "RE","NER","AE"], # 关键字
|
||||
description='DeepKE 是基于 Pytorch 的深度学习中文关系抽取处理套件。', # 说明
|
||||
long_description="client", #详细说明
|
||||
|
|
|
@ -1 +1,3 @@
|
|||
from .few_shot import *
|
||||
from .models import *
|
||||
from .module import *
|
||||
from .utils import *
|
|
@ -4,7 +4,7 @@ from torch.nn import functional as F
|
|||
from transformers.configuration_bart import BartConfig
|
||||
from .modeling_bart import BartModel, _prepare_bart_decoder_inputs
|
||||
|
||||
from ..utils.utils import avg_token_embeddings, seq_to_mask, _get_model_device
|
||||
from ..utils import avg_token_embeddings, seq_to_mask,get_model_device
|
||||
from functools import partial
|
||||
from typing import Union
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import optim
|
||||
from tqdm import tqdm
|
||||
from ..utils.utils import convert_preds_to_outputs, write_predictions
|
||||
from ..utils import convert_preds_to_outputs, write_predictions
|
||||
import random
|
||||
|
||||
class Trainer(object):
|
||||
|
|
|
@ -1 +1 @@
|
|||
from utils import *
|
||||
from .util import *
|
|
@ -63,7 +63,7 @@ def get_loss(tgt_tokens, tgt_seq_len, pred):
|
|||
loss = F.cross_entropy(target=tgt_tokens, input=pred.transpose(1, 2))
|
||||
return loss
|
||||
|
||||
def _get_model_device(model):
|
||||
def get_model_device(model):
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
parameters = list(model.parameters())
|
|
@ -0,0 +1,2 @@
|
|||
from .models import *
|
||||
from .tools import *
|
|
@ -1,2 +1,2 @@
|
|||
from .BasicNer import TrainNer
|
||||
from .InferBert import InferNer
|
||||
from .BasicNer import *
|
||||
from .InferBert import *
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
from .dataset import *
|
||||
from .preprocess import *
|
||||
from .trainer import train
|
||||
from .trainer import *
|
||||
|
|
|
@ -32,7 +32,7 @@ from hydra import utils
|
|||
from .dataset import *
|
||||
from .preprocess import *
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
from models.BasicNer import TrainNer
|
||||
from ..models import TrainNer
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
|
|
Loading…
Reference in New Issue