deepke/serializer.py

204 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import unicodedata
import jieba
import logging
from typing import List
logger = logging.getLogger(__name__)
jieba.setLogLevel(logging.INFO)
class Serializer():
def __init__(self, never_split: List = None, do_lower_case=True, do_chinese_split=False):
self.never_split = never_split if never_split is not None else []
self.do_lower_case = do_lower_case
self.do_chinese_split = do_chinese_split
def serialize(self, text, never_split: List = None):
never_split = self.never_split + (never_split if never_split is not None else [])
text = self._clean_text(text)
if self.do_chinese_split:
output_tokens = self._use_jieba_cut(text, never_split)
return output_tokens
text = self._tokenize_chinese_chars(text)
orig_tokens = self._orig_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case and token not in never_split:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token, never_split=never_split))
output_tokens = self._whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or self.is_control(char):
continue
if self.is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
def _use_jieba_cut(self, text, never_split):
for word in never_split:
jieba.suggest_freq(word, True)
tokens = jieba.lcut(text)
if self.do_lower_case:
tokens = [i.lower() for i in tokens]
try:
while True:
tokens.remove(' ')
except:
return tokens
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self.is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _orig_tokenize(self, text):
"""Splits text on whitespace and some punctuations like comma or period"""
text = text.strip()
if not text:
return []
# 常见的断句标点
punc = """,.?!;: 、|,。?!;:《》「」【】/<>|\“ ”‘ """
punc_re = '|'.join(re.escape(x) for x in punc)
tokens = re.sub(punc_re, lambda x: ' ' + x.group() + ' ', text)
tokens = tokens.split()
return tokens
def _whitespace_tokenize(self, text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text, never_split=None):
"""Splits punctuation on a piece of text."""
if never_split is not None and text in never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if self.is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
@staticmethod
def is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
@staticmethod
def is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
@staticmethod
def is_chinese_char(cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or (cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
@staticmethod
def is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96)
or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
if __name__ == '__main__':
text1 = "\t\n你 好呀, I\'m his pupp\'peer,\n\t"
text2 = '你孩子的爱情叫 Stam\'s 的打到天啊呢哦'
serializer = Serializer(do_chinese_split=False)
print(serializer.serialize(text1))
print(serializer.serialize(text2))
text3 = "good\'s head pupp\'er, "
# print: ["good's", 'pupp', "'", 'er', ',']
# true: ["good's", "pupp'er", ","]
print(serializer.serialize(text3, never_split=["pupp\'er"]))