reconstruct
This commit is contained in:
parent
55ab9f29c9
commit
7ceb1252e7
|
@ -125,17 +125,3 @@ class MultiHeadAttention(nn.Module):
|
|||
return attention_out,
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from utils import seq_len_to_mask
|
||||
|
||||
q = torch.randn(4, 6, 20) # [B, L, H]
|
||||
k = v = torch.randn(4, 5, 20) # [B, S, H]
|
||||
key_padding_mask = seq_len_to_mask([5, 4, 3, 2], max_len=5)
|
||||
attention_mask = torch.tensor([1, 0, 0, 1, 0]) # 为1 的地方 mask 掉
|
||||
head_mask = torch.tensor([0, 1]) # 为1 的地方 mask 掉
|
||||
|
||||
m = MultiHeadAttention(embed_dim=20, num_heads=2, dropout=0.0, output_attentions=True)
|
||||
ao, aw = m(q, k, v, key_padding_mask=key_padding_mask, attention_mask=attention_mask, head_mask=head_mask)
|
||||
print(ao.shape, aw.shape) # [B, L, H] [B, N, L, S]
|
||||
print(ao)
|
||||
print(aw.unbind(1))
|
||||
|
|
|
@ -132,17 +132,3 @@ def pad_adj(adj, max_len):
|
|||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
class Config():
|
||||
num_layers = 3
|
||||
input_size = 50
|
||||
hidden_size = 100
|
||||
dropout = 0.3
|
||||
cfg = Config()
|
||||
x = torch.randn(1, 10, 50)
|
||||
adj = torch.empty(1, 10, 10).random_(2)
|
||||
m = GCN(cfg)
|
||||
print(m)
|
||||
out = m(x, adj)
|
||||
print(out.shape)
|
||||
print(out)
|
|
@ -72,27 +72,3 @@ class RNN(nn.Module):
|
|||
return output, hn
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
class Config(object):
|
||||
type_rnn = 'LSTM'
|
||||
input_size = 5
|
||||
hidden_size = 4
|
||||
num_layers = 3
|
||||
dropout = 0.0
|
||||
last_layer_hn = False
|
||||
bidirectional = True
|
||||
|
||||
config = Config()
|
||||
model = RNN(config)
|
||||
print(model)
|
||||
|
||||
torch.manual_seed(1)
|
||||
x = torch.tensor([[4, 3, 2, 1], [5, 6, 7, 0], [8, 10, 0, 0]])
|
||||
x = torch.nn.Embedding(11, 5, padding_idx=0)(x) # B,L,H = 3,4,5
|
||||
x_len = torch.tensor([4, 3, 2])
|
||||
|
||||
o, h = model(x, x_len)
|
||||
|
||||
print(o.shape, h.shape, sep='\n\n')
|
||||
print(o[-1].data, h[-1].data, sep='\n\n')
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
from .dataset import *
|
||||
from .metrics import *
|
||||
from .predict import *
|
||||
from .preprocess import *
|
||||
from .serializer import *
|
||||
from .trainer import *
|
||||
from .vocab import *
|
|
@ -3,13 +3,15 @@ import sys
|
|||
import torch
|
||||
import logging
|
||||
import hydra
|
||||
import models
|
||||
from hydra import utils
|
||||
from utils import load_pkl, load_csv
|
||||
from serializer import Serializer
|
||||
from preprocess import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
import models
|
||||
from utils import load_pkl, load_csv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -5,8 +5,11 @@ from typing import List, Dict
|
|||
from transformers import BertTokenizer
|
||||
from serializer import Serializer
|
||||
from vocab import Vocab
|
||||
import sys
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
from utils import save_pkl, load_csv
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -221,5 +224,4 @@ def preprocess(cfg):
|
|||
logger.info('===== end preprocess data =====')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
|
|
@ -270,15 +270,4 @@ class Serializer():
|
|||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
text1 = "\t\n你 好呀, I\'m his pupp\'peer,\n\t"
|
||||
text2 = '你孩子的爱情叫 Stam\'s 的打到天啊呢哦'
|
||||
|
||||
serializer = Serializer(do_chinese_split=False)
|
||||
print(serializer.serialize(text1))
|
||||
print(serializer.serialize(text2))
|
||||
|
||||
text3 = "good\'s head pupp\'er, "
|
||||
# print: ["good's", 'pupp', "'", 'er', ',']
|
||||
# true: ["good's", "pupp'er", ","]
|
||||
print(serializer.serialize(text3, never_split=["pupp\'er"]))
|
|
@ -113,13 +113,3 @@ class Vocab(object):
|
|||
self.add_words(new_words)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
vocab = Vocab('test')
|
||||
sent = ' 我是中国人,我爱中国。'
|
||||
sent = list(sent)
|
||||
print(sent)
|
||||
|
||||
vocab.add_words(sent)
|
||||
print(vocab.word2count)
|
||||
vocab.trim(2)
|
||||
print(vocab.word2count)
|
Loading…
Reference in New Issue