deepke/serializer.py

204 lines
7.2 KiB
Python
Raw Normal View History

2019-12-03 18:47:25 +08:00
import re
import unicodedata
import jieba
import logging
from typing import List
logger = logging.getLogger(__name__)
jieba.setLogLevel(logging.INFO)
class Serializer():
def __init__(self, never_split: List = None, do_lower_case=True, do_chinese_split=False):
self.never_split = never_split if never_split is not None else []
self.do_lower_case = do_lower_case
self.do_chinese_split = do_chinese_split
def serialize(self, text, never_split: List = None):
never_split = self.never_split + (never_split if never_split is not None else [])
text = self._clean_text(text)
if self.do_chinese_split:
output_tokens = self._use_jieba_cut(text, never_split)
return output_tokens
text = self._tokenize_chinese_chars(text)
orig_tokens = self._orig_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case and token not in never_split:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token, never_split=never_split))
output_tokens = self._whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or self.is_control(char):
continue
if self.is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
def _use_jieba_cut(self, text, never_split):
for word in never_split:
jieba.suggest_freq(word, True)
tokens = jieba.lcut(text)
if self.do_lower_case:
tokens = [i.lower() for i in tokens]
try:
while True:
tokens.remove(' ')
except:
return tokens
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self.is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _orig_tokenize(self, text):
"""Splits text on whitespace and some punctuations like comma or period"""
text = text.strip()
if not text:
return []
# 常见的断句标点
punc = """,.?!;: 、|,。?!;:《》「」【】/<>|\“ ”‘ """
punc_re = '|'.join(re.escape(x) for x in punc)
tokens = re.sub(punc_re, lambda x: ' ' + x.group() + ' ', text)
tokens = tokens.split()
return tokens
def _whitespace_tokenize(self, text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text, never_split=None):
"""Splits punctuation on a piece of text."""
if never_split is not None and text in never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if self.is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
@staticmethod
def is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
@staticmethod
def is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
@staticmethod
def is_chinese_char(cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or (cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
@staticmethod
def is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96)
or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
if __name__ == '__main__':
text1 = "\t\n你 好呀, I\'m his pupp\'peer,\n\t"
text2 = '你孩子的爱情叫 Stam\'s 的打到天啊呢哦'
serializer = Serializer(do_chinese_split=False)
print(serializer.serialize(text1))
print(serializer.serialize(text2))
text3 = "good\'s head pupp\'er, "
# print: ["good's", 'pupp', "'", 'er', ',']
# true: ["good's", "pupp'er", ","]
print(serializer.serialize(text3, never_split=["pupp\'er"]))