Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -37,6 +37,7 @@ from typing import List, Annotated, Any
|
|
37 |
import re, operator
|
38 |
|
39 |
|
|
|
40 |
class MultiAgentState(BaseModel):
|
41 |
state: List[str] = []
|
42 |
messages: Annotated[list[AnyMessage], add_messages]
|
@@ -54,22 +55,38 @@ class StoryState(BaseModel):
|
|
54 |
stories_lst: Annotated[list, operator.add]
|
55 |
|
56 |
class DocumentRAG:
|
57 |
-
def __init__(self):
|
58 |
self.document_store = None
|
59 |
self.qa_chain = None
|
60 |
self.document_summary = ""
|
61 |
self.chat_history = []
|
62 |
self.last_processed_time = None
|
63 |
-
self.api_key = os.getenv("OPENAI_API_KEY")
|
64 |
self.init_time = datetime.now(pytz.UTC)
|
65 |
-
|
66 |
-
if not self.api_key:
|
67 |
-
raise ValueError("API Key not found. Make sure to set the 'OPENAI_API_KEY' environment variable.")
|
68 |
|
69 |
# Persistent directory for Chroma to avoid tenant-related errors
|
70 |
self.chroma_persist_dir = "./chroma_storage"
|
71 |
os.makedirs(self.chroma_persist_dir, exist_ok=True)
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
def process_documents(self, uploaded_files):
|
74 |
"""Process uploaded files by saving them temporarily and extracting content."""
|
75 |
if not self.api_key:
|
@@ -118,7 +135,7 @@ class DocumentRAG:
|
|
118 |
self.document_text = " ".join([doc.page_content for doc in documents]) # Store for later use
|
119 |
|
120 |
# Create embeddings and initialize retrieval chain
|
121 |
-
embeddings =
|
122 |
self.document_store = Chroma.from_documents(
|
123 |
documents,
|
124 |
embeddings,
|
@@ -294,29 +311,53 @@ class DocumentRAG:
|
|
294 |
def topic_extractor(self, state: MultiAgentState):
|
295 |
return {"sub_topic_list": self.extract_subtopics(state.sub_topics)}
|
296 |
|
297 |
-
def retrieve_docs(self, state: StoryState):
|
298 |
-
retriever = self.document_store.as_retriever(search_kwargs={"k": 20})
|
299 |
-
docs = retriever.get_relevant_documents(f"information about {state.story_topic}")
|
300 |
-
return {"retrieved_docs": docs}
|
301 |
|
302 |
-
def
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
|
|
307 |
|
308 |
-
def rerank_docs(self, state: StoryState):
|
309 |
topic = state.story_topic
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
docs = state.retrieved_docs
|
311 |
texts = [doc.page_content for doc in docs]
|
312 |
|
313 |
-
# Fallback: return top 5 if no reranker available
|
314 |
if not texts:
|
315 |
-
return {"reranked_docs": []}
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
|
317 |
-
|
318 |
-
|
319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
|
322 |
def run_multiagent_storygraph(self, topic: str, context: str):
|
@@ -324,9 +365,9 @@ class DocumentRAG:
|
|
324 |
|
325 |
# Define the story subgraph with reranking
|
326 |
story_graph = StateGraph(StoryState)
|
327 |
-
story_graph.add_node("Retrieve", self.
|
328 |
-
story_graph.add_node("Rerank", self.
|
329 |
-
story_graph.add_node("Generate", self.
|
330 |
story_graph.set_entry_point("Retrieve")
|
331 |
story_graph.add_edge("Retrieve", "Rerank")
|
332 |
story_graph.add_edge("Rerank", "Generate")
|
@@ -365,13 +406,9 @@ class DocumentRAG:
|
|
365 |
return result
|
366 |
|
367 |
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
# Initialize RAG system in session state
|
372 |
-
if "rag_system" not in st.session_state:
|
373 |
-
st.session_state.rag_system = DocumentRAG()
|
374 |
-
|
375 |
|
376 |
|
377 |
|
|
|
37 |
import re, operator
|
38 |
|
39 |
|
40 |
+
|
41 |
class MultiAgentState(BaseModel):
|
42 |
state: List[str] = []
|
43 |
messages: Annotated[list[AnyMessage], add_messages]
|
|
|
55 |
stories_lst: Annotated[list, operator.add]
|
56 |
|
57 |
class DocumentRAG:
|
58 |
+
def __init__(self, embedding_choice="OpenAI"):
|
59 |
self.document_store = None
|
60 |
self.qa_chain = None
|
61 |
self.document_summary = ""
|
62 |
self.chat_history = []
|
63 |
self.last_processed_time = None
|
64 |
+
self.api_key = os.getenv("OPENAI_API_KEY")
|
65 |
self.init_time = datetime.now(pytz.UTC)
|
66 |
+
self.embedding_choice = embedding_choice
|
|
|
|
|
67 |
|
68 |
# Persistent directory for Chroma to avoid tenant-related errors
|
69 |
self.chroma_persist_dir = "./chroma_storage"
|
70 |
os.makedirs(self.chroma_persist_dir, exist_ok=True)
|
71 |
|
72 |
+
|
73 |
+
def _get_embedding_model(self):
|
74 |
+
if self.embedding_choice == "OpenAI":
|
75 |
+
return OpenAIEmbeddings(api_key=self.api_key)
|
76 |
+
else:
|
77 |
+
from langchain.embeddings import CohereEmbeddings
|
78 |
+
return CohereEmbeddings(
|
79 |
+
model="embed-multilingual-light-v3.0",
|
80 |
+
cohere_api_key=os.getenv("COHERE_API_KEY")
|
81 |
+
)
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
if not self.api_key:
|
86 |
+
raise ValueError("API Key not found. Make sure to set the 'OPENAI_API_KEY' environment variable.")
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
def process_documents(self, uploaded_files):
|
91 |
"""Process uploaded files by saving them temporarily and extracting content."""
|
92 |
if not self.api_key:
|
|
|
135 |
self.document_text = " ".join([doc.page_content for doc in documents]) # Store for later use
|
136 |
|
137 |
# Create embeddings and initialize retrieval chain
|
138 |
+
embeddings = self._get_embedding_model()
|
139 |
self.document_store = Chroma.from_documents(
|
140 |
documents,
|
141 |
embeddings,
|
|
|
311 |
def topic_extractor(self, state: MultiAgentState):
|
312 |
return {"sub_topic_list": self.extract_subtopics(state.sub_topics)}
|
313 |
|
|
|
|
|
|
|
|
|
314 |
|
315 |
+
def retrieve_node(self, state: StoryState):
|
316 |
+
embedding = self._get_embedding_model()
|
317 |
+
retriever = Chroma(
|
318 |
+
persist_directory=self.chroma_persist_dir,
|
319 |
+
embedding_function=embedding
|
320 |
+
).as_retriever(search_kwargs={"k": 20})
|
321 |
|
|
|
322 |
topic = state.story_topic
|
323 |
+
query = f"information about {topic}"
|
324 |
+
docs = retriever.get_relevant_documents(query)
|
325 |
+
return {"retrieved_docs": docs, "question": query}
|
326 |
+
|
327 |
+
def rerank_node(self, state: StoryState):
|
328 |
+
topic = state.story_topic
|
329 |
+
query = f"Rerank documents based on how well they explain the topic {topic}"
|
330 |
docs = state.retrieved_docs
|
331 |
texts = [doc.page_content for doc in docs]
|
332 |
|
|
|
333 |
if not texts:
|
334 |
+
return {"reranked_docs": [], "question": query}
|
335 |
+
|
336 |
+
# Quick fallback: rank by length
|
337 |
+
top_docs = sorted(texts, key=lambda t: -len(t))[:5]
|
338 |
+
return {"reranked_docs": top_docs, "question": query}
|
339 |
+
|
340 |
+
|
341 |
|
342 |
+
def generate_story_node(self, state: StoryState):
|
343 |
+
context = "\n\n".join(state.reranked_docs)
|
344 |
+
topic = state.story_topic
|
345 |
+
|
346 |
+
system_message = f"""
|
347 |
+
Suppose you're a brilliant science storyteller.
|
348 |
+
You write stories that help middle schoolers understand complex science topics with fun and clarity.
|
349 |
+
Add subtle humor and make it engaging.
|
350 |
+
"""
|
351 |
+
prompt = f"""
|
352 |
+
Use the following context to write a fun and simple story explaining **{topic}** to a middle schooler:\n
|
353 |
+
Context:\n{context}\n\n
|
354 |
+
Story:
|
355 |
+
"""
|
356 |
+
|
357 |
+
msg = self.llm.invoke([SystemMessage(system_message), HumanMessage(prompt)])
|
358 |
+
return {"stories": msg}
|
359 |
+
|
360 |
+
|
361 |
|
362 |
|
363 |
def run_multiagent_storygraph(self, topic: str, context: str):
|
|
|
365 |
|
366 |
# Define the story subgraph with reranking
|
367 |
story_graph = StateGraph(StoryState)
|
368 |
+
story_graph.add_node("Retrieve", self.retrieve_node)
|
369 |
+
story_graph.add_node("Rerank", self.rerank_node)
|
370 |
+
story_graph.add_node("Generate", self.generate_story_node)
|
371 |
story_graph.set_entry_point("Retrieve")
|
372 |
story_graph.add_edge("Retrieve", "Rerank")
|
373 |
story_graph.add_edge("Rerank", "Generate")
|
|
|
406 |
return result
|
407 |
|
408 |
|
409 |
+
if "rag_system" not in st.session_state or st.session_state.embedding_model != embedding_choice:
|
410 |
+
st.session_state.embedding_model = embedding_choice
|
411 |
+
st.session_state.rag_system = DocumentRAG(embedding_choice=embedding_choice)
|
|
|
|
|
|
|
|
|
412 |
|
413 |
|
414 |
|