deepke/example/ae/standard
tlk-dsg 7dcc4378cb fix bug 2021-12-12 14:33:09 +08:00
..
conf update readme 2021-10-24 14:18:09 +08:00
README.md Update README.md 2021-11-23 16:28:19 +08:00
predict.py fix bug 2021-12-12 14:33:09 +08:00
requirements.txt test 2021-09-16 14:30:03 +08:00
run.py add wandb 2021-11-30 21:39:29 +08:00

README.md

快速上手

环境依赖

python == 3.8

  • torch == 1.5
  • hydra-core == 1.0.6
  • tensorboard == 2.4.1
  • matplotlib == 3.4.1
  • scikit-learn == 0.24.1
  • transformers == 3.4.0
  • jieba == 0.42.1
  • deepke

克隆代码

git clone git@github.com:zjunlp/DeepKE.git

使用pip安装

首先创建python虚拟环境再进入虚拟环境

  • 安装依赖: pip install -r requirements.txt

使用数据进行训练预测

  • 存放数据: 可先下载数据 wget 120.27.214.45/Data/ae/standard/data.tar.gz至此目录下

    解压后data/origin 文件夹下存放来训练数据。训练文件主要有三个文件。

    • train.csv:存放训练数据集

    • valid.csv:存放验证数据集

    • test.csv:存放测试数据集

    • attribute.csv:存放属性种类

  • 开始训练:python run.py (训练所用到参数都在conf文件夹中修改即可使用lm时可修改'lm_file'使用下载至本地的模型)

  • 每次训练的日志保存在 logs 文件夹内,模型结果保存在 checkpoints 文件夹内。

  • 进行预测 python predict.py

模型内容

1、CNN

2、RNN

3、Capsule

4、GCN

5、Transformer

6、预训练模型