Spaces:
Running
Running
| import argparse | |
| import os | |
| from difflib import SequenceMatcher | |
| import Levenshtein | |
| import numpy as np | |
| from tqdm import tqdm | |
| from helpers import write_lines, read_parallel_lines, encode_verb_form, \ | |
| apply_reverse_transformation, SEQ_DELIMETERS, START_TOKEN | |
| def perfect_align(t, T, insertions_allowed=0, | |
| cost_function=Levenshtein.distance): | |
| # dp[i, j, k] is a minimal cost of matching first `i` tokens of `t` with | |
| # first `j` tokens of `T`, after making `k` insertions after last match of | |
| # token from `t`. In other words t[:i] aligned with T[:j]. | |
| # Initialize with INFINITY (unknown) | |
| shape = (len(t) + 1, len(T) + 1, insertions_allowed + 1) | |
| dp = np.ones(shape, dtype=int) * int(1e9) | |
| come_from = np.ones(shape, dtype=int) * int(1e9) | |
| come_from_ins = np.ones(shape, dtype=int) * int(1e9) | |
| dp[0, 0, 0] = 0 # The only known starting point. Nothing matched to nothing. | |
| for i in range(len(t) + 1): # Go inclusive | |
| for j in range(len(T) + 1): # Go inclusive | |
| for q in range(insertions_allowed + 1): # Go inclusive | |
| if i < len(t): | |
| # Given matched sequence of t[:i] and T[:j], match token | |
| # t[i] with following tokens T[j:k]. | |
| for k in range(j, len(T) + 1): | |
| transform = \ | |
| apply_transformation(t[i], ' '.join(T[j:k])) | |
| if transform: | |
| cost = 0 | |
| else: | |
| cost = cost_function(t[i], ' '.join(T[j:k])) | |
| current = dp[i, j, q] + cost | |
| if dp[i + 1, k, 0] > current: | |
| dp[i + 1, k, 0] = current | |
| come_from[i + 1, k, 0] = j | |
| come_from_ins[i + 1, k, 0] = q | |
| if q < insertions_allowed: | |
| # Given matched sequence of t[:i] and T[:j], create | |
| # insertion with following tokens T[j:k]. | |
| for k in range(j, len(T) + 1): | |
| cost = len(' '.join(T[j:k])) | |
| current = dp[i, j, q] + cost | |
| if dp[i, k, q + 1] > current: | |
| dp[i, k, q + 1] = current | |
| come_from[i, k, q + 1] = j | |
| come_from_ins[i, k, q + 1] = q | |
| # Solution is in the dp[len(t), len(T), *]. Backtracking from there. | |
| alignment = [] | |
| i = len(t) | |
| j = len(T) | |
| q = dp[i, j, :].argmin() | |
| while i > 0 or q > 0: | |
| is_insert = (come_from_ins[i, j, q] != q) and (q != 0) | |
| j, k, q = come_from[i, j, q], j, come_from_ins[i, j, q] | |
| if not is_insert: | |
| i -= 1 | |
| if is_insert: | |
| alignment.append(['INSERT', T[j:k], (i, i)]) | |
| else: | |
| alignment.append([f'REPLACE_{t[i]}', T[j:k], (i, i + 1)]) | |
| assert j == 0 | |
| return dp[len(t), len(T)].min(), list(reversed(alignment)) | |
| def _split(token): | |
| if not token: | |
| return [] | |
| parts = token.split() | |
| return parts or [token] | |
| def apply_merge_transformation(source_tokens, target_words, shift_idx): | |
| edits = [] | |
| if len(source_tokens) > 1 and len(target_words) == 1: | |
| # check merge | |
| transform = check_merge(source_tokens, target_words) | |
| if transform: | |
| for i in range(len(source_tokens) - 1): | |
| edits.append([(shift_idx + i, shift_idx + i + 1), transform]) | |
| return edits | |
| if len(source_tokens) == len(target_words) == 2: | |
| # check swap | |
| transform = check_swap(source_tokens, target_words) | |
| if transform: | |
| edits.append([(shift_idx, shift_idx + 1), transform]) | |
| return edits | |
| def is_sent_ok(sent, delimeters=SEQ_DELIMETERS): | |
| for del_val in delimeters.values(): | |
| if del_val in sent and del_val != delimeters["tokens"]: | |
| return False | |
| return True | |
| def check_casetype(source_token, target_token): | |
| if source_token.lower() != target_token.lower(): | |
| return None | |
| if source_token.lower() == target_token: | |
| return "$TRANSFORM_CASE_LOWER" | |
| elif source_token.capitalize() == target_token: | |
| return "$TRANSFORM_CASE_CAPITAL" | |
| elif source_token.upper() == target_token: | |
| return "$TRANSFORM_CASE_UPPER" | |
| elif source_token[1:].capitalize() == target_token[1:] and source_token[0] == target_token[0]: | |
| return "$TRANSFORM_CASE_CAPITAL_1" | |
| elif source_token[:-1].upper() == target_token[:-1] and source_token[-1] == target_token[-1]: | |
| return "$TRANSFORM_CASE_UPPER_-1" | |
| else: | |
| return None | |
| def check_equal(source_token, target_token): | |
| if source_token == target_token: | |
| return "$KEEP" | |
| else: | |
| return None | |
| def check_split(source_token, target_tokens): | |
| if source_token.split("-") == target_tokens: | |
| return "$TRANSFORM_SPLIT_HYPHEN" | |
| else: | |
| return None | |
| def check_merge(source_tokens, target_tokens): | |
| if "".join(source_tokens) == "".join(target_tokens): | |
| return "$MERGE_SPACE" | |
| elif "-".join(source_tokens) == "-".join(target_tokens): | |
| return "$MERGE_HYPHEN" | |
| else: | |
| return None | |
| def check_swap(source_tokens, target_tokens): | |
| if source_tokens == [x for x in reversed(target_tokens)]: | |
| return "$MERGE_SWAP" | |
| else: | |
| return None | |
| def check_plural(source_token, target_token): | |
| if source_token.endswith("s") and source_token[:-1] == target_token: | |
| return "$TRANSFORM_AGREEMENT_SINGULAR" | |
| elif target_token.endswith("s") and source_token == target_token[:-1]: | |
| return "$TRANSFORM_AGREEMENT_PLURAL" | |
| else: | |
| return None | |
| def check_verb(source_token, target_token): | |
| encoding = encode_verb_form(source_token, target_token) | |
| if encoding: | |
| return f"$TRANSFORM_VERB_{encoding}" | |
| else: | |
| return None | |
| def apply_transformation(source_token, target_token): | |
| target_tokens = target_token.split() | |
| if len(target_tokens) > 1: | |
| # check split | |
| transform = check_split(source_token, target_tokens) | |
| if transform: | |
| return transform | |
| checks = [check_equal, check_casetype, check_verb, check_plural] | |
| for check in checks: | |
| transform = check(source_token, target_token) | |
| if transform: | |
| return transform | |
| return None | |
| def align_sequences(source_sent, target_sent): | |
| # check if sent is OK | |
| if not is_sent_ok(source_sent) or not is_sent_ok(target_sent): | |
| return None | |
| source_tokens = source_sent.split() | |
| target_tokens = target_sent.split() | |
| matcher = SequenceMatcher(None, source_tokens, target_tokens) | |
| diffs = list(matcher.get_opcodes()) | |
| all_edits = [] | |
| for diff in diffs: | |
| tag, i1, i2, j1, j2 = diff | |
| source_part = _split(" ".join(source_tokens[i1:i2])) | |
| target_part = _split(" ".join(target_tokens[j1:j2])) | |
| if tag == 'equal': | |
| continue | |
| elif tag == 'delete': | |
| # delete all words separatly | |
| for j in range(i2 - i1): | |
| edit = [(i1 + j, i1 + j + 1), '$DELETE'] | |
| all_edits.append(edit) | |
| elif tag == 'insert': | |
| # append to the previous word | |
| for target_token in target_part: | |
| edit = ((i1 - 1, i1), f"$APPEND_{target_token}") | |
| all_edits.append(edit) | |
| else: | |
| # check merge first of all | |
| edits = apply_merge_transformation(source_part, target_part, | |
| shift_idx=i1) | |
| if edits: | |
| all_edits.extend(edits) | |
| continue | |
| # normalize alignments if need (make them singleton) | |
| _, alignments = perfect_align(source_part, target_part, | |
| insertions_allowed=0) | |
| for alignment in alignments: | |
| new_shift = alignment[2][0] | |
| edits = convert_alignments_into_edits(alignment, | |
| shift_idx=i1 + new_shift) | |
| all_edits.extend(edits) | |
| # get labels | |
| labels = convert_edits_into_labels(source_tokens, all_edits) | |
| # match tags to source tokens | |
| sent_with_tags = add_labels_to_the_tokens(source_tokens, labels) | |
| return sent_with_tags | |
| def convert_edits_into_labels(source_tokens, all_edits): | |
| # make sure that edits are flat | |
| flat_edits = [] | |
| for edit in all_edits: | |
| (start, end), edit_operations = edit | |
| if isinstance(edit_operations, list): | |
| for operation in edit_operations: | |
| new_edit = [(start, end), operation] | |
| flat_edits.append(new_edit) | |
| elif isinstance(edit_operations, str): | |
| flat_edits.append(edit) | |
| else: | |
| raise Exception("Unknown operation type") | |
| all_edits = flat_edits[:] | |
| labels = [] | |
| total_labels = len(source_tokens) + 1 | |
| if not all_edits: | |
| labels = [["$KEEP"] for x in range(total_labels)] | |
| else: | |
| for i in range(total_labels): | |
| edit_operations = [x[1] for x in all_edits if x[0][0] == i - 1 | |
| and x[0][1] == i] | |
| if not edit_operations: | |
| labels.append(["$KEEP"]) | |
| else: | |
| labels.append(edit_operations) | |
| return labels | |
| def convert_alignments_into_edits(alignment, shift_idx): | |
| edits = [] | |
| action, target_tokens, new_idx = alignment | |
| source_token = action.replace("REPLACE_", "") | |
| # check if delete | |
| if not target_tokens: | |
| edit = [(shift_idx, 1 + shift_idx), "$DELETE"] | |
| return [edit] | |
| # check splits | |
| for i in range(1, len(target_tokens)): | |
| target_token = " ".join(target_tokens[:i + 1]) | |
| transform = apply_transformation(source_token, target_token) | |
| if transform: | |
| edit = [(shift_idx, shift_idx + 1), transform] | |
| edits.append(edit) | |
| target_tokens = target_tokens[i + 1:] | |
| for target in target_tokens: | |
| edits.append([(shift_idx, shift_idx + 1), f"$APPEND_{target}"]) | |
| return edits | |
| transform_costs = [] | |
| transforms = [] | |
| for target_token in target_tokens: | |
| transform = apply_transformation(source_token, target_token) | |
| if transform: | |
| cost = 0 | |
| transforms.append(transform) | |
| else: | |
| cost = Levenshtein.distance(source_token, target_token) | |
| transforms.append(None) | |
| transform_costs.append(cost) | |
| min_cost_idx = transform_costs.index(min(transform_costs)) | |
| # append to the previous word | |
| for i in range(0, min_cost_idx): | |
| target = target_tokens[i] | |
| edit = [(shift_idx - 1, shift_idx), f"$APPEND_{target}"] | |
| edits.append(edit) | |
| # replace/transform target word | |
| transform = transforms[min_cost_idx] | |
| target = transform if transform is not None \ | |
| else f"$REPLACE_{target_tokens[min_cost_idx]}" | |
| edit = [(shift_idx, 1 + shift_idx), target] | |
| edits.append(edit) | |
| # append to this word | |
| for i in range(min_cost_idx + 1, len(target_tokens)): | |
| target = target_tokens[i] | |
| edit = [(shift_idx, 1 + shift_idx), f"$APPEND_{target}"] | |
| edits.append(edit) | |
| return edits | |
| def add_labels_to_the_tokens(source_tokens, labels, delimeters=SEQ_DELIMETERS): | |
| tokens_with_all_tags = [] | |
| source_tokens_with_start = [START_TOKEN] + source_tokens | |
| for token, label_list in zip(source_tokens_with_start, labels): | |
| all_tags = delimeters['operations'].join(label_list) | |
| comb_record = token + delimeters['labels'] + all_tags | |
| tokens_with_all_tags.append(comb_record) | |
| return delimeters['tokens'].join(tokens_with_all_tags) | |
| def convert_data_from_raw_files(source_file, target_file, output_file, chunk_size): | |
| tagged = [] | |
| source_data, target_data = read_parallel_lines(source_file, target_file) | |
| print(f"The size of raw dataset is {len(source_data)}") | |
| cnt_total, cnt_all, cnt_tp = 0, 0, 0 | |
| for source_sent, target_sent in tqdm(zip(source_data, target_data)): | |
| try: | |
| aligned_sent = align_sequences(source_sent, target_sent) | |
| except Exception: | |
| aligned_sent = align_sequences(source_sent, target_sent) | |
| if source_sent != target_sent: | |
| cnt_tp += 1 | |
| alignments = [aligned_sent] | |
| cnt_all += len(alignments) | |
| try: | |
| check_sent = convert_tagged_line(aligned_sent) | |
| except Exception: | |
| # debug mode | |
| aligned_sent = align_sequences(source_sent, target_sent) | |
| check_sent = convert_tagged_line(aligned_sent) | |
| if "".join(check_sent.split()) != "".join( | |
| target_sent.split()): | |
| # do it again for debugging | |
| aligned_sent = align_sequences(source_sent, target_sent) | |
| check_sent = convert_tagged_line(aligned_sent) | |
| print(f"Incorrect pair: \n{target_sent}\n{check_sent}") | |
| continue | |
| if alignments: | |
| cnt_total += len(alignments) | |
| tagged.extend(alignments) | |
| if len(tagged) > chunk_size: | |
| write_lines(output_file, tagged, mode='a') | |
| tagged = [] | |
| print(f"Overall extracted {cnt_total}. " | |
| f"Original TP {cnt_tp}." | |
| f" Original TN {cnt_all - cnt_tp}") | |
| if tagged: | |
| write_lines(output_file, tagged, 'a') | |
| def convert_labels_into_edits(labels): | |
| all_edits = [] | |
| for i, label_list in enumerate(labels): | |
| if label_list == ["$KEEP"]: | |
| continue | |
| else: | |
| edit = [(i - 1, i), label_list] | |
| all_edits.append(edit) | |
| return all_edits | |
| def get_target_sent_by_levels(source_tokens, labels): | |
| relevant_edits = convert_labels_into_edits(labels) | |
| target_tokens = source_tokens[:] | |
| leveled_target_tokens = {} | |
| if not relevant_edits: | |
| target_sentence = " ".join(target_tokens) | |
| return leveled_target_tokens, target_sentence | |
| max_level = max([len(x[1]) for x in relevant_edits]) | |
| for level in range(max_level): | |
| rest_edits = [] | |
| shift_idx = 0 | |
| for edits in relevant_edits: | |
| (start, end), label_list = edits | |
| label = label_list[0] | |
| target_pos = start + shift_idx | |
| source_token = target_tokens[target_pos] if target_pos >= 0 else START_TOKEN | |
| if label == "$DELETE": | |
| del target_tokens[target_pos] | |
| shift_idx -= 1 | |
| elif label.startswith("$APPEND_"): | |
| word = label.replace("$APPEND_", "") | |
| target_tokens[target_pos + 1: target_pos + 1] = [word] | |
| shift_idx += 1 | |
| elif label.startswith("$REPLACE_"): | |
| word = label.replace("$REPLACE_", "") | |
| target_tokens[target_pos] = word | |
| elif label.startswith("$TRANSFORM"): | |
| word = apply_reverse_transformation(source_token, label) | |
| if word is None: | |
| word = source_token | |
| target_tokens[target_pos] = word | |
| elif label.startswith("$MERGE_"): | |
| # apply merge only on last stage | |
| if level == (max_level - 1): | |
| target_tokens[target_pos + 1: target_pos + 1] = [label] | |
| shift_idx += 1 | |
| else: | |
| rest_edit = [(start + shift_idx, end + shift_idx), [label]] | |
| rest_edits.append(rest_edit) | |
| rest_labels = label_list[1:] | |
| if rest_labels: | |
| rest_edit = [(start + shift_idx, end + shift_idx), rest_labels] | |
| rest_edits.append(rest_edit) | |
| leveled_tokens = target_tokens[:] | |
| # update next step | |
| relevant_edits = rest_edits[:] | |
| if level == (max_level - 1): | |
| leveled_tokens = replace_merge_transforms(leveled_tokens) | |
| leveled_labels = convert_edits_into_labels(leveled_tokens, | |
| relevant_edits) | |
| leveled_target_tokens[level + 1] = {"tokens": leveled_tokens, | |
| "labels": leveled_labels} | |
| target_sentence = " ".join(leveled_target_tokens[max_level]["tokens"]) | |
| return leveled_target_tokens, target_sentence | |
| def replace_merge_transforms(tokens): | |
| if all(not x.startswith("$MERGE_") for x in tokens): | |
| return tokens | |
| target_tokens = tokens[:] | |
| allowed_range = (1, len(tokens) - 1) | |
| for i in range(len(tokens)): | |
| target_token = tokens[i] | |
| if target_token.startswith("$MERGE"): | |
| if target_token.startswith("$MERGE_SWAP") and i in allowed_range: | |
| target_tokens[i - 1] = tokens[i + 1] | |
| target_tokens[i + 1] = tokens[i - 1] | |
| target_tokens[i: i + 1] = [] | |
| target_line = " ".join(target_tokens) | |
| target_line = target_line.replace(" $MERGE_HYPHEN ", "-") | |
| target_line = target_line.replace(" $MERGE_SPACE ", "") | |
| return target_line.split() | |
| def convert_tagged_line(line, delimeters=SEQ_DELIMETERS): | |
| label_del = delimeters['labels'] | |
| source_tokens = [x.split(label_del)[0] | |
| for x in line.split(delimeters['tokens'])][1:] | |
| labels = [x.split(label_del)[1].split(delimeters['operations']) | |
| for x in line.split(delimeters['tokens'])] | |
| assert len(source_tokens) + 1 == len(labels) | |
| levels_dict, target_line = get_target_sent_by_levels(source_tokens, labels) | |
| return target_line | |
| def main(args): | |
| convert_data_from_raw_files(args.source, args.target, args.output_file, args.chunk_size) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-s', '--source', | |
| help='Path to the source file', | |
| required=True) | |
| parser.add_argument('-t', '--target', | |
| help='Path to the target file', | |
| required=True) | |
| parser.add_argument('-o', '--output_file', | |
| help='Path to the output file', | |
| required=True) | |
| parser.add_argument('--chunk_size', | |
| type=int, | |
| help='Dump each chunk size.', | |
| default=1000000) | |
| args = parser.parse_args() | |
| main(args) | |