Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -406,7 +406,8 @@ class DocumentRAG:
|
|
406 |
|
407 |
|
408 |
|
409 |
-
def run_multiagent_storygraph(self, topic: str, context: str):
|
|
|
410 |
if self.embedding_choice == "OpenAI":
|
411 |
self.llm = ChatOpenAI(model_name="gpt-4", temperature=0.7, api_key=self.api_key)
|
412 |
elif self.embedding_choice == "Cohere":
|
@@ -421,7 +422,7 @@ class DocumentRAG:
|
|
421 |
story_graph = StateGraph(StoryState)
|
422 |
story_graph.add_node("Retrieve", self.retrieve_node)
|
423 |
story_graph.add_node("Rerank", self.rerank_node)
|
424 |
-
story_graph.add_node("Generate", self.generate_story_node)
|
425 |
story_graph.set_entry_point("Retrieve")
|
426 |
story_graph.add_edge("Retrieve", "Rerank")
|
427 |
story_graph.add_edge("Rerank", "Generate")
|
@@ -442,7 +443,7 @@ class DocumentRAG:
|
|
442 |
graph.add_edge("advanced_topic", "topic_extractor")
|
443 |
graph.add_conditional_edges(
|
444 |
"topic_extractor",
|
445 |
-
lambda state: [Send("story_generator", {"story_topic": t}) for t in state.sub_topic_list],
|
446 |
["story_generator"]
|
447 |
)
|
448 |
graph.add_edge("story_generator", END)
|
|
|
406 |
|
407 |
|
408 |
|
409 |
+
def run_multiagent_storygraph(self, topic: str, context: str, language: str = "English"):
|
410 |
+
|
411 |
if self.embedding_choice == "OpenAI":
|
412 |
self.llm = ChatOpenAI(model_name="gpt-4", temperature=0.7, api_key=self.api_key)
|
413 |
elif self.embedding_choice == "Cohere":
|
|
|
422 |
story_graph = StateGraph(StoryState)
|
423 |
story_graph.add_node("Retrieve", self.retrieve_node)
|
424 |
story_graph.add_node("Rerank", self.rerank_node)
|
425 |
+
story_graph.add_node("Generate", lambda state: self.generate_story_node(state, language=state.get("language", "English")))
|
426 |
story_graph.set_entry_point("Retrieve")
|
427 |
story_graph.add_edge("Retrieve", "Rerank")
|
428 |
story_graph.add_edge("Rerank", "Generate")
|
|
|
443 |
graph.add_edge("advanced_topic", "topic_extractor")
|
444 |
graph.add_conditional_edges(
|
445 |
"topic_extractor",
|
446 |
+
lambda state: [Send("story_generator", {"story_topic": t, "language": language}) for t in state.sub_topic_list],
|
447 |
["story_generator"]
|
448 |
)
|
449 |
graph.add_edge("story_generator", END)
|