Spaces:
Sleeping
Sleeping
Converted the evolve_question property of the state object into a List and created a function to return the last evolved question.
Browse files- app.py +2 -2
- graph/nodes/answer.py +3 -2
- graph/nodes/evolve.py +6 -3
- graph/nodes/retrieve.py +3 -2
- graph/types.py +6 -2
- main.py +1 -0
- tests/graph/nodes/test_evolve.py +37 -1
- tests/graph/test_build_graph.py +3 -1
app.py
CHANGED
|
@@ -88,8 +88,8 @@ if st.button("Generate Synthetic Data"):
|
|
| 88 |
# Display evolved questions
|
| 89 |
st.markdown("### Evolved Questions")
|
| 90 |
evolved_questions = [
|
| 91 |
-
{"id": f"q{i}", "question": q, "evolution_type": "simple"}
|
| 92 |
-
for i, q in enumerate(
|
| 93 |
]
|
| 94 |
st.json(evolved_questions)
|
| 95 |
|
|
|
|
| 88 |
# Display evolved questions
|
| 89 |
st.markdown("### Evolved Questions")
|
| 90 |
evolved_questions = [
|
| 91 |
+
{"id": f"q{i}", "question": q, "evolution_type": "simple"}
|
| 92 |
+
for i, q in enumerate(result.evolved_questions)
|
| 93 |
]
|
| 94 |
st.json(evolved_questions)
|
| 95 |
|
graph/nodes/answer.py
CHANGED
|
@@ -17,9 +17,10 @@ def generate_answer(state: SDGState) -> SDGState:
|
|
| 17 |
new_state = SDGState(
|
| 18 |
input=state.input,
|
| 19 |
documents=state.documents,
|
| 20 |
-
|
| 21 |
context=state.context,
|
| 22 |
-
answer=f"Based on the retrieved context:\n{context_snippet}"
|
|
|
|
| 23 |
)
|
| 24 |
|
| 25 |
logger.debug(f"Answer node returning state: {new_state}")
|
|
|
|
| 17 |
new_state = SDGState(
|
| 18 |
input=state.input,
|
| 19 |
documents=state.documents,
|
| 20 |
+
evolved_questions=state.evolved_questions,
|
| 21 |
context=state.context,
|
| 22 |
+
answer=f"Based on the retrieved context:\n{context_snippet}",
|
| 23 |
+
num_evolve_passes=state.num_evolve_passes
|
| 24 |
)
|
| 25 |
|
| 26 |
logger.debug(f"Answer node returning state: {new_state}")
|
graph/nodes/evolve.py
CHANGED
|
@@ -10,17 +10,20 @@ def evolve_question(state: SDGState, llm) -> SDGState:
|
|
| 10 |
"Rewrite or evolve the following question to be more challenging or insightful:\n\n{}",
|
| 11 |
"Rewrite or evolve the following question to be more creative or original:\n\n{}"
|
| 12 |
]
|
| 13 |
-
|
|
|
|
| 14 |
for i in range(num_passes):
|
| 15 |
prompt = prompts[i % len(prompts)].format(evolved)
|
| 16 |
response = llm.invoke(prompt)
|
| 17 |
evolved = response.content if hasattr(response, 'content') else str(response)
|
|
|
|
| 18 |
new_state = SDGState(
|
| 19 |
input=state.input,
|
| 20 |
documents=state.documents,
|
| 21 |
-
|
| 22 |
context=state.context,
|
| 23 |
-
answer=state.answer
|
|
|
|
| 24 |
)
|
| 25 |
logger.debug(f"Evolve node returning state: {new_state}")
|
| 26 |
return new_state
|
|
|
|
| 10 |
"Rewrite or evolve the following question to be more challenging or insightful:\n\n{}",
|
| 11 |
"Rewrite or evolve the following question to be more creative or original:\n\n{}"
|
| 12 |
]
|
| 13 |
+
evolved_questions = list(state.evolved_questions) if state.evolved_questions else [state.input]
|
| 14 |
+
evolved = evolved_questions[-1]
|
| 15 |
for i in range(num_passes):
|
| 16 |
prompt = prompts[i % len(prompts)].format(evolved)
|
| 17 |
response = llm.invoke(prompt)
|
| 18 |
evolved = response.content if hasattr(response, 'content') else str(response)
|
| 19 |
+
evolved_questions.append(evolved)
|
| 20 |
new_state = SDGState(
|
| 21 |
input=state.input,
|
| 22 |
documents=state.documents,
|
| 23 |
+
evolved_questions=evolved_questions,
|
| 24 |
context=state.context,
|
| 25 |
+
answer=state.answer,
|
| 26 |
+
num_evolve_passes=state.num_evolve_passes
|
| 27 |
)
|
| 28 |
logger.debug(f"Evolve node returning state: {new_state}")
|
| 29 |
return new_state
|
graph/nodes/retrieve.py
CHANGED
|
@@ -14,9 +14,10 @@ def retrieve_relevant_context(state: SDGState, vectorstore) -> SDGState:
|
|
| 14 |
new_state = SDGState(
|
| 15 |
input=state.input,
|
| 16 |
documents=state.documents,
|
| 17 |
-
|
| 18 |
context=[doc.page_content for doc in retrieved_docs],
|
| 19 |
-
answer=state.answer
|
|
|
|
| 20 |
)
|
| 21 |
|
| 22 |
logger.debug(f"Retrieve node returning state: {new_state}")
|
|
|
|
| 14 |
new_state = SDGState(
|
| 15 |
input=state.input,
|
| 16 |
documents=state.documents,
|
| 17 |
+
evolved_questions=state.evolved_questions,
|
| 18 |
context=[doc.page_content for doc in retrieved_docs],
|
| 19 |
+
answer=state.answer,
|
| 20 |
+
num_evolve_passes=state.num_evolve_passes
|
| 21 |
)
|
| 22 |
|
| 23 |
logger.debug(f"Retrieve node returning state: {new_state}")
|
graph/types.py
CHANGED
|
@@ -5,7 +5,11 @@ from pydantic import BaseModel, Field
|
|
| 5 |
class SDGState(BaseModel):
|
| 6 |
input: str = Field(default="")
|
| 7 |
documents: List[Document] = Field(default_factory=list)
|
| 8 |
-
|
| 9 |
context: List[str] = Field(default_factory=list)
|
| 10 |
answer: str = Field(default="")
|
| 11 |
-
num_evolve_passes: int = Field(default=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
class SDGState(BaseModel):
|
| 6 |
input: str = Field(default="")
|
| 7 |
documents: List[Document] = Field(default_factory=list)
|
| 8 |
+
evolved_questions: List[str] = Field(default_factory=list)
|
| 9 |
context: List[str] = Field(default_factory=list)
|
| 10 |
answer: str = Field(default="")
|
| 11 |
+
num_evolve_passes: int = Field(default=2)
|
| 12 |
+
|
| 13 |
+
@property
|
| 14 |
+
def evolved_question(self):
|
| 15 |
+
return self.evolved_questions[-1] if self.evolved_questions else ""
|
main.py
CHANGED
|
@@ -20,6 +20,7 @@ class DocumentEncoder(json.JSONEncoder):
|
|
| 20 |
if isinstance(obj, SDGState):
|
| 21 |
return {
|
| 22 |
"input": obj.input,
|
|
|
|
| 23 |
"evolved_question": obj.evolved_question,
|
| 24 |
"context": obj.context,
|
| 25 |
"answer": obj.answer
|
|
|
|
| 20 |
if isinstance(obj, SDGState):
|
| 21 |
return {
|
| 22 |
"input": obj.input,
|
| 23 |
+
"evolved_questions": obj.evolved_questions,
|
| 24 |
"evolved_question": obj.evolved_question,
|
| 25 |
"context": obj.context,
|
| 26 |
"answer": obj.answer
|
tests/graph/nodes/test_evolve.py
CHANGED
|
@@ -54,4 +54,40 @@ def test_evolve_question_three_passes():
|
|
| 54 |
call("Rewrite or evolve the following question to be more creative or original:\n\nChallenging: What were the top LLMs in 2023?"),
|
| 55 |
call("Rewrite or evolve the following question to be more challenging or insightful:\n\nCreative: What were the top LLMs in 2023?")
|
| 56 |
]
|
| 57 |
-
mock_llm.invoke.assert_has_calls(expected_calls)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
call("Rewrite or evolve the following question to be more creative or original:\n\nChallenging: What were the top LLMs in 2023?"),
|
| 55 |
call("Rewrite or evolve the following question to be more challenging or insightful:\n\nCreative: What were the top LLMs in 2023?")
|
| 56 |
]
|
| 57 |
+
mock_llm.invoke.assert_has_calls(expected_calls)
|
| 58 |
+
|
| 59 |
+
def test_evolved_questions_list_populated_correctly():
|
| 60 |
+
state = SDGState(input="Base question", num_evolve_passes=3)
|
| 61 |
+
mock_llm = MagicMock()
|
| 62 |
+
mock_llm.invoke.side_effect = [
|
| 63 |
+
MagicMock(content="Challenging: Base question"),
|
| 64 |
+
MagicMock(content="Creative: Challenging: Base question"),
|
| 65 |
+
MagicMock(content="Challenging Again: Creative: Challenging: Base question")
|
| 66 |
+
]
|
| 67 |
+
updated_state = evolve_question(state, mock_llm)
|
| 68 |
+
# The evolved_questions list should contain the initial input plus one entry per pass
|
| 69 |
+
assert updated_state.evolved_questions == [
|
| 70 |
+
"Base question",
|
| 71 |
+
"Challenging: Base question",
|
| 72 |
+
"Creative: Challenging: Base question",
|
| 73 |
+
"Challenging Again: Creative: Challenging: Base question"
|
| 74 |
+
]
|
| 75 |
+
# The property should return the last one
|
| 76 |
+
assert updated_state.evolved_question == "Challenging Again: Creative: Challenging: Base question"
|
| 77 |
+
|
| 78 |
+
def test_evolved_questions_list_with_existing_evolutions():
|
| 79 |
+
# If the state already has evolved_questions, it should continue from the last
|
| 80 |
+
state = SDGState(input="Base question", evolved_questions=["Base question", "First evolution"], num_evolve_passes=2)
|
| 81 |
+
mock_llm = MagicMock()
|
| 82 |
+
mock_llm.invoke.side_effect = [
|
| 83 |
+
MagicMock(content="Second evolution"),
|
| 84 |
+
MagicMock(content="Third evolution")
|
| 85 |
+
]
|
| 86 |
+
updated_state = evolve_question(state, mock_llm)
|
| 87 |
+
assert updated_state.evolved_questions == [
|
| 88 |
+
"Base question",
|
| 89 |
+
"First evolution",
|
| 90 |
+
"Second evolution",
|
| 91 |
+
"Third evolution"
|
| 92 |
+
]
|
| 93 |
+
assert updated_state.evolved_question == "Third evolution"
|
tests/graph/test_build_graph.py
CHANGED
|
@@ -17,6 +17,8 @@ def test_build_sdg_graph_runs():
|
|
| 17 |
result = graph.invoke(state)
|
| 18 |
|
| 19 |
assert isinstance(result, dict)
|
| 20 |
-
assert "
|
|
|
|
|
|
|
| 21 |
assert result["context"]
|
| 22 |
assert "Relevant content" in result["context"][0]
|
|
|
|
| 17 |
result = graph.invoke(state)
|
| 18 |
|
| 19 |
assert isinstance(result, dict)
|
| 20 |
+
assert "evolved_questions" in result
|
| 21 |
+
if result["evolved_questions"]:
|
| 22 |
+
assert result["evolved_questions"][-1] == "Evolved test question"
|
| 23 |
assert result["context"]
|
| 24 |
assert "Relevant content" in result["context"][0]
|