DrishtiSharma commited on
Commit
bbdd34f
ยท
verified ยท
1 Parent(s): a2bd7e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -2
app.py CHANGED
@@ -27,7 +27,29 @@ import os
27
  import tempfile
28
  from datetime import datetime
29
  import pytz
30
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  class DocumentRAG:
33
  def __init__(self):
@@ -248,10 +270,82 @@ class DocumentRAG:
248
  except Exception as e:
249
  return history + [("System", f"Error: {str(e)}")]
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  # Initialize RAG system in session state
252
  if "rag_system" not in st.session_state:
253
  st.session_state.rag_system = DocumentRAG()
254
 
 
 
 
255
  # Sidebar
256
  with st.sidebar:
257
  st.title("About")
@@ -337,7 +431,31 @@ if st.session_state.rag_system.qa_chain:
337
  else:
338
  st.info("Please process documents first to enable Q&A.")
339
 
340
- # Step 4: Generate Podcast
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  st.subheader("Step 4: Generate Podcast")
342
  st.write("Select Podcast Language:")
343
  podcast_language_options = ["English", "Hindi", "Spanish", "French", "German", "Chinese", "Japanese"]
@@ -348,6 +466,7 @@ podcast_language = st.radio(
348
  key="podcast_language"
349
  )
350
 
 
351
  if st.session_state.rag_system.document_summary:
352
  if st.button("Generate Podcast"):
353
  with st.spinner("Generating podcast, please wait..."):
 
27
  import tempfile
28
  from datetime import datetime
29
  import pytz
30
+ from langgraph.graph import StateGraph, START, END, Send, add_messages
31
+ from langgraph.checkpoint.memory import MemorySaver
32
+ from langchain_core.messages import HumanMessage, SystemMessage, AnyMessage
33
+ from pydantic import BaseModel
34
+ from typing import List, Annotated, Any
35
+ import re, operator
36
+
37
+
38
+ class MultiAgentState(BaseModel):
39
+ state: List[str] = []
40
+ messages: Annotated[list[AnyMessage], add_messages]
41
+ topic: List[str] = []
42
+ context: List[str] = []
43
+ sub_topic_list: List[str] = []
44
+ sub_topics: Annotated[list[AnyMessage], add_messages]
45
+ stories: Annotated[list[AnyMessage], add_messages]
46
+ stories_lst: Annotated[list, operator.add]
47
+
48
+ class StoryState(BaseModel):
49
+ retrieved_docs: List[Any] = []
50
+ stories: Annotated[list[AnyMessage], add_messages]
51
+ story_topic: str = ""
52
+ stories_lst: Annotated[list, operator.add]
53
 
54
  class DocumentRAG:
55
  def __init__(self):
 
270
  except Exception as e:
271
  return history + [("System", f"Error: {str(e)}")]
272
 
273
+ def extract_subtopics(self, messages):
274
+ text = "\n".join([msg.content for msg in messages])
275
+ return re.findall(r'- \*\*(.*?)\*\*', text)
276
+
277
+ def beginner_topic(self, state: MultiAgentState):
278
+ prompt = f"What are the beginner-level topics you can learn about {', '.join(state.topic)} in {', '.join(state.context)}?"
279
+ msg = self.llm.invoke([SystemMessage("Suppose you're a middle grader..."), HumanMessage(prompt)])
280
+ return {"message": msg, "sub_topics": msg}
281
+
282
+ def middle_topic(self, state: MultiAgentState):
283
+ prompt = f"What are the middle-level topics for {', '.join(state.topic)} in {', '.join(state.context)}? Avoid previous."
284
+ msg = self.llm.invoke([SystemMessage("Suppose you're a college student..."), HumanMessage(prompt)])
285
+ return {"message": msg, "sub_topics": msg}
286
+
287
+ def advanced_topic(self, state: MultiAgentState):
288
+ prompt = f"What are the advanced-level topics for {', '.join(state.topic)} in {', '.join(state.context)}? Avoid previous."
289
+ msg = self.llm.invoke([SystemMessage("Suppose you're a teacher..."), HumanMessage(prompt)])
290
+ return {"message": msg, "sub_topics": msg}
291
+
292
+ def topic_extractor(self, state: MultiAgentState):
293
+ return {"sub_topic_list": self.extract_subtopics(state.sub_topics)}
294
+
295
+ def retrieve_docs(self, state: StoryState):
296
+ retriever = self.document_store.as_retriever(search_kwargs={"k": 20})
297
+ docs = retriever.get_relevant_documents(f"information about {state.story_topic}")
298
+ return {"retrieved_docs": docs}
299
+
300
+ def generate_story(self, state: StoryState):
301
+ context = "\n\n".join([doc.page_content for doc in state.retrieved_docs[:5]])
302
+ prompt = f"""You're a witty science storyteller. Create a short, child-friendly story that explains **{state.story_topic}** based on this:\n\n{context}"""
303
+ msg = self.llm.invoke([SystemMessage("Use humor. Be clear."), HumanMessage(prompt)])
304
+ return {"stories": msg}
305
+
306
+ def run_multiagent_storygraph(self, topic: str, context: str):
307
+ self.llm = ChatOpenAI(model_name="gpt-4", temperature=0.7, api_key=self.api_key)
308
+
309
+ # Story subgraph
310
+ story_graph = StateGraph(StoryState)
311
+ story_graph.add_node("Retrieve", self.retrieve_docs)
312
+ story_graph.add_node("Generate", self.generate_story)
313
+ story_graph.set_entry_point("Retrieve")
314
+ story_graph.add_edge("Retrieve", "Generate")
315
+ story_graph.set_finish_point("Generate")
316
+ story_subgraph = story_graph.compile()
317
+
318
+ # Main graph
319
+ graph = StateGraph(MultiAgentState)
320
+ graph.add_node("beginner_topic", self.beginner_topic)
321
+ graph.add_node("middle_topic", self.middle_topic)
322
+ graph.add_node("advanced_topic", self.advanced_topic)
323
+ graph.add_node("topic_extractor", self.topic_extractor)
324
+ graph.add_node("story_generator", story_subgraph)
325
+
326
+ graph.add_edge(START, "beginner_topic")
327
+ graph.add_edge("beginner_topic", "middle_topic")
328
+ graph.add_edge("middle_topic", "advanced_topic")
329
+ graph.add_edge("advanced_topic", "topic_extractor")
330
+ graph.add_conditional_edges("topic_extractor",
331
+ lambda state: [Send("story_generator", {"story_topic": t}) for t in state.sub_topic_list],
332
+ ["story_generator"])
333
+ graph.add_edge("story_generator", END)
334
+
335
+ compiled = graph.compile(checkpointer=MemorySaver())
336
+ result = compiled.invoke({"topic": [topic], "context": [context]})
337
+ return result
338
+
339
+
340
+
341
+
342
  # Initialize RAG system in session state
343
  if "rag_system" not in st.session_state:
344
  st.session_state.rag_system = DocumentRAG()
345
 
346
+
347
+
348
+
349
  # Sidebar
350
  with st.sidebar:
351
  st.title("About")
 
431
  else:
432
  st.info("Please process documents first to enable Q&A.")
433
 
434
+
435
+ # Step 4: Multi-Agent Story Explorer
436
+ st.subheader("Step 5: Explore Subtopics via Multi-Agent Graph")
437
+ story_topic = st.text_input("Enter main topic:", value="Machine Learning")
438
+ story_context = st.text_input("Enter learning context:", value="Education")
439
+
440
+ if st.button("Run Story Graph"):
441
+ with st.spinner("Generating subtopics and stories..."):
442
+ result = st.session_state.rag_system.run_multiagent_storygraph(topic=story_topic, context=story_context)
443
+
444
+ subtopics = result.get("sub_topic_list", [])
445
+ st.markdown("### ๐Ÿง  Extracted Subtopics")
446
+ for sub in subtopics:
447
+ st.markdown(f"- {sub}")
448
+
449
+ stories = result.get("stories", [])
450
+ if stories:
451
+ st.markdown("### ๐Ÿ“š Generated Stories")
452
+ for i, story in enumerate(stories):
453
+ st.markdown(f"**Story {i+1}:**")
454
+ st.markdown(story.content)
455
+ else:
456
+ st.warning("No stories were generated.")
457
+
458
+ # Step 5: Generate Podcast
459
  st.subheader("Step 4: Generate Podcast")
460
  st.write("Select Podcast Language:")
461
  podcast_language_options = ["English", "Hindi", "Spanish", "French", "German", "Chinese", "Japanese"]
 
466
  key="podcast_language"
467
  )
468
 
469
+
470
  if st.session_state.rag_system.document_summary:
471
  if st.button("Generate Podcast"):
472
  with st.spinner("Generating podcast, please wait..."):