Spaces:
Sleeping
Sleeping
Replaced stubbed evolve.py code with an actual process of evolving the question.
Browse files- app.py +5 -1
- graph/build_graph.py +2 -2
- graph/nodes/evolve.py +7 -3
- main.py +3 -1
- tests/graph/nodes/test_evolve.py +5 -2
- tests/graph/test_build_graph.py +3 -1
- tests/test_main.py +1 -1
app.py
CHANGED
|
@@ -5,6 +5,7 @@ from preprocess.html_to_documents import extract_documents_from_html
|
|
| 5 |
from preprocess.embed_documents import create_or_load_vectorstore
|
| 6 |
from graph.build_graph import build_sdg_graph
|
| 7 |
from graph.types import SDGState
|
|
|
|
| 8 |
|
| 9 |
# Configure logging
|
| 10 |
logging.basicConfig(level=logging.DEBUG)
|
|
@@ -37,8 +38,11 @@ def initialize_resources():
|
|
| 37 |
# Create vectorstore
|
| 38 |
vectorstore = create_or_load_vectorstore(docs)
|
| 39 |
|
|
|
|
|
|
|
|
|
|
| 40 |
# Build graph
|
| 41 |
-
graph = build_sdg_graph(docs, vectorstore)
|
| 42 |
|
| 43 |
st.success("Resources initialized successfully!")
|
| 44 |
return docs, vectorstore, graph
|
|
|
|
| 5 |
from preprocess.embed_documents import create_or_load_vectorstore
|
| 6 |
from graph.build_graph import build_sdg_graph
|
| 7 |
from graph.types import SDGState
|
| 8 |
+
from langchain_openai import ChatOpenAI
|
| 9 |
|
| 10 |
# Configure logging
|
| 11 |
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
| 38 |
# Create vectorstore
|
| 39 |
vectorstore = create_or_load_vectorstore(docs)
|
| 40 |
|
| 41 |
+
# Initialize LLM client
|
| 42 |
+
llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=None) # None will use env var
|
| 43 |
+
|
| 44 |
# Build graph
|
| 45 |
+
graph = build_sdg_graph(docs, vectorstore, llm)
|
| 46 |
|
| 47 |
st.success("Resources initialized successfully!")
|
| 48 |
return docs, vectorstore, graph
|
graph/build_graph.py
CHANGED
|
@@ -5,12 +5,12 @@ from graph.nodes.retrieve import retrieve_relevant_context
|
|
| 5 |
from graph.nodes.answer import generate_answer
|
| 6 |
|
| 7 |
|
| 8 |
-
def build_sdg_graph(docs, vectorstore) -> StateGraph:
|
| 9 |
# Create a new graph with our state type
|
| 10 |
builder = StateGraph(SDGState)
|
| 11 |
|
| 12 |
# Add nodes with explicit state handling
|
| 13 |
-
builder.add_node("evolve", evolve_question)
|
| 14 |
builder.add_node("retrieve", lambda state: retrieve_relevant_context(state, vectorstore))
|
| 15 |
builder.add_node("generate_answer", generate_answer)
|
| 16 |
|
|
|
|
| 5 |
from graph.nodes.answer import generate_answer
|
| 6 |
|
| 7 |
|
| 8 |
+
def build_sdg_graph(docs, vectorstore, llm) -> StateGraph:
|
| 9 |
# Create a new graph with our state type
|
| 10 |
builder = StateGraph(SDGState)
|
| 11 |
|
| 12 |
# Add nodes with explicit state handling
|
| 13 |
+
builder.add_node("evolve", lambda state: evolve_question(state, llm))
|
| 14 |
builder.add_node("retrieve", lambda state: retrieve_relevant_context(state, vectorstore))
|
| 15 |
builder.add_node("generate_answer", generate_answer)
|
| 16 |
|
graph/nodes/evolve.py
CHANGED
|
@@ -4,14 +4,18 @@ import logging
|
|
| 4 |
|
| 5 |
logger = logging.getLogger(__name__)
|
| 6 |
|
| 7 |
-
def evolve_question(state: SDGState) -> SDGState:
|
| 8 |
logger.debug(f"Evolve node received state: {state}")
|
| 9 |
|
| 10 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
new_state = SDGState(
|
| 12 |
input=state.input,
|
| 13 |
documents=state.documents,
|
| 14 |
-
evolved_question=
|
| 15 |
context=state.context,
|
| 16 |
answer=state.answer
|
| 17 |
)
|
|
|
|
| 4 |
|
| 5 |
logger = logging.getLogger(__name__)
|
| 6 |
|
| 7 |
+
def evolve_question(state: SDGState, llm) -> SDGState:
|
| 8 |
logger.debug(f"Evolve node received state: {state}")
|
| 9 |
|
| 10 |
+
# Use the LLM to generate an evolved question
|
| 11 |
+
prompt = f"Rewrite or evolve the following question to be more challenging or insightful:\n\n{state.input}"
|
| 12 |
+
response = llm.invoke(prompt)
|
| 13 |
+
evolved_question = response.content if hasattr(response, 'content') else str(response)
|
| 14 |
+
|
| 15 |
new_state = SDGState(
|
| 16 |
input=state.input,
|
| 17 |
documents=state.documents,
|
| 18 |
+
evolved_question=evolved_question,
|
| 19 |
context=state.context,
|
| 20 |
answer=state.answer
|
| 21 |
)
|
main.py
CHANGED
|
@@ -7,6 +7,7 @@ from pathlib import Path
|
|
| 7 |
from graph.types import SDGState
|
| 8 |
from preprocess.embed_documents import create_or_load_vectorstore
|
| 9 |
from graph.build_graph import build_sdg_graph
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class DocumentEncoder(json.JSONEncoder):
|
|
@@ -70,7 +71,8 @@ def main():
|
|
| 70 |
|
| 71 |
vectorstore = create_or_load_vectorstore(docs)
|
| 72 |
|
| 73 |
-
|
|
|
|
| 74 |
initial_state = SDGState(input="How did LLMs evolve in 2023?")
|
| 75 |
|
| 76 |
result = graph.invoke(initial_state)
|
|
|
|
| 7 |
from graph.types import SDGState
|
| 8 |
from preprocess.embed_documents import create_or_load_vectorstore
|
| 9 |
from graph.build_graph import build_sdg_graph
|
| 10 |
+
from langchain_openai import ChatOpenAI
|
| 11 |
|
| 12 |
|
| 13 |
class DocumentEncoder(json.JSONEncoder):
|
|
|
|
| 71 |
|
| 72 |
vectorstore = create_or_load_vectorstore(docs)
|
| 73 |
|
| 74 |
+
llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=None) # None will use env var
|
| 75 |
+
graph = build_sdg_graph(docs, vectorstore, llm)
|
| 76 |
initial_state = SDGState(input="How did LLMs evolve in 2023?")
|
| 77 |
|
| 78 |
result = graph.invoke(initial_state)
|
tests/graph/nodes/test_evolve.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
| 1 |
from graph.types import SDGState
|
| 2 |
from graph.nodes.evolve import evolve_question
|
|
|
|
| 3 |
|
| 4 |
def test_evolve_question_modifies_state():
|
| 5 |
state = SDGState(input="What were the top LLMs in 2023?")
|
| 6 |
-
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
assert updated_state.evolved_question.startswith("Evolved
|
| 9 |
assert updated_state.evolved_question.endswith("2023?")
|
|
|
|
| 1 |
from graph.types import SDGState
|
| 2 |
from graph.nodes.evolve import evolve_question
|
| 3 |
+
from unittest.mock import MagicMock
|
| 4 |
|
| 5 |
def test_evolve_question_modifies_state():
|
| 6 |
state = SDGState(input="What were the top LLMs in 2023?")
|
| 7 |
+
mock_llm = MagicMock()
|
| 8 |
+
mock_llm.invoke.return_value = MagicMock(content="Evolved: What were the top LLMs in 2023?")
|
| 9 |
+
updated_state = evolve_question(state, mock_llm)
|
| 10 |
|
| 11 |
+
assert updated_state.evolved_question.startswith("Evolved:")
|
| 12 |
assert updated_state.evolved_question.endswith("2023?")
|
tests/graph/test_build_graph.py
CHANGED
|
@@ -9,8 +9,10 @@ def test_build_sdg_graph_runs():
|
|
| 9 |
mock_vectorstore.similarity_search.return_value = [
|
| 10 |
Document(page_content="Relevant content", metadata={})
|
| 11 |
]
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
graph = build_sdg_graph(docs, mock_vectorstore)
|
| 14 |
state = SDGState(input="Test input", documents=docs)
|
| 15 |
result = graph.invoke(state)
|
| 16 |
|
|
|
|
| 9 |
mock_vectorstore.similarity_search.return_value = [
|
| 10 |
Document(page_content="Relevant content", metadata={})
|
| 11 |
]
|
| 12 |
+
mock_llm = MagicMock()
|
| 13 |
+
mock_llm.invoke.return_value = MagicMock(content="Evolved test question")
|
| 14 |
|
| 15 |
+
graph = build_sdg_graph(docs, mock_vectorstore, mock_llm)
|
| 16 |
state = SDGState(input="Test input", documents=docs)
|
| 17 |
result = graph.invoke(state)
|
| 18 |
|
tests/test_main.py
CHANGED
|
@@ -113,5 +113,5 @@ def test_main_runs_dev_mode(mock_dev, mock_docs, mock_vectorstore, mock_graph):
|
|
| 113 |
|
| 114 |
mock_docs.assert_called_once()
|
| 115 |
mock_vectorstore.assert_called_once()
|
| 116 |
-
mock_graph.
|
| 117 |
mock_graph.return_value.invoke.assert_called_once()
|
|
|
|
| 113 |
|
| 114 |
mock_docs.assert_called_once()
|
| 115 |
mock_vectorstore.assert_called_once()
|
| 116 |
+
mock_graph.assert_called_once()
|
| 117 |
mock_graph.return_value.invoke.assert_called_once()
|