from typing import Optional, Iterator, Callable, Any

import torch
from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer


def load_text_dataset(tokenizer: AutoTokenizer,
                      kind: str,
                      path: str,
                      name: Optional[str]=None,
                      data_dir: Optional[str]=None,
                      data_files: Optional[str]=None,
                      keep_in_memory: bool=False,
                      revision: Optional[str]=None,
                      split: str='train',
                      num_proc: Optional[int]=None,
                      format: Optional[Callable|str]=None) -> Any:
    assert isinstance(format, str) or callable(format), f'{path=} {format=}'
    assert kind == 'base'

    dataset = load_dataset(path=path,
                           name=name,
                           data_dir=data_dir,
                           data_files=data_files,
                           keep_in_memory=keep_in_memory,
                           revision=revision,
                           split=split,
                           trust_remote_code=True,
                           num_proc=num_proc)

    EOS_TOKEN = tokenizer.eos_token

    def format_dataset(batch):
        nonlocal EOS_TOKEN
        nonlocal format
        texts: list = []
        rows = [dict(zip(batch.keys(), values)) for values in zip(*batch.values())]

        if callable(format):
            for row in rows:
                # print(f'{row=}')
                text = format(row)

                if not text:
                    text = '[NONE]'

                text += EOS_TOKEN
                texts.append(text)
        else:
            for row in rows:
                # print(f'{row=}')
                text = format.format(**row)

                if not text:
                    text = '[NONE]'

                text += EOS_TOKEN
                texts.append(text)

        return {'text': texts}

    dataset = dataset.map(format_dataset, batched=True)
    return dataset


def load_chat_dataset(tokenizer: AutoTokenizer,
                      kind: str,
                      path: str,
                      name: Optional[str]=None,
                      data_dir: Optional[str]=None,
                      data_files: Optional[str]=None,
                      keep_in_memory: bool=False,
                      revision: Optional[str]=None,
                      split: str='train',
                      num_proc: Optional[int]=None,
                      field: Optional[str]=None,
                      transform: Optional[Callable]=None) -> Any:
    assert kind == 'instruct'

    dataset = load_dataset(path=path,
                           name=name,
                           data_dir=data_dir,
                           data_files=data_files,
                           keep_in_memory=keep_in_memory,
                           revision=revision,
                           split=split,
                           trust_remote_code=True,
                           num_proc=num_proc)

    EOS_TOKEN = tokenizer.eos_token

    def format_dataset(batch):
        nonlocal EOS_TOKEN
        nonlocal tokenizer
        nonlocal field
        nonlocal transform
        texts: list = []
        rows = [dict(zip(batch.keys(), values)) for values in zip(*batch.values())]

        if callable(transform):
            for row in rows:
                if field:
                    messages = transform(row[field])
                else:
                    messages = transform(row)

                text = tokenizer.apply_chat_template(messages, tokenize=False)
                text += EOS_TOKEN
                texts.append(text)
        else:
            for row in rows:
                if field:
                    messages = row[field]
                else:
                    raise ValueError(field)

                text = tokenizer.apply_chat_template(messages, tokenize=False)
                text += EOS_TOKEN
                texts.append(text)

        return {'text': texts}

    dataset = dataset.map(format_dataset, batched=True)
    return dataset