DrishtiSharma commited on
Commit
cf28062
·
verified ·
1 Parent(s): c375485

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -406,7 +406,8 @@ class DocumentRAG:
406
 
407
 
408
 
409
- def run_multiagent_storygraph(self, topic: str, context: str):
 
410
  if self.embedding_choice == "OpenAI":
411
  self.llm = ChatOpenAI(model_name="gpt-4", temperature=0.7, api_key=self.api_key)
412
  elif self.embedding_choice == "Cohere":
@@ -421,7 +422,7 @@ class DocumentRAG:
421
  story_graph = StateGraph(StoryState)
422
  story_graph.add_node("Retrieve", self.retrieve_node)
423
  story_graph.add_node("Rerank", self.rerank_node)
424
- story_graph.add_node("Generate", self.generate_story_node)
425
  story_graph.set_entry_point("Retrieve")
426
  story_graph.add_edge("Retrieve", "Rerank")
427
  story_graph.add_edge("Rerank", "Generate")
@@ -442,7 +443,7 @@ class DocumentRAG:
442
  graph.add_edge("advanced_topic", "topic_extractor")
443
  graph.add_conditional_edges(
444
  "topic_extractor",
445
- lambda state: [Send("story_generator", {"story_topic": t}) for t in state.sub_topic_list],
446
  ["story_generator"]
447
  )
448
  graph.add_edge("story_generator", END)
 
406
 
407
 
408
 
409
+ def run_multiagent_storygraph(self, topic: str, context: str, language: str = "English"):
410
+
411
  if self.embedding_choice == "OpenAI":
412
  self.llm = ChatOpenAI(model_name="gpt-4", temperature=0.7, api_key=self.api_key)
413
  elif self.embedding_choice == "Cohere":
 
422
  story_graph = StateGraph(StoryState)
423
  story_graph.add_node("Retrieve", self.retrieve_node)
424
  story_graph.add_node("Rerank", self.rerank_node)
425
+ story_graph.add_node("Generate", lambda state: self.generate_story_node(state, language=state.get("language", "English")))
426
  story_graph.set_entry_point("Retrieve")
427
  story_graph.add_edge("Retrieve", "Rerank")
428
  story_graph.add_edge("Rerank", "Generate")
 
443
  graph.add_edge("advanced_topic", "topic_extractor")
444
  graph.add_conditional_edges(
445
  "topic_extractor",
446
+ lambda state: [Send("story_generator", {"story_topic": t, "language": language}) for t in state.sub_topic_list],
447
  ["story_generator"]
448
  )
449
  graph.add_edge("story_generator", END)