add ioutil

This commit is contained in:
leo 2020-05-05 22:50:07 +08:00
parent f419eab2fe
commit 4c23fa5fff
2 changed files with 79 additions and 1 deletions

View File

@ -21,6 +21,8 @@ class Embedding(nn.Module):
self.wordEmbed = nn.Embedding(self.vocab_size, self.word_dim, padding_idx=0)
self.headPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0)
self.tailPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0)
self.layer_norm = nn.LayerNorm(self.word_dim)
def forward(self, *x):
word, head, tail = x
@ -32,6 +34,6 @@ class Embedding(nn.Module):
return torch.cat((word_embedding, head_embedding, tail_embedding), -1)
elif self.dim_strategy == 'sum':
# 此时 pos_dim == word_dim
return word_embedding + head_embedding + tail_embedding
return self.layer_norm(word_embedding + head_embedding + tail_embedding)
else:
raise Exception('dim_strategy must choose from [sum, cat]')

View File

@ -1,5 +1,6 @@
import os
import csv
import json
import pickle
import logging
from typing import NewType, List, Tuple, Dict, Any
@ -9,6 +10,10 @@ __all__ = [
'save_pkl',
'load_csv',
'save_csv',
'load_jsonld',
'save_jsonld',
'jsonld2csv',
'csv2jsonld',
]
logger = logging.getLogger(__name__)
@ -54,3 +59,74 @@ def save_csv(data: List[Dict], fp: Path, save_in_tsv: False, write_head=True, ve
if write_head:
writer.writeheader()
writer.writerows(data)
def load_jsonld(fp: Path, verbose: bool = True) -> List:
if verbose:
logger.info(f'load jsonld from {fp}')
datas = []
with open(fp, encoding='utf-8') as f:
for l in f:
line = json.loads(l)
data = list(line.values())
datas.append(data)
return datas
def save_jsonld(fp):
pass
def jsonld2csv(fp: str, verbose: bool = True) -> str:
'''
读入 jsonld 文件存储在同位置同名的 csv 文件
:param fp: jsonld 文件地址
:param verbose: whether print logging
:return: csv 文件地址
'''
data = []
root, ext = os.path.splitext(fp)
fp_new = root + '.csv'
if verbose:
print(f'read jsonld file in: {fp}')
with open(fp, encoding='utf-8') as f:
for l in f:
line = json.loads(l)
data.append(line)
if verbose:
print('saving...')
with open(fp_new, 'w', encoding='utf-8') as f:
fieldnames = data[0].keys()
writer = csv.DictWriter(f, fieldnames=fieldnames, dialect='excel')
writer.writeheader()
writer.writerows(data)
if verbose:
print(f'saved csv file in: {fp_new}')
return fp_new
def csv2jsonld(fp: str, verbose: bool = True) -> str:
'''
读入 csv 文件存储为同位置同名的 jsonld 文件
:param fp: csv 文件地址
:param verbose: whether print logging
:return: jsonld 地址
'''
data = []
root, ext = os.path.splitext(fp)
fp_new = root + '.jsonld'
if verbose:
print(f'read csv file in: {fp}')
with open(fp, encoding='utf-8') as f:
writer = csv.DictReader(f, fieldnames=None, dialect='excel')
for line in writer:
data.append(line)
if verbose:
print('saving...')
with open(fp_new, 'w', encoding='utf-8') as f:
f.write(os.linesep.join([json.dumps(l, ensure_ascii=False) for l in data]))
if verbose:
print(f'saved jsonld file in: {fp_new}')
return fp_new