priyanshu23456 commited on
Commit
efe75c0
·
verified ·
1 Parent(s): 6a4c2d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -58
app.py CHANGED
@@ -12,6 +12,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
  from sentence_transformers import SentenceTransformer
13
  import faiss
14
  import numpy as np
 
15
 
16
  # Set up logging
17
  logging.basicConfig(level=logging.INFO)
@@ -31,14 +32,26 @@ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
31
 
32
  # Global model variables
33
  embeddings_model = None
 
34
 
35
  def initialize_models():
36
- """Initialize lightweight embedding model only"""
37
- global embeddings_model
38
  try:
39
- logger.info("Initializing embedding model...")
40
 
41
- # Load small embeddings model (only 22MB)
 
 
 
 
 
 
 
 
 
 
 
42
  logger.info("Loading all-MiniLM-L6-v2...")
43
  embeddings_model = SentenceTransformer(
44
  "all-MiniLM-L6-v2",
@@ -46,11 +59,11 @@ def initialize_models():
46
  cache_folder="/tmp"
47
  )
48
 
49
- logger.info("Model initialized successfully")
50
  return True
51
 
52
  except Exception as e:
53
- logger.error(f"Error initializing model: {str(e)}")
54
  import traceback
55
  traceback.print_exc()
56
  return False
@@ -66,7 +79,6 @@ def load_pdf(filepath: str) -> List[str]:
66
  logger.warning("No pages extracted from PDF")
67
  return []
68
 
69
- # Combine page content
70
  docs = [page.page_content for page in pages if page.page_content.strip()]
71
  logger.info(f"Loaded {len(pages)} pages")
72
  return docs
@@ -82,7 +94,6 @@ def create_faiss_index(chunks: List[str]):
82
  try:
83
  logger.info(f"Creating FAISS index for {len(chunks)} chunks")
84
 
85
- # Encode chunks in batches to save memory
86
  batch_size = 32
87
  embeddings_list = []
88
 
@@ -93,14 +104,12 @@ def create_faiss_index(chunks: List[str]):
93
 
94
  embeddings = np.vstack(embeddings_list).astype('float32')
95
 
96
- # Create FAISS index
97
  dim = embeddings.shape[1]
98
  index = faiss.IndexFlatL2(dim)
99
  index.add(embeddings)
100
 
101
  logger.info(f"FAISS index created with dimension {dim}")
102
 
103
- # Clean up
104
  del embeddings_list
105
  gc.collect()
106
 
@@ -112,19 +121,16 @@ def create_faiss_index(chunks: List[str]):
112
  traceback.print_exc()
113
  raise
114
 
115
- def retrieve_context(question: str, chunks: List[str], index, k: int = 3) -> str:
116
  """Retrieve relevant context for question"""
117
  try:
118
- # Encode question
119
  q_embedding = embeddings_model.encode([question])
120
  q_embedding = np.array(q_embedding).astype('float32')
121
 
122
- # Search index
123
  distances, indices = index.search(q_embedding, k)
124
 
125
- # Get relevant chunks with distances
126
  relevant_chunks = []
127
- for i, dist in zip(indices[0], distances[0]):
128
  if i < len(chunks):
129
  relevant_chunks.append(chunks[i])
130
 
@@ -137,37 +143,37 @@ def retrieve_context(question: str, chunks: List[str], index, k: int = 3) -> str
137
  logger.error(f"Error retrieving context: {str(e)}")
138
  return ""
139
 
140
- def generate_extractive_answer(question: str, context: str) -> str:
141
- """Generate extractive answer from context (no LLM needed)"""
142
  try:
143
- # Simple extractive approach: return most relevant sentences
144
- sentences = context.split('.')
145
-
146
- # Score sentences by keyword overlap with question
147
- question_words = set(question.lower().split())
148
- scored_sentences = []
149
-
150
- for sent in sentences:
151
- if len(sent.strip()) > 20: # Only consider substantial sentences
152
- sent_words = set(sent.lower().split())
153
- score = len(question_words & sent_words)
154
- scored_sentences.append((score, sent.strip()))
155
-
156
- # Sort by score and get top sentences
157
- scored_sentences.sort(reverse=True)
158
- top_sentences = [sent for score, sent in scored_sentences[:3] if score > 0]
159
-
160
- if top_sentences:
161
- answer = ". ".join(top_sentences) + "."
162
- logger.info(f"Generated extractive answer")
163
- return answer
164
- else:
165
- # Return first part of context if no good match
166
- return context[:500] + "..."
167
 
168
  except Exception as e:
169
- logger.error(f"Error generating answer: {str(e)}")
170
- return context[:500] + "..."
 
 
171
 
172
  def cleanup_temp_files(filepath):
173
  """Clean up temporary files"""
@@ -181,9 +187,10 @@ def cleanup_temp_files(filepath):
181
  @app.route('/')
182
  def home():
183
  return jsonify({
184
- "message": "PDF QA API is running!",
185
  "status": "healthy",
186
- "model": "all-MiniLM-L6-v2 (extractive)"
 
187
  })
188
 
189
  @app.route('/health')
@@ -213,8 +220,8 @@ def ask():
213
 
214
  # Split into chunks
215
  splitter = RecursiveCharacterTextSplitter(
216
- chunk_size=500,
217
- chunk_overlap=50,
218
  separators=["\n\n", "\n", ". ", " ", ""]
219
  )
220
 
@@ -222,11 +229,6 @@ def ask():
222
  for doc in docs:
223
  chunks.extend(splitter.split_text(doc))
224
 
225
- # Limit chunks to avoid memory issues
226
- if len(chunks) > 200:
227
- logger.warning(f"Too many chunks ({len(chunks)}), limiting to 200")
228
- chunks = chunks[:200]
229
-
230
  logger.info(f"Created {len(chunks)} chunks")
231
 
232
  if not chunks:
@@ -235,14 +237,14 @@ def ask():
235
  # Create FAISS index
236
  index, embeddings = create_faiss_index(chunks)
237
 
238
- # Retrieve context
239
- context = retrieve_context(question, chunks, index, k=5)
240
 
241
  if not context:
242
  return jsonify({"error": "Failed to retrieve context from PDF"}), 500
243
 
244
- # Generate extractive answer
245
- answer = generate_extractive_answer(question, context)
246
 
247
  if not answer or len(answer.strip()) < 10:
248
  return jsonify({"error": "Failed to generate answer from PDF content"}), 500
@@ -254,7 +256,7 @@ def ask():
254
 
255
  return jsonify({
256
  "answer": answer,
257
- "method": "extractive"
258
  })
259
 
260
  except Exception as e:
@@ -265,7 +267,6 @@ def ask():
265
  finally:
266
  if filepath:
267
  cleanup_temp_files(filepath)
268
- # Force garbage collection
269
  gc.collect()
270
 
271
  if __name__ == "__main__":
 
12
  from sentence_transformers import SentenceTransformer
13
  import faiss
14
  import numpy as np
15
+ import google.generativeai as genai
16
 
17
  # Set up logging
18
  logging.basicConfig(level=logging.INFO)
 
32
 
33
  # Global model variables
34
  embeddings_model = None
35
+ gemini_model = None
36
 
37
  def initialize_models():
38
+ """Initialize embedding model and Gemini API"""
39
+ global embeddings_model, gemini_model
40
  try:
41
+ logger.info("Initializing models...")
42
 
43
+ # Get Gemini API key from environment
44
+ gemini_api_key = os.environ.get("GEMINI_API_KEY")
45
+ if not gemini_api_key:
46
+ logger.error("GEMINI_API_KEY not found in environment variables!")
47
+ return False
48
+
49
+ # Configure Gemini
50
+ genai.configure(api_key=gemini_api_key)
51
+ gemini_model = genai.GenerativeModel('gemini-2.0-flash-exp')
52
+ logger.info("Gemini API configured successfully")
53
+
54
+ # Load embeddings model (only 22MB!)
55
  logger.info("Loading all-MiniLM-L6-v2...")
56
  embeddings_model = SentenceTransformer(
57
  "all-MiniLM-L6-v2",
 
59
  cache_folder="/tmp"
60
  )
61
 
62
+ logger.info("Models initialized successfully")
63
  return True
64
 
65
  except Exception as e:
66
+ logger.error(f"Error initializing models: {str(e)}")
67
  import traceback
68
  traceback.print_exc()
69
  return False
 
79
  logger.warning("No pages extracted from PDF")
80
  return []
81
 
 
82
  docs = [page.page_content for page in pages if page.page_content.strip()]
83
  logger.info(f"Loaded {len(pages)} pages")
84
  return docs
 
94
  try:
95
  logger.info(f"Creating FAISS index for {len(chunks)} chunks")
96
 
 
97
  batch_size = 32
98
  embeddings_list = []
99
 
 
104
 
105
  embeddings = np.vstack(embeddings_list).astype('float32')
106
 
 
107
  dim = embeddings.shape[1]
108
  index = faiss.IndexFlatL2(dim)
109
  index.add(embeddings)
110
 
111
  logger.info(f"FAISS index created with dimension {dim}")
112
 
 
113
  del embeddings_list
114
  gc.collect()
115
 
 
121
  traceback.print_exc()
122
  raise
123
 
124
+ def retrieve_context(question: str, chunks: List[str], index, k: int = 5) -> str:
125
  """Retrieve relevant context for question"""
126
  try:
 
127
  q_embedding = embeddings_model.encode([question])
128
  q_embedding = np.array(q_embedding).astype('float32')
129
 
 
130
  distances, indices = index.search(q_embedding, k)
131
 
 
132
  relevant_chunks = []
133
+ for i in indices[0]:
134
  if i < len(chunks):
135
  relevant_chunks.append(chunks[i])
136
 
 
143
  logger.error(f"Error retrieving context: {str(e)}")
144
  return ""
145
 
146
+ def generate_answer_with_gemini(question: str, context: str) -> str:
147
+ """Generate answer using Gemini API"""
148
  try:
149
+ logger.info(f"Generating answer with Gemini for: {question}")
150
+
151
+ prompt = f"""You are a helpful AI assistant that answers questions based on the provided context from a PDF document.
152
+
153
+ Context from PDF:
154
+ {context}
155
+
156
+ Question: {question}
157
+
158
+ Instructions:
159
+ - Answer the question clearly and concisely based ONLY on the context provided
160
+ - If the context doesn't contain enough information to answer, say so
161
+ - Provide a well-structured, informative answer
162
+ - If asked to summarize, provide a comprehensive summary
163
+
164
+ Answer:"""
165
+
166
+ response = gemini_model.generate_content(prompt)
167
+ answer = response.text.strip()
168
+
169
+ logger.info(f"Generated answer: {answer[:100]}...")
170
+ return answer
 
 
171
 
172
  except Exception as e:
173
+ logger.error(f"Error generating answer with Gemini: {str(e)}")
174
+ import traceback
175
+ traceback.print_exc()
176
+ return "Sorry, I couldn't generate an answer. Please try again."
177
 
178
  def cleanup_temp_files(filepath):
179
  """Clean up temporary files"""
 
187
  @app.route('/')
188
  def home():
189
  return jsonify({
190
+ "message": "PDF QA API with Gemini 2.0 Flash is running!",
191
  "status": "healthy",
192
+ "model": "Google Gemini 2.0 Flash",
193
+ "embeddings": "all-MiniLM-L6-v2"
194
  })
195
 
196
  @app.route('/health')
 
220
 
221
  # Split into chunks
222
  splitter = RecursiveCharacterTextSplitter(
223
+ chunk_size=800,
224
+ chunk_overlap=100,
225
  separators=["\n\n", "\n", ". ", " ", ""]
226
  )
227
 
 
229
  for doc in docs:
230
  chunks.extend(splitter.split_text(doc))
231
 
 
 
 
 
 
232
  logger.info(f"Created {len(chunks)} chunks")
233
 
234
  if not chunks:
 
237
  # Create FAISS index
238
  index, embeddings = create_faiss_index(chunks)
239
 
240
+ # Retrieve context (get more chunks for Gemini since it can handle it)
241
+ context = retrieve_context(question, chunks, index, k=7)
242
 
243
  if not context:
244
  return jsonify({"error": "Failed to retrieve context from PDF"}), 500
245
 
246
+ # Generate answer with Gemini
247
+ answer = generate_answer_with_gemini(question, context)
248
 
249
  if not answer or len(answer.strip()) < 10:
250
  return jsonify({"error": "Failed to generate answer from PDF content"}), 500
 
256
 
257
  return jsonify({
258
  "answer": answer,
259
+ "model": "gemini-2.0-flash-exp"
260
  })
261
 
262
  except Exception as e:
 
267
  finally:
268
  if filepath:
269
  cleanup_temp_files(filepath)
 
270
  gc.collect()
271
 
272
  if __name__ == "__main__":