Updated app.py
Browse files
app.py
CHANGED
@@ -6,27 +6,29 @@ from sklearn.feature_extraction.text import TfidfVectorizer
|
|
6 |
import faiss
|
7 |
import numpy as np
|
8 |
|
9 |
-
|
|
|
10 |
MODEL_REPO = "google/gemma-3-1b-it-qat-q4_0-gguf"
|
11 |
MODEL_PATH = "./gemma-3-1b-it-qat-q4_0/gemma-3-1b-it-q4_0.gguf"
|
12 |
MODEL_DIR = "./gemma-3-1b-it-qat-q4_0"
|
|
|
13 |
DEFAULT_SYSTEM_PROMPT = (
|
14 |
"You are a Wise Mentor. Speak in a calm and concise manner. "
|
15 |
"If asked for advice, give a maximum of 3 actionable steps. "
|
16 |
"Avoid unnecessary elaboration. Decline unethical or harmful requests."
|
17 |
)
|
18 |
|
19 |
-
#
|
20 |
hf_token = os.environ.get("HF_TOKEN")
|
21 |
if not os.path.exists(MODEL_PATH):
|
22 |
if not hf_token:
|
23 |
-
raise ValueError("HF_TOKEN is missing. Set it as an environment variable")
|
24 |
-
|
25 |
login(hf_token)
|
26 |
snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_DIR, local_dir_use_symlinks=False)
|
27 |
|
28 |
-
|
29 |
-
|
|
|
30 |
vectorizer = TfidfVectorizer()
|
31 |
index = None
|
32 |
|
@@ -34,40 +36,41 @@ def update_rag_index():
|
|
34 |
global index
|
35 |
if not documents:
|
36 |
return
|
37 |
-
|
|
|
38 |
index = faiss.IndexFlatL2(vectors.shape[1])
|
39 |
index.add(vectors)
|
40 |
|
41 |
-
def
|
42 |
if not documents or index is None:
|
43 |
-
return
|
44 |
-
|
45 |
-
|
46 |
-
D, I = index.search(
|
47 |
-
return
|
48 |
|
|
|
49 |
|
50 |
-
|
51 |
-
def estimate_n_ctx(full_prompt, buffer = 500):
|
52 |
total_tokens = len(full_prompt.split())
|
53 |
-
return min(3500, total_tokens+buffer)
|
|
|
|
|
54 |
|
55 |
-
#-----------------------CHAT FUNCTION-----------------------
|
56 |
def chat(user_input, history, system_prompt):
|
57 |
-
|
58 |
-
formatted_turns = "".join([f"<user>{u}</user><bot>{b}</bot>" for u, b in
|
59 |
|
60 |
full_prompt = (
|
61 |
-
f"<s>[INST] <<SYS>>\n{system_prompt}\
|
62 |
f"{formatted_turns}<user>{user_input}[/INST]"
|
63 |
)
|
64 |
|
65 |
-
|
66 |
-
n_ctx = estimate_n_ctx(full_prompt=full_prompt)
|
67 |
|
68 |
llm = Llama(
|
69 |
-
model_path=
|
70 |
-
n_ctx
|
71 |
n_threads=2,
|
72 |
n_batch=128
|
73 |
)
|
@@ -75,22 +78,22 @@ def chat(user_input, history, system_prompt):
|
|
75 |
output = llm(full_prompt, max_tokens=256, stop=["</s>", "<user>"])
|
76 |
bot_reply = output["choices"][0]["text"].strip()
|
77 |
|
78 |
-
documents.append(
|
79 |
update_rag_index()
|
80 |
|
81 |
history.append((user_input, bot_reply))
|
82 |
return "", history
|
83 |
|
84 |
-
|
|
|
85 |
with gr.Blocks() as demo:
|
86 |
-
gr.Markdown("
|
87 |
-
|
88 |
-
system_prompt_box = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=3)
|
89 |
chatbot = gr.Chatbot()
|
90 |
msg = gr.Textbox(label="Your Message")
|
91 |
-
clear = gr.Button("ποΈ Clear")
|
92 |
|
93 |
msg.submit(chat, [msg, chatbot, system_prompt_box], [msg, chatbot])
|
94 |
clear.click(lambda: [], None, chatbot)
|
95 |
|
96 |
-
demo.launch()
|
|
|
6 |
import faiss
|
7 |
import numpy as np
|
8 |
|
9 |
+
# ------------------ Model Setup ------------------
|
10 |
+
|
11 |
MODEL_REPO = "google/gemma-3-1b-it-qat-q4_0-gguf"
|
12 |
MODEL_PATH = "./gemma-3-1b-it-qat-q4_0/gemma-3-1b-it-q4_0.gguf"
|
13 |
MODEL_DIR = "./gemma-3-1b-it-qat-q4_0"
|
14 |
+
|
15 |
DEFAULT_SYSTEM_PROMPT = (
|
16 |
"You are a Wise Mentor. Speak in a calm and concise manner. "
|
17 |
"If asked for advice, give a maximum of 3 actionable steps. "
|
18 |
"Avoid unnecessary elaboration. Decline unethical or harmful requests."
|
19 |
)
|
20 |
|
21 |
+
# Hugging Face Token and model download
|
22 |
hf_token = os.environ.get("HF_TOKEN")
|
23 |
if not os.path.exists(MODEL_PATH):
|
24 |
if not hf_token:
|
25 |
+
raise ValueError("HF_TOKEN is missing. Set it as an environment variable.")
|
|
|
26 |
login(hf_token)
|
27 |
snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_DIR, local_dir_use_symlinks=False)
|
28 |
|
29 |
+
# ------------------ RAG Setup ------------------
|
30 |
+
|
31 |
+
documents = [] # stores (user, bot) tuples
|
32 |
vectorizer = TfidfVectorizer()
|
33 |
index = None
|
34 |
|
|
|
36 |
global index
|
37 |
if not documents:
|
38 |
return
|
39 |
+
flat_docs = [f"user: {u} bot: {b}" for u, b in documents]
|
40 |
+
vectors = vectorizer.fit_transform(flat_docs).toarray().astype('float32')
|
41 |
index = faiss.IndexFlatL2(vectors.shape[1])
|
42 |
index.add(vectors)
|
43 |
|
44 |
+
def retrieve_relevant_docs(query, k=3):
|
45 |
if not documents or index is None:
|
46 |
+
return []
|
47 |
+
flat_docs = [f"user: {u} bot: {b}" for u, b in documents]
|
48 |
+
query_vec = vectorizer.transform([query]).toarray().astype('float32')
|
49 |
+
D, I = index.search(query_vec, k)
|
50 |
+
return [documents[i] for i in I[0] if i < len(documents)]
|
51 |
|
52 |
+
# ------------------ Context Estimation ------------------
|
53 |
|
54 |
+
def estimate_n_ctx(full_prompt, buffer=500):
|
|
|
55 |
total_tokens = len(full_prompt.split())
|
56 |
+
return min(3500, total_tokens + buffer)
|
57 |
+
|
58 |
+
# ------------------ Chat Function ------------------
|
59 |
|
|
|
60 |
def chat(user_input, history, system_prompt):
|
61 |
+
relevant_context = retrieve_relevant_docs(user_input)
|
62 |
+
formatted_turns = "".join([f"<user>{u}</user><bot>{b}</bot>" for u, b in relevant_context])
|
63 |
|
64 |
full_prompt = (
|
65 |
+
f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n"
|
66 |
f"{formatted_turns}<user>{user_input}[/INST]"
|
67 |
)
|
68 |
|
69 |
+
n_ctx = estimate_n_ctx(full_prompt)
|
|
|
70 |
|
71 |
llm = Llama(
|
72 |
+
model_path=MODEL_PATH,
|
73 |
+
n_ctx=n_ctx,
|
74 |
n_threads=2,
|
75 |
n_batch=128
|
76 |
)
|
|
|
78 |
output = llm(full_prompt, max_tokens=256, stop=["</s>", "<user>"])
|
79 |
bot_reply = output["choices"][0]["text"].strip()
|
80 |
|
81 |
+
documents.append((user_input, bot_reply))
|
82 |
update_rag_index()
|
83 |
|
84 |
history.append((user_input, bot_reply))
|
85 |
return "", history
|
86 |
|
87 |
+
# ------------------ UI ------------------
|
88 |
+
|
89 |
with gr.Blocks() as demo:
|
90 |
+
gr.Markdown("## π€ Prompt-Engineered Persona Agent with Mini-RAG")
|
91 |
+
system_prompt_box = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=3)
|
|
|
92 |
chatbot = gr.Chatbot()
|
93 |
msg = gr.Textbox(label="Your Message")
|
94 |
+
clear = gr.Button("ποΈ Clear Chat")
|
95 |
|
96 |
msg.submit(chat, [msg, chatbot, system_prompt_box], [msg, chatbot])
|
97 |
clear.click(lambda: [], None, chatbot)
|
98 |
|
99 |
+
demo.launch()
|