|
|
|
|
|
import re |
|
import bw2ar |
|
import torch |
|
import xer |
|
|
|
|
|
FATHATAN = u'\u064b' |
|
DAMMATAN = u'\u064c' |
|
KASRATAN = u'\u064d' |
|
FATHA = u'\u064e' |
|
DAMMA = u'\u064f' |
|
KASRA = u'\u0650' |
|
SHADDA = u'\u0651' |
|
SUKUN = u'\u0652' |
|
TATWEEL = u'\u0640' |
|
|
|
HARAKAT_PAT = re.compile(u"["+u"".join([FATHATAN, DAMMATAN, KASRATAN, |
|
FATHA, DAMMA, KASRA, SUKUN, |
|
SHADDA])+u"]") |
|
|
|
|
|
class TashkeelTokenizer: |
|
|
|
def __init__(self): |
|
self.letters = [' ', '$', '&', "'", '*', '<', '>', 'A', 'D', 'E', 'H', 'S', 'T', 'Y', 'Z', |
|
'b', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 't', |
|
'v', 'w', 'x', 'y', 'z', '|', '}' |
|
] |
|
self.letters = ['<PAD>', '<BOS>', '<EOS>'] + self.letters + ['<MASK>'] |
|
|
|
self.no_tashkeel_tag = '<NT>' |
|
self.tashkeel_list = ['<NT>', '<SD>', '<SDD>', '<SF>', '<SFF>', '<SK>', |
|
'<SKK>', 'F', 'K', 'N', 'a', 'i', 'o', 'u', '~'] |
|
|
|
self.tashkeel_list = ['<PAD>', '<BOS>', '<EOS>'] + self.tashkeel_list |
|
|
|
self.tashkeel_map = {c:i for i,c in enumerate(self.tashkeel_list)} |
|
self.letters_map = {c:i for i,c in enumerate(self.letters)} |
|
self.inverse_tags = { |
|
'~a': '<SF>', |
|
'~u': '<SD>', |
|
'~i': '<SK>', |
|
'~F': '<SFF>', |
|
'~N': '<SDD>', |
|
'~K': '<SKK>' |
|
} |
|
self.tags = {v:k for k,v in self.inverse_tags.items()} |
|
self.shaddah_last = ['a~', 'u~', 'i~', 'F~', 'N~', 'K~'] |
|
self.shaddah_first = ['~a', '~u', '~i', '~F', '~N', '~K'] |
|
self.tahkeel_chars = ['F','N','K','a', 'u', 'i', '~', 'o'] |
|
|
|
|
|
def clean_text(self, text): |
|
text = re.sub(u'[%s]' % u'\u0640', '', text) |
|
text = text.replace('ٱ', 'ا') |
|
return ' '.join(re.sub(u"[^\u0621-\u063A\u0640-\u0652\u0670\u0671\ufefb\ufef7\ufef5\ufef9 ]", " ", text, flags=re.UNICODE).split()) |
|
|
|
|
|
def check_match(self, text_with_tashkeel, letter_n_tashkeel_pairs): |
|
text_with_tashkeel = text_with_tashkeel.strip() |
|
|
|
syn_text = self.combine_tashkeel_with_text(letter_n_tashkeel_pairs) |
|
return syn_text == text_with_tashkeel or syn_text == self.unify_shaddah_position(text_with_tashkeel) |
|
|
|
|
|
def unify_shaddah_position(self, text_with_tashkeel): |
|
|
|
for i in range(len(self.shaddah_first)): |
|
text_with_tashkeel = text_with_tashkeel.replace(self.shaddah_last[i], self.shaddah_first[i]) |
|
return text_with_tashkeel |
|
|
|
|
|
def split_tashkeel_from_text(self, text_with_tashkeel, test_match=True): |
|
text_with_tashkeel = self.clean_text(text_with_tashkeel) |
|
text_with_tashkeel = bw2ar.transliterate_text(text_with_tashkeel, 'ar2bw') |
|
text_with_tashkeel = text_with_tashkeel.replace('`', '') |
|
|
|
|
|
text_with_tashkeel = self.unify_shaddah_position(text_with_tashkeel) |
|
|
|
|
|
for i in range(len(self.tahkeel_chars)): |
|
text_with_tashkeel = text_with_tashkeel.replace(self.tahkeel_chars[i]*2, self.tahkeel_chars[i]) |
|
|
|
letter_n_tashkeel_pairs = [] |
|
for i in range(len(text_with_tashkeel)): |
|
|
|
if i < (len(text_with_tashkeel) - 1) and not text_with_tashkeel[i] in self.tashkeel_list and text_with_tashkeel[i+1] in self.tashkeel_list: |
|
|
|
|
|
if text_with_tashkeel[i+1] == '~': |
|
|
|
|
|
|
|
|
|
if i+2 < len(text_with_tashkeel) and f'~{text_with_tashkeel[i+2]}' in self.inverse_tags: |
|
letter_n_tashkeel_pairs.append((text_with_tashkeel[i], self.inverse_tags[f'~{text_with_tashkeel[i+2]}'])) |
|
else: |
|
|
|
letter_n_tashkeel_pairs.append((text_with_tashkeel[i], '~')) |
|
else: |
|
letter_n_tashkeel_pairs.append((text_with_tashkeel[i], text_with_tashkeel[i+1])) |
|
|
|
|
|
|
|
elif not text_with_tashkeel[i] in self.tashkeel_list: |
|
letter_n_tashkeel_pairs.append((text_with_tashkeel[i], self.no_tashkeel_tag)) |
|
|
|
if test_match: |
|
|
|
assert self.check_match(text_with_tashkeel, letter_n_tashkeel_pairs) |
|
return [('<BOS>', '<BOS>')] + letter_n_tashkeel_pairs + [('<EOS>', '<EOS>')] |
|
|
|
|
|
def combine_tashkeel_with_text(self, letter_n_tashkeel_pairs): |
|
combined_with_tashkeel = [] |
|
for letter, tashkeel in letter_n_tashkeel_pairs: |
|
combined_with_tashkeel.append(letter) |
|
if tashkeel in self.tags: |
|
combined_with_tashkeel.append(self.tags[tashkeel]) |
|
elif tashkeel != self.no_tashkeel_tag: |
|
combined_with_tashkeel.append(tashkeel) |
|
text = ''.join(combined_with_tashkeel) |
|
return text |
|
|
|
|
|
def encode(self, text_with_tashkeel, test_match=True): |
|
letter_n_tashkeel_pairs = self.split_tashkeel_from_text(text_with_tashkeel, test_match) |
|
text, tashkeel = zip(*letter_n_tashkeel_pairs) |
|
input_ids = [self.letters_map[c] for c in text] |
|
target_ids = [self.tashkeel_map[c] for c in tashkeel] |
|
return torch.LongTensor(input_ids), torch.LongTensor(target_ids) |
|
|
|
|
|
def filter_tashkeel(self, tashkeel): |
|
tmp = [] |
|
for i, t in enumerate(tashkeel): |
|
if i != 0 and t == '<BOS>': |
|
t = self.no_tashkeel_tag |
|
elif i != (len(tashkeel) - 1) and t == '<EOS>': |
|
t = self.no_tashkeel_tag |
|
tmp.append(t) |
|
tashkeel = tmp |
|
return tashkeel |
|
|
|
|
|
def decode(self, input_ids, target_ids): |
|
|
|
|
|
input_ids = input_ids.cpu().tolist() |
|
target_ids = target_ids.cpu().tolist() |
|
ar_texts = [] |
|
for j in range(len(input_ids)): |
|
letters = [self.letters[i] for i in input_ids[j]] |
|
tashkeel = [self.tashkeel_list[i] for i in target_ids[j]] |
|
|
|
letters = list(filter(lambda x: x != '<BOS>' and x != '<EOS>' and x != '<PAD>', letters)) |
|
tashkeel = self.filter_tashkeel(tashkeel) |
|
tashkeel = list(filter(lambda x: x != '<BOS>' and x != '<EOS>' and x != '<PAD>', tashkeel)) |
|
|
|
|
|
letter_n_tashkeel_pairs = list(zip(letters, tashkeel)) |
|
bw_text = self.combine_tashkeel_with_text(letter_n_tashkeel_pairs) |
|
ar_text = bw2ar.transliterate_text(bw_text, 'bw2ar') |
|
ar_texts.append(ar_text) |
|
return ar_texts |
|
|
|
def get_tashkeel_with_case_ending(self, text, case_ending=True): |
|
text_split = self.split_tashkeel_from_text(text, test_match=False) |
|
text_spaces_indecies = [i for i, el in enumerate(text_split) if el == (' ', '<NT>')] |
|
new_text_split = [] |
|
for i, el in enumerate(text_split): |
|
if not case_ending and (i+1) in text_spaces_indecies: |
|
el = (el[0], '<NT>') |
|
new_text_split.append(el) |
|
letters, tashkeel = zip(*new_text_split) |
|
return letters, tashkeel |
|
|
|
|
|
def compute_der(self, ref, hyp, case_ending=True): |
|
_, ref_tashkeel = self.get_tashkeel_with_case_ending(ref, case_ending=case_ending) |
|
_, hyp_tashkeel = self.get_tashkeel_with_case_ending(hyp, case_ending=case_ending) |
|
ref_tashkeel = ' '.join(ref_tashkeel) |
|
hyp_tashkeel = ' '.join(hyp_tashkeel) |
|
return xer.wer(ref_tashkeel, hyp_tashkeel) |
|
|
|
def compute_wer(self, ref, hyp, case_ending=True): |
|
ref_letters, ref_tashkeel = self.get_tashkeel_with_case_ending(ref, case_ending=case_ending) |
|
hyp_letters, hyp_tashkeel = self.get_tashkeel_with_case_ending(hyp, case_ending=case_ending) |
|
ref_text_combined = self.combine_tashkeel_with_text(zip(ref_letters, ref_tashkeel)) |
|
hyp_text_combined = self.combine_tashkeel_with_text(zip(hyp_letters, hyp_tashkeel)) |
|
return xer.wer(ref_text_combined, hyp_text_combined) |
|
|
|
def remove_tashkeel(self, text): |
|
text = HARAKAT_PAT.sub('', text) |
|
text = re.sub(u"[\u064E]", "", text, flags=re.UNICODE) |
|
text = re.sub(u"[\u0671]", "", text, flags=re.UNICODE) |
|
return text |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
import utils |
|
from tqdm import tqdm |
|
tokenizer = TashkeelTokenizer() |
|
|
|
txt_folder_path = 'dataset/train' |
|
prepared_lines = [] |
|
for filepath in utils.get_files(txt_folder_path, '*.txt'): |
|
print(f'Reading file: {filepath}') |
|
with open(filepath) as f1: |
|
for line in f1: |
|
clean_line = tokenizer.clean_text(line) |
|
if clean_line != '': |
|
prepared_lines.append(clean_line) |
|
print(f'completed file: {filepath}') |
|
|
|
good_sentences = [] |
|
bad_sentences = [] |
|
tokenized_sentences = [] |
|
for line in tqdm(prepared_lines): |
|
try: |
|
letter_n_tashkeel_pairs = tokenizer.split_tashkeel_from_text(line, test_match=True) |
|
tokenized_sentences.append(letter_n_tashkeel_pairs) |
|
good_sentences.append(line) |
|
except AssertionError as e: |
|
bad_sentences.append(line) |
|
|
|
print('len(good_sentences), len(bad_sentences):', len(good_sentences), len(bad_sentences)) |
|
|
|
|
|
|
|
|