Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -308,16 +308,18 @@ class DocumentRAG:
|
|
308 |
def run_multiagent_storygraph(self, topic: str, context: str):
|
309 |
self.llm = ChatOpenAI(model_name="gpt-4", temperature=0.7, api_key=self.api_key)
|
310 |
|
311 |
-
#
|
312 |
story_graph = StateGraph(StoryState)
|
313 |
story_graph.add_node("Retrieve", self.retrieve_docs)
|
|
|
314 |
story_graph.add_node("Generate", self.generate_story)
|
315 |
story_graph.set_entry_point("Retrieve")
|
316 |
-
story_graph.add_edge("Retrieve", "
|
|
|
317 |
story_graph.set_finish_point("Generate")
|
318 |
story_subgraph = story_graph.compile()
|
319 |
|
320 |
-
# Main graph
|
321 |
graph = StateGraph(MultiAgentState)
|
322 |
graph.add_node("beginner_topic", self.beginner_topic)
|
323 |
graph.add_node("middle_topic", self.middle_topic)
|
@@ -334,17 +336,24 @@ class DocumentRAG:
|
|
334 |
["story_generator"])
|
335 |
graph.add_edge("story_generator", END)
|
336 |
|
337 |
-
|
338 |
compiled = graph.compile(checkpointer=MemorySaver())
|
339 |
thread = {"configurable": {"thread_id": "storygraph-session"}}
|
340 |
|
|
|
341 |
result = compiled.invoke({"topic": [topic], "context": [context]}, thread)
|
342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
return result
|
344 |
|
345 |
|
346 |
|
347 |
|
|
|
348 |
# Initialize RAG system in session state
|
349 |
if "rag_system" not in st.session_state:
|
350 |
st.session_state.rag_system = DocumentRAG()
|
|
|
308 |
def run_multiagent_storygraph(self, topic: str, context: str):
|
309 |
self.llm = ChatOpenAI(model_name="gpt-4", temperature=0.7, api_key=self.api_key)
|
310 |
|
311 |
+
# Define the story subgraph with reranking
|
312 |
story_graph = StateGraph(StoryState)
|
313 |
story_graph.add_node("Retrieve", self.retrieve_docs)
|
314 |
+
story_graph.add_node("Rerank", self.rerank_docs) # Add rerank step
|
315 |
story_graph.add_node("Generate", self.generate_story)
|
316 |
story_graph.set_entry_point("Retrieve")
|
317 |
+
story_graph.add_edge("Retrieve", "Rerank")
|
318 |
+
story_graph.add_edge("Rerank", "Generate")
|
319 |
story_graph.set_finish_point("Generate")
|
320 |
story_subgraph = story_graph.compile()
|
321 |
|
322 |
+
# Main graph setup
|
323 |
graph = StateGraph(MultiAgentState)
|
324 |
graph.add_node("beginner_topic", self.beginner_topic)
|
325 |
graph.add_node("middle_topic", self.middle_topic)
|
|
|
336 |
["story_generator"])
|
337 |
graph.add_edge("story_generator", END)
|
338 |
|
|
|
339 |
compiled = graph.compile(checkpointer=MemorySaver())
|
340 |
thread = {"configurable": {"thread_id": "storygraph-session"}}
|
341 |
|
342 |
+
# Initial run to extract subtopics
|
343 |
result = compiled.invoke({"topic": [topic], "context": [context]}, thread)
|
344 |
|
345 |
+
# Fallback if no subtopics were extracted
|
346 |
+
if not result.get("sub_topic_list"):
|
347 |
+
fallback_subs = ["Neural Networks", "Reinforcement Learning", "Supervised vs Unsupervised"]
|
348 |
+
compiled.update_state(thread, {"sub_topic_list": fallback_subs})
|
349 |
+
result = compiled.invoke(None, thread, stream_mode="values")
|
350 |
+
|
351 |
return result
|
352 |
|
353 |
|
354 |
|
355 |
|
356 |
+
|
357 |
# Initialize RAG system in session state
|
358 |
if "rag_system" not in st.session_state:
|
359 |
st.session_state.rag_system = DocumentRAG()
|