add ioutil
This commit is contained in:
parent
f419eab2fe
commit
4c23fa5fff
|
@ -21,6 +21,8 @@ class Embedding(nn.Module):
|
|||
self.wordEmbed = nn.Embedding(self.vocab_size, self.word_dim, padding_idx=0)
|
||||
self.headPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0)
|
||||
self.tailPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(self.word_dim)
|
||||
|
||||
def forward(self, *x):
|
||||
word, head, tail = x
|
||||
|
@ -32,6 +34,6 @@ class Embedding(nn.Module):
|
|||
return torch.cat((word_embedding, head_embedding, tail_embedding), -1)
|
||||
elif self.dim_strategy == 'sum':
|
||||
# 此时 pos_dim == word_dim
|
||||
return word_embedding + head_embedding + tail_embedding
|
||||
return self.layer_norm(word_embedding + head_embedding + tail_embedding)
|
||||
else:
|
||||
raise Exception('dim_strategy must choose from [sum, cat]')
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import csv
|
||||
import json
|
||||
import pickle
|
||||
import logging
|
||||
from typing import NewType, List, Tuple, Dict, Any
|
||||
|
@ -9,6 +10,10 @@ __all__ = [
|
|||
'save_pkl',
|
||||
'load_csv',
|
||||
'save_csv',
|
||||
'load_jsonld',
|
||||
'save_jsonld',
|
||||
'jsonld2csv',
|
||||
'csv2jsonld',
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -54,3 +59,74 @@ def save_csv(data: List[Dict], fp: Path, save_in_tsv: False, write_head=True, ve
|
|||
if write_head:
|
||||
writer.writeheader()
|
||||
writer.writerows(data)
|
||||
|
||||
|
||||
def load_jsonld(fp: Path, verbose: bool = True) -> List:
|
||||
if verbose:
|
||||
logger.info(f'load jsonld from {fp}')
|
||||
|
||||
datas = []
|
||||
with open(fp, encoding='utf-8') as f:
|
||||
for l in f:
|
||||
line = json.loads(l)
|
||||
data = list(line.values())
|
||||
datas.append(data)
|
||||
|
||||
return datas
|
||||
|
||||
|
||||
def save_jsonld(fp):
|
||||
pass
|
||||
|
||||
|
||||
def jsonld2csv(fp: str, verbose: bool = True) -> str:
|
||||
'''
|
||||
读入 jsonld 文件,存储在同位置同名的 csv 文件
|
||||
:param fp: jsonld 文件地址
|
||||
:param verbose: whether print logging
|
||||
:return: csv 文件地址
|
||||
'''
|
||||
data = []
|
||||
root, ext = os.path.splitext(fp)
|
||||
fp_new = root + '.csv'
|
||||
if verbose:
|
||||
print(f'read jsonld file in: {fp}')
|
||||
with open(fp, encoding='utf-8') as f:
|
||||
for l in f:
|
||||
line = json.loads(l)
|
||||
data.append(line)
|
||||
if verbose:
|
||||
print('saving...')
|
||||
with open(fp_new, 'w', encoding='utf-8') as f:
|
||||
fieldnames = data[0].keys()
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames, dialect='excel')
|
||||
writer.writeheader()
|
||||
writer.writerows(data)
|
||||
if verbose:
|
||||
print(f'saved csv file in: {fp_new}')
|
||||
return fp_new
|
||||
|
||||
|
||||
def csv2jsonld(fp: str, verbose: bool = True) -> str:
|
||||
'''
|
||||
读入 csv 文件,存储为同位置同名的 jsonld 文件
|
||||
:param fp: csv 文件地址
|
||||
:param verbose: whether print logging
|
||||
:return: jsonld 地址
|
||||
'''
|
||||
data = []
|
||||
root, ext = os.path.splitext(fp)
|
||||
fp_new = root + '.jsonld'
|
||||
if verbose:
|
||||
print(f'read csv file in: {fp}')
|
||||
with open(fp, encoding='utf-8') as f:
|
||||
writer = csv.DictReader(f, fieldnames=None, dialect='excel')
|
||||
for line in writer:
|
||||
data.append(line)
|
||||
if verbose:
|
||||
print('saving...')
|
||||
with open(fp_new, 'w', encoding='utf-8') as f:
|
||||
f.write(os.linesep.join([json.dumps(l, ensure_ascii=False) for l in data]))
|
||||
if verbose:
|
||||
print(f'saved jsonld file in: {fp_new}')
|
||||
return fp_new
|
||||
|
|
Loading…
Reference in New Issue