mwalker22 commited on
Commit
bc82499
·
1 Parent(s): bb469cc

Converted the evolve_question property of the state object into a List and created a function to return the last evolved question.

Browse files
app.py CHANGED
@@ -88,8 +88,8 @@ if st.button("Generate Synthetic Data"):
88
  # Display evolved questions
89
  st.markdown("### Evolved Questions")
90
  evolved_questions = [
91
- {"id": f"q{i}", "question": q, "evolution_type": "simple"}
92
- for i, q in enumerate([result.evolved_question]) # Currently only one question
93
  ]
94
  st.json(evolved_questions)
95
 
 
88
  # Display evolved questions
89
  st.markdown("### Evolved Questions")
90
  evolved_questions = [
91
+ {"id": f"q{i}", "question": q, "evolution_type": "simple"}
92
+ for i, q in enumerate(result.evolved_questions)
93
  ]
94
  st.json(evolved_questions)
95
 
graph/nodes/answer.py CHANGED
@@ -17,9 +17,10 @@ def generate_answer(state: SDGState) -> SDGState:
17
  new_state = SDGState(
18
  input=state.input,
19
  documents=state.documents,
20
- evolved_question=state.evolved_question,
21
  context=state.context,
22
- answer=f"Based on the retrieved context:\n{context_snippet}"
 
23
  )
24
 
25
  logger.debug(f"Answer node returning state: {new_state}")
 
17
  new_state = SDGState(
18
  input=state.input,
19
  documents=state.documents,
20
+ evolved_questions=state.evolved_questions,
21
  context=state.context,
22
+ answer=f"Based on the retrieved context:\n{context_snippet}",
23
+ num_evolve_passes=state.num_evolve_passes
24
  )
25
 
26
  logger.debug(f"Answer node returning state: {new_state}")
graph/nodes/evolve.py CHANGED
@@ -10,17 +10,20 @@ def evolve_question(state: SDGState, llm) -> SDGState:
10
  "Rewrite or evolve the following question to be more challenging or insightful:\n\n{}",
11
  "Rewrite or evolve the following question to be more creative or original:\n\n{}"
12
  ]
13
- evolved = state.input
 
14
  for i in range(num_passes):
15
  prompt = prompts[i % len(prompts)].format(evolved)
16
  response = llm.invoke(prompt)
17
  evolved = response.content if hasattr(response, 'content') else str(response)
 
18
  new_state = SDGState(
19
  input=state.input,
20
  documents=state.documents,
21
- evolved_question=evolved,
22
  context=state.context,
23
- answer=state.answer
 
24
  )
25
  logger.debug(f"Evolve node returning state: {new_state}")
26
  return new_state
 
10
  "Rewrite or evolve the following question to be more challenging or insightful:\n\n{}",
11
  "Rewrite or evolve the following question to be more creative or original:\n\n{}"
12
  ]
13
+ evolved_questions = list(state.evolved_questions) if state.evolved_questions else [state.input]
14
+ evolved = evolved_questions[-1]
15
  for i in range(num_passes):
16
  prompt = prompts[i % len(prompts)].format(evolved)
17
  response = llm.invoke(prompt)
18
  evolved = response.content if hasattr(response, 'content') else str(response)
19
+ evolved_questions.append(evolved)
20
  new_state = SDGState(
21
  input=state.input,
22
  documents=state.documents,
23
+ evolved_questions=evolved_questions,
24
  context=state.context,
25
+ answer=state.answer,
26
+ num_evolve_passes=state.num_evolve_passes
27
  )
28
  logger.debug(f"Evolve node returning state: {new_state}")
29
  return new_state
graph/nodes/retrieve.py CHANGED
@@ -14,9 +14,10 @@ def retrieve_relevant_context(state: SDGState, vectorstore) -> SDGState:
14
  new_state = SDGState(
15
  input=state.input,
16
  documents=state.documents,
17
- evolved_question=state.evolved_question,
18
  context=[doc.page_content for doc in retrieved_docs],
19
- answer=state.answer
 
20
  )
21
 
22
  logger.debug(f"Retrieve node returning state: {new_state}")
 
14
  new_state = SDGState(
15
  input=state.input,
16
  documents=state.documents,
17
+ evolved_questions=state.evolved_questions,
18
  context=[doc.page_content for doc in retrieved_docs],
19
+ answer=state.answer,
20
+ num_evolve_passes=state.num_evolve_passes
21
  )
22
 
23
  logger.debug(f"Retrieve node returning state: {new_state}")
graph/types.py CHANGED
@@ -5,7 +5,11 @@ from pydantic import BaseModel, Field
5
  class SDGState(BaseModel):
6
  input: str = Field(default="")
7
  documents: List[Document] = Field(default_factory=list)
8
- evolved_question: str = Field(default="")
9
  context: List[str] = Field(default_factory=list)
10
  answer: str = Field(default="")
11
- num_evolve_passes: int = Field(default=2)
 
 
 
 
 
5
  class SDGState(BaseModel):
6
  input: str = Field(default="")
7
  documents: List[Document] = Field(default_factory=list)
8
+ evolved_questions: List[str] = Field(default_factory=list)
9
  context: List[str] = Field(default_factory=list)
10
  answer: str = Field(default="")
11
+ num_evolve_passes: int = Field(default=2)
12
+
13
+ @property
14
+ def evolved_question(self):
15
+ return self.evolved_questions[-1] if self.evolved_questions else ""
main.py CHANGED
@@ -20,6 +20,7 @@ class DocumentEncoder(json.JSONEncoder):
20
  if isinstance(obj, SDGState):
21
  return {
22
  "input": obj.input,
 
23
  "evolved_question": obj.evolved_question,
24
  "context": obj.context,
25
  "answer": obj.answer
 
20
  if isinstance(obj, SDGState):
21
  return {
22
  "input": obj.input,
23
+ "evolved_questions": obj.evolved_questions,
24
  "evolved_question": obj.evolved_question,
25
  "context": obj.context,
26
  "answer": obj.answer
tests/graph/nodes/test_evolve.py CHANGED
@@ -54,4 +54,40 @@ def test_evolve_question_three_passes():
54
  call("Rewrite or evolve the following question to be more creative or original:\n\nChallenging: What were the top LLMs in 2023?"),
55
  call("Rewrite or evolve the following question to be more challenging or insightful:\n\nCreative: What were the top LLMs in 2023?")
56
  ]
57
- mock_llm.invoke.assert_has_calls(expected_calls)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  call("Rewrite or evolve the following question to be more creative or original:\n\nChallenging: What were the top LLMs in 2023?"),
55
  call("Rewrite or evolve the following question to be more challenging or insightful:\n\nCreative: What were the top LLMs in 2023?")
56
  ]
57
+ mock_llm.invoke.assert_has_calls(expected_calls)
58
+
59
+ def test_evolved_questions_list_populated_correctly():
60
+ state = SDGState(input="Base question", num_evolve_passes=3)
61
+ mock_llm = MagicMock()
62
+ mock_llm.invoke.side_effect = [
63
+ MagicMock(content="Challenging: Base question"),
64
+ MagicMock(content="Creative: Challenging: Base question"),
65
+ MagicMock(content="Challenging Again: Creative: Challenging: Base question")
66
+ ]
67
+ updated_state = evolve_question(state, mock_llm)
68
+ # The evolved_questions list should contain the initial input plus one entry per pass
69
+ assert updated_state.evolved_questions == [
70
+ "Base question",
71
+ "Challenging: Base question",
72
+ "Creative: Challenging: Base question",
73
+ "Challenging Again: Creative: Challenging: Base question"
74
+ ]
75
+ # The property should return the last one
76
+ assert updated_state.evolved_question == "Challenging Again: Creative: Challenging: Base question"
77
+
78
+ def test_evolved_questions_list_with_existing_evolutions():
79
+ # If the state already has evolved_questions, it should continue from the last
80
+ state = SDGState(input="Base question", evolved_questions=["Base question", "First evolution"], num_evolve_passes=2)
81
+ mock_llm = MagicMock()
82
+ mock_llm.invoke.side_effect = [
83
+ MagicMock(content="Second evolution"),
84
+ MagicMock(content="Third evolution")
85
+ ]
86
+ updated_state = evolve_question(state, mock_llm)
87
+ assert updated_state.evolved_questions == [
88
+ "Base question",
89
+ "First evolution",
90
+ "Second evolution",
91
+ "Third evolution"
92
+ ]
93
+ assert updated_state.evolved_question == "Third evolution"
tests/graph/test_build_graph.py CHANGED
@@ -17,6 +17,8 @@ def test_build_sdg_graph_runs():
17
  result = graph.invoke(state)
18
 
19
  assert isinstance(result, dict)
20
- assert "evolved_question" in result
 
 
21
  assert result["context"]
22
  assert "Relevant content" in result["context"][0]
 
17
  result = graph.invoke(state)
18
 
19
  assert isinstance(result, dict)
20
+ assert "evolved_questions" in result
21
+ if result["evolved_questions"]:
22
+ assert result["evolved_questions"][-1] == "Evolved test question"
23
  assert result["context"]
24
  assert "Relevant content" in result["context"][0]