localsavageai commited on
Commit
b0c410e
·
verified ·
1 Parent(s): 900c5e5

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +209 -0
  2. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import numpy as np
4
+ from typing import List, Optional, Tuple
5
+ import torch
6
+ import gradio as gr
7
+ import spaces
8
+ from sentence_transformers import SentenceTransformer
9
+ from langchain_community.vectorstores import FAISS
10
+ from langchain.embeddings.base import Embeddings
11
+ from gradio_client import Client
12
+ import requests
13
+ from tqdm import tqdm
14
+
15
+ # Configuration
16
+ DATA_FILE = "data-mtc.txt"
17
+ DATABASE_DIR = "semantic_memory"
18
+ QWEN_API_URL = "Qwen/Qwen2.5-Max-Demo" # Gradio API for Qwen2.5 chat
19
+ CHUNK_SIZE = 800
20
+ TOP_K_RESULTS = 150
21
+ SIMILARITY_THRESHOLD = 0.4
22
+
23
+ BASE_SYSTEM_PROMPT = """
24
+ Répondez en français selon ces règles :
25
+
26
+ 1. Utilisez EXCLUSIVEMENT le contexte fourni
27
+ 2. Structurez la réponse en :
28
+ - Définition principale
29
+ - Caractéristiques clés (3 points maximum)
30
+ - Relations avec d'autres concepts
31
+ 3. Si aucune information pertinente, indiquez-le clairement
32
+
33
+ Contexte :
34
+ {context}
35
+ """
36
+
37
+ # Configure logging
38
+ logging.basicConfig(
39
+ level=logging.INFO,
40
+ format='%(asctime)s - %(levelname)s - %(message)s',
41
+ handlers=[
42
+ logging.FileHandler("mtc_chat.log"),
43
+ logging.StreamHandler()
44
+ ]
45
+ )
46
+
47
+ class LocalEmbeddings(Embeddings):
48
+ """Local sentence-transformers embeddings"""
49
+ def __init__(self, model):
50
+ self.model = model
51
+
52
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
53
+ embeddings = []
54
+ for text in tqdm(texts, desc="Creating embeddings"):
55
+ embeddings.append(self.model.encode(text).tolist())
56
+ return embeddings
57
+
58
+ def embed_query(self, text: str) -> List[float]:
59
+ return self.model.encode(text).tolist()
60
+
61
+ def split_text_into_chunks(text: str) -> List[str]:
62
+ """Split text with overlap and sentence preservation"""
63
+ chunks = []
64
+ start = 0
65
+ text_length = len(text)
66
+
67
+ while start < text_length:
68
+ end = min(start + CHUNK_SIZE, text_length)
69
+ chunk = text[start:end]
70
+
71
+ # Find last complete punctuation
72
+ last_punct = max(
73
+ chunk.rfind('.'),
74
+ chunk.rfind('!'),
75
+ chunk.rfind('?'),
76
+ chunk.rfind('\n\n')
77
+ )
78
+
79
+ if last_punct != -1 and (end - start) > CHUNK_SIZE//2:
80
+ end = start + last_punct + 1
81
+
82
+ chunks.append(text[start:end].strip())
83
+ start = end if end > start else start + CHUNK_SIZE
84
+
85
+ return chunks
86
+
87
+ def initialize_vector_store(embeddings: Embeddings) -> FAISS:
88
+ """Initialize FAISS vector store"""
89
+ if os.path.exists(DATABASE_DIR):
90
+ try:
91
+ logging.info("Loading existing database...")
92
+ return FAISS.load_local(
93
+ DATABASE_DIR,
94
+ embeddings,
95
+ allow_dangerous_deserialization=True
96
+ )
97
+ except Exception as e:
98
+ logging.error(f"FAISS load error: {str(e)}")
99
+ raise
100
+
101
+ logging.info("Creating new vector database...")
102
+ if not os.path.exists(DATA_FILE):
103
+ raise FileNotFoundError(f"{DATA_FILE} not found")
104
+
105
+ try:
106
+ with open(DATA_FILE, "r", encoding="utf-8") as f:
107
+ text = f.read()
108
+
109
+ chunks = split_text_into_chunks(text)
110
+ if not chunks:
111
+ raise ValueError("No valid chunks generated")
112
+
113
+ logging.info(f"Creating {len(chunks)} chunks...")
114
+ vector_store = FAISS.from_texts(chunks, embeddings)
115
+ vector_store.save_local(DATABASE_DIR)
116
+ logging.info("Vector store initialized successfully")
117
+ return vector_store
118
+
119
+ except Exception as e:
120
+ logging.error(f"Initialization failed: {str(e)}")
121
+ raise
122
+
123
+ def generate_response(user_input: str, vector_store: FAISS) -> Optional[str]:
124
+ """Generate response using Qwen API"""
125
+ try:
126
+ # Contextual search
127
+ docs_scores = vector_store.similarity_search_with_score(
128
+ user_input,
129
+ k=TOP_K_RESULTS*3
130
+ )
131
+
132
+ # Filter results
133
+ filtered_docs = [
134
+ (doc, score) for doc, score in docs_scores
135
+ if score < SIMILARITY_THRESHOLD
136
+ ]
137
+ filtered_docs.sort(key=lambda x: x[1])
138
+
139
+ if not filtered_docs:
140
+ return "Aucune correspondance trouvée. Essayez des termes plus spécifiques."
141
+
142
+ best_docs = [doc for doc, _ in filtered_docs[:TOP_K_RESULTS]]
143
+
144
+ # Build context
145
+ context = "\n".join(
146
+ f"=== Source {i+1} ===\n{doc.page_content}\n"
147
+ for i, doc in enumerate(best_docs)
148
+ )
149
+
150
+ # Call Qwen API
151
+ client = Client(QWEN_API_URL, verbose=False)
152
+ response = client.predict(
153
+ query=user_input,
154
+ history=[],
155
+ system=BASE_SYSTEM_PROMPT.format(context=context),
156
+ api_name="/model_chat"
157
+ )
158
+
159
+ # Extract response
160
+ if isinstance(response, tuple) and len(response) >= 2:
161
+ chat_history = response[1]
162
+ if chat_history and len(chat_history[-1]) >= 2:
163
+ return chat_history[-1][1]
164
+
165
+ return "Réponse indisponible - Veuillez reformuler votre question."
166
+
167
+ except Exception as e:
168
+ logging.error(f"Generation error: {str(e)}", exc_info=True)
169
+ return None
170
+
171
+ # Initialize models and vector store
172
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
173
+ model = SentenceTransformer("cnmoro/snowflake-arctic-embed-m-v2.0-cpu", device=device, trust_remote_code=True)
174
+ embeddings = LocalEmbeddings(model)
175
+ vector_store = initialize_vector_store(embeddings)
176
+
177
+ # Gradio interface
178
+ @spaces.GPU
179
+ def embed(document: str):
180
+ return model.encode(document).tolist()
181
+
182
+ def chat_response(message: str, history: List[Tuple[str, str]]):
183
+ response = generate_response(message, vector_store)
184
+ return response or "Erreur de génération - Veuillez réessayer."
185
+
186
+ with gr.Blocks() as app:
187
+ gr.Markdown("# MTC Knowledge Assistant")
188
+
189
+ with gr.Tab("Embeddings"):
190
+ gr.Markdown("## Text Embedding Demo")
191
+ text_input = gr.Textbox(label="Enter text to embed")
192
+ output = gr.JSON(label="Embedding Vector")
193
+ text_input.submit(embed, inputs=text_input, outputs=output)
194
+
195
+ with gr.Tab("MTC Chat"):
196
+ gr.Markdown("## Posez vos questions sur la médecine traditionnelle chinoise")
197
+ chatbot = gr.Chatbot(height=500)
198
+ msg = gr.Textbox(label="Votre question")
199
+ clear = gr.ClearButton([msg, chatbot])
200
+
201
+ msg.submit(
202
+ chat_response,
203
+ inputs=[msg, chatbot],
204
+ outputs=[msg, chatbot],
205
+ queue=True
206
+ )
207
+
208
+ if __name__ == "__main__":
209
+ app.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=5.23.2
2
+ sentence-transformers
3
+ torch
4
+ langchain
5
+ langchain-community
6
+ faiss-cpu
7
+ gradio-client
8
+ tqdm
9
+ requests
10
+ numpy
11
+ einops==0.7.0