"""
## convert to gguf

python convert_hf_to_gguf.py /workspace/xusong/huggingface/models/Qwen2-0.5B-Instruct/

## predict
./llama-cli -m /workspace/xusong/huggingface/models/Qwen1.5-0.5B-Chat/Qwen1.5-0.5B-Chat-F16.gguf -p "I believe the meaning of life is" -n 128
./llama-cli -m /workspace/xusong/huggingface/models/Qwen1.5-0.5B-Chat/Qwen1.5-0.5B-Chat-F16.gguf -f prompt.txt -n 128
./llama-cli -m /workspace/xusong/huggingface/models/Qwen1.5-0.5B-Chat/Qwen1.5-0.5B-Chat-F16.gguf -p "You are a helpful assistant" -cnv


## timing


**重庆GPU服务器,cache为空 **
llama_print_timings:        load time =    1711.48 ms
llama_print_timings:      sample time =      73.89 ms /    41 runs   (    1.80 ms per token,   554.84 tokens per second)
llama_print_timings: prompt eval time =    2621.25 ms /     5 tokens (  524.25 ms per token,     1.91 tokens per second)   # 0.2-0.5秒/token
llama_print_timings:        eval time =    1430.91 ms /    40 runs   (   35.77 ms per token,    27.95 tokens per second)
llama_print_timings:       total time =    4848.09 ms /    45 tokens

llama_print_timings:        load time =    1939.72 ms
llama_print_timings:      sample time =     286.69 ms /   170 runs   (    1.69 ms per token,   592.99 tokens per second)
llama_print_timings: prompt eval time =       0.00 ms /     0 tokens (    -nan ms per token,     -nan tokens per second)  # warmup后,加速明显。
llama_print_timings:        eval time =    5737.50 ms /   170 runs   (   33.75 ms per token,    29.63 tokens per second)
llama_print_timings:       total time =    8219.82 ms /   170 tokens


**hf-space,cache为空 (关闭GGML_BLAS) ** -----------
llama_print_timings:        load time =   28230.06 ms
llama_print_timings:      sample time =     147.58 ms /     8 runs   (   18.45 ms per token,    54.21 tokens per second)   # 18ms/token
llama_print_timings: prompt eval time =   28864.82 ms /     5 tokens ( 5772.96 ms per token,     0.17 tokens per second)   # 5.7s/token
llama_print_timings:        eval time =    1557.94 ms /     7 runs   (  222.56 ms per token,     4.49 tokens per second)
llama_print_timings:       total time =   30753.48 ms /    12 tokens


**hf-space,cache为空 (开启GGML_BLAS)** -----------
llama_print_timings:        load time =   27347.29 ms
llama_print_timings:      sample time =      82.53 ms /    26 runs   (    3.17 ms per token,   315.05 tokens per second)   # 3ms/token
llama_print_timings: prompt eval time =   28855.64 ms /     9 tokens ( 3206.18 ms per token,     0.31 tokens per second)   # 3s/token
llama_print_timings:        eval time =    9810.01 ms /    25 runs   (  392.40 ms per token,     2.55 tokens per second)
llama_print_timings:       total time =   39073.77 ms /    34 tokens

llama_print_timings:        load time =   27347.29 ms
llama_print_timings:      sample time =     272.12 ms /    96 runs   (    2.83 ms per token,   352.79 tokens per second)   # 2.8ms/token
llama_print_timings: prompt eval time =       0.00 ms /     0 tokens (    -nan ms per token,     -nan tokens per second)
llama_print_timings:        eval time =   19974.85 ms /    96 runs   (  208.07 ms per token,     4.81 tokens per second)
llama_print_timings:       total time =   22517.08 ms /    96 tokens


## TODO:

- 解决warmup慢的问题
- 支持cache,并提前对所有预设system进行cache。

## reference

- https://github.com/abetlen/llama-cpp-python/blob/main/examples/gradio_chat/local.py
- https://github.com/awinml/llama-cpp-python-bindings
- https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/llms/llamacpp.py
- https://github.com/abetlen/llama-cpp-python/blob/main/examples/gradio_chat/server.py
- https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/server/model.py
- https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/server/app.py
"""

import json
import copy
import os

from models.base_model import Simulator
import llama_cpp
from transformers import AutoTokenizer
from utils.logging_util import logger
import config


class Qwen2Simulator(Simulator):

    def __init__(self):
        local_path = "/workspace/xusong/huggingface/models/Qwen2-0.5B-Instruct-GGUF/qwen2-0_5b-instruct-fp16.gguf"
        if os.path.exists(local_path):
            self.hf_tokenizer = AutoTokenizer.from_pretrained(
                "/workspace/xusong/huggingface/models/Qwen2-0.5B-Instruct/")
            self.llm = llama_cpp.Llama(  # n_ctx, n_threads
                model_path=local_path,
                # 默认的tokenizer有bug,tokenize后的id不同
                tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer(self.hf_tokenizer),
                n_ctx=config.MAX_SEQUENCE_LENGTH,  #
                # n_threads=None, # 默认会根据cpu数来设置 n_threads
                # use_mlock=True,
                verbose=True,
            )
        else:
            self.hf_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
            self.llm = llama_cpp.Llama.from_pretrained(
                repo_id="Qwen/Qwen2-0.5B-Instruct-GGUF",
                tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer(self.hf_tokenizer),
                filename="*fp16.gguf",
                n_ctx=config.MAX_SEQUENCE_LENGTH,
                # use_mlock=True,
                verbose=True,
            )
        logger.info(f"llm has been initialized: {self.llm}, "
                    f"n_threads={self.llm.n_threads}, n_ctx={self.llm.n_ctx}, "
                    f"env[CACHE]={os.environ.get('CACHE', None)}")

        self.stop_words = [
            "<|im_end|>",
            "<|im_start|>",
            "<|endoftext|>",
        ]
        self.stop_tokens = self.tokenize("".join(self.stop_words))
        self.generation_kwargs = dict(
            temperature=config.DEFAULT_TEMPERATURE,
            top_p=config.DEFAULT_TOP_P,
            top_k=config.DEFAULT_TOP_K,
            max_tokens=config.DEFAULT_MAX_NEW_TOKENS,
            repeat_penalty=1.1,
            # qwen2-0.5b-chat 有时内容生成结束没有<|im_end|>,直接跟 <|im_start|>
            stop=self.stop_words,
        )

        self.user_start_tokens = self.tokenize("<|im_start|>user\n")
        self.assistant_start_tokens = self.tokenize("<|im_start|>assistant\n")
        # self.llm.generate  .set_cache   .last_n_tokens_size  .reset  .ctx ._ctx

        # cache = llama_cpp.LlamaDiskCache(capacity_bytes=cache_size)
        cache = llama_cpp.LlamaRAMCache(capacity_bytes=2 << 30) # 2G
        self.llm.set_cache(cache)

    def tokenize(self, text):
        return self.llm.tokenize(text.encode("utf-8"))

    def detokenize(self, tokens):
        return self.llm.detokenize(tokens).decode("utf-8")

    def strip_stoptokens(self, tokens):
        while tokens and tokens[0] in self.stop_tokens:
            logger.info(f"head-striping {tokens[0]} {self.detokenize([tokens[0]])}")
            tokens.pop(0)
        while tokens and tokens[-1] in self.stop_tokens:
            logger.info(f"tail-striping {tokens[-1]} {self.detokenize([tokens[-1]])}")
            tokens.pop()
        return tokens

    def generate(self, history, stream=True):
        """
        额外前向:remains 5 to forward "<|im_end|>\n<|im_start|>assistant\n"

        :param history:
        :param stream:
        :return:
        """
        if history[-1]['role'] in ["user"]:
            start_tokens = self.assistant_start_tokens
            suffix_tokens = self.user_start_tokens
        elif history[-1]['role'] in ["assistant", "system"]:
            start_tokens = self.user_start_tokens
            suffix_tokens = self.assistant_start_tokens

        input_ids = []
        for message in history:
            if "tokens" not in message:  # tokens
                message["tokens"] = self.tokenize(message["content"])
            input_ids += self.tokenize(f"<|im_start|>{message['role']}\n") \
                         + message["tokens"] \
                         + self.tokenize("<|im_end|>\n")
        input_ids += start_tokens
        if stream:
            return self._stream_generate(input_ids, suffix_tokens)
        else:
            return self._generate(input_ids)

    def _stream_generate(self, input_ids, suffix_tokens=None):
        logger.info(f"generation_kwargs {self.generation_kwargs}")
        output = self.llm.create_completion(
            input_ids,
            stream=True,
            **self.generation_kwargs
        )
        # TODO: 检测finish reason,如果是length,则shift,并继续生成。
        # TODO: 返回 token_id,
        for out in output:
            stream = copy.deepcopy(out)
            if stream["choices"][0]["finish_reason"] is None:
                yield stream["choices"][0]["completion_text"], stream["choices"][0]["completion_tokens"]
            else:
                logger.info(
                    f'finish_reason {stream["choices"][0]["finish_reason"]} with text: {stream["choices"][0]["text"]}')

        #
        self.post_cache(suffix_tokens)

    def pre_cache_system(self, system_list):
        """ warmup for system prompt
        :param system_list:
        :return:
        """
        logger.info(f"cache size {self.llm.cache.cache_size}")
        for system_prompt in system_list:
            logger.info(f"pre caching '{system_prompt}'")
            input_ids = self.tokenize(f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n")
            output = self.llm.create_completion(
                input_ids,
                stream=False,
                max_tokens=1,
                top_k=1
            )
            logger.info(f"cache size {self.llm.cache.cache_size}")

        # disable cache after
        llama_cpp.LlamaRAMCache.__setitem__ = lambda *args: None
        llama_cpp.Llama.save_state = lambda *args: None

    def post_cache(self, suffix_tokens):
        """ warmup for next turn generation
        :param suffix_tokens:
        :return:
        """
        if suffix_tokens:
            logger.info(f"before warmup: n_tokens = {self.llm.n_tokens}")
            self.llm.eval([151645, 198] + suffix_tokens)  # <|im_end|>\n
            logger.info(f"after warmup: n_tokens = {self.llm.n_tokens}")


bot = Qwen2Simulator()

if __name__ == "__main__":

    messages = [{"role": "system", "content": "你是一个导游。"}]
    generated_tokens = None
    print("######## requesting", messages)
    for generated_text, generated_tokens in bot.generate(messages, stream=True):
        print(generated_text, generated_tokens)

    for i in range(3):
        generated_tokens = bot.strip_stoptokens(generated_tokens)
        messages.append(
            {"role": "user" if i % 2 == 0 else "assistant", "content": generated_text, "tokens": generated_tokens})
        print("######## requesting", messages)
        for generated_text, generated_tokens in bot.generate(messages, stream=True):
            pass
            # print(generated_text, all_tokens)