mwalker22 commited on
Commit
8055f0d
·
1 Parent(s): c01c987

Replaced stubbed evolve.py code with an actual process of evolving the question.

Browse files
app.py CHANGED
@@ -5,6 +5,7 @@ from preprocess.html_to_documents import extract_documents_from_html
5
  from preprocess.embed_documents import create_or_load_vectorstore
6
  from graph.build_graph import build_sdg_graph
7
  from graph.types import SDGState
 
8
 
9
  # Configure logging
10
  logging.basicConfig(level=logging.DEBUG)
@@ -37,8 +38,11 @@ def initialize_resources():
37
  # Create vectorstore
38
  vectorstore = create_or_load_vectorstore(docs)
39
 
 
 
 
40
  # Build graph
41
- graph = build_sdg_graph(docs, vectorstore)
42
 
43
  st.success("Resources initialized successfully!")
44
  return docs, vectorstore, graph
 
5
  from preprocess.embed_documents import create_or_load_vectorstore
6
  from graph.build_graph import build_sdg_graph
7
  from graph.types import SDGState
8
+ from langchain_openai import ChatOpenAI
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.DEBUG)
 
38
  # Create vectorstore
39
  vectorstore = create_or_load_vectorstore(docs)
40
 
41
+ # Initialize LLM client
42
+ llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=None) # None will use env var
43
+
44
  # Build graph
45
+ graph = build_sdg_graph(docs, vectorstore, llm)
46
 
47
  st.success("Resources initialized successfully!")
48
  return docs, vectorstore, graph
graph/build_graph.py CHANGED
@@ -5,12 +5,12 @@ from graph.nodes.retrieve import retrieve_relevant_context
5
  from graph.nodes.answer import generate_answer
6
 
7
 
8
- def build_sdg_graph(docs, vectorstore) -> StateGraph:
9
  # Create a new graph with our state type
10
  builder = StateGraph(SDGState)
11
 
12
  # Add nodes with explicit state handling
13
- builder.add_node("evolve", evolve_question)
14
  builder.add_node("retrieve", lambda state: retrieve_relevant_context(state, vectorstore))
15
  builder.add_node("generate_answer", generate_answer)
16
 
 
5
  from graph.nodes.answer import generate_answer
6
 
7
 
8
+ def build_sdg_graph(docs, vectorstore, llm) -> StateGraph:
9
  # Create a new graph with our state type
10
  builder = StateGraph(SDGState)
11
 
12
  # Add nodes with explicit state handling
13
+ builder.add_node("evolve", lambda state: evolve_question(state, llm))
14
  builder.add_node("retrieve", lambda state: retrieve_relevant_context(state, vectorstore))
15
  builder.add_node("generate_answer", generate_answer)
16
 
graph/nodes/evolve.py CHANGED
@@ -4,14 +4,18 @@ import logging
4
 
5
  logger = logging.getLogger(__name__)
6
 
7
- def evolve_question(state: SDGState) -> SDGState:
8
  logger.debug(f"Evolve node received state: {state}")
9
 
10
- # Create a new state with the evolved question
 
 
 
 
11
  new_state = SDGState(
12
  input=state.input,
13
  documents=state.documents,
14
- evolved_question=f"Evolved version of: {state.input}",
15
  context=state.context,
16
  answer=state.answer
17
  )
 
4
 
5
  logger = logging.getLogger(__name__)
6
 
7
+ def evolve_question(state: SDGState, llm) -> SDGState:
8
  logger.debug(f"Evolve node received state: {state}")
9
 
10
+ # Use the LLM to generate an evolved question
11
+ prompt = f"Rewrite or evolve the following question to be more challenging or insightful:\n\n{state.input}"
12
+ response = llm.invoke(prompt)
13
+ evolved_question = response.content if hasattr(response, 'content') else str(response)
14
+
15
  new_state = SDGState(
16
  input=state.input,
17
  documents=state.documents,
18
+ evolved_question=evolved_question,
19
  context=state.context,
20
  answer=state.answer
21
  )
main.py CHANGED
@@ -7,6 +7,7 @@ from pathlib import Path
7
  from graph.types import SDGState
8
  from preprocess.embed_documents import create_or_load_vectorstore
9
  from graph.build_graph import build_sdg_graph
 
10
 
11
 
12
  class DocumentEncoder(json.JSONEncoder):
@@ -70,7 +71,8 @@ def main():
70
 
71
  vectorstore = create_or_load_vectorstore(docs)
72
 
73
- graph = build_sdg_graph(docs, vectorstore)
 
74
  initial_state = SDGState(input="How did LLMs evolve in 2023?")
75
 
76
  result = graph.invoke(initial_state)
 
7
  from graph.types import SDGState
8
  from preprocess.embed_documents import create_or_load_vectorstore
9
  from graph.build_graph import build_sdg_graph
10
+ from langchain_openai import ChatOpenAI
11
 
12
 
13
  class DocumentEncoder(json.JSONEncoder):
 
71
 
72
  vectorstore = create_or_load_vectorstore(docs)
73
 
74
+ llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=None) # None will use env var
75
+ graph = build_sdg_graph(docs, vectorstore, llm)
76
  initial_state = SDGState(input="How did LLMs evolve in 2023?")
77
 
78
  result = graph.invoke(initial_state)
tests/graph/nodes/test_evolve.py CHANGED
@@ -1,9 +1,12 @@
1
  from graph.types import SDGState
2
  from graph.nodes.evolve import evolve_question
 
3
 
4
  def test_evolve_question_modifies_state():
5
  state = SDGState(input="What were the top LLMs in 2023?")
6
- updated_state = evolve_question(state)
 
 
7
 
8
- assert updated_state.evolved_question.startswith("Evolved version of: ")
9
  assert updated_state.evolved_question.endswith("2023?")
 
1
  from graph.types import SDGState
2
  from graph.nodes.evolve import evolve_question
3
+ from unittest.mock import MagicMock
4
 
5
  def test_evolve_question_modifies_state():
6
  state = SDGState(input="What were the top LLMs in 2023?")
7
+ mock_llm = MagicMock()
8
+ mock_llm.invoke.return_value = MagicMock(content="Evolved: What were the top LLMs in 2023?")
9
+ updated_state = evolve_question(state, mock_llm)
10
 
11
+ assert updated_state.evolved_question.startswith("Evolved:")
12
  assert updated_state.evolved_question.endswith("2023?")
tests/graph/test_build_graph.py CHANGED
@@ -9,8 +9,10 @@ def test_build_sdg_graph_runs():
9
  mock_vectorstore.similarity_search.return_value = [
10
  Document(page_content="Relevant content", metadata={})
11
  ]
 
 
12
 
13
- graph = build_sdg_graph(docs, mock_vectorstore)
14
  state = SDGState(input="Test input", documents=docs)
15
  result = graph.invoke(state)
16
 
 
9
  mock_vectorstore.similarity_search.return_value = [
10
  Document(page_content="Relevant content", metadata={})
11
  ]
12
+ mock_llm = MagicMock()
13
+ mock_llm.invoke.return_value = MagicMock(content="Evolved test question")
14
 
15
+ graph = build_sdg_graph(docs, mock_vectorstore, mock_llm)
16
  state = SDGState(input="Test input", documents=docs)
17
  result = graph.invoke(state)
18
 
tests/test_main.py CHANGED
@@ -113,5 +113,5 @@ def test_main_runs_dev_mode(mock_dev, mock_docs, mock_vectorstore, mock_graph):
113
 
114
  mock_docs.assert_called_once()
115
  mock_vectorstore.assert_called_once()
116
- mock_graph.assert_called_once_with(mock_docs.return_value, mock_vectorstore.return_value)
117
  mock_graph.return_value.invoke.assert_called_once()
 
113
 
114
  mock_docs.assert_called_once()
115
  mock_vectorstore.assert_called_once()
116
+ mock_graph.assert_called_once()
117
  mock_graph.return_value.invoke.assert_called_once()