import json
import os
import random
import shutil
from datetime import datetime

import langdetect
import nltk
import pandas as pd
from langdetect import DetectorFactory, LangDetectException

DATA_FILE = "data/crashes/thecrashes_data_all_text.json"

DEV_PORTION = .10

random.seed(2001)
DetectorFactory.seed = 0


def is_a_real_time(timestamp):
    """Helper function, checks if a given timestamp really has a time"""

    # 00:00:00 (midnight) is the "empty" timestamp, ignore it
    if timestamp.hour == timestamp.minute == timestamp.second == 0:
        return False
    return True


def main():
    process_events()


def detect_language(article):
    if article["alltext"]:
        sample = article["alltext"]
    elif article["summary"]:
        sample = article["summary"]
    else:
        sample = article["title"]

    try:
        return langdetect.detect(sample)
    except LangDetectException:
        print(f"\tCould not detect language for text_id={article['id']}")
        print(f"\tSample={sample})")
        print()
        return "UNK_LANG"


def extract_text_info(event):
    ev_text_lines = []
    ev_id_lines = []
    ev_meta_rows = []

    for article in event["articles"]:
        text_id = article["id"]
        try:
            pubdate = datetime.fromisoformat(article["publishedtime"]).strftime("%Y-%m-%d %H:%M:%S")
        except ValueError:
            print(f"\t\tcould not parse date {article['publishedtime']}")
            pubdate = None
        url = article["url"]
        provider = article["sitename"]
        title = article["title"]
        language = detect_language(article)
        ev_meta_rows.append({
            "event_id": event["id"],
            "text_id": text_id,
            "pubdate": pubdate,
            "language": language,
            "url": url,
            "provider": provider,
            "title": title
        })

        summary = article["summary"]
        body = article["alltext"]

        text_lines = []
        id_lines = []

        for line in segment(title, language):
            text_lines.append(line)
            id_lines.append(f"event {event['id']}\ttext {text_id}\ttitle")

        for line in segment(summary, language):
            text_lines.append(line)
            id_lines.append(f"event {event['id']}\ttext {text_id}\tsummary")

        for line in segment(body, language):
            text_lines.append(line)
            id_lines.append(f"event {event['id']}\ttext {text_id}\tbody")

        ev_text_lines.append(text_lines)
        ev_id_lines.append(id_lines)

    return ev_text_lines, ev_id_lines, ev_meta_rows


def segment(text, language):
    # don't split Hebrew and Vietnamese (because we don't have a segmenter for it)
    if language in ["he", "vi"]:
        return text

    lang_map = {
        "nl": "dutch",
        "en": "english",
        "es": "spanish",
        "de": "german",
        "fr": "french",
        "ru": "russian",
        "pt": "portuguese"
    }

    nltk_lang = lang_map.get(language)

    # what to do with languages without sent tokenizer in NLTK (apart from Hebrew):
    if not nltk_lang:
        if language == "af":
            # treat Afrikaans as Dutch
            nltk_lang = "dutch"
        else:
            print(f"Found an article with unsupported language={language}, falling back to English NLTK")
            nltk_lang = "english"

    return nltk.sent_tokenize(text, nltk_lang)


def write_to_text_by_event(text_lines, text_meta_lines, event_id, split_to_dir, split):
    event_dir = f"{split_to_dir[split]}/{event_id}"
    os.makedirs(event_dir, exist_ok=True)
    for art_lines, row in zip(text_lines, text_meta_lines):
        text_file = f"{event_dir}/{row['text_id']}.txt"
        with open(text_file, "w", encoding="utf-8") as f:
            for line in art_lines:
                f.write(line + os.linesep)


def process_events():
    print("Loading data file...")
    with open(DATA_FILE, encoding="utf-8") as f:
        data = json.load(f)
    event_all_rows = []
    event_dev_rows = []
    event_main_rows = []

    text_all_rows = []
    text_dev_rows = []
    text_main_rows = []

    # make empty text files
    text_file_basenames = {
        "all": "output/crashes/split_data/all.texts",
        "dev": "output/crashes/split_data/split_dev10.texts",
        "main": "output/crashes/split_data/split_main.texts"
    }
    for split, bn in text_file_basenames.items():
        for ext in [".text.txt", ".ids.txt"]:
            f = open(f"{bn}{ext}", "w", encoding="utf-8")
            f.close()

    # clear & make text file directories
    text_files_by_event_dir = {}
    for split in ["all", "dev", "main"]:
        prefix = "split_dev10" if split == "dev" else "split_main" if split == "main" else "all"
        text_dir = f"output/crashes/split_data/{prefix}_texts_by_event"
        text_files_by_event_dir[split] = text_dir
        if os.path.exists(text_dir):
            shutil.rmtree(text_dir)
        os.mkdir(text_dir)

    # helper function for writing text files
    def append_to_txt(txt_file, lines):
        with open(txt_file, "a", encoding="utf-8") as f_out:
            for art_lines in lines:
                for line in art_lines:
                    f_out.write(line + os.linesep)

    print("Processing events...")
    for event in data:
        event_id = event["id"]
        print(f"\tevent_id={event_id}")
        try:
            timestamp = datetime.fromisoformat(event["date"])
        except ValueError:
            timestamp = None

        event_row = {
            "event:id": event_id,
            "event:date": timestamp.strftime("%Y-%m-%d") if timestamp else None,
            "event:time": timestamp.strftime("%H-%M-%S") if timestamp and is_a_real_time(timestamp) else None,
            "event:coordinates": f"{event['latitude'], event['longitude']}",
            "vehicle_involved": 1 if any(p for p in event["persons"] if p["transportationmode"] in range(5, 14)) else 0
        }

        for health, health_code in (("dead", 3), ("injured", 2)):
            all_with_health = [p for p in event["persons"] if p["health"] == health_code]
            event_row[f"outcomes:{health}:total"] = len(all_with_health)
            event_row[f"outcomes:{health}:child"] = len([p for p in all_with_health if p["child"] == 1])
            for mode, mode_codes in (("pedestrian", [1]), ("cyclist", [2]), ("vehicle", range(5, 14))):
                event_row[f"outcomes:{health}:{mode}"] = len([p for p in all_with_health
                                                              if p["transportationmode"] in mode_codes])

        text_lines, text_id_lines, text_meta_rows = extract_text_info(event)

        event_all_rows.append(event_row)
        text_all_rows.extend(text_meta_rows)
        append_to_txt(text_file_basenames["all"] + ".text.txt", text_lines)
        append_to_txt(text_file_basenames["all"] + ".ids.txt", text_id_lines)
        write_to_text_by_event(text_lines, text_meta_rows, event_id, text_files_by_event_dir, "all")

        if random.random() < DEV_PORTION:
            event_dev_rows.append(event_row)
            text_dev_rows.extend(text_meta_rows)
            append_to_txt(text_file_basenames["dev"] + ".text.txt", text_lines)
            append_to_txt(text_file_basenames["dev"] + ".ids.txt", text_id_lines)
            write_to_text_by_event(text_lines, text_meta_rows, event_id, text_files_by_event_dir, "dev")

        else:
            event_main_rows.append(event_row)
            text_main_rows.extend(text_meta_rows)
            append_to_txt(text_file_basenames["main"] + ".text.txt", text_lines)
            append_to_txt(text_file_basenames["main"] + ".ids.txt", text_id_lines)
            write_to_text_by_event(text_lines, text_meta_rows, event_id, text_files_by_event_dir, "main")

    all_ev_df = pd.DataFrame(event_all_rows)
    main_ev_df = pd.DataFrame(event_main_rows)
    dev_ev_df = pd.DataFrame(event_dev_rows)
    for df, file in ((all_ev_df, "all.events"), (main_ev_df, "split_main.events"), (dev_ev_df, "split_dev10.events")):
        df.to_csv(f"output/crashes/split_data/{file}.csv")

    all_txt_df = pd.DataFrame(text_all_rows)
    main_txt_df = pd.DataFrame(text_main_rows)
    dev_txt_df = pd.DataFrame(text_dev_rows)
    for df, file in ((all_txt_df, "all.texts"), (main_txt_df, "split_main.texts"), (dev_txt_df, "split_dev10.texts")):
        df.to_csv(f"output/crashes/split_data/{file}.meta.csv")


if __name__ == '__main__':
    main()