pasha commited on
Commit
f17b98f
·
1 Parent(s): 6189dba

Chat and requiremnts added

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. chat_transformers.py +158 -0
  3. requirements.txt +5 -0
.gitignore CHANGED
@@ -1 +1,2 @@
1
  /.idea/
 
 
1
  /.idea/
2
+ /venv/
chat_transformers.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fire
2
+ from typing import List, Dict
3
+ import torch
4
+ from peft import PeftModel
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig
6
+
7
+ MODEL_BASE = "t-tech/T-lite-it-1.0"
8
+ MODEL_ADAPTER = "evilfreelancer/T-lite-it-1.0_lora_thinking"
9
+
10
+ SYSTEM_PROMPT = """\
11
+ Вы — ИИ-помощник. Отформатируйте свои ответы следующим образом: \
12
+ <Thought> Ваши мысли (понимание, рассуждения) </Thought> \
13
+ <Output> Ваш ответ </Output>\
14
+ """
15
+
16
+
17
+ class ChatHistory:
18
+ def __init__(self, history_limit: int = None, system_prompt: str = None):
19
+ self.history_limit: int | None = history_limit
20
+ self.system_prompt: str | None = system_prompt
21
+ self.messages: List[Dict] = []
22
+ if self.system_prompt is not None:
23
+ self.messages.append({"role": "system", "content": self.system_prompt})
24
+
25
+ def add_message(self, role: str, message: str):
26
+ self.messages.append({"role": role, "content": message})
27
+ self.trim_history()
28
+
29
+ def add_user_message(self, message: str):
30
+ self.add_message("user", message)
31
+
32
+ def add_assistant_message(self, message: str):
33
+ self.add_message("assistant", message)
34
+
35
+ def add_function_call(self, message: str):
36
+ self.add_message("function_call", message)
37
+
38
+ def add_function_response(self, message: str):
39
+ self.add_message("function_response", message)
40
+
41
+ def trim_history(self):
42
+ appendix = 0
43
+ if self.system_prompt is not None:
44
+ appendix = 1
45
+ if self.history_limit is not None and len(self.messages) > self.history_limit + appendix:
46
+ overflow = len(self.messages) - (self.history_limit + appendix)
47
+ self.messages = [self.messages[0]] + self.messages[overflow + appendix:]
48
+
49
+ def get_messages(self) -> list:
50
+ return self.messages
51
+
52
+
53
+ def generate(model, tokenizer, prompt, generation_config):
54
+ data = tokenizer(prompt, return_tensors="pt")
55
+ data = {k: v.to(model.device) for k, v in data.items()}
56
+ output_ids = model.generate(**data, generation_config=generation_config)[0]
57
+ output_ids = output_ids[len(data["input_ids"][0]):]
58
+ output = tokenizer.decode(output_ids, skip_special_tokens=True)
59
+ return output.strip()
60
+
61
+
62
+ def get_prompt(tokenizer, messages: List[Dict], add_generation_prompt: bool = False):
63
+ return tokenizer.apply_chat_template(
64
+ messages,
65
+ add_special_tokens=False,
66
+ tokenize=False,
67
+ add_generation_prompt=add_generation_prompt,
68
+ )
69
+
70
+
71
+ def chat(
72
+ history_limit: int = 10,
73
+ system_prompt: str | None = SYSTEM_PROMPT,
74
+ max_new_tokens: int = 200,
75
+ repetition_penalty: float = 1.2,
76
+ do_sample: bool = True,
77
+ temperature: float = 0.5,
78
+ top_p: float = 0.6,
79
+ top_k: int = 40,
80
+ ):
81
+ #
82
+ # Tokenizer preparation
83
+ #
84
+
85
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE)
86
+
87
+ #
88
+ # Model preparation
89
+ #
90
+
91
+ # Quantization config
92
+ quantization_config = BitsAndBytesConfig(
93
+ load_in_4bit=True,
94
+ bnb_4bit_compute_dtype=torch.bfloat16,
95
+ bnb_4bit_quant_type="nf4",
96
+ bnb_4bit_use_double_quant=True
97
+ )
98
+
99
+ # Generator config
100
+ generation_config = GenerationConfig.from_pretrained(MODEL_ADAPTER)
101
+ generation_config.max_new_tokens = max_new_tokens
102
+ generation_config.repetition_penalty = repetition_penalty
103
+ generation_config.do_sample = do_sample
104
+ generation_config.temperature = temperature
105
+ generation_config.top_p = top_p
106
+ generation_config.top_k = top_k
107
+
108
+ # Read model from folder with trained checkpoints
109
+ model = AutoModelForCausalLM.from_pretrained(
110
+ MODEL_BASE,
111
+ generation_config=generation_config,
112
+ quantization_config=quantization_config,
113
+ torch_dtype=torch.bfloat16,
114
+ attn_implementation=None
115
+ )
116
+
117
+ # If we've trained a LoRA adapter
118
+ model = PeftModel.from_pretrained(
119
+ model=model,
120
+ model_id=MODEL_ADAPTER,
121
+ torch_dtype=torch.bfloat16,
122
+ )
123
+
124
+ #
125
+ # Chat loop
126
+ #
127
+
128
+ # Start chat loop
129
+ chat_history = ChatHistory(history_limit, system_prompt)
130
+ while True:
131
+ user_message = input("User: ")
132
+
133
+ # Reset chat command
134
+ if user_message.strip() == "/reset":
135
+ chat_history = ChatHistory(history_limit, system_prompt)
136
+ print("History reset completed!")
137
+ continue
138
+
139
+ # Skip empty messages from user
140
+ if user_message.strip() == "":
141
+ continue
142
+
143
+ # Add user message to chat history
144
+ chat_history.add_user_message(user_message)
145
+
146
+ # Get list of messages
147
+ prompt = get_prompt(tokenizer, chat_history.get_messages(), True)
148
+
149
+ # Generate response
150
+ output = generate(model, tokenizer, prompt, generation_config)
151
+
152
+ # Save response to chat history as assistant's message
153
+ chat_history.add_assistant_message(output)
154
+ print("Assistant:", output)
155
+
156
+
157
+ if __name__ == "__main__":
158
+ fire.Fire(chat)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ pyyaml>=6.0.2
2
+ fire>=0.7.0
3
+ torch>=2.5.1
4
+ transformers>=4.47.1
5
+ peft>=0.14.0