import fire from typing import List, Dict import torch from peft import PeftModel from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig MODEL_BASE = "t-tech/T-lite-it-1.0" MODEL_ADAPTER = "evilfreelancer/o1_t-lite-it-1.0_lora" SYSTEM_PROMPT = "Вы — ИИ-помощник. Отформатируйте свои ответы следующим образом: Ваши мысли (понимание, рассуждения) Ваш ответ " class ChatHistory: def __init__(self, history_limit: int = None, system_prompt: str = None): self.history_limit: int | None = history_limit self.system_prompt: str | None = system_prompt self.messages: List[Dict] = [] if self.system_prompt is not None: self.messages.append({"role": "system", "content": self.system_prompt}) def add_message(self, role: str, message: str): self.messages.append({"role": role, "content": message}) self.trim_history() def add_user_message(self, message: str): self.add_message("user", message) def add_assistant_message(self, message: str): self.add_message("assistant", message) def add_function_call(self, message: str): self.add_message("function_call", message) def add_function_response(self, message: str): self.add_message("function_response", message) def trim_history(self): appendix = 0 if self.system_prompt is not None: appendix = 1 if self.history_limit is not None and len(self.messages) > self.history_limit + appendix: overflow = len(self.messages) - (self.history_limit + appendix) self.messages = [self.messages[0]] + self.messages[overflow + appendix:] def get_messages(self) -> list: return self.messages def generate(model, tokenizer, prompt, generation_config): data = tokenizer(prompt, return_tensors="pt") data = {k: v.to(model.device) for k, v in data.items()} output_ids = model.generate(**data, generation_config=generation_config)[0] output_ids = output_ids[len(data["input_ids"][0]):] output = tokenizer.decode(output_ids, skip_special_tokens=True) return output.strip() def get_prompt(tokenizer, messages: List[Dict], add_generation_prompt: bool = False): return tokenizer.apply_chat_template( messages, add_special_tokens=False, tokenize=False, add_generation_prompt=add_generation_prompt, ) def chat( history_limit: int = 1, system_prompt: str | None = SYSTEM_PROMPT, max_new_tokens: int = 2048, repetition_penalty: float = 1.2, do_sample: bool = True, temperature: float = 0.5, top_p: float = 0.6, top_k: int = 40, ): # # Tokenizer preparation # tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE) # # Model preparation # # Quantization config quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True ) # Generator config generation_config = GenerationConfig.from_pretrained(MODEL_ADAPTER) generation_config.max_new_tokens = max_new_tokens generation_config.repetition_penalty = repetition_penalty generation_config.do_sample = do_sample generation_config.temperature = temperature generation_config.top_p = top_p generation_config.top_k = top_k # Read model from folder with trained checkpoints model = AutoModelForCausalLM.from_pretrained( MODEL_BASE, generation_config=generation_config, quantization_config=quantization_config, torch_dtype=torch.bfloat16, attn_implementation=None ) # If we've trained a LoRA adapter model = PeftModel.from_pretrained( model=model, model_id=MODEL_ADAPTER, torch_dtype=torch.bfloat16, ) # # Chat loop # # Start chat loop chat_history = ChatHistory(history_limit, system_prompt) while True: user_message = input("User: ") # Reset chat command if user_message.strip() == "/reset": chat_history = ChatHistory(history_limit, system_prompt) print("History reset completed!") continue # Skip empty messages from user if user_message.strip() == "": continue # Add user message to chat history chat_history.add_user_message(user_message) # Get list of messages prompt = get_prompt(tokenizer, chat_history.get_messages(), True) # Generate response output = generate(model, tokenizer, prompt, generation_config) # Save response to chat history as assistant's message chat_history.add_assistant_message(output) print("Assistant:", output) if __name__ == "__main__": fire.Fire(chat)