File size: 5,003 Bytes
36aac00 336e3e6 36aac00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import fire
from typing import List, Dict
import torch
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig
MODEL_BASE = "t-tech/T-lite-it-1.0"
MODEL_ADAPTER = "ZeroAgency/o1_t-lite-it-1.0_lora"
SYSTEM_PROMPT = "Вы — ИИ-помощник. Отформатируйте свои ответы следующим образом: <Thought> Ваши мысли (понимание, рассуждения) </Thought> <output> Ваш ответ </output>"
class ChatHistory:
def __init__(self, history_limit: int = None, system_prompt: str = None):
self.history_limit: int | None = history_limit
self.system_prompt: str | None = system_prompt
self.messages: List[Dict] = []
if self.system_prompt is not None:
self.messages.append({"role": "system", "content": self.system_prompt})
def add_message(self, role: str, message: str):
self.messages.append({"role": role, "content": message})
self.trim_history()
def add_user_message(self, message: str):
self.add_message("user", message)
def add_assistant_message(self, message: str):
self.add_message("assistant", message)
def add_function_call(self, message: str):
self.add_message("function_call", message)
def add_function_response(self, message: str):
self.add_message("function_response", message)
def trim_history(self):
appendix = 0
if self.system_prompt is not None:
appendix = 1
if self.history_limit is not None and len(self.messages) > self.history_limit + appendix:
overflow = len(self.messages) - (self.history_limit + appendix)
self.messages = [self.messages[0]] + self.messages[overflow + appendix:]
def get_messages(self) -> list:
return self.messages
def generate(model, tokenizer, prompt, generation_config):
data = tokenizer(prompt, return_tensors="pt")
data = {k: v.to(model.device) for k, v in data.items()}
output_ids = model.generate(**data, generation_config=generation_config)[0]
output_ids = output_ids[len(data["input_ids"][0]):]
output = tokenizer.decode(output_ids, skip_special_tokens=True)
return output.strip()
def get_prompt(tokenizer, messages: List[Dict], add_generation_prompt: bool = False):
return tokenizer.apply_chat_template(
messages,
add_special_tokens=False,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
def chat(
history_limit: int = 1,
system_prompt: str | None = SYSTEM_PROMPT,
max_new_tokens: int = 2048,
repetition_penalty: float = 1.2,
do_sample: bool = True,
temperature: float = 0.5,
top_p: float = 0.6,
top_k: int = 40,
):
#
# Tokenizer preparation
#
tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE)
#
# Model preparation
#
# Quantization config
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
# Generator config
generation_config = GenerationConfig.from_pretrained(MODEL_ADAPTER)
generation_config.max_new_tokens = max_new_tokens
generation_config.repetition_penalty = repetition_penalty
generation_config.do_sample = do_sample
generation_config.temperature = temperature
generation_config.top_p = top_p
generation_config.top_k = top_k
# Read model from folder with trained checkpoints
model = AutoModelForCausalLM.from_pretrained(
MODEL_BASE,
generation_config=generation_config,
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
attn_implementation=None
)
# If we've trained a LoRA adapter
model = PeftModel.from_pretrained(
model=model,
model_id=MODEL_ADAPTER,
torch_dtype=torch.bfloat16,
)
#
# Chat loop
#
# Start chat loop
chat_history = ChatHistory(history_limit, system_prompt)
while True:
user_message = input("User: ")
# Reset chat command
if user_message.strip() == "/reset":
chat_history = ChatHistory(history_limit, system_prompt)
print("History reset completed!")
continue
# Skip empty messages from user
if user_message.strip() == "":
continue
# Add user message to chat history
chat_history.add_user_message(user_message)
# Get list of messages
prompt = get_prompt(tokenizer, chat_history.get_messages(), True)
# Generate response
output = generate(model, tokenizer, prompt, generation_config)
# Save response to chat history as assistant's message
chat_history.add_assistant_message(output)
print("Assistant:", output)
if __name__ == "__main__":
fire.Fire(chat)
|