localsavageai commited on
Commit
26ce4aa
·
verified ·
1 Parent(s): 9b574d9

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +34 -88
  2. requirements.txt +3 -1
app.py CHANGED
@@ -5,8 +5,8 @@ import numpy as np
5
  from typing import List, Optional
6
  from langchain_community.vectorstores import FAISS
7
  from langchain.embeddings.base import Embeddings
8
- from gradio_client import Client
9
  import gradio as gr
 
10
 
11
  # Configuration
12
  DATA_FILE = "data-mtc.txt" # This file is no longer used in the Space
@@ -38,30 +38,17 @@ logging.basicConfig(
38
  ]
39
  )
40
 
41
- # Example Questions Pool
42
- EXAMPLE_QUESTIONS = [
43
- "Comment intégrer les enseignements MTC dans la vie quotidienne ?",
44
- "Comment se préparer à une discussion de groupe MTC ?",
45
- "Quels sont les obstacles courants à la compréhension des Chroniques ?"
46
- ]
47
 
48
- class GradioEmbeddings(Embeddings):
49
- """Embedding management using Gradio API"""
50
-
51
- def __init__(self):
52
- super().__init__()
53
- self.client = Client("localsavageai/embijiji3")
54
 
55
  def _generate_embedding(self, text: str) -> np.ndarray:
56
- """Generate an embedding via the Gradio API"""
57
  try:
58
- result = self.client.predict(
59
- document=text.strip(),
60
- api_name="/embed"
61
- )
62
- if not isinstance(result, list):
63
- raise ValueError("Invalid embedding response from Gradio API")
64
- return np.array(result, dtype=np.float32)
65
  except Exception as e:
66
  logging.error(f"Embedding error: {str(e)}")
67
  raise RuntimeError("Failed to generate embedding") from e
@@ -75,7 +62,7 @@ class GradioEmbeddings(Embeddings):
75
 
76
  def initialize_vector_store() -> FAISS:
77
  """Robust initialization of the vector store"""
78
- embeddings = GradioEmbeddings()
79
 
80
  try:
81
  logging.info("Loading existing database...")
@@ -114,21 +101,28 @@ def generate_response(user_input: str, vector_store: FAISS) -> Optional[str]:
114
  for i, doc in enumerate(best_docs)
115
  )
116
 
117
- response = Client("Qwen/Qwen2.5-Max-Demo").predict(
118
- query=user_input,
119
- history=[],
120
- system=BASE_SYSTEM_PROMPT.format(context=context),
121
- api_name="/model_chat"
122
- )
123
 
124
- if isinstance(response, tuple) and len(response) >= 2:
125
- chat_history = response[1]
126
- if isinstance(chat_history, list) and len(chat_history) > 0:
127
- last_message = chat_history[-1]
128
- if isinstance(last_message, (list, tuple)) and len(last_message) >= 2:
129
- return last_message[1]
 
 
 
 
130
 
131
- return "Réponse indisponible - Veuillez reformuler votre question."
 
 
 
 
 
 
 
 
132
 
133
  except Exception as e:
134
  logging.error(f"Generation error: {str(e)}", exc_info=True)
@@ -146,66 +140,18 @@ def chatbot(query):
146
  return f"Une erreur s'est produite : {str(e)}"
147
 
148
 
149
- # Rotating Example Questions Functionality
150
- def get_random_questions():
151
- """Selects three random example questions"""
152
- return random.sample(EXAMPLE_QUESTIONS, 3)
153
-
154
  # Gradio Interface Setup with Enhanced UI
155
  with gr.Blocks(title="MTC Chatbot") as demo:
156
  gr.Markdown("# Apprenez-en plus sur le savoir MTC!")
157
 
158
  chatbot_ui = gr.Chatbot(label="MTC Assistant", type="messages")
159
 
160
- with gr.Row():
161
- input_box = gr.Textbox(
162
- placeholder="Posez votre question ici...",
163
- label="Votre question"
164
- )
165
 
166
- def respond(message, history):
167
- vs = initialize_vector_store()
168
- response = generate_response(message, vs)
169
-
170
- history.append({"role": "user", "content": message})
171
- history.append({"role": "assistant", "content": response})
172
-
173
- # After every interaction, get new random questions
174
- example_questions = get_random_questions()
175
-
176
- # Recreate the buttons with new questions
177
- example_buttons = []
178
- for question in example_questions:
179
- btn = gr.Button(question)
180
- btn.click(
181
- process_example_click,
182
- inputs=[gr.Textbox(value=question, visible=False), chatbot_ui],
183
- outputs=chatbot_ui
184
- )
185
- example_buttons.append(btn)
186
-
187
- return history
188
-
189
- def process_example_click(example_query, history):
190
- response = chatbot(example_query)
191
- history.append({"role": "user", "content": example_query})
192
- history.append({"role": "assistant", "content": response})
193
- return history
194
-
195
- # Initial example questions
196
- example_questions = get_random_questions()
197
- with gr.Row():
198
- example_buttons = []
199
- for question in example_questions:
200
- btn = gr.Button(question)
201
- btn.click(
202
- process_example_click,
203
- inputs=[gr.Textbox(value=question, visible=False), chatbot_ui],
204
- outputs=chatbot_ui
205
- )
206
- example_buttons.append(btn)
207
-
208
- input_box.submit(respond, [input_box, chatbot_ui], chatbot_ui)
209
 
210
  if __name__ == "__main__":
211
  demo.launch()
 
5
  from typing import List, Optional
6
  from langchain_community.vectorstores import FAISS
7
  from langchain.embeddings.base import Embeddings
 
8
  import gradio as gr
9
+ from sentence_transformers import SentenceTransformer
10
 
11
  # Configuration
12
  DATA_FILE = "data-mtc.txt" # This file is no longer used in the Space
 
38
  ]
39
  )
40
 
41
+ # Embedding Model Integration
42
+ device = torch.device("cpu")
43
+ embedding_model = SentenceTransformer("Snowflake/snowflake-arctic-embed-l", device=device, trust_remote_code=True)
 
 
 
44
 
45
+ class HuggingFaceEmbeddings(Embeddings):
46
+ """Embedding management using Hugging Face SentenceTransformer"""
 
 
 
 
47
 
48
  def _generate_embedding(self, text: str) -> np.ndarray:
49
+ """Generate an embedding via the Hugging Face model"""
50
  try:
51
+ return np.array(embedding_model.encode(text.strip()), dtype=np.float32)
 
 
 
 
 
 
52
  except Exception as e:
53
  logging.error(f"Embedding error: {str(e)}")
54
  raise RuntimeError("Failed to generate embedding") from e
 
62
 
63
  def initialize_vector_store() -> FAISS:
64
  """Robust initialization of the vector store"""
65
+ embeddings = HuggingFaceEmbeddings()
66
 
67
  try:
68
  logging.info("Loading existing database...")
 
101
  for i, doc in enumerate(best_docs)
102
  )
103
 
104
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
105
 
106
+ model_name = "Qwen/Qwen2.5-72B-Instruct"
107
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
108
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
109
+
110
+ prompt = BASE_SYSTEM_PROMPT.format(context=context)
111
+
112
+ messages = [
113
+ {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
114
+ {"role": "user", "content": user_input}
115
+ ]
116
 
117
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
118
+
119
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
120
+
121
+ generated_ids = model.generate(**model_inputs, max_new_tokens=512)
122
+
123
+ response = tokenizer.batch_decode(generated_ids[:, model_inputs.input_ids.shape[-1]:], skip_special_tokens=True)
124
+
125
+ return response[0] if response else "Réponse indisponible - Veuillez reformuler votre question."
126
 
127
  except Exception as e:
128
  logging.error(f"Generation error: {str(e)}", exc_info=True)
 
140
  return f"Une erreur s'est produite : {str(e)}"
141
 
142
 
 
 
 
 
 
143
  # Gradio Interface Setup with Enhanced UI
144
  with gr.Blocks(title="MTC Chatbot") as demo:
145
  gr.Markdown("# Apprenez-en plus sur le savoir MTC!")
146
 
147
  chatbot_ui = gr.Chatbot(label="MTC Assistant", type="messages")
148
 
149
+ input_box = gr.Textbox(
150
+ placeholder="Posez votre question ici...",
151
+ label="Votre question"
152
+ )
 
153
 
154
+ input_box.submit(chatbot, inputs=input_box, outputs=chatbot_ui)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  if __name__ == "__main__":
157
  demo.launch()
requirements.txt CHANGED
@@ -4,4 +4,6 @@ faiss-cpu
4
  gradio
5
  gradio_client
6
  numpy
7
-
 
 
 
4
  gradio
5
  gradio_client
6
  numpy
7
+ sentence_transformers
8
+ einops
9
+ torch