reconstruct

This commit is contained in:
tlk-dsg 2021-05-11 16:09:14 +08:00
parent 55ab9f29c9
commit 7ceb1252e7
11 changed files with 15 additions and 77 deletions

View File

@ -125,17 +125,3 @@ class MultiHeadAttention(nn.Module):
return attention_out,
if __name__ == '__main__':
from utils import seq_len_to_mask
q = torch.randn(4, 6, 20) # [B, L, H]
k = v = torch.randn(4, 5, 20) # [B, S, H]
key_padding_mask = seq_len_to_mask([5, 4, 3, 2], max_len=5)
attention_mask = torch.tensor([1, 0, 0, 1, 0]) # 为1 的地方 mask 掉
head_mask = torch.tensor([0, 1]) # 为1 的地方 mask 掉
m = MultiHeadAttention(embed_dim=20, num_heads=2, dropout=0.0, output_attentions=True)
ao, aw = m(q, k, v, key_padding_mask=key_padding_mask, attention_mask=attention_mask, head_mask=head_mask)
print(ao.shape, aw.shape) # [B, L, H] [B, N, L, S]
print(ao)
print(aw.unbind(1))

View File

@ -132,17 +132,3 @@ def pad_adj(adj, max_len):
if __name__ == '__main__':
class Config():
num_layers = 3
input_size = 50
hidden_size = 100
dropout = 0.3
cfg = Config()
x = torch.randn(1, 10, 50)
adj = torch.empty(1, 10, 10).random_(2)
m = GCN(cfg)
print(m)
out = m(x, adj)
print(out.shape)
print(out)

View File

@ -72,27 +72,3 @@ class RNN(nn.Module):
return output, hn
if __name__ == '__main__':
class Config(object):
type_rnn = 'LSTM'
input_size = 5
hidden_size = 4
num_layers = 3
dropout = 0.0
last_layer_hn = False
bidirectional = True
config = Config()
model = RNN(config)
print(model)
torch.manual_seed(1)
x = torch.tensor([[4, 3, 2, 1], [5, 6, 7, 0], [8, 10, 0, 0]])
x = torch.nn.Embedding(11, 5, padding_idx=0)(x) # B,L,H = 3,4,5
x_len = torch.tensor([4, 3, 2])
o, h = model(x, x_len)
print(o.shape, h.shape, sep='\n\n')
print(o[-1].data, h[-1].data, sep='\n\n')

7
tools/__init__.py Normal file
View File

@ -0,0 +1,7 @@
from .dataset import *
from .metrics import *
from .predict import *
from .preprocess import *
from .serializer import *
from .trainer import *
from .vocab import *

View File

@ -3,13 +3,15 @@ import sys
import torch
import logging
import hydra
import models
from hydra import utils
from utils import load_pkl, load_csv
from serializer import Serializer
from preprocess import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
import matplotlib.pyplot as plt
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
import models
from utils import load_pkl, load_csv
logger = logging.getLogger(__name__)

View File

@ -5,8 +5,11 @@ from typing import List, Dict
from transformers import BertTokenizer
from serializer import Serializer
from vocab import Vocab
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
from utils import save_pkl, load_csv
logger = logging.getLogger(__name__)
@ -221,5 +224,4 @@ def preprocess(cfg):
logger.info('===== end preprocess data =====')
if __name__ == '__main__':
pass

View File

@ -270,15 +270,4 @@ class Serializer():
return False
if __name__ == '__main__':
text1 = "\t\n你 好呀, I\'m his pupp\'peer,\n\t"
text2 = '你孩子的爱情叫 Stam\'s 的打到天啊呢哦'
serializer = Serializer(do_chinese_split=False)
print(serializer.serialize(text1))
print(serializer.serialize(text2))
text3 = "good\'s head pupp\'er, "
# print: ["good's", 'pupp', "'", 'er', ',']
# true: ["good's", "pupp'er", ","]
print(serializer.serialize(text3, never_split=["pupp\'er"]))

View File

@ -113,13 +113,3 @@ class Vocab(object):
self.add_words(new_words)
if __name__ == '__main__':
vocab = Vocab('test')
sent = ' 我是中国人,我爱中国。'
sent = list(sent)
print(sent)
vocab.add_words(sent)
print(vocab.word2count)
vocab.trim(2)
print(vocab.word2count)