diff --git a/README.md b/README.md index 1d5027f..c53337a 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ DeepKE 提供了多种知识抽取模型。 1. RE ``` - 1.REGULAR + 1.STANDARD 2.FEW-SHOT @@ -53,12 +53,12 @@ DeepKE 提供了多种知识抽取模型。 2. NER ``` - REGULAR + STANDARD ``` 3. AE ``` - REGULAR + STANDARD ``` @@ -76,7 +76,7 @@ DeepKE 提供了多种知识抽取模型。 具体流程请进入详细的README中,RE包括了以下三个子功能 - **[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/re/regular/README.md)** + **[STANDARD](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/re/regular/README.md)** FEW-SHORT @@ -94,7 +94,7 @@ DeepKE 提供了多种知识抽取模型。 具体流程请进入详细的README中: - **[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ner/regular/README.md)** + **[STANDARD](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ner/regular/README.md)** 3. AE @@ -108,7 +108,7 @@ DeepKE 提供了多种知识抽取模型。 具体流程请进入详细的README中: - **[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ae/regular/README.md)** + **[STANDARD](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ae/regular/README.md)** @@ -127,6 +127,8 @@ Deepke的架构图如下所示 1. 安装后提示 `ModuleNotFoundError: No module named 'past'`,输入命令 `pip install future` 即可解决。 +1. 使用语言预训练模型时,在线安装下载模型比较慢,更建议提前下载好,存放到 pretrained 文件夹内。具体存放文件要求见文件夹内的 readme.md。 + ## 致谢 diff --git a/example/ae/regular/README.md b/example/ae/standard/README.md similarity index 100% rename from example/ae/regular/README.md rename to example/ae/standard/README.md diff --git a/example/ae/regular/conf/config.yaml b/example/ae/standard/conf/config.yaml similarity index 100% rename from example/ae/regular/conf/config.yaml rename to example/ae/standard/conf/config.yaml diff --git a/example/ae/regular/conf/embedding.yaml b/example/ae/standard/conf/embedding.yaml similarity index 100% rename from example/ae/regular/conf/embedding.yaml rename to example/ae/standard/conf/embedding.yaml diff --git a/example/ae/regular/conf/hydra/output/custom.yaml b/example/ae/standard/conf/hydra/output/custom.yaml similarity index 100% rename from example/ae/regular/conf/hydra/output/custom.yaml rename to example/ae/standard/conf/hydra/output/custom.yaml diff --git a/example/ae/regular/conf/model/capsule.yaml b/example/ae/standard/conf/model/capsule.yaml similarity index 100% rename from example/ae/regular/conf/model/capsule.yaml rename to example/ae/standard/conf/model/capsule.yaml diff --git a/example/ae/regular/conf/model/cnn.yaml b/example/ae/standard/conf/model/cnn.yaml similarity index 100% rename from example/ae/regular/conf/model/cnn.yaml rename to example/ae/standard/conf/model/cnn.yaml diff --git a/example/ae/regular/conf/model/gcn.yaml b/example/ae/standard/conf/model/gcn.yaml similarity index 100% rename from example/ae/regular/conf/model/gcn.yaml rename to example/ae/standard/conf/model/gcn.yaml diff --git a/example/ae/regular/conf/model/lm.yaml b/example/ae/standard/conf/model/lm.yaml similarity index 100% rename from example/ae/regular/conf/model/lm.yaml rename to example/ae/standard/conf/model/lm.yaml diff --git a/example/ae/regular/conf/model/rnn.yaml b/example/ae/standard/conf/model/rnn.yaml similarity index 100% rename from example/ae/regular/conf/model/rnn.yaml rename to example/ae/standard/conf/model/rnn.yaml diff --git a/example/ae/regular/conf/model/transformer.yaml b/example/ae/standard/conf/model/transformer.yaml similarity index 100% rename from example/ae/regular/conf/model/transformer.yaml rename to example/ae/standard/conf/model/transformer.yaml diff --git a/example/ae/regular/conf/predict.yaml b/example/ae/standard/conf/predict.yaml similarity index 100% rename from example/ae/regular/conf/predict.yaml rename to example/ae/standard/conf/predict.yaml diff --git a/example/ae/regular/conf/preprocess.yaml b/example/ae/standard/conf/preprocess.yaml similarity index 100% rename from example/ae/regular/conf/preprocess.yaml rename to example/ae/standard/conf/preprocess.yaml diff --git a/example/ae/regular/conf/train.yaml b/example/ae/standard/conf/train.yaml similarity index 100% rename from example/ae/regular/conf/train.yaml rename to example/ae/standard/conf/train.yaml diff --git a/example/ae/regular/data/origin/attribute.csv b/example/ae/standard/data/origin/attribute.csv similarity index 100% rename from example/ae/regular/data/origin/attribute.csv rename to example/ae/standard/data/origin/attribute.csv diff --git a/example/ae/regular/data/origin/test.csv b/example/ae/standard/data/origin/test.csv similarity index 100% rename from example/ae/regular/data/origin/test.csv rename to example/ae/standard/data/origin/test.csv diff --git a/example/ae/regular/data/origin/train.csv b/example/ae/standard/data/origin/train.csv similarity index 100% rename from example/ae/regular/data/origin/train.csv rename to example/ae/standard/data/origin/train.csv diff --git a/example/ae/regular/data/origin/valid.csv b/example/ae/standard/data/origin/valid.csv similarity index 100% rename from example/ae/regular/data/origin/valid.csv rename to example/ae/standard/data/origin/valid.csv diff --git a/example/ae/regular/predict.py b/example/ae/standard/predict.py similarity index 98% rename from example/ae/regular/predict.py rename to example/ae/standard/predict.py index f296a6c..0a6a456 100644 --- a/example/ae/regular/predict.py +++ b/example/ae/standard/predict.py @@ -8,8 +8,8 @@ from deepke.ae_re_tools import Serializer from deepke.ae_re_tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_attribute_data import matplotlib.pyplot as plt sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) -from deepke.ae_re_utils import load_pkl, load_csv -import deepke.ae_re_models as models +from deepke.ae_st_utils import load_pkl, load_csv +import deepke.ae_st_models as models logger = logging.getLogger(__name__) diff --git a/example/ae/regular/requirements.txt b/example/ae/standard/requirements.txt similarity index 100% rename from example/ae/regular/requirements.txt rename to example/ae/standard/requirements.txt diff --git a/example/ae/regular/run.py b/example/ae/standard/run.py similarity index 97% rename from example/ae/regular/run.py rename to example/ae/standard/run.py index ca48525..d24a088 100644 --- a/example/ae/regular/run.py +++ b/example/ae/standard/run.py @@ -11,9 +11,9 @@ from torch.utils.tensorboard import SummaryWriter # self import sys sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) -import deepke.ae_re_models as models -from deepke.ae_re_tools import preprocess , CustomDataset, collate_fn ,train, validate -from deepke.ae_re_utils import manual_seed, load_pkl +import deepke.ae_st_models as models +from deepke.ae_st_tools import preprocess , CustomDataset, collate_fn ,train, validate +from deepke.ae_st_utils import manual_seed, load_pkl logger = logging.getLogger(__name__) diff --git a/example/ner/regular/README.md b/example/ner/standard/README.md similarity index 100% rename from example/ner/regular/README.md rename to example/ner/standard/README.md diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt.zip b/example/ner/standard/data/nltk_data/tokenizers/punkt.zip similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt.zip rename to example/ner/standard/data/nltk_data/tokenizers/punkt.zip diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/README b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/README similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/README rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/README diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/czech.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/czech.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/czech.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/czech.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/danish.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/danish.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/danish.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/danish.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/dutch.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/dutch.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/dutch.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/dutch.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/english.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/english.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/english.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/english.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/estonian.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/estonian.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/estonian.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/estonian.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/finnish.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/finnish.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/finnish.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/finnish.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/french.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/french.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/french.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/french.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/german.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/german.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/german.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/german.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/greek.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/greek.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/greek.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/greek.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/italian.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/italian.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/italian.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/italian.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/norwegian.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/norwegian.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/norwegian.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/norwegian.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/polish.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/polish.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/polish.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/polish.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/portuguese.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/portuguese.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/portuguese.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/portuguese.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/russian.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/russian.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/russian.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/russian.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/slovene.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/slovene.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/slovene.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/slovene.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/spanish.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/spanish.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/spanish.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/spanish.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/swedish.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/swedish.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/swedish.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/swedish.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/turkish.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/turkish.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/PY3/turkish.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/PY3/turkish.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/README b/example/ner/standard/data/nltk_data/tokenizers/punkt/README similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/README rename to example/ner/standard/data/nltk_data/tokenizers/punkt/README diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/czech.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/czech.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/czech.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/czech.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/danish.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/danish.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/danish.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/danish.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/dutch.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/dutch.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/dutch.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/dutch.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/english.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/english.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/english.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/english.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/estonian.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/estonian.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/estonian.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/estonian.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/finnish.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/finnish.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/finnish.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/finnish.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/french.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/french.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/french.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/french.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/german.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/german.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/german.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/german.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/greek.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/greek.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/greek.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/greek.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/italian.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/italian.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/italian.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/italian.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/norwegian.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/norwegian.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/norwegian.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/norwegian.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/polish.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/polish.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/polish.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/polish.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/portuguese.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/portuguese.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/portuguese.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/portuguese.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/russian.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/russian.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/russian.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/russian.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/slovene.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/slovene.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/slovene.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/slovene.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/spanish.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/spanish.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/spanish.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/spanish.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/swedish.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/swedish.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/swedish.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/swedish.pickle diff --git a/example/ner/regular/data/nltk_data/tokenizers/punkt/turkish.pickle b/example/ner/standard/data/nltk_data/tokenizers/punkt/turkish.pickle similarity index 100% rename from example/ner/regular/data/nltk_data/tokenizers/punkt/turkish.pickle rename to example/ner/standard/data/nltk_data/tokenizers/punkt/turkish.pickle diff --git a/example/ner/regular/data/test.txt b/example/ner/standard/data/test.txt similarity index 100% rename from example/ner/regular/data/test.txt rename to example/ner/standard/data/test.txt diff --git a/example/ner/regular/data/train.txt b/example/ner/standard/data/train.txt similarity index 100% rename from example/ner/regular/data/train.txt rename to example/ner/standard/data/train.txt diff --git a/example/ner/regular/data/valid.txt b/example/ner/standard/data/valid.txt similarity index 100% rename from example/ner/regular/data/valid.txt rename to example/ner/standard/data/valid.txt diff --git a/example/ner/regular/predict.py b/example/ner/standard/predict.py similarity index 100% rename from example/ner/regular/predict.py rename to example/ner/standard/predict.py diff --git a/example/ner/regular/requirements.txt b/example/ner/standard/requirements.txt similarity index 100% rename from example/ner/regular/requirements.txt rename to example/ner/standard/requirements.txt diff --git a/example/ner/regular/run.py b/example/ner/standard/run.py similarity index 100% rename from example/ner/regular/run.py rename to example/ner/standard/run.py diff --git a/example/re/regular/README.md b/example/re/standard/README.md similarity index 100% rename from example/re/regular/README.md rename to example/re/standard/README.md diff --git a/example/re/regular/conf/config.yaml b/example/re/standard/conf/config.yaml similarity index 100% rename from example/re/regular/conf/config.yaml rename to example/re/standard/conf/config.yaml diff --git a/example/re/regular/conf/embedding.yaml b/example/re/standard/conf/embedding.yaml similarity index 100% rename from example/re/regular/conf/embedding.yaml rename to example/re/standard/conf/embedding.yaml diff --git a/example/re/regular/conf/hydra/output/custom.yaml b/example/re/standard/conf/hydra/output/custom.yaml similarity index 100% rename from example/re/regular/conf/hydra/output/custom.yaml rename to example/re/standard/conf/hydra/output/custom.yaml diff --git a/example/re/regular/conf/model/capsule.yaml b/example/re/standard/conf/model/capsule.yaml similarity index 100% rename from example/re/regular/conf/model/capsule.yaml rename to example/re/standard/conf/model/capsule.yaml diff --git a/example/re/regular/conf/model/cnn.yaml b/example/re/standard/conf/model/cnn.yaml similarity index 100% rename from example/re/regular/conf/model/cnn.yaml rename to example/re/standard/conf/model/cnn.yaml diff --git a/example/re/regular/conf/model/gcn.yaml b/example/re/standard/conf/model/gcn.yaml similarity index 100% rename from example/re/regular/conf/model/gcn.yaml rename to example/re/standard/conf/model/gcn.yaml diff --git a/example/re/regular/conf/model/lm.yaml b/example/re/standard/conf/model/lm.yaml similarity index 100% rename from example/re/regular/conf/model/lm.yaml rename to example/re/standard/conf/model/lm.yaml diff --git a/example/re/regular/conf/model/rnn.yaml b/example/re/standard/conf/model/rnn.yaml similarity index 100% rename from example/re/regular/conf/model/rnn.yaml rename to example/re/standard/conf/model/rnn.yaml diff --git a/example/re/regular/conf/model/transformer.yaml b/example/re/standard/conf/model/transformer.yaml similarity index 100% rename from example/re/regular/conf/model/transformer.yaml rename to example/re/standard/conf/model/transformer.yaml diff --git a/example/re/regular/conf/predict.yaml b/example/re/standard/conf/predict.yaml similarity index 100% rename from example/re/regular/conf/predict.yaml rename to example/re/standard/conf/predict.yaml diff --git a/example/re/regular/conf/preprocess.yaml b/example/re/standard/conf/preprocess.yaml similarity index 100% rename from example/re/regular/conf/preprocess.yaml rename to example/re/standard/conf/preprocess.yaml diff --git a/example/re/regular/conf/train.yaml b/example/re/standard/conf/train.yaml similarity index 100% rename from example/re/regular/conf/train.yaml rename to example/re/standard/conf/train.yaml diff --git a/example/re/regular/data/origin/relation.csv b/example/re/standard/data/origin/relation.csv similarity index 100% rename from example/re/regular/data/origin/relation.csv rename to example/re/standard/data/origin/relation.csv diff --git a/example/re/regular/data/origin/test.csv b/example/re/standard/data/origin/test.csv similarity index 100% rename from example/re/regular/data/origin/test.csv rename to example/re/standard/data/origin/test.csv diff --git a/example/re/regular/data/origin/train.csv b/example/re/standard/data/origin/train.csv similarity index 100% rename from example/re/regular/data/origin/train.csv rename to example/re/standard/data/origin/train.csv diff --git a/example/re/regular/data/origin/valid.csv b/example/re/standard/data/origin/valid.csv similarity index 100% rename from example/re/regular/data/origin/valid.csv rename to example/re/standard/data/origin/valid.csv diff --git a/example/re/regular/predict.py b/example/re/standard/predict.py similarity index 96% rename from example/re/regular/predict.py rename to example/re/standard/predict.py index 867f862..3bc8c1d 100644 --- a/example/re/regular/predict.py +++ b/example/re/standard/predict.py @@ -4,12 +4,12 @@ import torch import logging import hydra from hydra import utils -from deepke.re_re_tools import Serializer -from deepke.re_re_tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data +from deepke.re_st_tools import Serializer +from deepke.re_st_tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data import matplotlib.pyplot as plt sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) -from deepke.re_re_utils import load_pkl, load_csv -import deepke.re_re_models as models +from deepke.re_st_utils import load_pkl, load_csv +import deepke.re_st_models as models logger = logging.getLogger(__name__) diff --git a/example/re/regular/requirements.txt b/example/re/standard/requirements.txt similarity index 100% rename from example/re/regular/requirements.txt rename to example/re/standard/requirements.txt diff --git a/example/re/regular/run.py b/example/re/standard/run.py similarity index 97% rename from example/re/regular/run.py rename to example/re/standard/run.py index 21de135..5dad919 100644 --- a/example/re/regular/run.py +++ b/example/re/standard/run.py @@ -11,9 +11,9 @@ from torch.utils.tensorboard import SummaryWriter # self import sys sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) -import deepke.re_re_models as models -from deepke.re_re_tools import preprocess , CustomDataset, collate_fn ,train, validate -from deepke.re_re_utils import manual_seed, load_pkl +import deepke.re_st_models as models +from deepke.re_st_tools import preprocess , CustomDataset, collate_fn ,train, validate +from deepke.re_st_utils import manual_seed, load_pkl logger = logging.getLogger(__name__) diff --git a/src/deepke/ae/regular/tools/main.py b/src/deepke/ae/regular/tools/main.py deleted file mode 100644 index 8263a7a..0000000 --- a/src/deepke/ae/regular/tools/main.py +++ /dev/null @@ -1,142 +0,0 @@ -import os -import hydra -import torch -import logging -import torch.nn as nn -from torch import optim -from hydra import utils -import matplotlib.pyplot as plt -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter -# self -import sys -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) -import models as models -from tools import preprocess , CustomDataset, collate_fn ,train, validate -from utils import manual_seed, load_pkl - - -logger = logging.getLogger(__name__) - -@hydra.main(config_path="../conf/config.yaml") -def main(cfg): - cwd = utils.get_original_cwd() - cwd = cwd[0:-5] - cfg.cwd = cwd - cfg.pos_size = 2 * cfg.pos_limit + 2 - logger.info(f'\n{cfg.pretty()}') - - __Model__ = { - 'cnn': models.PCNN, - 'rnn': models.BiLSTM, - 'transformer': models.Transformer, - 'gcn': models.GCN, - 'capsule': models.Capsule, - 'lm': models.LM, - } - - # device - if cfg.use_gpu and torch.cuda.is_available(): - device = torch.device('cuda', cfg.gpu_id) - else: - device = torch.device('cpu') - logger.info(f'device: {device}') - - # 如果不修改预处理的过程,这一步最好注释掉,不用每次运行都预处理数据一次 - if cfg.preprocess: - preprocess(cfg) - - train_data_path = os.path.join(cfg.cwd, cfg.out_path, 'train.pkl') - valid_data_path = os.path.join(cfg.cwd, cfg.out_path, 'valid.pkl') - test_data_path = os.path.join(cfg.cwd, cfg.out_path, 'test.pkl') - vocab_path = os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl') - - if cfg.model_name == 'lm': - vocab_size = None - else: - vocab = load_pkl(vocab_path) - vocab_size = vocab.count - cfg.vocab_size = vocab_size - - train_dataset = CustomDataset(train_data_path) - valid_dataset = CustomDataset(valid_data_path) - test_dataset = CustomDataset(test_data_path) - - train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg)) - valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg)) - test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg)) - - model = __Model__[cfg.model_name](cfg) - model.to(device) - logger.info(f'\n {model}') - - optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) - scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience) - criterion = nn.CrossEntropyLoss() - - best_f1, best_epoch = -1, 0 - es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, es_path, best_es_path = 1e8, -1, 0, 0, 0, -1, '', '' - train_losses, valid_losses = [], [] - - if cfg.show_plot and cfg.plot_utils == 'tensorboard': - writer = SummaryWriter('tensorboard') - else: - writer = None - - logger.info('=' * 10 + ' Start training ' + '=' * 10) - - for epoch in range(1, cfg.epoch + 1): - manual_seed(cfg.seed + epoch) - train_loss = train(epoch, model, train_dataloader, optimizer, criterion, device, writer, cfg) - valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion, device, cfg) - scheduler.step(valid_loss) - model_path = model.save(epoch, cfg) - # logger.info(model_path) - - train_losses.append(train_loss) - valid_losses.append(valid_loss) - if best_f1 < valid_f1: - best_f1 = valid_f1 - best_epoch = epoch - # 使用 valid loss 做 early stopping 的判断标准 - if es_loss > valid_loss: - es_loss = valid_loss - es_f1 = valid_f1 - es_epoch = epoch - es_patience = 0 - es_path = model_path - else: - es_patience += 1 - if es_patience >= cfg.early_stopping_patience: - best_es_epoch = es_epoch - best_es_f1 = es_f1 - best_es_path = es_path - - if cfg.show_plot: - if cfg.plot_utils == 'matplot': - plt.plot(train_losses, 'x-') - plt.plot(valid_losses, '+-') - plt.legend(['train', 'valid']) - plt.title('train/valid comparison loss') - plt.show() - - if cfg.plot_utils == 'tensorboard': - for i in range(len(train_losses)): - writer.add_scalars('train/valid_comparison_loss', { - 'train': train_losses[i], - 'valid': valid_losses[i] - }, i) - writer.close() - - logger.info(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, ' - f'this epoch macro f1: {best_es_f1:0.4f}') - logger.info(f'this model save path: {best_es_path}') - logger.info(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, ' - f'this epoch macro f1: {best_f1:.4f}') - - logger.info('=====end of training====') - logger.info('') - logger.info('=====start test performance====') - validate(-1, model, test_dataloader, criterion, device, cfg) - logger.info('=====ending====') - diff --git a/src/deepke/ae/regular/tools/predict.py b/src/deepke/ae/regular/tools/predict.py deleted file mode 100644 index a79bc99..0000000 --- a/src/deepke/ae/regular/tools/predict.py +++ /dev/null @@ -1,147 +0,0 @@ -import os -import sys -import torch -import logging -import hydra -from hydra import utils -from serializer import Serializer -from preprocess import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_attribute_data -import matplotlib.pyplot as plt -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) -from utils import load_pkl, load_csv -import models as models - - -logger = logging.getLogger(__name__) - - -def _preprocess_data(data, cfg): - vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False) - attribute_data = load_csv(os.path.join(cfg.cwd, cfg.data_path, 'attribute.csv'), verbose=False) - atts = _handle_attribute_data(attribute_data) - cfg.vocab_size = vocab.count - serializer = Serializer(do_chinese_split=cfg.chinese_split) - serial = serializer.serialize - - _serialize_sentence(data, serial, cfg) - _convert_tokens_into_index(data, vocab) - _add_pos_seq(data, cfg) - logger.info('start sentence preprocess...') - formats = '\nsentence: {}\nchinese_split: {}\n' \ - 'tokens: {}\ntoken2idx: {}\nlength: {}\nentity_index: {}\nattribute_value_index: {}' - logger.info( - formats.format(data[0]['sentence'], cfg.chinese_split, - data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'], - data[0]['entity_index'], data[0]['attribute_value_index'])) - return data, atts - - -def _get_predict_instance(cfg): - flag = input('是否使用范例[y/n],退出请输入: exit .... ') - flag = flag.strip().lower() - if flag == 'y' or flag == 'yes': - sentence = '张冬梅,女,汉族,1968年2月生,河南淇县人,1988年7月加入中国共产党,1989年9月参加工作,中央党校经济管理专业毕业,中央党校研究生学历' - entity = '张冬梅' - attribute_value = '汉族' - elif flag == 'n' or flag == 'no': - sentence = input('请输入句子:') - entity = input('请输入句中需要预测的实体:') - attribute_value = input('请输入句中需要预测的属性值:') - elif flag == 'exit': - sys.exit(0) - else: - print('please input yes or no, or exit!') - _get_predict_instance(cfg) - - - instance = dict() - instance['sentence'] = sentence.strip() - instance['entity'] = entity.strip() - instance['attribute_value'] = attribute_value.strip() - instance['entity_offset'] = sentence.find(entity) - instance['attribute_value_offset'] = sentence.find(attribute_value) - - return instance - - - - -@hydra.main(config_path='../conf/config.yaml') -def main(cfg): - cwd = utils.get_original_cwd() - cwd = cwd[0:-5] - cfg.cwd = cwd - cfg.pos_size = 2 * cfg.pos_limit + 2 - print(cfg.pretty()) - - # get predict instance - instance = _get_predict_instance(cfg) - data = [instance] - - # preprocess data - data, rels = _preprocess_data(data, cfg) - - # model - __Model__ = { - 'cnn': models.PCNN, - 'rnn': models.BiLSTM, - 'transformer': models.Transformer, - 'gcn': models.GCN, - 'capsule': models.Capsule, - 'lm': models.LM, - } - - # 最好在 cpu 上预测 - cfg.use_gpu = False - if cfg.use_gpu and torch.cuda.is_available(): - device = torch.device('cuda', cfg.gpu_id) - else: - device = torch.device('cpu') - logger.info(f'device: {device}') - - model = __Model__[cfg.model_name](cfg) - logger.info(f'model name: {cfg.model_name}') - logger.info(f'\n {model}') - model.load(cfg.fp, device=device) - model.to(device) - model.eval() - - x = dict() - x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']]) - - if cfg.model_name != 'lm': - x['entity_pos'], x['attribute_value_pos'] = torch.tensor([data[0]['entity_pos']]), torch.tensor([data[0]['attribute_value_pos']]) - if cfg.model_name == 'cnn': - if cfg.use_pcnn: - x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']]) - if cfg.model_name == 'gcn': - # 没找到合适的做 parsing tree 的工具,暂时随机初始化 - adj = torch.empty(1,data[0]['seq_len'],data[0]['seq_len']).random_(2) - x['adj'] = adj - - - for key in x.keys(): - x[key] = x[key].to(device) - - with torch.no_grad(): - y_pred = model(x) - y_pred = torch.softmax(y_pred, dim=-1)[0] - prob = y_pred.max().item() - prob_att = list(rels.keys())[y_pred.argmax().item()] - logger.info(f"\"{data[0]['entity']}\" 和 \"{data[0]['attribute_value']}\" 在句中属性为:\"{prob_att}\",置信度为{prob:.2f}。") - - if cfg.predict_plot: - plt.rcParams["font.family"] = 'Arial Unicode MS' - x = list(rels.keys()) - height = list(y_pred.cpu().numpy()) - plt.bar(x, height) - for x, y in zip(x, height): - plt.text(x, y, '%.2f' % y, ha="center", va="bottom") - plt.xlabel('关系') - plt.ylabel('置信度') - plt.xticks(rotation=315) - plt.show() - - -if __name__ == '__main__': - main() diff --git a/src/deepke/ae/regular/models/BasicModule.py b/src/deepke/ae/standard/models/BasicModule.py similarity index 100% rename from src/deepke/ae/regular/models/BasicModule.py rename to src/deepke/ae/standard/models/BasicModule.py diff --git a/src/deepke/ae/regular/models/BiLSTM.py b/src/deepke/ae/standard/models/BiLSTM.py similarity index 100% rename from src/deepke/ae/regular/models/BiLSTM.py rename to src/deepke/ae/standard/models/BiLSTM.py diff --git a/src/deepke/ae/regular/models/Capsule.py b/src/deepke/ae/standard/models/Capsule.py similarity index 100% rename from src/deepke/ae/regular/models/Capsule.py rename to src/deepke/ae/standard/models/Capsule.py diff --git a/src/deepke/ae/regular/models/GCN.py b/src/deepke/ae/standard/models/GCN.py similarity index 100% rename from src/deepke/ae/regular/models/GCN.py rename to src/deepke/ae/standard/models/GCN.py diff --git a/src/deepke/ae/regular/models/LM.py b/src/deepke/ae/standard/models/LM.py similarity index 100% rename from src/deepke/ae/regular/models/LM.py rename to src/deepke/ae/standard/models/LM.py diff --git a/src/deepke/ae/regular/models/PCNN.py b/src/deepke/ae/standard/models/PCNN.py similarity index 100% rename from src/deepke/ae/regular/models/PCNN.py rename to src/deepke/ae/standard/models/PCNN.py diff --git a/src/deepke/ae/regular/models/Transformer.py b/src/deepke/ae/standard/models/Transformer.py similarity index 100% rename from src/deepke/ae/regular/models/Transformer.py rename to src/deepke/ae/standard/models/Transformer.py diff --git a/src/deepke/ae/regular/models/__init__.py b/src/deepke/ae/standard/models/__init__.py similarity index 100% rename from src/deepke/ae/regular/models/__init__.py rename to src/deepke/ae/standard/models/__init__.py diff --git a/src/deepke/ae/regular/module/Attention.py b/src/deepke/ae/standard/module/Attention.py similarity index 100% rename from src/deepke/ae/regular/module/Attention.py rename to src/deepke/ae/standard/module/Attention.py diff --git a/src/deepke/ae/regular/module/CNN.py b/src/deepke/ae/standard/module/CNN.py similarity index 100% rename from src/deepke/ae/regular/module/CNN.py rename to src/deepke/ae/standard/module/CNN.py diff --git a/src/deepke/ae/regular/module/Capsule.py b/src/deepke/ae/standard/module/Capsule.py similarity index 100% rename from src/deepke/ae/regular/module/Capsule.py rename to src/deepke/ae/standard/module/Capsule.py diff --git a/src/deepke/ae/regular/module/Embedding.py b/src/deepke/ae/standard/module/Embedding.py similarity index 100% rename from src/deepke/ae/regular/module/Embedding.py rename to src/deepke/ae/standard/module/Embedding.py diff --git a/src/deepke/ae/regular/module/GCN.py b/src/deepke/ae/standard/module/GCN.py similarity index 100% rename from src/deepke/ae/regular/module/GCN.py rename to src/deepke/ae/standard/module/GCN.py diff --git a/src/deepke/ae/regular/module/RNN.py b/src/deepke/ae/standard/module/RNN.py similarity index 100% rename from src/deepke/ae/regular/module/RNN.py rename to src/deepke/ae/standard/module/RNN.py diff --git a/src/deepke/ae/regular/module/Transformer.py b/src/deepke/ae/standard/module/Transformer.py similarity index 100% rename from src/deepke/ae/regular/module/Transformer.py rename to src/deepke/ae/standard/module/Transformer.py diff --git a/src/deepke/ae/regular/module/__init__.py b/src/deepke/ae/standard/module/__init__.py similarity index 100% rename from src/deepke/ae/regular/module/__init__.py rename to src/deepke/ae/standard/module/__init__.py diff --git a/src/deepke/ae/regular/tools/__init__.py b/src/deepke/ae/standard/tools/__init__.py similarity index 100% rename from src/deepke/ae/regular/tools/__init__.py rename to src/deepke/ae/standard/tools/__init__.py diff --git a/src/deepke/ae/regular/tools/dataset.py b/src/deepke/ae/standard/tools/dataset.py similarity index 100% rename from src/deepke/ae/regular/tools/dataset.py rename to src/deepke/ae/standard/tools/dataset.py diff --git a/src/deepke/ae/regular/tools/metrics.py b/src/deepke/ae/standard/tools/metrics.py similarity index 100% rename from src/deepke/ae/regular/tools/metrics.py rename to src/deepke/ae/standard/tools/metrics.py diff --git a/src/deepke/ae/regular/tools/preprocess.py b/src/deepke/ae/standard/tools/preprocess.py similarity index 100% rename from src/deepke/ae/regular/tools/preprocess.py rename to src/deepke/ae/standard/tools/preprocess.py diff --git a/src/deepke/ae/regular/tools/serializer.py b/src/deepke/ae/standard/tools/serializer.py similarity index 100% rename from src/deepke/ae/regular/tools/serializer.py rename to src/deepke/ae/standard/tools/serializer.py diff --git a/src/deepke/ae/regular/tools/trainer.py b/src/deepke/ae/standard/tools/trainer.py similarity index 100% rename from src/deepke/ae/regular/tools/trainer.py rename to src/deepke/ae/standard/tools/trainer.py diff --git a/src/deepke/ae/regular/tools/vocab.py b/src/deepke/ae/standard/tools/vocab.py similarity index 100% rename from src/deepke/ae/regular/tools/vocab.py rename to src/deepke/ae/standard/tools/vocab.py diff --git a/src/deepke/ae/regular/utils/__init__.py b/src/deepke/ae/standard/utils/__init__.py similarity index 100% rename from src/deepke/ae/regular/utils/__init__.py rename to src/deepke/ae/standard/utils/__init__.py diff --git a/src/deepke/ae/regular/utils/ioUtils.py b/src/deepke/ae/standard/utils/ioUtils.py similarity index 100% rename from src/deepke/ae/regular/utils/ioUtils.py rename to src/deepke/ae/standard/utils/ioUtils.py diff --git a/src/deepke/ae/regular/utils/nnUtils.py b/src/deepke/ae/standard/utils/nnUtils.py similarity index 100% rename from src/deepke/ae/regular/utils/nnUtils.py rename to src/deepke/ae/standard/utils/nnUtils.py diff --git a/src/deepke/ner/regular/models/BERTNER.py b/src/deepke/ner/standard/models/BERTNER.py similarity index 100% rename from src/deepke/ner/regular/models/BERTNER.py rename to src/deepke/ner/standard/models/BERTNER.py diff --git a/src/deepke/ner/regular/models/NER.py b/src/deepke/ner/standard/models/NER.py similarity index 100% rename from src/deepke/ner/regular/models/NER.py rename to src/deepke/ner/standard/models/NER.py diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt.zip b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt.zip similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt.zip rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt.zip diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/README b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/README similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/README rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/README diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/czech.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/czech.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/czech.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/czech.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/danish.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/danish.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/danish.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/danish.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/dutch.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/dutch.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/dutch.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/dutch.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/english.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/english.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/english.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/english.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/estonian.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/estonian.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/estonian.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/estonian.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/finnish.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/finnish.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/finnish.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/finnish.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/french.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/french.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/french.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/french.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/german.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/german.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/german.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/german.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/greek.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/greek.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/greek.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/greek.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/italian.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/italian.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/italian.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/italian.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/norwegian.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/norwegian.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/norwegian.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/norwegian.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/polish.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/polish.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/polish.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/polish.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/portuguese.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/portuguese.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/portuguese.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/portuguese.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/russian.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/russian.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/russian.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/russian.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/slovene.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/slovene.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/slovene.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/slovene.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/spanish.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/spanish.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/spanish.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/spanish.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/swedish.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/swedish.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/swedish.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/swedish.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/turkish.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/turkish.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/PY3/turkish.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/PY3/turkish.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/README b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/README similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/README rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/README diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/czech.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/czech.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/czech.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/czech.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/danish.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/danish.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/danish.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/danish.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/dutch.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/dutch.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/dutch.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/dutch.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/english.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/english.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/english.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/english.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/estonian.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/estonian.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/estonian.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/estonian.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/finnish.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/finnish.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/finnish.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/finnish.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/french.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/french.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/french.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/french.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/german.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/german.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/german.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/german.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/greek.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/greek.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/greek.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/greek.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/italian.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/italian.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/italian.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/italian.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/norwegian.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/norwegian.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/norwegian.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/norwegian.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/polish.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/polish.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/polish.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/polish.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/portuguese.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/portuguese.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/portuguese.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/portuguese.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/russian.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/russian.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/russian.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/russian.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/slovene.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/slovene.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/slovene.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/slovene.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/spanish.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/spanish.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/spanish.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/spanish.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/swedish.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/swedish.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/swedish.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/swedish.pickle diff --git a/src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/turkish.pickle b/src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/turkish.pickle similarity index 100% rename from src/deepke/ner/regular/module/data/nltk_data/tokenizers/punkt/turkish.pickle rename to src/deepke/ner/standard/module/data/nltk_data/tokenizers/punkt/turkish.pickle diff --git a/src/deepke/ner/regular/module/data/test.txt b/src/deepke/ner/standard/module/data/test.txt similarity index 100% rename from src/deepke/ner/regular/module/data/test.txt rename to src/deepke/ner/standard/module/data/test.txt diff --git a/src/deepke/ner/regular/module/data/train.txt b/src/deepke/ner/standard/module/data/train.txt similarity index 100% rename from src/deepke/ner/regular/module/data/train.txt rename to src/deepke/ner/standard/module/data/train.txt diff --git a/src/deepke/ner/regular/module/data/valid.txt b/src/deepke/ner/standard/module/data/valid.txt similarity index 100% rename from src/deepke/ner/regular/module/data/valid.txt rename to src/deepke/ner/standard/module/data/valid.txt diff --git a/src/deepke/ner/regular/module/dataset.py b/src/deepke/ner/standard/module/dataset.py similarity index 100% rename from src/deepke/ner/regular/module/dataset.py rename to src/deepke/ner/standard/module/dataset.py diff --git a/src/deepke/ner/regular/module/finetune.py b/src/deepke/ner/standard/module/finetune.py similarity index 100% rename from src/deepke/ner/regular/module/finetune.py rename to src/deepke/ner/standard/module/finetune.py diff --git a/src/deepke/ner/regular/module/predict.py b/src/deepke/ner/standard/module/predict.py similarity index 100% rename from src/deepke/ner/regular/module/predict.py rename to src/deepke/ner/standard/module/predict.py diff --git a/src/deepke/ner/regular/module/preprocess.py b/src/deepke/ner/standard/module/preprocess.py similarity index 100% rename from src/deepke/ner/regular/module/preprocess.py rename to src/deepke/ner/standard/module/preprocess.py diff --git a/src/deepke/re/regular/tools/main.py b/src/deepke/re/regular/tools/main.py deleted file mode 100644 index 1615dcf..0000000 --- a/src/deepke/re/regular/tools/main.py +++ /dev/null @@ -1,145 +0,0 @@ -import os -import hydra -import torch -import logging -import torch.nn as nn -from torch import optim -from hydra import utils -import matplotlib.pyplot as plt -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter -# self -import sys -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) -import models as models -from tools import preprocess , CustomDataset, collate_fn ,train, validate -from utils import manual_seed, load_pkl - -logger = logging.getLogger(__name__) - -@hydra.main(config_path='../conf/config.yaml') -def main(cfg): - cwd = utils.get_original_cwd() - cwd = cwd[0:-5] - cfg.cwd = cwd - cfg.pos_size = 2 * cfg.pos_limit + 2 - logger.info(f'\n{cfg.pretty()}') - - __Model__ = { - 'cnn': models.PCNN, - 'rnn': models.BiLSTM, - 'transformer': models.Transformer, - 'gcn': models.GCN, - 'capsule': models.Capsule, - 'lm': models.LM, - } - - # device - if cfg.use_gpu and torch.cuda.is_available(): - device = torch.device('cuda', cfg.gpu_id) - else: - device = torch.device('cpu') - logger.info(f'device: {device}') - - # 如果不修改预处理的过程,这一步最好注释掉,不用每次运行都预处理数据一次 - if cfg.preprocess: - preprocess(cfg) - - train_data_path = os.path.join(cfg.cwd, cfg.out_path, 'train.pkl') - valid_data_path = os.path.join(cfg.cwd, cfg.out_path, 'valid.pkl') - test_data_path = os.path.join(cfg.cwd, cfg.out_path, 'test.pkl') - vocab_path = os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl') - - if cfg.model_name == 'lm': - vocab_size = None - else: - vocab = load_pkl(vocab_path) - vocab_size = vocab.count - cfg.vocab_size = vocab_size - - train_dataset = CustomDataset(train_data_path) - valid_dataset = CustomDataset(valid_data_path) - test_dataset = CustomDataset(test_data_path) - - train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg)) - valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg)) - test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg)) - - model = __Model__[cfg.model_name](cfg) - model.to(device) - logger.info(f'\n {model}') - - optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) - scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience) - criterion = nn.CrossEntropyLoss() - - best_f1, best_epoch = -1, 0 - es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, es_path, best_es_path = 1e8, -1, 0, 0, 0, -1, '', '' - train_losses, valid_losses = [], [] - - if cfg.show_plot and cfg.plot_utils == 'tensorboard': - writer = SummaryWriter('tensorboard') - else: - writer = None - - logger.info('=' * 10 + ' Start training ' + '=' * 10) - - for epoch in range(1, cfg.epoch + 1): - manual_seed(cfg.seed + epoch) - train_loss = train(epoch, model, train_dataloader, optimizer, criterion, device, writer, cfg) - valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion, device, cfg) - scheduler.step(valid_loss) - model_path = model.save(epoch, cfg) - # logger.info(model_path) - - train_losses.append(train_loss) - valid_losses.append(valid_loss) - if best_f1 < valid_f1: - best_f1 = valid_f1 - best_epoch = epoch - # 使用 valid loss 做 early stopping 的判断标准 - if es_loss > valid_loss: - es_loss = valid_loss - es_f1 = valid_f1 - es_epoch = epoch - es_patience = 0 - es_path = model_path - else: - es_patience += 1 - if es_patience >= cfg.early_stopping_patience: - best_es_epoch = es_epoch - best_es_f1 = es_f1 - best_es_path = es_path - - if cfg.show_plot: - if cfg.plot_utils == 'matplot': - plt.plot(train_losses, 'x-') - plt.plot(valid_losses, '+-') - plt.legend(['train', 'valid']) - plt.title('train/valid comparison loss') - plt.show() - - if cfg.plot_utils == 'tensorboard': - for i in range(len(train_losses)): - writer.add_scalars('train/valid_comparison_loss', { - 'train': train_losses[i], - 'valid': valid_losses[i] - }, i) - writer.close() - - logger.info(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, ' - f'this epoch macro f1: {best_es_f1:0.4f}') - logger.info(f'this model save path: {best_es_path}') - logger.info(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, ' - f'this epoch macro f1: {best_f1:.4f}') - - logger.info('=====end of training====') - logger.info('') - logger.info('=====start test performance====') - validate(-1, model, test_dataloader, criterion, device, cfg) - logger.info('=====ending====') - - -if __name__ == '__main__': - main() - \ No newline at end of file diff --git a/src/deepke/re/regular/tools/predict.py b/src/deepke/re/regular/tools/predict.py deleted file mode 100644 index 1d94454..0000000 --- a/src/deepke/re/regular/tools/predict.py +++ /dev/null @@ -1,156 +0,0 @@ -import os -import sys -import torch -import logging -import hydra -from hydra import utils -from serializer import Serializer -from preprocess import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data -import matplotlib.pyplot as plt -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) -from utils import load_pkl, load_csv -import models as models - - -logger = logging.getLogger(__name__) - - -def _preprocess_data(data, cfg): - vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False) - relation_data = load_csv(os.path.join(cfg.cwd, cfg.data_path, 'relation.csv'), verbose=False) - rels = _handle_relation_data(relation_data) - cfg.vocab_size = vocab.count - serializer = Serializer(do_chinese_split=cfg.chinese_split) - serial = serializer.serialize - - _serialize_sentence(data, serial, cfg) - _convert_tokens_into_index(data, vocab) - _add_pos_seq(data, cfg) - logger.info('start sentence preprocess...') - formats = '\nsentence: {}\nchinese_split: {}\nreplace_entity_with_type: {}\nreplace_entity_with_scope: {}\n' \ - 'tokens: {}\ntoken2idx: {}\nlength: {}\nhead_idx: {}\ntail_idx: {}' - logger.info( - formats.format(data[0]['sentence'], cfg.chinese_split, cfg.replace_entity_with_type, - cfg.replace_entity_with_scope, data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'], - data[0]['head_idx'], data[0]['tail_idx'])) - return data, rels - - -def _get_predict_instance(cfg): - flag = input('是否使用范例[y/n],退出请输入: exit .... ') - flag = flag.strip().lower() - if flag == 'y' or flag == 'yes': - sentence = '《乡村爱情》是一部由知名导演赵本山在1985年所拍摄的农村青春偶像剧。' - head = '乡村爱情' - tail = '赵本山' - head_type = '电视剧' - tail_type = '人物' - elif flag == 'n' or flag == 'no': - sentence = input('请输入句子:') - head = input('请输入句中需要预测关系的头实体:') - head_type = input('请输入头实体类型(可以为空,按enter跳过):') - tail = input('请输入句中需要预测关系的尾实体:') - tail_type = input('请输入尾实体类型(可以为空,按enter跳过):') - elif flag == 'exit': - sys.exit(0) - else: - print('please input yes or no, or exit!') - _get_predict_instance() - - instance = dict() - instance['sentence'] = sentence.strip() - instance['head'] = head.strip() - instance['tail'] = tail.strip() - if head_type.strip() == '' or tail_type.strip() == '': - cfg.replace_entity_with_type = False - instance['head_type'] = 'None' - instance['tail_type'] = 'None' - else: - instance['head_type'] = head_type.strip() - instance['tail_type'] = tail_type.strip() - - return instance - - - - -@hydra.main(config_path='../conf/config.yaml') -def main(cfg): - cwd = utils.get_original_cwd() - cwd = cwd[0:-5] - cfg.cwd = cwd - cfg.pos_size = 2 * cfg.pos_limit + 2 - print(cfg.pretty()) - - # get predict instance - instance = _get_predict_instance(cfg) - data = [instance] - - # preprocess data - data, rels = _preprocess_data(data, cfg) - - # model - __Model__ = { - 'cnn': models.PCNN, - 'rnn': models.BiLSTM, - 'transformer': models.Transformer, - 'gcn': models.GCN, - 'capsule': models.Capsule, - 'lm': models.LM, - } - - # 最好在 cpu 上预测 - cfg.use_gpu = False - if cfg.use_gpu and torch.cuda.is_available(): - device = torch.device('cuda', cfg.gpu_id) - else: - device = torch.device('cpu') - logger.info(f'device: {device}') - - model = __Model__[cfg.model_name](cfg) - logger.info(f'model name: {cfg.model_name}') - logger.info(f'\n {model}') - model.load(cfg.fp, device=device) - model.to(device) - model.eval() - - x = dict() - x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']]) - - if cfg.model_name != 'lm': - x['head_pos'], x['tail_pos'] = torch.tensor([data[0]['head_pos']]), torch.tensor([data[0]['tail_pos']]) - if cfg.model_name == 'cnn': - if cfg.use_pcnn: - x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']]) - if cfg.model_name == 'gcn': - # 没找到合适的做 parsing tree 的工具,暂时随机初始化 - adj = torch.empty(1,data[0]['seq_len'],data[0]['seq_len']).random_(2) - x['adj'] = adj - - - for key in x.keys(): - x[key] = x[key].to(device) - - with torch.no_grad(): - y_pred = model(x) - y_pred = torch.softmax(y_pred, dim=-1)[0] - prob = y_pred.max().item() - prob_rel = list(rels.keys())[y_pred.argmax().item()] - logger.info(f"\"{data[0]['head']}\" 和 \"{data[0]['tail']}\" 在句中关系为:\"{prob_rel}\",置信度为{prob:.2f}。") - - if cfg.predict_plot: - # maplot 默认显示不支持中文 - plt.rcParams["font.family"] = 'Arial Unicode MS' - x = list(rels.keys()) - height = list(y_pred.cpu().numpy()) - plt.bar(x, height) - for x, y in zip(x, height): - plt.text(x, y, '%.2f' % y, ha="center", va="bottom") - plt.xlabel('关系') - plt.ylabel('置信度') - plt.xticks(rotation=315) - plt.show() - - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/src/deepke/re/regular/models/BasicModule.py b/src/deepke/re/standard/models/BasicModule.py similarity index 100% rename from src/deepke/re/regular/models/BasicModule.py rename to src/deepke/re/standard/models/BasicModule.py diff --git a/src/deepke/re/regular/models/BiLSTM.py b/src/deepke/re/standard/models/BiLSTM.py similarity index 100% rename from src/deepke/re/regular/models/BiLSTM.py rename to src/deepke/re/standard/models/BiLSTM.py diff --git a/src/deepke/re/regular/models/Capsule.py b/src/deepke/re/standard/models/Capsule.py similarity index 100% rename from src/deepke/re/regular/models/Capsule.py rename to src/deepke/re/standard/models/Capsule.py diff --git a/src/deepke/re/regular/models/GCN.py b/src/deepke/re/standard/models/GCN.py similarity index 100% rename from src/deepke/re/regular/models/GCN.py rename to src/deepke/re/standard/models/GCN.py diff --git a/src/deepke/re/regular/models/LM.py b/src/deepke/re/standard/models/LM.py similarity index 100% rename from src/deepke/re/regular/models/LM.py rename to src/deepke/re/standard/models/LM.py diff --git a/src/deepke/re/regular/models/PCNN.py b/src/deepke/re/standard/models/PCNN.py similarity index 100% rename from src/deepke/re/regular/models/PCNN.py rename to src/deepke/re/standard/models/PCNN.py diff --git a/src/deepke/re/regular/models/Transformer.py b/src/deepke/re/standard/models/Transformer.py similarity index 100% rename from src/deepke/re/regular/models/Transformer.py rename to src/deepke/re/standard/models/Transformer.py diff --git a/src/deepke/re/regular/models/__init__.py b/src/deepke/re/standard/models/__init__.py similarity index 100% rename from src/deepke/re/regular/models/__init__.py rename to src/deepke/re/standard/models/__init__.py diff --git a/src/deepke/re/regular/module/Attention.py b/src/deepke/re/standard/module/Attention.py similarity index 100% rename from src/deepke/re/regular/module/Attention.py rename to src/deepke/re/standard/module/Attention.py diff --git a/src/deepke/re/regular/module/CNN.py b/src/deepke/re/standard/module/CNN.py similarity index 100% rename from src/deepke/re/regular/module/CNN.py rename to src/deepke/re/standard/module/CNN.py diff --git a/src/deepke/re/regular/module/Capsule.py b/src/deepke/re/standard/module/Capsule.py similarity index 100% rename from src/deepke/re/regular/module/Capsule.py rename to src/deepke/re/standard/module/Capsule.py diff --git a/src/deepke/re/regular/module/Embedding.py b/src/deepke/re/standard/module/Embedding.py similarity index 100% rename from src/deepke/re/regular/module/Embedding.py rename to src/deepke/re/standard/module/Embedding.py diff --git a/src/deepke/re/regular/module/GCN.py b/src/deepke/re/standard/module/GCN.py similarity index 100% rename from src/deepke/re/regular/module/GCN.py rename to src/deepke/re/standard/module/GCN.py diff --git a/src/deepke/re/regular/module/RNN.py b/src/deepke/re/standard/module/RNN.py similarity index 100% rename from src/deepke/re/regular/module/RNN.py rename to src/deepke/re/standard/module/RNN.py diff --git a/src/deepke/re/regular/module/Transformer.py b/src/deepke/re/standard/module/Transformer.py similarity index 100% rename from src/deepke/re/regular/module/Transformer.py rename to src/deepke/re/standard/module/Transformer.py diff --git a/src/deepke/re/regular/module/__init__.py b/src/deepke/re/standard/module/__init__.py similarity index 100% rename from src/deepke/re/regular/module/__init__.py rename to src/deepke/re/standard/module/__init__.py diff --git a/src/deepke/re/regular/tools/__init__.py b/src/deepke/re/standard/tools/__init__.py similarity index 100% rename from src/deepke/re/regular/tools/__init__.py rename to src/deepke/re/standard/tools/__init__.py diff --git a/src/deepke/re/regular/tools/dataset.py b/src/deepke/re/standard/tools/dataset.py similarity index 100% rename from src/deepke/re/regular/tools/dataset.py rename to src/deepke/re/standard/tools/dataset.py diff --git a/src/deepke/re/regular/tools/metrics.py b/src/deepke/re/standard/tools/metrics.py similarity index 100% rename from src/deepke/re/regular/tools/metrics.py rename to src/deepke/re/standard/tools/metrics.py diff --git a/src/deepke/re/regular/tools/preprocess.py b/src/deepke/re/standard/tools/preprocess.py similarity index 100% rename from src/deepke/re/regular/tools/preprocess.py rename to src/deepke/re/standard/tools/preprocess.py diff --git a/src/deepke/re/regular/tools/serializer.py b/src/deepke/re/standard/tools/serializer.py similarity index 100% rename from src/deepke/re/regular/tools/serializer.py rename to src/deepke/re/standard/tools/serializer.py diff --git a/src/deepke/re/regular/tools/trainer.py b/src/deepke/re/standard/tools/trainer.py similarity index 100% rename from src/deepke/re/regular/tools/trainer.py rename to src/deepke/re/standard/tools/trainer.py diff --git a/src/deepke/re/regular/tools/vocab.py b/src/deepke/re/standard/tools/vocab.py similarity index 100% rename from src/deepke/re/regular/tools/vocab.py rename to src/deepke/re/standard/tools/vocab.py diff --git a/src/deepke/re/regular/utils/__init__.py b/src/deepke/re/standard/utils/__init__.py similarity index 100% rename from src/deepke/re/regular/utils/__init__.py rename to src/deepke/re/standard/utils/__init__.py diff --git a/src/deepke/re/regular/utils/ioUtils.py b/src/deepke/re/standard/utils/ioUtils.py similarity index 100% rename from src/deepke/re/regular/utils/ioUtils.py rename to src/deepke/re/standard/utils/ioUtils.py diff --git a/src/deepke/re/regular/utils/nnUtils.py b/src/deepke/re/standard/utils/nnUtils.py similarity index 100% rename from src/deepke/re/regular/utils/nnUtils.py rename to src/deepke/re/standard/utils/nnUtils.py