Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -27,7 +27,29 @@ import os
|
|
27 |
import tempfile
|
28 |
from datetime import datetime
|
29 |
import pytz
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
class DocumentRAG:
|
33 |
def __init__(self):
|
@@ -248,10 +270,82 @@ class DocumentRAG:
|
|
248 |
except Exception as e:
|
249 |
return history + [("System", f"Error: {str(e)}")]
|
250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
# Initialize RAG system in session state
|
252 |
if "rag_system" not in st.session_state:
|
253 |
st.session_state.rag_system = DocumentRAG()
|
254 |
|
|
|
|
|
|
|
255 |
# Sidebar
|
256 |
with st.sidebar:
|
257 |
st.title("About")
|
@@ -337,7 +431,31 @@ if st.session_state.rag_system.qa_chain:
|
|
337 |
else:
|
338 |
st.info("Please process documents first to enable Q&A.")
|
339 |
|
340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
st.subheader("Step 4: Generate Podcast")
|
342 |
st.write("Select Podcast Language:")
|
343 |
podcast_language_options = ["English", "Hindi", "Spanish", "French", "German", "Chinese", "Japanese"]
|
@@ -348,6 +466,7 @@ podcast_language = st.radio(
|
|
348 |
key="podcast_language"
|
349 |
)
|
350 |
|
|
|
351 |
if st.session_state.rag_system.document_summary:
|
352 |
if st.button("Generate Podcast"):
|
353 |
with st.spinner("Generating podcast, please wait..."):
|
|
|
27 |
import tempfile
|
28 |
from datetime import datetime
|
29 |
import pytz
|
30 |
+
from langgraph.graph import StateGraph, START, END, Send, add_messages
|
31 |
+
from langgraph.checkpoint.memory import MemorySaver
|
32 |
+
from langchain_core.messages import HumanMessage, SystemMessage, AnyMessage
|
33 |
+
from pydantic import BaseModel
|
34 |
+
from typing import List, Annotated, Any
|
35 |
+
import re, operator
|
36 |
+
|
37 |
+
|
38 |
+
class MultiAgentState(BaseModel):
|
39 |
+
state: List[str] = []
|
40 |
+
messages: Annotated[list[AnyMessage], add_messages]
|
41 |
+
topic: List[str] = []
|
42 |
+
context: List[str] = []
|
43 |
+
sub_topic_list: List[str] = []
|
44 |
+
sub_topics: Annotated[list[AnyMessage], add_messages]
|
45 |
+
stories: Annotated[list[AnyMessage], add_messages]
|
46 |
+
stories_lst: Annotated[list, operator.add]
|
47 |
+
|
48 |
+
class StoryState(BaseModel):
|
49 |
+
retrieved_docs: List[Any] = []
|
50 |
+
stories: Annotated[list[AnyMessage], add_messages]
|
51 |
+
story_topic: str = ""
|
52 |
+
stories_lst: Annotated[list, operator.add]
|
53 |
|
54 |
class DocumentRAG:
|
55 |
def __init__(self):
|
|
|
270 |
except Exception as e:
|
271 |
return history + [("System", f"Error: {str(e)}")]
|
272 |
|
273 |
+
def extract_subtopics(self, messages):
|
274 |
+
text = "\n".join([msg.content for msg in messages])
|
275 |
+
return re.findall(r'- \*\*(.*?)\*\*', text)
|
276 |
+
|
277 |
+
def beginner_topic(self, state: MultiAgentState):
|
278 |
+
prompt = f"What are the beginner-level topics you can learn about {', '.join(state.topic)} in {', '.join(state.context)}?"
|
279 |
+
msg = self.llm.invoke([SystemMessage("Suppose you're a middle grader..."), HumanMessage(prompt)])
|
280 |
+
return {"message": msg, "sub_topics": msg}
|
281 |
+
|
282 |
+
def middle_topic(self, state: MultiAgentState):
|
283 |
+
prompt = f"What are the middle-level topics for {', '.join(state.topic)} in {', '.join(state.context)}? Avoid previous."
|
284 |
+
msg = self.llm.invoke([SystemMessage("Suppose you're a college student..."), HumanMessage(prompt)])
|
285 |
+
return {"message": msg, "sub_topics": msg}
|
286 |
+
|
287 |
+
def advanced_topic(self, state: MultiAgentState):
|
288 |
+
prompt = f"What are the advanced-level topics for {', '.join(state.topic)} in {', '.join(state.context)}? Avoid previous."
|
289 |
+
msg = self.llm.invoke([SystemMessage("Suppose you're a teacher..."), HumanMessage(prompt)])
|
290 |
+
return {"message": msg, "sub_topics": msg}
|
291 |
+
|
292 |
+
def topic_extractor(self, state: MultiAgentState):
|
293 |
+
return {"sub_topic_list": self.extract_subtopics(state.sub_topics)}
|
294 |
+
|
295 |
+
def retrieve_docs(self, state: StoryState):
|
296 |
+
retriever = self.document_store.as_retriever(search_kwargs={"k": 20})
|
297 |
+
docs = retriever.get_relevant_documents(f"information about {state.story_topic}")
|
298 |
+
return {"retrieved_docs": docs}
|
299 |
+
|
300 |
+
def generate_story(self, state: StoryState):
|
301 |
+
context = "\n\n".join([doc.page_content for doc in state.retrieved_docs[:5]])
|
302 |
+
prompt = f"""You're a witty science storyteller. Create a short, child-friendly story that explains **{state.story_topic}** based on this:\n\n{context}"""
|
303 |
+
msg = self.llm.invoke([SystemMessage("Use humor. Be clear."), HumanMessage(prompt)])
|
304 |
+
return {"stories": msg}
|
305 |
+
|
306 |
+
def run_multiagent_storygraph(self, topic: str, context: str):
|
307 |
+
self.llm = ChatOpenAI(model_name="gpt-4", temperature=0.7, api_key=self.api_key)
|
308 |
+
|
309 |
+
# Story subgraph
|
310 |
+
story_graph = StateGraph(StoryState)
|
311 |
+
story_graph.add_node("Retrieve", self.retrieve_docs)
|
312 |
+
story_graph.add_node("Generate", self.generate_story)
|
313 |
+
story_graph.set_entry_point("Retrieve")
|
314 |
+
story_graph.add_edge("Retrieve", "Generate")
|
315 |
+
story_graph.set_finish_point("Generate")
|
316 |
+
story_subgraph = story_graph.compile()
|
317 |
+
|
318 |
+
# Main graph
|
319 |
+
graph = StateGraph(MultiAgentState)
|
320 |
+
graph.add_node("beginner_topic", self.beginner_topic)
|
321 |
+
graph.add_node("middle_topic", self.middle_topic)
|
322 |
+
graph.add_node("advanced_topic", self.advanced_topic)
|
323 |
+
graph.add_node("topic_extractor", self.topic_extractor)
|
324 |
+
graph.add_node("story_generator", story_subgraph)
|
325 |
+
|
326 |
+
graph.add_edge(START, "beginner_topic")
|
327 |
+
graph.add_edge("beginner_topic", "middle_topic")
|
328 |
+
graph.add_edge("middle_topic", "advanced_topic")
|
329 |
+
graph.add_edge("advanced_topic", "topic_extractor")
|
330 |
+
graph.add_conditional_edges("topic_extractor",
|
331 |
+
lambda state: [Send("story_generator", {"story_topic": t}) for t in state.sub_topic_list],
|
332 |
+
["story_generator"])
|
333 |
+
graph.add_edge("story_generator", END)
|
334 |
+
|
335 |
+
compiled = graph.compile(checkpointer=MemorySaver())
|
336 |
+
result = compiled.invoke({"topic": [topic], "context": [context]})
|
337 |
+
return result
|
338 |
+
|
339 |
+
|
340 |
+
|
341 |
+
|
342 |
# Initialize RAG system in session state
|
343 |
if "rag_system" not in st.session_state:
|
344 |
st.session_state.rag_system = DocumentRAG()
|
345 |
|
346 |
+
|
347 |
+
|
348 |
+
|
349 |
# Sidebar
|
350 |
with st.sidebar:
|
351 |
st.title("About")
|
|
|
431 |
else:
|
432 |
st.info("Please process documents first to enable Q&A.")
|
433 |
|
434 |
+
|
435 |
+
# Step 4: Multi-Agent Story Explorer
|
436 |
+
st.subheader("Step 5: Explore Subtopics via Multi-Agent Graph")
|
437 |
+
story_topic = st.text_input("Enter main topic:", value="Machine Learning")
|
438 |
+
story_context = st.text_input("Enter learning context:", value="Education")
|
439 |
+
|
440 |
+
if st.button("Run Story Graph"):
|
441 |
+
with st.spinner("Generating subtopics and stories..."):
|
442 |
+
result = st.session_state.rag_system.run_multiagent_storygraph(topic=story_topic, context=story_context)
|
443 |
+
|
444 |
+
subtopics = result.get("sub_topic_list", [])
|
445 |
+
st.markdown("### ๐ง Extracted Subtopics")
|
446 |
+
for sub in subtopics:
|
447 |
+
st.markdown(f"- {sub}")
|
448 |
+
|
449 |
+
stories = result.get("stories", [])
|
450 |
+
if stories:
|
451 |
+
st.markdown("### ๐ Generated Stories")
|
452 |
+
for i, story in enumerate(stories):
|
453 |
+
st.markdown(f"**Story {i+1}:**")
|
454 |
+
st.markdown(story.content)
|
455 |
+
else:
|
456 |
+
st.warning("No stories were generated.")
|
457 |
+
|
458 |
+
# Step 5: Generate Podcast
|
459 |
st.subheader("Step 4: Generate Podcast")
|
460 |
st.write("Select Podcast Language:")
|
461 |
podcast_language_options = ["English", "Hindi", "Spanish", "French", "German", "Chinese", "Japanese"]
|
|
|
466 |
key="podcast_language"
|
467 |
)
|
468 |
|
469 |
+
|
470 |
if st.session_state.rag_system.document_summary:
|
471 |
if st.button("Generate Podcast"):
|
472 |
with st.spinner("Generating podcast, please wait..."):
|