File size: 5,007 Bytes
f17b98f
 
 
 
 
 
 
7491360
f17b98f
96c9385
f17b98f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d356c0
f17b98f
c58bd2c
f17b98f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import fire
from typing import List, Dict
import torch
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig

MODEL_BASE = "t-tech/T-lite-it-1.0"
MODEL_ADAPTER = "evilfreelancer/o1_t-lite-it-1.0_lora"

SYSTEM_PROMPT = "Вы — ИИ-помощник. Отформатируйте свои ответы следующим образом: <Thought> Ваши мысли (понимание, рассуждения) </Thought> <output> Ваш ответ </output>"


class ChatHistory:
    def __init__(self, history_limit: int = None, system_prompt: str = None):
        self.history_limit: int | None = history_limit
        self.system_prompt: str | None = system_prompt
        self.messages: List[Dict] = []
        if self.system_prompt is not None:
            self.messages.append({"role": "system", "content": self.system_prompt})

    def add_message(self, role: str, message: str):
        self.messages.append({"role": role, "content": message})
        self.trim_history()

    def add_user_message(self, message: str):
        self.add_message("user", message)

    def add_assistant_message(self, message: str):
        self.add_message("assistant", message)

    def add_function_call(self, message: str):
        self.add_message("function_call", message)

    def add_function_response(self, message: str):
        self.add_message("function_response", message)

    def trim_history(self):
        appendix = 0
        if self.system_prompt is not None:
            appendix = 1
        if self.history_limit is not None and len(self.messages) > self.history_limit + appendix:
            overflow = len(self.messages) - (self.history_limit + appendix)
            self.messages = [self.messages[0]] + self.messages[overflow + appendix:]

    def get_messages(self) -> list:
        return self.messages


def generate(model, tokenizer, prompt, generation_config):
    data = tokenizer(prompt, return_tensors="pt")
    data = {k: v.to(model.device) for k, v in data.items()}
    output_ids = model.generate(**data, generation_config=generation_config)[0]
    output_ids = output_ids[len(data["input_ids"][0]):]
    output = tokenizer.decode(output_ids, skip_special_tokens=True)
    return output.strip()


def get_prompt(tokenizer, messages: List[Dict], add_generation_prompt: bool = False):
    return tokenizer.apply_chat_template(
        messages,
        add_special_tokens=False,
        tokenize=False,
        add_generation_prompt=add_generation_prompt,
    )


def chat(
    history_limit: int = 1,
    system_prompt: str | None = SYSTEM_PROMPT,
    max_new_tokens: int = 2048,
    repetition_penalty: float = 1.2,
    do_sample: bool = True,
    temperature: float = 0.5,
    top_p: float = 0.6,
    top_k: int = 40,
):
    #
    # Tokenizer preparation
    #

    tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE)

    #
    # Model preparation
    #

    # Quantization config
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True
    )

    # Generator config
    generation_config = GenerationConfig.from_pretrained(MODEL_ADAPTER)
    generation_config.max_new_tokens = max_new_tokens
    generation_config.repetition_penalty = repetition_penalty
    generation_config.do_sample = do_sample
    generation_config.temperature = temperature
    generation_config.top_p = top_p
    generation_config.top_k = top_k

    # Read model from folder with trained checkpoints
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_BASE,
        generation_config=generation_config,
        quantization_config=quantization_config,
        torch_dtype=torch.bfloat16,
        attn_implementation=None
    )

    # If we've trained a LoRA adapter
    model = PeftModel.from_pretrained(
        model=model,
        model_id=MODEL_ADAPTER,
        torch_dtype=torch.bfloat16,
    )

    #
    # Chat loop
    #

    # Start chat loop
    chat_history = ChatHistory(history_limit, system_prompt)
    while True:
        user_message = input("User: ")

        # Reset chat command
        if user_message.strip() == "/reset":
            chat_history = ChatHistory(history_limit, system_prompt)
            print("History reset completed!")
            continue

        # Skip empty messages from user
        if user_message.strip() == "":
            continue

        # Add user message to chat history
        chat_history.add_user_message(user_message)

        # Get list of messages
        prompt = get_prompt(tokenizer, chat_history.get_messages(), True)

        # Generate response
        output = generate(model, tokenizer, prompt, generation_config)

        # Save response to chat history as assistant's message
        chat_history.add_assistant_message(output)
        print("Assistant:", output)


if __name__ == "__main__":
    fire.Fire(chat)