DrishtiSharma commited on
Commit
a97c223
·
verified ·
1 Parent(s): ae885de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -308,16 +308,18 @@ class DocumentRAG:
308
  def run_multiagent_storygraph(self, topic: str, context: str):
309
  self.llm = ChatOpenAI(model_name="gpt-4", temperature=0.7, api_key=self.api_key)
310
 
311
- # Story subgraph
312
  story_graph = StateGraph(StoryState)
313
  story_graph.add_node("Retrieve", self.retrieve_docs)
 
314
  story_graph.add_node("Generate", self.generate_story)
315
  story_graph.set_entry_point("Retrieve")
316
- story_graph.add_edge("Retrieve", "Generate")
 
317
  story_graph.set_finish_point("Generate")
318
  story_subgraph = story_graph.compile()
319
 
320
- # Main graph
321
  graph = StateGraph(MultiAgentState)
322
  graph.add_node("beginner_topic", self.beginner_topic)
323
  graph.add_node("middle_topic", self.middle_topic)
@@ -334,17 +336,24 @@ class DocumentRAG:
334
  ["story_generator"])
335
  graph.add_edge("story_generator", END)
336
 
337
-
338
  compiled = graph.compile(checkpointer=MemorySaver())
339
  thread = {"configurable": {"thread_id": "storygraph-session"}}
340
 
 
341
  result = compiled.invoke({"topic": [topic], "context": [context]}, thread)
342
 
 
 
 
 
 
 
343
  return result
344
 
345
 
346
 
347
 
 
348
  # Initialize RAG system in session state
349
  if "rag_system" not in st.session_state:
350
  st.session_state.rag_system = DocumentRAG()
 
308
  def run_multiagent_storygraph(self, topic: str, context: str):
309
  self.llm = ChatOpenAI(model_name="gpt-4", temperature=0.7, api_key=self.api_key)
310
 
311
+ # Define the story subgraph with reranking
312
  story_graph = StateGraph(StoryState)
313
  story_graph.add_node("Retrieve", self.retrieve_docs)
314
+ story_graph.add_node("Rerank", self.rerank_docs) # Add rerank step
315
  story_graph.add_node("Generate", self.generate_story)
316
  story_graph.set_entry_point("Retrieve")
317
+ story_graph.add_edge("Retrieve", "Rerank")
318
+ story_graph.add_edge("Rerank", "Generate")
319
  story_graph.set_finish_point("Generate")
320
  story_subgraph = story_graph.compile()
321
 
322
+ # Main graph setup
323
  graph = StateGraph(MultiAgentState)
324
  graph.add_node("beginner_topic", self.beginner_topic)
325
  graph.add_node("middle_topic", self.middle_topic)
 
336
  ["story_generator"])
337
  graph.add_edge("story_generator", END)
338
 
 
339
  compiled = graph.compile(checkpointer=MemorySaver())
340
  thread = {"configurable": {"thread_id": "storygraph-session"}}
341
 
342
+ # Initial run to extract subtopics
343
  result = compiled.invoke({"topic": [topic], "context": [context]}, thread)
344
 
345
+ # Fallback if no subtopics were extracted
346
+ if not result.get("sub_topic_list"):
347
+ fallback_subs = ["Neural Networks", "Reinforcement Learning", "Supervised vs Unsupervised"]
348
+ compiled.update_state(thread, {"sub_topic_list": fallback_subs})
349
+ result = compiled.invoke(None, thread, stream_mode="values")
350
+
351
  return result
352
 
353
 
354
 
355
 
356
+
357
  # Initialize RAG system in session state
358
  if "rag_system" not in st.session_state:
359
  st.session_state.rag_system = DocumentRAG()