import fire
from typing import List, Dict
import torch
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig
MODEL_BASE = "t-tech/T-lite-it-1.0"
MODEL_ADAPTER = "evilfreelancer/o1_t-lite-it-1.0_lora"
SYSTEM_PROMPT = "Вы — ИИ-помощник. Отформатируйте свои ответы следующим образом: Ваши мысли (понимание, рассуждения) "
class ChatHistory:
def __init__(self, history_limit: int = None, system_prompt: str = None):
self.history_limit: int | None = history_limit
self.system_prompt: str | None = system_prompt
self.messages: List[Dict] = []
if self.system_prompt is not None:
self.messages.append({"role": "system", "content": self.system_prompt})
def add_message(self, role: str, message: str):
self.messages.append({"role": role, "content": message})
self.trim_history()
def add_user_message(self, message: str):
self.add_message("user", message)
def add_assistant_message(self, message: str):
self.add_message("assistant", message)
def add_function_call(self, message: str):
self.add_message("function_call", message)
def add_function_response(self, message: str):
self.add_message("function_response", message)
def trim_history(self):
appendix = 0
if self.system_prompt is not None:
appendix = 1
if self.history_limit is not None and len(self.messages) > self.history_limit + appendix:
overflow = len(self.messages) - (self.history_limit + appendix)
self.messages = [self.messages[0]] + self.messages[overflow + appendix:]
def get_messages(self) -> list:
return self.messages
def generate(model, tokenizer, prompt, generation_config):
data = tokenizer(prompt, return_tensors="pt")
data = {k: v.to(model.device) for k, v in data.items()}
output_ids = model.generate(**data, generation_config=generation_config)[0]
output_ids = output_ids[len(data["input_ids"][0]):]
output = tokenizer.decode(output_ids, skip_special_tokens=True)
return output.strip()
def get_prompt(tokenizer, messages: List[Dict], add_generation_prompt: bool = False):
return tokenizer.apply_chat_template(
messages,
add_special_tokens=False,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
def chat(
history_limit: int = 1,
system_prompt: str | None = SYSTEM_PROMPT,
max_new_tokens: int = 2048,
repetition_penalty: float = 1.2,
do_sample: bool = True,
temperature: float = 0.5,
top_p: float = 0.6,
top_k: int = 40,
):
#
# Tokenizer preparation
#
tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE)
#
# Model preparation
#
# Quantization config
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
# Generator config
generation_config = GenerationConfig.from_pretrained(MODEL_ADAPTER)
generation_config.max_new_tokens = max_new_tokens
generation_config.repetition_penalty = repetition_penalty
generation_config.do_sample = do_sample
generation_config.temperature = temperature
generation_config.top_p = top_p
generation_config.top_k = top_k
# Read model from folder with trained checkpoints
model = AutoModelForCausalLM.from_pretrained(
MODEL_BASE,
generation_config=generation_config,
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
attn_implementation=None
)
# If we've trained a LoRA adapter
model = PeftModel.from_pretrained(
model=model,
model_id=MODEL_ADAPTER,
torch_dtype=torch.bfloat16,
)
#
# Chat loop
#
# Start chat loop
chat_history = ChatHistory(history_limit, system_prompt)
while True:
user_message = input("User: ")
# Reset chat command
if user_message.strip() == "/reset":
chat_history = ChatHistory(history_limit, system_prompt)
print("History reset completed!")
continue
# Skip empty messages from user
if user_message.strip() == "":
continue
# Add user message to chat history
chat_history.add_user_message(user_message)
# Get list of messages
prompt = get_prompt(tokenizer, chat_history.get_messages(), True)
# Generate response
output = generate(model, tokenizer, prompt, generation_config)
# Save response to chat history as assistant's message
chat_history.add_assistant_message(output)
print("Assistant:", output)
if __name__ == "__main__":
fire.Fire(chat)