localsavageai commited on
Commit
943806d
·
verified ·
1 Parent(s): 78b2249

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +88 -140
  2. requirements.txt +4 -8
app.py CHANGED
@@ -1,23 +1,17 @@
1
  import os
2
  import logging
3
  import numpy as np
4
- from typing import List, Optional, Tuple
5
- import torch
6
- import gradio as gr
7
- import spaces
8
- from sentence_transformers import SentenceTransformer
9
  from langchain_community.vectorstores import FAISS
10
  from langchain.embeddings.base import Embeddings
11
  from gradio_client import Client
12
- import requests
13
- from tqdm import tqdm
14
 
15
  # Configuration
16
- DATA_FILE = "data-mtc.txt"
17
- DATABASE_DIR = "semantic_memory"
18
- QWEN_API_URL = "Qwen/Qwen2.5-Max-Demo" # Gradio API for Qwen2.5 chat
19
  CHUNK_SIZE = 800
20
- TOP_K_RESULTS = 150
21
  SIMILARITY_THRESHOLD = 0.4
22
 
23
  BASE_SYSTEM_PROMPT = """
@@ -34,176 +28,130 @@ Contexte :
34
  {context}
35
  """
36
 
37
- # Configure logging
38
  logging.basicConfig(
39
  level=logging.INFO,
40
  format='%(asctime)s - %(levelname)s - %(message)s',
41
  handlers=[
42
- logging.FileHandler("mtc_chat.log"),
43
- logging.StreamHandler()
44
  ]
45
  )
46
 
47
- class LocalEmbeddings(Embeddings):
48
- """Local sentence-transformers embeddings"""
49
- def __init__(self, model):
50
- self.model = model
51
 
52
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
53
- embeddings = []
54
- for text in tqdm(texts, desc="Creating embeddings"):
55
- embeddings.append(self.model.encode(text).tolist())
56
- return embeddings
57
 
58
- def embed_query(self, text: str) -> List[float]:
59
- return self.model.encode(text).tolist()
60
-
61
- def split_text_into_chunks(text: str) -> List[str]:
62
- """Split text with overlap and sentence preservation"""
63
- chunks = []
64
- start = 0
65
- text_length = len(text)
66
-
67
- while start < text_length:
68
- end = min(start + CHUNK_SIZE, text_length)
69
- chunk = text[start:end]
70
-
71
- # Find last complete punctuation
72
- last_punct = max(
73
- chunk.rfind('.'),
74
- chunk.rfind('!'),
75
- chunk.rfind('?'),
76
- chunk.rfind('\n\n')
77
- )
78
-
79
- if last_punct != -1 and (end - start) > CHUNK_SIZE//2:
80
- end = start + last_punct + 1
81
-
82
- chunks.append(text[start:end].strip())
83
- start = end if end > start else start + CHUNK_SIZE
84
-
85
- return chunks
86
-
87
- def initialize_vector_store(embeddings: Embeddings) -> FAISS:
88
- """Initialize FAISS vector store"""
89
- if os.path.exists(DATABASE_DIR):
90
  try:
91
- logging.info("Loading existing database...")
92
- return FAISS.load_local(
93
- DATABASE_DIR,
94
- embeddings,
95
- allow_dangerous_deserialization=True
96
  )
 
 
 
97
  except Exception as e:
98
- logging.error(f"FAISS load error: {str(e)}")
99
- raise
100
 
101
- logging.info("Creating new vector database...")
102
- if not os.path.exists(DATA_FILE):
103
- raise FileNotFoundError(f"{DATA_FILE} not found")
 
 
 
 
 
 
 
104
 
105
  try:
106
- with open(DATA_FILE, "r", encoding="utf-8") as f:
107
- text = f.read()
108
-
109
- chunks = split_text_into_chunks(text)
110
- if not chunks:
111
- raise ValueError("No valid chunks generated")
112
-
113
- logging.info(f"Creating {len(chunks)} chunks...")
114
- vector_store = FAISS.from_texts(chunks, embeddings)
115
- vector_store.save_local(DATABASE_DIR)
116
- logging.info("Vector store initialized successfully")
117
- return vector_store
118
-
119
  except Exception as e:
120
- logging.error(f"Initialization failed: {str(e)}")
121
  raise
122
 
 
123
  def generate_response(user_input: str, vector_store: FAISS) -> Optional[str]:
124
- """Generate response using Qwen API"""
125
  try:
126
- # Contextual search
127
  docs_scores = vector_store.similarity_search_with_score(
128
- user_input,
129
- k=TOP_K_RESULTS*3
130
  )
131
-
132
- # Filter results
133
  filtered_docs = [
134
- (doc, score) for doc, score in docs_scores
135
  if score < SIMILARITY_THRESHOLD
136
  ]
137
  filtered_docs.sort(key=lambda x: x[1])
138
-
139
  if not filtered_docs:
140
- return "Aucune correspondance trouvée. Essayez des termes plus spécifiques."
141
-
 
142
  best_docs = [doc for doc, _ in filtered_docs[:TOP_K_RESULTS]]
143
-
144
- # Build context
145
  context = "\n".join(
146
- f"=== Source {i+1} ===\n{doc.page_content}\n"
147
  for i, doc in enumerate(best_docs)
148
  )
149
-
150
- # Call Qwen API
151
- client = Client(QWEN_API_URL, verbose=False)
152
- response = client.predict(
153
  query=user_input,
154
  history=[],
155
  system=BASE_SYSTEM_PROMPT.format(context=context),
156
  api_name="/model_chat"
157
  )
158
 
159
- # Extract response
160
  if isinstance(response, tuple) and len(response) >= 2:
161
  chat_history = response[1]
162
- if chat_history and len(chat_history[-1]) >= 2:
163
- return chat_history[-1][1]
164
-
165
- return "Réponse indisponible - Veuillez reformuler votre question."
 
 
166
 
167
  except Exception as e:
168
  logging.error(f"Generation error: {str(e)}", exc_info=True)
169
- return None
170
-
171
- # Initialize models and vector store
172
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
173
- model = SentenceTransformer("cnmoro/snowflake-arctic-embed-m-v2.0-cpu", device=device, trust_remote_code=True)
174
- embeddings = LocalEmbeddings(model)
175
- vector_store = initialize_vector_store(embeddings)
176
-
177
- # Gradio interface
178
- @spaces.GPU
179
- def embed(document: str):
180
- return model.encode(document).tolist()
181
-
182
- def chat_response(message: str, history: List[Tuple[str, str]]):
183
- response = generate_response(message, vector_store)
184
- return response or "Erreur de génération - Veuillez réessayer."
185
-
186
- with gr.Blocks() as app:
187
- gr.Markdown("# MTC Knowledge Assistant")
188
-
189
- with gr.Tab("Embeddings"):
190
- gr.Markdown("## Text Embedding Demo")
191
- text_input = gr.Textbox(label="Enter text to embed")
192
- output = gr.JSON(label="Embedding Vector")
193
- text_input.submit(embed, inputs=text_input, outputs=output)
194
-
195
- with gr.Tab("MTC Chat"):
196
- gr.Markdown("## Posez vos questions sur la médecine traditionnelle chinoise")
197
- chatbot = gr.Chatbot(height=500)
198
- msg = gr.Textbox(label="Votre question")
199
- clear = gr.ClearButton([msg, chatbot])
200
-
201
- msg.submit(
202
- chat_response,
203
- inputs=[msg, chatbot],
204
- outputs=[msg, chatbot],
205
- queue=True
206
- )
207
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  if __name__ == "__main__":
209
- app.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import logging
3
  import numpy as np
4
+ from typing import List, Optional
 
 
 
 
5
  from langchain_community.vectorstores import FAISS
6
  from langchain.embeddings.base import Embeddings
7
  from gradio_client import Client
8
+ import gradio as gr
 
9
 
10
  # Configuration
11
+ DATA_FILE = "data-mtc.txt" # This file is no longer used in the Space
12
+ DATABASE_DIR = "." # Database files are in the root directory
 
13
  CHUNK_SIZE = 800
14
+ TOP_K_RESULTS = 100
15
  SIMILARITY_THRESHOLD = 0.4
16
 
17
  BASE_SYSTEM_PROMPT = """
 
28
  {context}
29
  """
30
 
31
+ # Logging configuration
32
  logging.basicConfig(
33
  level=logging.INFO,
34
  format='%(asctime)s - %(levelname)s - %(message)s',
35
  handlers=[
36
+ logging.StreamHandler() # Output to console in the Space
 
37
  ]
38
  )
39
 
 
 
 
 
40
 
41
+ class GradioEmbeddings(Embeddings):
42
+ """Embedding management using Gradio API"""
 
 
 
43
 
44
+ def __init__(self):
45
+ super().__init__()
46
+ self.client = Client("localsavageai/embijiji3")
47
+
48
+ def _generate_embedding(self, text: str) -> np.ndarray:
49
+ """Generate an embedding via the Gradio API"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  try:
51
+ result = self.client.predict(
52
+ document=text.strip(),
53
+ api_name="/embed"
 
 
54
  )
55
+ if not isinstance(result, list):
56
+ raise ValueError("Invalid embedding response from Gradio API")
57
+ return np.array(result, dtype=np.float32)
58
  except Exception as e:
59
+ logging.error(f"Embedding error: {str(e)}")
60
+ raise RuntimeError("Failed to generate embedding") from e
61
 
62
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
63
+ return [self._generate_embedding(text).tolist() for text in texts]
64
+
65
+ def embed_query(self, text: str) -> List[float]:
66
+ return self._generate_embedding(text).tolist()
67
+
68
+
69
+ def initialize_vector_store() -> FAISS:
70
+ """Robust initialization of the vector store"""
71
+ embeddings = GradioEmbeddings()
72
 
73
  try:
74
+ logging.info("Loading existing database...")
75
+ return FAISS.load_local(
76
+ DATABASE_DIR,
77
+ embeddings,
78
+ allow_dangerous_deserialization=True
79
+ )
 
 
 
 
 
 
 
80
  except Exception as e:
81
+ logging.error(f"FAISS loading error: {str(e)}")
82
  raise
83
 
84
+
85
  def generate_response(user_input: str, vector_store: FAISS) -> Optional[str]:
86
+ """Generate a response with complete error handling"""
87
  try:
 
88
  docs_scores = vector_store.similarity_search_with_score(
89
+ user_input,
90
+ k=TOP_K_RESULTS * 3
91
  )
92
+
 
93
  filtered_docs = [
94
+ (doc, score) for doc, score in docs_scores
95
  if score < SIMILARITY_THRESHOLD
96
  ]
97
  filtered_docs.sort(key=lambda x: x[1])
98
+
99
  if not filtered_docs:
100
+ return ("No matches found in MTC texts. "
101
+ "Try using more specific terms.")
102
+
103
  best_docs = [doc for doc, _ in filtered_docs[:TOP_K_RESULTS]]
104
+
 
105
  context = "\n".join(
106
+ f"=== Source {i + 1} ===\n{doc.page_content}\n"
107
  for i, doc in enumerate(best_docs)
108
  )
109
+
110
+ response = Client("Qwen/Qwen2.5-Max-Demo").predict(
 
 
111
  query=user_input,
112
  history=[],
113
  system=BASE_SYSTEM_PROMPT.format(context=context),
114
  api_name="/model_chat"
115
  )
116
 
 
117
  if isinstance(response, tuple) and len(response) >= 2:
118
  chat_history = response[1]
119
+ if isinstance(chat_history, list) and len(chat_history) > 0:
120
+ last_message = chat_history[-1]
121
+ if isinstance(last_message, (list, tuple)) and len(last_message) >= 2:
122
+ return last_message[1]
123
+
124
+ return "Response unavailable - Please rephrase your question."
125
 
126
  except Exception as e:
127
  logging.error(f"Generation error: {str(e)}", exc_info=True)
128
+ return "An error occurred while generating the response."
129
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ def chatbot(query):
132
+ """Main function to run the chatbot"""
133
+ try:
134
+ vs = initialize_vector_store()
135
+ response = generate_response(query, vs)
136
+ return response or "No response generated."
137
+ except Exception as e:
138
+ logging.error(f"Chatbot error: {str(e)}")
139
+ return f"An error occurred: {str(e)}"
140
+
141
+
142
+ # Gradio Interface
143
  if __name__ == "__main__":
144
+ try:
145
+ interface = gr.Interface(
146
+ fn=chatbot,
147
+ inputs=gr.Textbox(lines=7, placeholder="Enter your query here..."),
148
+ outputs=gr.Textbox(lines=7, placeholder="Response from MTC will appear here..."),
149
+ title="MTC Chatbot",
150
+ description="Ask questions about MTC and get answers based on the provided data."
151
+ )
152
+ interface.launch()
153
+
154
+ except Exception as e:
155
+ logging.critical(f"CRITICAL ERROR: {str(e)}")
156
+ print("Failed to launch Gradio interface. Check logs.")
157
+
requirements.txt CHANGED
@@ -1,11 +1,7 @@
1
- gradio>=5.23.2
2
- sentence-transformers
3
- torch
4
  langchain
5
- langchain-community
6
  faiss-cpu
7
- gradio-client
8
- tqdm
9
- requests
10
  numpy
11
- einops==0.7.0
 
 
 
 
1
  langchain
2
+ langchain_community
3
  faiss-cpu
4
+ gradio
5
+ gradio_client
 
6
  numpy
7
+