import gc
from typing import Optional, Iterator, Callable

import torch
from datasets import load_dataset
from litgpt.tokenizer import Tokenizer
from transformers import AutoTokenizer


def batch_text_iterator(kind: str,
                        path: str,
                        name: Optional[str]=None,
                        data_dir: Optional[str]=None,
                        data_files: Optional[str]=None,
                        keep_in_memory: bool=False,
                        revision: Optional[str]=None,
                        split: str='train',
                        num_proc: Optional[int]=None,
                        format: Optional[Callable|str]=None) -> Iterator[str]:
    assert isinstance(format, str) or callable(format), f'{path=} {format=}'
    assert kind == 'base'

    dataset = load_dataset(path=path,
                           name=name,
                           data_dir=data_dir,
                           data_files=data_files,
                           keep_in_memory=keep_in_memory,
                           revision=revision,
                           split=split,
                           trust_remote_code=True,
                           num_proc=num_proc)

    if callable(format):
        for row in dataset:
            text = format(row)

            if not text:
                continue

            yield text
    else:
        for row in dataset:
            text = format.format(**row)

            if not text:
                continue

            yield text

    del dataset
    gc.collect()


def batch_chat_iterator(kind: str,
                        path: str,
                        name: Optional[str]=None,
                        data_dir: Optional[str]=None,
                        data_files: Optional[str]=None,
                        keep_in_memory: bool=False,
                        revision: Optional[str]=None,
                        split: str='train',
                        num_proc: Optional[int]=None,
                        field: Optional[str]=None,
                        transform: Optional[Callable]=None) -> Iterator[list[dict[str, str]]]:
    assert kind == 'instruct'

    dataset = load_dataset(path=path,
                           name=name,
                           data_dir=data_dir,
                           data_files=data_files,
                           keep_in_memory=keep_in_memory,
                           revision=revision,
                           split=split,
                           trust_remote_code=True,
                           num_proc=num_proc)

    if callable(transform):
        for row in dataset:
            if field:
                messages = transform(row[field])
            else:
                messages = transform(row)

            if not messages:
                continue

            yield messages
    else:
        for row in dataset:
            if field:
                messages = row[field]
            else:
                raise ValueError(field)

            if not messages:
                continue

            yield messages

    del dataset
    gc.collect()


# NOTE: used only by tokenizer trainer
def batch_dataset_iterator(dataset_config: dict) -> Iterator[str]:
    if dataset_config['kind'] == 'base':
        for text in batch_text_iterator(**dataset_config):
           yield text
    elif dataset_config['kind'] == 'instruct':
        for messages in batch_chat_iterator(**dataset_config):
            text = '\n'.join(n['content'] for n in messages)
            yield text


def tokenize_text_fn(dataset_config: dict, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer) -> Iterator[torch.Tensor]:
    for text in batch_text_iterator(**dataset_config):
        text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=True)
        yield text_ids


def tokenize_chat_fn(dataset_config: dict, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer) -> Iterator[torch.Tensor]:
    for messages in batch_chat_iterator(**dataset_config):
        text: str = hf_tokenizer.apply_chat_template(messages, tokenize=False)
        text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=False)
        yield text_ids


def tokenize_fn(dataset_config: dict, min_len: int, max_len: int, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer) -> Iterator[torch.Tensor]:
    if dataset_config['kind'] == 'base':
        for text in batch_text_iterator(**dataset_config):
            try:
                text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=True)
            except Exception as e:
                print(f'Skip base raw: {e=} {type(text)=} {text=}')
                continue

            if min_len <= len(text_ids) <= max_len:
                yield text_ids
    elif dataset_config['kind'] == 'instruct':
        for messages in batch_chat_iterator(**dataset_config):
            try:
                text: str = hf_tokenizer.apply_chat_template(messages, tokenize=False)
                text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=False)
            except Exception as e:
                print(f'Skip instruct row: {e=} {type(messages)=} {messages=}')
                continue

            if min_len <= len(text_ids) <= max_len:
                yield text_ids
    else:
        raise ValueError(dataset_config['kind'])