DrishtiSharma commited on
Commit
584bb82
·
verified ·
1 Parent(s): 39ae944

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -31
app.py CHANGED
@@ -37,6 +37,7 @@ from typing import List, Annotated, Any
37
  import re, operator
38
 
39
 
 
40
  class MultiAgentState(BaseModel):
41
  state: List[str] = []
42
  messages: Annotated[list[AnyMessage], add_messages]
@@ -54,22 +55,38 @@ class StoryState(BaseModel):
54
  stories_lst: Annotated[list, operator.add]
55
 
56
  class DocumentRAG:
57
- def __init__(self):
58
  self.document_store = None
59
  self.qa_chain = None
60
  self.document_summary = ""
61
  self.chat_history = []
62
  self.last_processed_time = None
63
- self.api_key = os.getenv("OPENAI_API_KEY") # Fetch the API key from environment variable
64
  self.init_time = datetime.now(pytz.UTC)
65
-
66
- if not self.api_key:
67
- raise ValueError("API Key not found. Make sure to set the 'OPENAI_API_KEY' environment variable.")
68
 
69
  # Persistent directory for Chroma to avoid tenant-related errors
70
  self.chroma_persist_dir = "./chroma_storage"
71
  os.makedirs(self.chroma_persist_dir, exist_ok=True)
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def process_documents(self, uploaded_files):
74
  """Process uploaded files by saving them temporarily and extracting content."""
75
  if not self.api_key:
@@ -118,7 +135,7 @@ class DocumentRAG:
118
  self.document_text = " ".join([doc.page_content for doc in documents]) # Store for later use
119
 
120
  # Create embeddings and initialize retrieval chain
121
- embeddings = OpenAIEmbeddings(api_key=self.api_key)
122
  self.document_store = Chroma.from_documents(
123
  documents,
124
  embeddings,
@@ -294,29 +311,53 @@ class DocumentRAG:
294
  def topic_extractor(self, state: MultiAgentState):
295
  return {"sub_topic_list": self.extract_subtopics(state.sub_topics)}
296
 
297
- def retrieve_docs(self, state: StoryState):
298
- retriever = self.document_store.as_retriever(search_kwargs={"k": 20})
299
- docs = retriever.get_relevant_documents(f"information about {state.story_topic}")
300
- return {"retrieved_docs": docs}
301
 
302
- def generate_story(self, state: StoryState):
303
- context = "\n\n".join([doc.page_content for doc in state.retrieved_docs[:5]])
304
- prompt = f"""You're a witty science storyteller. Create a short, child-friendly story that explains **{state.story_topic}** based on this:\n\n{context}"""
305
- msg = self.llm.invoke([SystemMessage("Use humor. Be clear."), HumanMessage(prompt)])
306
- return {"stories": msg}
 
307
 
308
- def rerank_docs(self, state: StoryState):
309
  topic = state.story_topic
 
 
 
 
 
 
 
310
  docs = state.retrieved_docs
311
  texts = [doc.page_content for doc in docs]
312
 
313
- # Fallback: return top 5 if no reranker available
314
  if not texts:
315
- return {"reranked_docs": []}
 
 
 
 
 
 
316
 
317
- # Quick ranking by doc length (or use a real reranker if you have access)
318
- ranked = sorted(texts, key=lambda t: -len(t))[:5]
319
- return {"reranked_docs": ranked}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
 
322
  def run_multiagent_storygraph(self, topic: str, context: str):
@@ -324,9 +365,9 @@ class DocumentRAG:
324
 
325
  # Define the story subgraph with reranking
326
  story_graph = StateGraph(StoryState)
327
- story_graph.add_node("Retrieve", self.retrieve_docs)
328
- story_graph.add_node("Rerank", self.rerank_docs) # Add rerank step
329
- story_graph.add_node("Generate", self.generate_story)
330
  story_graph.set_entry_point("Retrieve")
331
  story_graph.add_edge("Retrieve", "Rerank")
332
  story_graph.add_edge("Rerank", "Generate")
@@ -365,13 +406,9 @@ class DocumentRAG:
365
  return result
366
 
367
 
368
-
369
-
370
-
371
- # Initialize RAG system in session state
372
- if "rag_system" not in st.session_state:
373
- st.session_state.rag_system = DocumentRAG()
374
-
375
 
376
 
377
 
 
37
  import re, operator
38
 
39
 
40
+
41
  class MultiAgentState(BaseModel):
42
  state: List[str] = []
43
  messages: Annotated[list[AnyMessage], add_messages]
 
55
  stories_lst: Annotated[list, operator.add]
56
 
57
  class DocumentRAG:
58
+ def __init__(self, embedding_choice="OpenAI"):
59
  self.document_store = None
60
  self.qa_chain = None
61
  self.document_summary = ""
62
  self.chat_history = []
63
  self.last_processed_time = None
64
+ self.api_key = os.getenv("OPENAI_API_KEY")
65
  self.init_time = datetime.now(pytz.UTC)
66
+ self.embedding_choice = embedding_choice
 
 
67
 
68
  # Persistent directory for Chroma to avoid tenant-related errors
69
  self.chroma_persist_dir = "./chroma_storage"
70
  os.makedirs(self.chroma_persist_dir, exist_ok=True)
71
 
72
+
73
+ def _get_embedding_model(self):
74
+ if self.embedding_choice == "OpenAI":
75
+ return OpenAIEmbeddings(api_key=self.api_key)
76
+ else:
77
+ from langchain.embeddings import CohereEmbeddings
78
+ return CohereEmbeddings(
79
+ model="embed-multilingual-light-v3.0",
80
+ cohere_api_key=os.getenv("COHERE_API_KEY")
81
+ )
82
+
83
+
84
+
85
+ if not self.api_key:
86
+ raise ValueError("API Key not found. Make sure to set the 'OPENAI_API_KEY' environment variable.")
87
+
88
+
89
+
90
  def process_documents(self, uploaded_files):
91
  """Process uploaded files by saving them temporarily and extracting content."""
92
  if not self.api_key:
 
135
  self.document_text = " ".join([doc.page_content for doc in documents]) # Store for later use
136
 
137
  # Create embeddings and initialize retrieval chain
138
+ embeddings = self._get_embedding_model()
139
  self.document_store = Chroma.from_documents(
140
  documents,
141
  embeddings,
 
311
  def topic_extractor(self, state: MultiAgentState):
312
  return {"sub_topic_list": self.extract_subtopics(state.sub_topics)}
313
 
 
 
 
 
314
 
315
+ def retrieve_node(self, state: StoryState):
316
+ embedding = self._get_embedding_model()
317
+ retriever = Chroma(
318
+ persist_directory=self.chroma_persist_dir,
319
+ embedding_function=embedding
320
+ ).as_retriever(search_kwargs={"k": 20})
321
 
 
322
  topic = state.story_topic
323
+ query = f"information about {topic}"
324
+ docs = retriever.get_relevant_documents(query)
325
+ return {"retrieved_docs": docs, "question": query}
326
+
327
+ def rerank_node(self, state: StoryState):
328
+ topic = state.story_topic
329
+ query = f"Rerank documents based on how well they explain the topic {topic}"
330
  docs = state.retrieved_docs
331
  texts = [doc.page_content for doc in docs]
332
 
 
333
  if not texts:
334
+ return {"reranked_docs": [], "question": query}
335
+
336
+ # Quick fallback: rank by length
337
+ top_docs = sorted(texts, key=lambda t: -len(t))[:5]
338
+ return {"reranked_docs": top_docs, "question": query}
339
+
340
+
341
 
342
+ def generate_story_node(self, state: StoryState):
343
+ context = "\n\n".join(state.reranked_docs)
344
+ topic = state.story_topic
345
+
346
+ system_message = f"""
347
+ Suppose you're a brilliant science storyteller.
348
+ You write stories that help middle schoolers understand complex science topics with fun and clarity.
349
+ Add subtle humor and make it engaging.
350
+ """
351
+ prompt = f"""
352
+ Use the following context to write a fun and simple story explaining **{topic}** to a middle schooler:\n
353
+ Context:\n{context}\n\n
354
+ Story:
355
+ """
356
+
357
+ msg = self.llm.invoke([SystemMessage(system_message), HumanMessage(prompt)])
358
+ return {"stories": msg}
359
+
360
+
361
 
362
 
363
  def run_multiagent_storygraph(self, topic: str, context: str):
 
365
 
366
  # Define the story subgraph with reranking
367
  story_graph = StateGraph(StoryState)
368
+ story_graph.add_node("Retrieve", self.retrieve_node)
369
+ story_graph.add_node("Rerank", self.rerank_node)
370
+ story_graph.add_node("Generate", self.generate_story_node)
371
  story_graph.set_entry_point("Retrieve")
372
  story_graph.add_edge("Retrieve", "Rerank")
373
  story_graph.add_edge("Rerank", "Generate")
 
406
  return result
407
 
408
 
409
+ if "rag_system" not in st.session_state or st.session_state.embedding_model != embedding_choice:
410
+ st.session_state.embedding_model = embedding_choice
411
+ st.session_state.rag_system = DocumentRAG(embedding_choice=embedding_choice)
 
 
 
 
412
 
413
 
414