File size: 10,963 Bytes
bcc0c7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235


import re
import bw2ar
import torch
import xer

# Diacritics
FATHATAN = u'\u064b'
DAMMATAN = u'\u064c'
KASRATAN = u'\u064d'
FATHA = u'\u064e'
DAMMA = u'\u064f'
KASRA = u'\u0650'
SHADDA = u'\u0651'
SUKUN = u'\u0652'
TATWEEL = u'\u0640'

HARAKAT_PAT = re.compile(u"["+u"".join([FATHATAN, DAMMATAN, KASRATAN,
                                        FATHA, DAMMA, KASRA, SUKUN,
                                        SHADDA])+u"]")


class TashkeelTokenizer:

    def __init__(self):
        self.letters = [' ', '$', '&', "'", '*', '<', '>', 'A', 'D', 'E', 'H', 'S', 'T', 'Y', 'Z',
                        'b', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 't',
                        'v', 'w', 'x', 'y', 'z', '|', '}'
                       ]
        self.letters = ['<PAD>', '<BOS>', '<EOS>'] + self.letters + ['<MASK>']

        self.no_tashkeel_tag = '<NT>'
        self.tashkeel_list = ['<NT>', '<SD>', '<SDD>', '<SF>', '<SFF>', '<SK>',
                               '<SKK>', 'F', 'K', 'N', 'a', 'i', 'o', 'u', '~']

        self.tashkeel_list = ['<PAD>', '<BOS>', '<EOS>'] + self.tashkeel_list

        self.tashkeel_map = {c:i for i,c in enumerate(self.tashkeel_list)}
        self.letters_map = {c:i for i,c in enumerate(self.letters)}
        self.inverse_tags = {
                 '~a': '<SF>',  # shaddah and fatHa
                 '~u': '<SD>',  # shaddah and Damma
                 '~i': '<SK>',  # shaddah and kasra
                 '~F': '<SFF>', # shaddah and fatHatayn
                 '~N': '<SDD>', # shaddah and Dammatayn
                 '~K': '<SKK>'  # shaddah and kasratayn
        }
        self.tags = {v:k for k,v in self.inverse_tags.items()}
        self.shaddah_last  = ['a~', 'u~', 'i~', 'F~', 'N~', 'K~']
        self.shaddah_first = ['~a', '~u', '~i', '~F', '~N', '~K']
        self.tahkeel_chars = ['F','N','K','a', 'u', 'i', '~', 'o']


    def clean_text(self, text):
        text = re.sub(u'[%s]' % u'\u0640', '', text) # strip tatweel
        text = text.replace('ٱ', 'ا')
        return ' '.join(re.sub(u"[^\u0621-\u063A\u0640-\u0652\u0670\u0671\ufefb\ufef7\ufef5\ufef9 ]", " ", text,  flags=re.UNICODE).split())


    def check_match(self, text_with_tashkeel, letter_n_tashkeel_pairs):
        text_with_tashkeel = text_with_tashkeel.strip()
        # test if the reconstructed text with tashkeel is the same as the original one
        syn_text = self.combine_tashkeel_with_text(letter_n_tashkeel_pairs)
        return syn_text == text_with_tashkeel or syn_text == self.unify_shaddah_position(text_with_tashkeel)


    def unify_shaddah_position(self, text_with_tashkeel):
        # unify the order of shaddah and the harakah to make shaddah always at the beginning
        for i in range(len(self.shaddah_first)):
            text_with_tashkeel = text_with_tashkeel.replace(self.shaddah_last[i], self.shaddah_first[i])
        return text_with_tashkeel


    def split_tashkeel_from_text(self, text_with_tashkeel, test_match=True):
        text_with_tashkeel = self.clean_text(text_with_tashkeel)
        text_with_tashkeel = bw2ar.transliterate_text(text_with_tashkeel, 'ar2bw')
        text_with_tashkeel = text_with_tashkeel.replace('`', '') # remove dagger 'alif

        # unify the order of shaddah and the harakah to make shaddah always at the beginning
        text_with_tashkeel = self.unify_shaddah_position(text_with_tashkeel)

        # remove duplicated harakat
        for i in range(len(self.tahkeel_chars)):
            text_with_tashkeel = text_with_tashkeel.replace(self.tahkeel_chars[i]*2, self.tahkeel_chars[i])

        letter_n_tashkeel_pairs = []
        for i in range(len(text_with_tashkeel)): # go over the whole text
            # check if the first character is a normal letter and the second character is a tashkeel
            if i < (len(text_with_tashkeel) - 1) and not text_with_tashkeel[i] in self.tashkeel_list and text_with_tashkeel[i+1] in self.tashkeel_list:
                # IMPORTANT: check if tashkeel is Shaddah, then there might be another Tashkeel char associated with it. If so,
                # replace both Shaddah and the Tashkeel chars with the appropriate tag
                if text_with_tashkeel[i+1] == '~':
                    # IMPORTANT: the following if statement depends on the concept of short circuit!!
                    # The first condition checks if there are still more chars before it access position i+2
                    # "text_with_tashkeel[i+2]" since it causes "index out of range" exception. Notice that
                    # Shaddah here is put in the first position before the Harakah.
                    if i+2 < len(text_with_tashkeel) and f'~{text_with_tashkeel[i+2]}' in self.inverse_tags:
                        letter_n_tashkeel_pairs.append((text_with_tashkeel[i], self.inverse_tags[f'~{text_with_tashkeel[i+2]}']))
                    else:
                        # if it is only Shaddah, just add it to the list
                        letter_n_tashkeel_pairs.append((text_with_tashkeel[i], '~'))
                else:
                    letter_n_tashkeel_pairs.append((text_with_tashkeel[i], text_with_tashkeel[i+1]))
            # if the character at position i is a normal letter and has no Tashkeel, then add
            # it with the tag "self.no_tashkeel_tag"
            # IMPORTANT: this elif block ensures also that there is no two or more consecutive tashkeel other than shaddah
            elif not text_with_tashkeel[i] in self.tashkeel_list:
                letter_n_tashkeel_pairs.append((text_with_tashkeel[i], self.no_tashkeel_tag))

        if test_match:
            # test if the split is done correctly by ensuring that we can retrieve back the original text
            assert self.check_match(text_with_tashkeel, letter_n_tashkeel_pairs)
        return [('<BOS>', '<BOS>')] + letter_n_tashkeel_pairs + [('<EOS>', '<EOS>')]


    def combine_tashkeel_with_text(self, letter_n_tashkeel_pairs):
        combined_with_tashkeel = []
        for letter, tashkeel in letter_n_tashkeel_pairs:
            combined_with_tashkeel.append(letter)
            if tashkeel in self.tags:
                combined_with_tashkeel.append(self.tags[tashkeel])
            elif tashkeel != self.no_tashkeel_tag:
                combined_with_tashkeel.append(tashkeel)
        text = ''.join(combined_with_tashkeel)
        return text


    def encode(self, text_with_tashkeel, test_match=True):
        letter_n_tashkeel_pairs = self.split_tashkeel_from_text(text_with_tashkeel, test_match)
        text, tashkeel = zip(*letter_n_tashkeel_pairs)
        input_ids = [self.letters_map[c] for c in text]
        target_ids = [self.tashkeel_map[c] for c in tashkeel]
        return torch.LongTensor(input_ids), torch.LongTensor(target_ids)


    def filter_tashkeel(self, tashkeel):
        tmp = []
        for i, t in enumerate(tashkeel):
            if i != 0 and t == '<BOS>':
                t = self.no_tashkeel_tag
            elif i != (len(tashkeel) - 1) and t == '<EOS>':
                t = self.no_tashkeel_tag
            tmp.append(t)
        tashkeel = tmp
        return tashkeel


    def decode(self, input_ids, target_ids):
#        print('input_ids.shape:', input_ids.shape)
#        print('target_ids.shape:', target_ids.shape)
        input_ids = input_ids.cpu().tolist()
        target_ids = target_ids.cpu().tolist()
        ar_texts = []
        for j in range(len(input_ids)):
            letters = [self.letters[i] for i in input_ids[j]]
            tashkeel = [self.tashkeel_list[i] for i in target_ids[j]]

            letters = list(filter(lambda x: x != '<BOS>' and x != '<EOS>' and x != '<PAD>', letters))
            tashkeel = self.filter_tashkeel(tashkeel)
            tashkeel = list(filter(lambda x: x != '<BOS>' and x != '<EOS>' and x != '<PAD>', tashkeel))

            # VERY IMPORTANT NOTE: zip takes min(len(letters), len(tashkeel)) and discard the reset of letters / tashkeels
            letter_n_tashkeel_pairs = list(zip(letters, tashkeel))
            bw_text = self.combine_tashkeel_with_text(letter_n_tashkeel_pairs)
            ar_text = bw2ar.transliterate_text(bw_text, 'bw2ar')
            ar_texts.append(ar_text)
        return ar_texts

    def get_tashkeel_with_case_ending(self, text, case_ending=True):
        text_split = self.split_tashkeel_from_text(text, test_match=False)
        text_spaces_indecies = [i for i, el in enumerate(text_split) if el == (' ', '<NT>')]
        new_text_split = []
        for i, el in enumerate(text_split):
            if not case_ending and (i+1) in text_spaces_indecies:
                el = (el[0], '<NT>') # no case ending
            new_text_split.append(el)
        letters, tashkeel = zip(*new_text_split)
        return letters, tashkeel


    def compute_der(self, ref, hyp, case_ending=True):
        _, ref_tashkeel = self.get_tashkeel_with_case_ending(ref, case_ending=case_ending)
        _, hyp_tashkeel = self.get_tashkeel_with_case_ending(hyp, case_ending=case_ending)
        ref_tashkeel = ' '.join(ref_tashkeel)
        hyp_tashkeel = ' '.join(hyp_tashkeel)
        return xer.wer(ref_tashkeel, hyp_tashkeel)

    def compute_wer(self, ref, hyp, case_ending=True):
        ref_letters, ref_tashkeel = self.get_tashkeel_with_case_ending(ref, case_ending=case_ending)
        hyp_letters, hyp_tashkeel = self.get_tashkeel_with_case_ending(hyp, case_ending=case_ending)
        ref_text_combined = self.combine_tashkeel_with_text(zip(ref_letters, ref_tashkeel))
        hyp_text_combined = self.combine_tashkeel_with_text(zip(hyp_letters, hyp_tashkeel))
        return xer.wer(ref_text_combined, hyp_text_combined)

    def remove_tashkeel(self, text):
        text = HARAKAT_PAT.sub('', text)
        text = re.sub(u"[\u064E]", "", text,  flags=re.UNICODE) # fattha
        text = re.sub(u"[\u0671]", "", text,  flags=re.UNICODE) # waSla
        return text



if __name__ == '__main__':
    import utils
    from tqdm import tqdm
    tokenizer = TashkeelTokenizer()

    txt_folder_path = 'dataset/train'
    prepared_lines = []
    for filepath in utils.get_files(txt_folder_path, '*.txt'):
        print(f'Reading file: {filepath}')
        with open(filepath) as f1:
            for line in f1:
                clean_line = tokenizer.clean_text(line)
                if clean_line != '':
                    prepared_lines.append(clean_line)
        print(f'completed file: {filepath}')

    good_sentences = []
    bad_sentences = []
    tokenized_sentences = []
    for line in tqdm(prepared_lines):
        try:
            letter_n_tashkeel_pairs = tokenizer.split_tashkeel_from_text(line, test_match=True)
            tokenized_sentences.append(letter_n_tashkeel_pairs)
            good_sentences.append(line)
        except AssertionError as e:
            bad_sentences.append(line)

    print('len(good_sentences), len(bad_sentences):', len(good_sentences), len(bad_sentences))