ujwal55 commited on
Commit
b53f983
Β·
1 Parent(s): 527ef02

Updated app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -31
app.py CHANGED
@@ -6,27 +6,29 @@ from sklearn.feature_extraction.text import TfidfVectorizer
6
  import faiss
7
  import numpy as np
8
 
9
- #--------------------MODEL SETUP--------------------
 
10
  MODEL_REPO = "google/gemma-3-1b-it-qat-q4_0-gguf"
11
  MODEL_PATH = "./gemma-3-1b-it-qat-q4_0/gemma-3-1b-it-q4_0.gguf"
12
  MODEL_DIR = "./gemma-3-1b-it-qat-q4_0"
 
13
  DEFAULT_SYSTEM_PROMPT = (
14
  "You are a Wise Mentor. Speak in a calm and concise manner. "
15
  "If asked for advice, give a maximum of 3 actionable steps. "
16
  "Avoid unnecessary elaboration. Decline unethical or harmful requests."
17
  )
18
 
19
- # Huggingface Token and download
20
  hf_token = os.environ.get("HF_TOKEN")
21
  if not os.path.exists(MODEL_PATH):
22
  if not hf_token:
23
- raise ValueError("HF_TOKEN is missing. Set it as an environment variable")
24
-
25
  login(hf_token)
26
  snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_DIR, local_dir_use_symlinks=False)
27
 
28
- #--------------------RAG SETUP------------------------
29
- documents = [] # stores all chat turns
 
30
  vectorizer = TfidfVectorizer()
31
  index = None
32
 
@@ -34,40 +36,41 @@ def update_rag_index():
34
  global index
35
  if not documents:
36
  return
37
- vectors = vectorizer.fit_transform(documents).toarray().astype('float32')
 
38
  index = faiss.IndexFlatL2(vectors.shape[1])
39
  index.add(vectors)
40
 
41
- def retrive_relvant_docs(query, k=2):
42
  if not documents or index is None:
43
- return ""
44
-
45
- query_vac = vectorizer.transform([query]).toarray().astype('float32')
46
- D, I = index.search(query_vac, k)
47
- return "\n".join(documents[i] for i in I[0] if i < len(documents))
48
 
 
49
 
50
- #-----------------------CONTEXT LENGTH ESTIMATION---------------------
51
- def estimate_n_ctx(full_prompt, buffer = 500):
52
  total_tokens = len(full_prompt.split())
53
- return min(3500, total_tokens+buffer)
 
 
54
 
55
- #-----------------------CHAT FUNCTION-----------------------
56
  def chat(user_input, history, system_prompt):
57
- relevent_context = retrive_relvant_docs(user_input)
58
- formatted_turns = "".join([f"<user>{u}</user><bot>{b}</bot>" for u, b in relevent_context])
59
 
60
  full_prompt = (
61
- f"<s>[INST] <<SYS>>\n{system_prompt}\nContext:\n{relevent_context}\n<</SYS>>\n"
62
  f"{formatted_turns}<user>{user_input}[/INST]"
63
  )
64
 
65
- # Dynamic estimate n_ctx
66
- n_ctx = estimate_n_ctx(full_prompt=full_prompt)
67
 
68
  llm = Llama(
69
- model_path= MODEL_PATH,
70
- n_ctx = n_ctx,
71
  n_threads=2,
72
  n_batch=128
73
  )
@@ -75,22 +78,22 @@ def chat(user_input, history, system_prompt):
75
  output = llm(full_prompt, max_tokens=256, stop=["</s>", "<user>"])
76
  bot_reply = output["choices"][0]["text"].strip()
77
 
78
- documents.append(f"user: {user_input} bot: {bot_reply}")
79
  update_rag_index()
80
 
81
  history.append((user_input, bot_reply))
82
  return "", history
83
 
84
- #-----------------------UI---------------------
 
85
  with gr.Blocks() as demo:
86
- gr.Markdown("# πŸ€– Persona Agent with Mini-RAG + Dynamic Context")
87
- with gr.Row():
88
- system_prompt_box = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=3)
89
  chatbot = gr.Chatbot()
90
  msg = gr.Textbox(label="Your Message")
91
- clear = gr.Button("πŸ—‘οΈ Clear")
92
 
93
  msg.submit(chat, [msg, chatbot, system_prompt_box], [msg, chatbot])
94
  clear.click(lambda: [], None, chatbot)
95
 
96
- demo.launch()
 
6
  import faiss
7
  import numpy as np
8
 
9
+ # ------------------ Model Setup ------------------
10
+
11
  MODEL_REPO = "google/gemma-3-1b-it-qat-q4_0-gguf"
12
  MODEL_PATH = "./gemma-3-1b-it-qat-q4_0/gemma-3-1b-it-q4_0.gguf"
13
  MODEL_DIR = "./gemma-3-1b-it-qat-q4_0"
14
+
15
  DEFAULT_SYSTEM_PROMPT = (
16
  "You are a Wise Mentor. Speak in a calm and concise manner. "
17
  "If asked for advice, give a maximum of 3 actionable steps. "
18
  "Avoid unnecessary elaboration. Decline unethical or harmful requests."
19
  )
20
 
21
+ # Hugging Face Token and model download
22
  hf_token = os.environ.get("HF_TOKEN")
23
  if not os.path.exists(MODEL_PATH):
24
  if not hf_token:
25
+ raise ValueError("HF_TOKEN is missing. Set it as an environment variable.")
 
26
  login(hf_token)
27
  snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_DIR, local_dir_use_symlinks=False)
28
 
29
+ # ------------------ RAG Setup ------------------
30
+
31
+ documents = [] # stores (user, bot) tuples
32
  vectorizer = TfidfVectorizer()
33
  index = None
34
 
 
36
  global index
37
  if not documents:
38
  return
39
+ flat_docs = [f"user: {u} bot: {b}" for u, b in documents]
40
+ vectors = vectorizer.fit_transform(flat_docs).toarray().astype('float32')
41
  index = faiss.IndexFlatL2(vectors.shape[1])
42
  index.add(vectors)
43
 
44
+ def retrieve_relevant_docs(query, k=3):
45
  if not documents or index is None:
46
+ return []
47
+ flat_docs = [f"user: {u} bot: {b}" for u, b in documents]
48
+ query_vec = vectorizer.transform([query]).toarray().astype('float32')
49
+ D, I = index.search(query_vec, k)
50
+ return [documents[i] for i in I[0] if i < len(documents)]
51
 
52
+ # ------------------ Context Estimation ------------------
53
 
54
+ def estimate_n_ctx(full_prompt, buffer=500):
 
55
  total_tokens = len(full_prompt.split())
56
+ return min(3500, total_tokens + buffer)
57
+
58
+ # ------------------ Chat Function ------------------
59
 
 
60
  def chat(user_input, history, system_prompt):
61
+ relevant_context = retrieve_relevant_docs(user_input)
62
+ formatted_turns = "".join([f"<user>{u}</user><bot>{b}</bot>" for u, b in relevant_context])
63
 
64
  full_prompt = (
65
+ f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n"
66
  f"{formatted_turns}<user>{user_input}[/INST]"
67
  )
68
 
69
+ n_ctx = estimate_n_ctx(full_prompt)
 
70
 
71
  llm = Llama(
72
+ model_path=MODEL_PATH,
73
+ n_ctx=n_ctx,
74
  n_threads=2,
75
  n_batch=128
76
  )
 
78
  output = llm(full_prompt, max_tokens=256, stop=["</s>", "<user>"])
79
  bot_reply = output["choices"][0]["text"].strip()
80
 
81
+ documents.append((user_input, bot_reply))
82
  update_rag_index()
83
 
84
  history.append((user_input, bot_reply))
85
  return "", history
86
 
87
+ # ------------------ UI ------------------
88
+
89
  with gr.Blocks() as demo:
90
+ gr.Markdown("## πŸ€– Prompt-Engineered Persona Agent with Mini-RAG")
91
+ system_prompt_box = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=3)
 
92
  chatbot = gr.Chatbot()
93
  msg = gr.Textbox(label="Your Message")
94
+ clear = gr.Button("πŸ—‘οΈ Clear Chat")
95
 
96
  msg.submit(chat, [msg, chatbot, system_prompt_box], [msg, chatbot])
97
  clear.click(lambda: [], None, chatbot)
98
 
99
+ demo.launch()