DrishtiSharma commited on
Commit
99973bd
·
verified ·
1 Parent(s): 584bb82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -9
app.py CHANGED
@@ -50,6 +50,7 @@ class MultiAgentState(BaseModel):
50
 
51
  class StoryState(BaseModel):
52
  retrieved_docs: List[Any] = []
 
53
  stories: Annotated[list[AnyMessage], add_messages]
54
  story_topic: str = ""
55
  stories_lst: Annotated[list, operator.add]
@@ -65,7 +66,24 @@ class DocumentRAG:
65
  self.init_time = datetime.now(pytz.UTC)
66
  self.embedding_choice = embedding_choice
67
 
68
- # Persistent directory for Chroma to avoid tenant-related errors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  self.chroma_persist_dir = "./chroma_storage"
70
  os.makedirs(self.chroma_persist_dir, exist_ok=True)
71
 
@@ -324,6 +342,7 @@ class DocumentRAG:
324
  docs = retriever.get_relevant_documents(query)
325
  return {"retrieved_docs": docs, "question": query}
326
 
 
327
  def rerank_node(self, state: StoryState):
328
  topic = state.story_topic
329
  query = f"Rerank documents based on how well they explain the topic {topic}"
@@ -333,12 +352,22 @@ class DocumentRAG:
333
  if not texts:
334
  return {"reranked_docs": [], "question": query}
335
 
336
- # Quick fallback: rank by length
337
- top_docs = sorted(texts, key=lambda t: -len(t))[:5]
 
 
 
 
 
 
 
 
 
338
  return {"reranked_docs": top_docs, "question": query}
339
 
340
 
341
 
 
342
  def generate_story_node(self, state: StoryState):
343
  context = "\n\n".join(state.reranked_docs)
344
  topic = state.story_topic
@@ -361,7 +390,15 @@ class DocumentRAG:
361
 
362
 
363
  def run_multiagent_storygraph(self, topic: str, context: str):
364
- self.llm = ChatOpenAI(model_name="gpt-4", temperature=0.7, api_key=self.api_key)
 
 
 
 
 
 
 
 
365
 
366
  # Define the story subgraph with reranking
367
  story_graph = StateGraph(StoryState)
@@ -374,7 +411,7 @@ class DocumentRAG:
374
  story_graph.set_finish_point("Generate")
375
  story_subgraph = story_graph.compile()
376
 
377
- # Main graph setup
378
  graph = StateGraph(MultiAgentState)
379
  graph.add_node("beginner_topic", self.beginner_topic)
380
  graph.add_node("middle_topic", self.middle_topic)
@@ -386,18 +423,20 @@ class DocumentRAG:
386
  graph.add_edge("beginner_topic", "middle_topic")
387
  graph.add_edge("middle_topic", "advanced_topic")
388
  graph.add_edge("advanced_topic", "topic_extractor")
389
- graph.add_conditional_edges("topic_extractor",
 
390
  lambda state: [Send("story_generator", {"story_topic": t}) for t in state.sub_topic_list],
391
- ["story_generator"])
 
392
  graph.add_edge("story_generator", END)
393
 
394
  compiled = graph.compile(checkpointer=MemorySaver())
395
  thread = {"configurable": {"thread_id": "storygraph-session"}}
396
 
397
- # Initial run to extract subtopics
398
  result = compiled.invoke({"topic": [topic], "context": [context]}, thread)
399
 
400
- # Fallback if no subtopics were extracted
401
  if not result.get("sub_topic_list"):
402
  fallback_subs = ["Neural Networks", "Reinforcement Learning", "Supervised vs Unsupervised"]
403
  compiled.update_state(thread, {"sub_topic_list": fallback_subs})
@@ -406,6 +445,7 @@ class DocumentRAG:
406
  return result
407
 
408
 
 
409
  if "rag_system" not in st.session_state or st.session_state.embedding_model != embedding_choice:
410
  st.session_state.embedding_model = embedding_choice
411
  st.session_state.rag_system = DocumentRAG(embedding_choice=embedding_choice)
 
50
 
51
  class StoryState(BaseModel):
52
  retrieved_docs: List[Any] = []
53
+ reranked_docs: List[str] = []
54
  stories: Annotated[list[AnyMessage], add_messages]
55
  story_topic: str = ""
56
  stories_lst: Annotated[list, operator.add]
 
66
  self.init_time = datetime.now(pytz.UTC)
67
  self.embedding_choice = embedding_choice
68
 
69
+ # Set up appropriate LLM
70
+ if self.embedding_choice == "Cohere":
71
+ from langchain_cohere import ChatCohere
72
+ import cohere
73
+ self.llm = ChatCohere(
74
+ model="command-r-plus-08-2024",
75
+ temperature=0.7,
76
+ cohere_api_key=os.getenv("COHERE_API_KEY")
77
+ )
78
+ self.cohere_client = cohere.Client(os.getenv("COHERE_API_KEY"))
79
+ else:
80
+ self.llm = ChatOpenAI(
81
+ model_name="gpt-4",
82
+ temperature=0.7,
83
+ api_key=self.api_key
84
+ )
85
+
86
+ # Persistent directory for Chroma
87
  self.chroma_persist_dir = "./chroma_storage"
88
  os.makedirs(self.chroma_persist_dir, exist_ok=True)
89
 
 
342
  docs = retriever.get_relevant_documents(query)
343
  return {"retrieved_docs": docs, "question": query}
344
 
345
+
346
  def rerank_node(self, state: StoryState):
347
  topic = state.story_topic
348
  query = f"Rerank documents based on how well they explain the topic {topic}"
 
352
  if not texts:
353
  return {"reranked_docs": [], "question": query}
354
 
355
+ if self.embedding_choice == "Cohere":
356
+ rerank_results = self.cohere_client.rerank(
357
+ query=query,
358
+ documents=texts,
359
+ top_n=5,
360
+ model="rerank-v3.5"
361
+ )
362
+ top_docs = [texts[result.index] for result in rerank_results.results]
363
+ else:
364
+ top_docs = sorted(texts, key=lambda t: -len(t))[:5]
365
+
366
  return {"reranked_docs": top_docs, "question": query}
367
 
368
 
369
 
370
+
371
  def generate_story_node(self, state: StoryState):
372
  context = "\n\n".join(state.reranked_docs)
373
  topic = state.story_topic
 
390
 
391
 
392
  def run_multiagent_storygraph(self, topic: str, context: str):
393
+ if self.embedding_choice == "OpenAI":
394
+ self.llm = ChatOpenAI(model_name="gpt-4", temperature=0.7, api_key=self.api_key)
395
+ elif self.embedding_choice == "Cohere":
396
+ from langchain_cohere import ChatCohere
397
+ self.llm = ChatCohere(
398
+ model="command-r-plus-08-2024",
399
+ temperature=0.7,
400
+ cohere_api_key=os.getenv("COHERE_API_KEY")
401
+ )
402
 
403
  # Define the story subgraph with reranking
404
  story_graph = StateGraph(StoryState)
 
411
  story_graph.set_finish_point("Generate")
412
  story_subgraph = story_graph.compile()
413
 
414
+ # Define the main graph
415
  graph = StateGraph(MultiAgentState)
416
  graph.add_node("beginner_topic", self.beginner_topic)
417
  graph.add_node("middle_topic", self.middle_topic)
 
423
  graph.add_edge("beginner_topic", "middle_topic")
424
  graph.add_edge("middle_topic", "advanced_topic")
425
  graph.add_edge("advanced_topic", "topic_extractor")
426
+ graph.add_conditional_edges(
427
+ "topic_extractor",
428
  lambda state: [Send("story_generator", {"story_topic": t}) for t in state.sub_topic_list],
429
+ ["story_generator"]
430
+ )
431
  graph.add_edge("story_generator", END)
432
 
433
  compiled = graph.compile(checkpointer=MemorySaver())
434
  thread = {"configurable": {"thread_id": "storygraph-session"}}
435
 
436
+ # Initial invocation
437
  result = compiled.invoke({"topic": [topic], "context": [context]}, thread)
438
 
439
+ # Fallback if no subtopics found
440
  if not result.get("sub_topic_list"):
441
  fallback_subs = ["Neural Networks", "Reinforcement Learning", "Supervised vs Unsupervised"]
442
  compiled.update_state(thread, {"sub_topic_list": fallback_subs})
 
445
  return result
446
 
447
 
448
+
449
  if "rag_system" not in st.session_state or st.session_state.embedding_model != embedding_choice:
450
  st.session_state.embedding_model = embedding_choice
451
  st.session_state.rag_system = DocumentRAG(embedding_choice=embedding_choice)