import os
import shutil

from transformers import PreTrainedTokenizerFast
from tokenizers import Tokenizer, normalizers, pre_tokenizers, processors, decoders
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer

from utils import batch_dataset_iterator
from base_datasets import base_datasets
from base_instruct_datasets import base_instruct_datasets


tokenizer_path = '../tokenizer'

if os.path.exists(tokenizer_path):
    shutil.rmtree(tokenizer_path)

os.makedirs(tokenizer_path, exist_ok=True)

#
# special_tokens
#
bos_token = '<|endoftext|>'
eos_token = '<|im_end|>'
pad_token = '<|pad|>'
unk_token = '<|unk|>'

special_tokens = [
    bos_token,
    eos_token,
    pad_token,
    unk_token,
    '<|im_start|>',
    '<|im_sep|>',
    'system',
    'user',
    'assistant',
    '<tools>',
    '</tools>',
    '<tool>',
    '</tool>',
    '<tool_call>',
    '</tool_call>',
    '<tool_response>',
    '</tool_response>',
    '<question>',
    '</question>',
    '<think>',
    '</think>',
    '<answer>',
    '</answer>',
]

for i in range(64 - len(special_tokens)):
    special_tokens.append(f'<|reserved_{i}|>')

#
# BPE Tokenizer
#
bpe = BPE(unk_token=None, byte_fallback=True)
tokenizer = Tokenizer(bpe)

# normalizer
tokenizer.normalizer = None

# pre-tokenizer
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False, trim_offsets=True, use_regex=True)

# post-processor
tokenizer.post_processor = processors.ByteLevel(add_prefix_space=True, trim_offsets=False, use_regex=True)

# decoder
tokenizer.decoder = decoders.ByteLevel(add_prefix_space=True, trim_offsets=True, use_regex=True)

#
# BPE Trainer
#
trainer = BpeTrainer(
    vocab_size=65536, # 64 * 1024
    min_frequency=3,
    special_tokens=special_tokens,
    max_token_length=16,
)

tokenizer_datasets = base_datasets + base_instruct_datasets

tokenizer.train_from_iterator(
    (batch_dataset_iterator(n) for n in tokenizer_datasets),
    trainer,
)

tokenizer.save(os.path.join(tokenizer_path, 'tokenizer.json'))
tokenizer.model.save(tokenizer_path)

#
# PreTrainedTokenizerFast
#
CHAT_TEMPLATE = (
    "{% for message in messages %}"
        "{{'<|im_start|>' + message['role'] + '<|im_sep|>' + message['content'] + '<|im_end|>'}}"
    "{% endfor %}"

    "{% if add_generation_prompt %}"
        "{{ '<|im_start|>assistant<|im_sep|>' }}"
    "{% endif %}"
)

fast_tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=tokenizer,
    chat_template=CHAT_TEMPLATE,
    bos_token=bos_token,
    eos_token=eos_token,
    pad_token=pad_token,
    unk_token=unk_token,
    clean_up_tokenization_spaces=False,
)

fast_tokenizer.save_pretrained(tokenizer_path)