#!/usr/bin/env python
# coding=utf-8
import re
from tqdm import tqdm

from datasets import load_dataset, interleave_datasets, concatenate_datasets

TEXT_COLUMN_NAME = "text"
AUDIO_COLUMN_NAME = "audio"
CHARS_TO_IGNORE_REGEX = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\–\_\\\+\#\/0-9]'

# Pre-processing dataset
def replace_hatted_characters(batch):
    text = batch["text"]
    text = re.sub(CHARS_TO_IGNORE_REGEX, '', text).lower() + ' '
    text = re.sub('[áàâ]', 'a', text)
    text = re.sub('[ä]', 'æ', text)
    text = re.sub('[éèëê]', 'e', text)
    text = re.sub('[íìïî]', 'i', text)
    text = re.sub('[óòöô]', 'o', text)
    text = re.sub('[ö]', 'ø', text)
    text = re.sub('[ç]', 'c', text)
    text = re.sub('[úùüû]', 'u', text)
    text = re.sub('\xa0', ' ', text)
    text = re.sub('<ee>', 'eee', text)
    text = re.sub('<qq>', 'qqq', text)
    text = re.sub('<mm>', 'mmm', text)
    text = re.sub('<inaudible>', 'xxx', text)
    text = re.sub('[<>]', '', text)
    text = re.sub(r'\s+', ' ', text)
    return {"text": text}


def main():
    npsc = load_dataset(
        "NbAiLab/NPSC",
        "16K_mp3",
        split="train+validation",
        use_auth_token=True,
    )
    ncc = load_dataset(
        "NbAiLab/NCC",
        split="train+validation",
        use_auth_token=True
    )
    dataset = concatenate_datasets([npsc, ncc])
    dataset = dataset.map(
        replace_hatted_characters,
        desc="replacing hesitations and homophones",
    )

    # Create file with all text together
    text_count = len(dataset)
    with open("text.txt", "w") as text_file:
        for idx, text in tqdm(enumerate(dataset["text"]), total=text_count, desc="Writing text"):
            if idx == text_count:
                text_file.write(text)
            else:
                text_file.write(text + " ")

    # Create KenLM model
    !~/bin/lmplz -o 5 --text text.txt --arpa 5gram.arpa.orig -T $(pwd)

    # Adjusting for Huggingface decoding
    with open("5gram.arpa.orig", "r") as read_file, open("5gram.arpa", "w") as write_file:
        has_added_eos = False
        for line in read_file:
          if not has_added_eos and "ngram 1=" in line:
            count=line.strip().split("=")[-1]
            write_file.write(line.replace(f"{count}", f"{int(count)+1}"))
          elif not has_added_eos and "<s>" in line:
            write_file.write(line)
            write_file.write(line.replace("<s>", "</s>"))
            has_added_eos = True
          else:
            write_file.write(line)

    # Compress as binary
    !~/bin/build_binary 5gram.arpa 5gram.bin -T $(pwd)
    !rm 5gram.arpa*


if __name__ == "__main__":
    main()