File size: 3,864 Bytes
a8ab4cd 2e5aca2 a8ab4cd 2e5aca2 a8ab4cd 2e5aca2 a8ab4cd 2e5aca2 a8ab4cd 2e5aca2 a8ab4cd 2e5aca2 a8ab4cd 2e5aca2 a8ab4cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from typing import List
from langchain.schema import Document
from langgraph.graph import END, StateGraph
from nodes.generator import rag_chain
from nodes.grader import retrieval_grader
from nodes.retriever import retriever
from nodes.rewriter import question_rewriter
from tools.search_wikipedia import wikipedia
from typing_extensions import TypedDict
# DEFINE STATE GRAPH
class GraphState(TypedDict):
"""
Represents the state of our graph.
Attributes:
question: question
generation: LLM generation
web_search: whether to add search
documents: list of documents
"""
# Update this to work with memory in a better way.
question: str
generation: str
wiki_search: str
documents: List[str]
# DEFINE NODES
def retrieve(state):
print("Retrieving documents...")
question = state["question"]
docs = retriever.invoke(question)
return {"question": question, "documents": docs}
def generate(state):
print("Generating answer...")
question = state["question"]
documents = state["documents"]
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
def grade_documents(state):
print("Grading documents...")
question = state["question"]
documents = state["documents"]
filtered_docs = []
search_wikipedia = False
for doc in documents:
score = retrieval_grader.invoke(
{"question": question, "document": doc.page_content}
)
grade = score.binary_score
if grade == "yes":
print("Document is relevant to the question.")
filtered_docs.append(doc)
else:
print("Document is not relevant to the question.")
search_wikipedia = True
continue
return {
"documents": filtered_docs,
"question": question,
"wiki_search": search_wikipedia,
}
def rewrite_query(state):
print("Rewriting question...")
question = state["question"]
documents = state["documents"]
rewritten_question = question_rewriter.invoke({"question": question})
return {"question": rewritten_question, "documents": documents}
def search_wikipedia(state):
print("Searching Wikipedia...")
question = state["question"]
documents = state["documents"]
wiki_search = wikipedia.invoke(question)
wiki_results = Document(page_content=wiki_search)
documents.append(wiki_results)
return {"question": question, "documents": documents}
# DEFINE CONDITIONAL EDGES
def generate_or_not(state):
print("Determining whether to query Wikipedia...")
wiki_search = state["wiki_search"]
filtered_docs = state["documents"]
if len(filtered_docs) == 0 and wiki_search:
print("Rewriting query and supplementing information from Wikipedia...")
return "rewrite_query"
else:
print("Relevant documents found.")
return "generate"
def create_graph():
# DEFINE WORKFLOW
workflow = StateGraph(GraphState)
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("search_wikipedia", search_wikipedia)
workflow.add_node("generate", generate)
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
generate_or_not,
{"rewrite_query": "rewrite_query", "generate": "generate"},
)
workflow.add_edge("rewrite_query", "search_wikipedia")
workflow.add_edge("search_wikipedia", "generate")
workflow.add_edge("generate", END)
# COMPILE GRAPH
app = workflow.compile()
return app
|