mwalker22 commited on
Commit
e9f7aa8
·
unverified ·
2 Parent(s): 8ee27c8 ed80a59

Merge pull request #17 from mwalker-tmd/feature/evolve-instruct

Browse files
Dockerfile CHANGED
@@ -1,6 +1,10 @@
1
  # Use Python 3.11 as base image
2
  FROM python:3.11-slim
3
 
 
 
 
 
4
  # Set working directory
5
  WORKDIR /app
6
 
@@ -26,18 +30,21 @@ COPY data/ data/
26
 
27
  # Create a shell script to run the application
28
  RUN echo '#!/bin/bash\n\
 
29
  source /app/.venv/bin/activate\n\
30
- exec /app/.venv/bin/streamlit run app.py --server.port=8501 --server.address=0.0.0.0' > /app/run.sh && \
 
31
  chmod +x /app/run.sh
32
 
33
- # Expose the port Streamlit runs on
34
- EXPOSE 8501
35
 
36
  # Set environment variables
37
  ENV PYTHONUNBUFFERED=1
38
  ENV ENVIRONMENT=development
39
  ENV LANGCHAIN_TRACING_V2=false
40
  ENV PATH="/app/.venv/bin:$PATH"
 
41
 
42
  # Command to run the application
43
  CMD ["/app/run.sh"]
 
1
  # Use Python 3.11 as base image
2
  FROM python:3.11-slim
3
 
4
+ # Add build argument for version tracking
5
+ ARG BUILD_VERSION=1.0.0
6
+ ENV BUILD_VERSION=${BUILD_VERSION}
7
+
8
  # Set working directory
9
  WORKDIR /app
10
 
 
30
 
31
  # Create a shell script to run the application
32
  RUN echo '#!/bin/bash\n\
33
+ echo "Starting application version ${BUILD_VERSION}"\n\
34
  source /app/.venv/bin/activate\n\
35
+ PORT=${PORT:-8501}\n\
36
+ exec /app/.venv/bin/streamlit run app.py --server.port=${PORT} --server.address=0.0.0.0' > /app/run.sh && \
37
  chmod +x /app/run.sh
38
 
39
+ # Expose the default port Streamlit runs on
40
+ EXPOSE ${PORT:-8501}
41
 
42
  # Set environment variables
43
  ENV PYTHONUNBUFFERED=1
44
  ENV ENVIRONMENT=development
45
  ENV LANGCHAIN_TRACING_V2=false
46
  ENV PATH="/app/.venv/bin:$PATH"
47
+ ENV PORT=8501
48
 
49
  # Command to run the application
50
  CMD ["/app/run.sh"]
README.md CHANGED
@@ -14,11 +14,48 @@ This project reproduces the RAGAS Synthetic Data Generation steps using LangGrap
14
 
15
  ## Features
16
 
17
- - Synthetic data generation using Evol Instruct method
18
- - Three evolution types: Simple, Multi-Context, and Reasoning
19
- - Output includes evolved questions, answers, and relevant contexts
 
 
 
20
  - Deployed as a Streamlit app on Hugging Face Spaces
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ## Quick Start
23
 
24
  ### Local Development
@@ -68,13 +105,20 @@ The following environment variables need to be set in your HuggingFace Space set
68
  - `OPENAI_API_KEY`: Your OpenAI API key
69
  - `LANGCHAIN_API_KEY`: Your LangChain API key (optional)
70
  - `LANGCHAIN_PROJECT`: Your LangChain project name (optional)
 
71
  - `ENVIRONMENT`: Set to "production" for production mode
 
 
72
 
73
  ## Project Structure
74
 
75
  - `app.py`: Streamlit application for the Hugging Face deployment
 
76
  - `preprocess/`: Code for preprocessing HTML files and creating embeddings
77
  - `graph/`: LangGraph implementation for synthetic data generation
 
 
 
78
  - `data/`: HTML files containing LLM evolution data
79
- - `tests/`: Test files
80
- - `generated/`: Generated documents and vectorstore
 
14
 
15
  ## Features
16
 
17
+ - Synthetic data generation using Evol Instruct methodology
18
+ - Iterative question evolution with alternating prompts:
19
+ - Even iterations: More challenging and insightful questions
20
+ - Odd iterations: More creative and original questions
21
+ - Consistent state management across iterations
22
+ - Standardized JSON output format with linked questions, answers, and contexts
23
  - Deployed as a Streamlit app on Hugging Face Spaces
24
 
25
+ ## Evol Instruct Implementation
26
+
27
+ This project implements the Evol Instruct methodology for evolving questions through multiple iterations. The implementation has several key aspects that should be considered when modifying the code:
28
+
29
+ ### Core Principles
30
+
31
+ 1. **Single Evolution Per Pass**: Each graph invocation performs one evolution step, maintaining clarity and control over the evolution process.
32
+ 2. **Alternating Prompts**: The system alternates between:
33
+ - Challenging/insightful prompts (even-numbered iterations)
34
+ - Creative/original prompts (odd-numbered iterations)
35
+ 3. **State Management**: Evolution history is preserved between iterations of the evolving questions process. In addition, each node in the chain only processes the latest evolved question.
36
+ 4. **Configurable Evolution Count**: The number of evolution passes can be controlled through UI or environment variables, allowing flexibility in the evolution process.
37
+
38
+ ### Implementation Details
39
+
40
+ - The evolution logic is implemented in `graph/nodes/evolve.py`
41
+ - Prompt selection is based on the number of existing evolutions
42
+ - State management ensures each evolution builds upon previous results
43
+ - Results maintain consistent IDs (`q0`, `q1`, etc.) across questions, answers, and contexts
44
+
45
+ ### Configuration
46
+
47
+ - Number of evolution passes can be controlled via:
48
+ - Streamlit UI slider (web interface)
49
+ - `NUM_EVOLVE_PASSES` environment variable (CLI)
50
+
51
+ ### ⚠️ Important Considerations
52
+
53
+ When modifying this codebase, please keep in mind:
54
+ 1. The evolution process is intentionally sequential and builds upon previous iterations
55
+ 2. Maintaining the alternating prompt pattern is crucial for question diversity
56
+ 3. State management between iterations must preserve the evolution history
57
+ 4. The ID system (`q0`, `q1`, etc.) must remain consistent across all collections
58
+
59
  ## Quick Start
60
 
61
  ### Local Development
 
105
  - `OPENAI_API_KEY`: Your OpenAI API key
106
  - `LANGCHAIN_API_KEY`: Your LangChain API key (optional)
107
  - `LANGCHAIN_PROJECT`: Your LangChain project name (optional)
108
+ - `LANGCHAIN_TRACING_V2`: Set to "true" to enable tracing
109
  - `ENVIRONMENT`: Set to "production" for production mode
110
+ - `NUM_EVOLVE_PASSES`: Number of evolution iterations (default: 2)
111
+ - `VECTORSTORE_PATH`: Path to store vectors (default: /tmp/vectorstore)
112
 
113
  ## Project Structure
114
 
115
  - `app.py`: Streamlit application for the Hugging Face deployment
116
+ - `main.py`: CLI interface with the same functionality as the web app
117
  - `preprocess/`: Code for preprocessing HTML files and creating embeddings
118
  - `graph/`: LangGraph implementation for synthetic data generation
119
+ - `nodes/`: Individual graph nodes (evolve, retrieve, answer)
120
+ - `types.py`: State management and data structures
121
+ - `build_graph.py`: Graph construction and configuration
122
  - `data/`: HTML files containing LLM evolution data
123
+ - `tests/`: Test files ensuring correct implementation
124
+ - `generated/`: Generated documents, vectorstore, and results
app.py CHANGED
@@ -52,24 +52,46 @@ def initialize_resources():
52
  # Initialize resources
53
  docs, vectorstore, graph = initialize_resources()
54
 
 
 
 
 
 
 
 
 
 
 
55
  # Generate synthetic data button
56
  if st.button("Generate Synthetic Data"):
57
  with st.spinner("Generating synthetic data..."):
58
  # Create initial state
59
- initial_state = SDGState(
60
  input="Generate synthetic data about LLM evolution",
61
  documents=[],
62
- evolved_question="",
63
  context=[],
64
- answer=""
 
65
  )
66
- logger.debug(f"Initial state before invoke: {initial_state}")
67
 
68
- # Invoke the graph with the SDGState object
69
- result = graph.invoke(initial_state)
70
- logger.debug(f"Graph result: {result}")
71
- if not isinstance(result, SDGState):
72
- result = SDGState(**dict(result))
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  # Display results
75
  st.subheader("Generated Data")
@@ -77,22 +99,24 @@ if st.button("Generate Synthetic Data"):
77
  # Display evolved questions
78
  st.markdown("### Evolved Questions")
79
  evolved_questions = [
80
- {"id": f"q{i}", "question": q, "evolution_type": "simple"}
81
- for i, q in enumerate([result.evolved_question]) # Currently only one question
82
  ]
83
  st.json(evolved_questions)
84
 
85
  # Display answers
86
  st.markdown("### Answers")
87
  answers = [
88
- {"id": "q0", "answer": result.answer}
 
89
  ]
90
  st.json(answers)
91
 
92
  # Display contexts
93
  st.markdown("### Contexts")
94
  contexts = [
95
- {"id": "q0", "contexts": result.context}
 
96
  ]
97
  st.json(contexts)
98
 
 
52
  # Initialize resources
53
  docs, vectorstore, graph = initialize_resources()
54
 
55
+ # Add a number input for evolution passes
56
+ num_evolve_passes = st.number_input(
57
+ label="Number of Evolution Passes",
58
+ min_value=1,
59
+ max_value=10,
60
+ value=2,
61
+ step=1,
62
+ help="How many times to evolve the question (alternates between challenging and creative prompts)."
63
+ )
64
+
65
  # Generate synthetic data button
66
  if st.button("Generate Synthetic Data"):
67
  with st.spinner("Generating synthetic data..."):
68
  # Create initial state
69
+ state = SDGState(
70
  input="Generate synthetic data about LLM evolution",
71
  documents=[],
72
+ evolved_questions=[],
73
  context=[],
74
+ answer="",
75
+ num_evolve_passes=num_evolve_passes
76
  )
 
77
 
78
+ # Run the graph for each evolution pass
79
+ all_results = []
80
+ for i in range(num_evolve_passes):
81
+ logger.debug(f"Running evolution pass {i+1}/{num_evolve_passes}")
82
+ result = graph.invoke(state)
83
+ if not isinstance(result, SDGState):
84
+ result = SDGState(**dict(result))
85
+ all_results.append(result)
86
+ # Update state for next iteration with evolved questions
87
+ state = SDGState(
88
+ input=state.input,
89
+ documents=state.documents,
90
+ evolved_questions=result.evolved_questions, # Pass forward all evolved questions
91
+ context=[], # Reset context for next iteration
92
+ answer="", # Reset answer for next iteration
93
+ num_evolve_passes=num_evolve_passes
94
+ )
95
 
96
  # Display results
97
  st.subheader("Generated Data")
 
99
  # Display evolved questions
100
  st.markdown("### Evolved Questions")
101
  evolved_questions = [
102
+ {"id": f"q{i}", "question": result.evolved_questions[-1], "evolution_type": "simple"}
103
+ for i, result in enumerate(all_results)
104
  ]
105
  st.json(evolved_questions)
106
 
107
  # Display answers
108
  st.markdown("### Answers")
109
  answers = [
110
+ {"id": f"q{i}", "answer": result.answer}
111
+ for i, result in enumerate(all_results)
112
  ]
113
  st.json(answers)
114
 
115
  # Display contexts
116
  st.markdown("### Contexts")
117
  contexts = [
118
+ {"id": f"q{i}", "contexts": result.context}
119
+ for i, result in enumerate(all_results)
120
  ]
121
  st.json(contexts)
122
 
graph/nodes/answer.py CHANGED
@@ -17,9 +17,10 @@ def generate_answer(state: SDGState) -> SDGState:
17
  new_state = SDGState(
18
  input=state.input,
19
  documents=state.documents,
20
- evolved_question=state.evolved_question,
21
  context=state.context,
22
- answer=f"Based on the retrieved context:\n{context_snippet}"
 
23
  )
24
 
25
  logger.debug(f"Answer node returning state: {new_state}")
 
17
  new_state = SDGState(
18
  input=state.input,
19
  documents=state.documents,
20
+ evolved_questions=state.evolved_questions,
21
  context=state.context,
22
+ answer=f"Based on the retrieved context:\n{context_snippet}",
23
+ num_evolve_passes=state.num_evolve_passes
24
  )
25
 
26
  logger.debug(f"Answer node returning state: {new_state}")
graph/nodes/evolve.py CHANGED
@@ -5,20 +5,27 @@ import logging
5
  logger = logging.getLogger(__name__)
6
 
7
  def evolve_question(state: SDGState, llm) -> SDGState:
8
- logger.debug(f"Evolve node received state: {state}")
 
 
 
9
 
10
- # Use the LLM to generate an evolved question
11
- prompt = f"Rewrite or evolve the following question to be more challenging or insightful:\n\n{state.input}"
 
 
 
12
  response = llm.invoke(prompt)
13
- evolved_question = response.content if hasattr(response, 'content') else str(response)
14
-
 
15
  new_state = SDGState(
16
  input=state.input,
17
  documents=state.documents,
18
- evolved_question=evolved_question,
19
  context=state.context,
20
- answer=state.answer
 
21
  )
22
-
23
  logger.debug(f"Evolve node returning state: {new_state}")
24
  return new_state
 
5
  logger = logging.getLogger(__name__)
6
 
7
  def evolve_question(state: SDGState, llm) -> SDGState:
8
+ prompts = [
9
+ "Rewrite or evolve the following question to be more challenging or insightful:\n\n{}",
10
+ "Rewrite or evolve the following question to be more creative or original:\n\n{}"
11
+ ]
12
 
13
+ # Choose prompt based on number of existing evolutions (even/odd)
14
+ prompt_idx = len(state.evolved_questions) % len(prompts)
15
+ prompt = prompts[prompt_idx].format(state.evolved_question)
16
+
17
+ # Generate new evolution
18
  response = llm.invoke(prompt)
19
+ evolved = response.content if hasattr(response, 'content') else str(response)
20
+
21
+ # Create new state with appended evolution
22
  new_state = SDGState(
23
  input=state.input,
24
  documents=state.documents,
25
+ evolved_questions=state.evolved_questions + [evolved],
26
  context=state.context,
27
+ answer=state.answer,
28
+ num_evolve_passes=state.num_evolve_passes
29
  )
 
30
  logger.debug(f"Evolve node returning state: {new_state}")
31
  return new_state
graph/nodes/retrieve.py CHANGED
@@ -14,9 +14,10 @@ def retrieve_relevant_context(state: SDGState, vectorstore) -> SDGState:
14
  new_state = SDGState(
15
  input=state.input,
16
  documents=state.documents,
17
- evolved_question=state.evolved_question,
18
  context=[doc.page_content for doc in retrieved_docs],
19
- answer=state.answer
 
20
  )
21
 
22
  logger.debug(f"Retrieve node returning state: {new_state}")
 
14
  new_state = SDGState(
15
  input=state.input,
16
  documents=state.documents,
17
+ evolved_questions=state.evolved_questions,
18
  context=[doc.page_content for doc in retrieved_docs],
19
+ answer=state.answer,
20
+ num_evolve_passes=state.num_evolve_passes
21
  )
22
 
23
  logger.debug(f"Retrieve node returning state: {new_state}")
graph/types.py CHANGED
@@ -5,6 +5,11 @@ from pydantic import BaseModel, Field
5
  class SDGState(BaseModel):
6
  input: str = Field(default="")
7
  documents: List[Document] = Field(default_factory=list)
8
- evolved_question: str = Field(default="")
9
  context: List[str] = Field(default_factory=list)
10
- answer: str = Field(default="")
 
 
 
 
 
 
5
  class SDGState(BaseModel):
6
  input: str = Field(default="")
7
  documents: List[Document] = Field(default_factory=list)
8
+ evolved_questions: List[str] = Field(default_factory=list)
9
  context: List[str] = Field(default_factory=list)
10
+ answer: str = Field(default="")
11
+ num_evolve_passes: int = Field(default=2)
12
+
13
+ @property
14
+ def evolved_question(self):
15
+ return self.evolved_questions[-1] if self.evolved_questions else self.input
main.py CHANGED
@@ -20,6 +20,7 @@ class DocumentEncoder(json.JSONEncoder):
20
  if isinstance(obj, SDGState):
21
  return {
22
  "input": obj.input,
 
23
  "evolved_question": obj.evolved_question,
24
  "context": obj.context,
25
  "answer": obj.answer
@@ -63,6 +64,27 @@ def load_or_generate_documents() -> list[Document]:
63
  return docs
64
 
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def main():
67
  if is_dev_mode():
68
  print("🚧 Running in development mode...")
@@ -74,11 +96,50 @@ def main():
74
 
75
  llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=None) # None will use env var
76
  graph = build_sdg_graph(docs, vectorstore, llm)
77
- initial_state = SDGState(input="How did LLMs evolve in 2023?")
78
 
79
- result = graph.invoke(initial_state)
80
- print("🧠 Agent Output:")
81
- print(json.dumps(result, indent=2, ensure_ascii=False, cls=DocumentEncoder))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  else:
83
  print("🔒 Production mode detected. Skipping document generation.")
84
 
 
20
  if isinstance(obj, SDGState):
21
  return {
22
  "input": obj.input,
23
+ "evolved_questions": obj.evolved_questions,
24
  "evolved_question": obj.evolved_question,
25
  "context": obj.context,
26
  "answer": obj.answer
 
64
  return docs
65
 
66
 
67
+ def format_results(all_results):
68
+ """Format results into the standard JSON structure."""
69
+ evolved_questions = [
70
+ {"id": f"q{i}", "question": result.evolved_questions[-1], "evolution_type": "simple"}
71
+ for i, result in enumerate(all_results)
72
+ ]
73
+ answers = [
74
+ {"id": f"q{i}", "answer": result.answer}
75
+ for i, result in enumerate(all_results)
76
+ ]
77
+ contexts = [
78
+ {"id": f"q{i}", "contexts": result.context}
79
+ for i, result in enumerate(all_results)
80
+ ]
81
+ return {
82
+ "evolved_questions": evolved_questions,
83
+ "answers": answers,
84
+ "contexts": contexts
85
+ }
86
+
87
+
88
  def main():
89
  if is_dev_mode():
90
  print("🚧 Running in development mode...")
 
96
 
97
  llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=None) # None will use env var
98
  graph = build_sdg_graph(docs, vectorstore, llm)
 
99
 
100
+ # Set up initial state with desired number of passes
101
+ num_evolve_passes = int(os.environ.get("NUM_EVOLVE_PASSES", "2"))
102
+ state = SDGState(
103
+ input="How did LLMs evolve in 2023?",
104
+ documents=[],
105
+ evolved_questions=[],
106
+ context=[],
107
+ answer="",
108
+ num_evolve_passes=num_evolve_passes
109
+ )
110
+
111
+ # Run the graph for each evolution pass
112
+ all_results = []
113
+ print(f"🔄 Running {num_evolve_passes} evolution passes...")
114
+ for i in range(num_evolve_passes):
115
+ print(f"\n📝 Evolution pass {i+1}/{num_evolve_passes}:")
116
+ result = graph.invoke(state)
117
+ if not isinstance(result, SDGState):
118
+ result = SDGState(**dict(result))
119
+ all_results.append(result)
120
+ # Update state for next iteration with evolved questions
121
+ state = SDGState(
122
+ input=state.input,
123
+ documents=state.documents,
124
+ evolved_questions=result.evolved_questions, # Pass forward all evolved questions
125
+ context=[], # Reset context for next iteration
126
+ answer="", # Reset answer for next iteration
127
+ num_evolve_passes=num_evolve_passes
128
+ )
129
+ print(f" Question: {result.evolved_questions[-1]}")
130
+ print(f" Answer: {result.answer[:100]}...")
131
+
132
+ # Format and output results
133
+ print("\n🧠 Final Output:")
134
+ results = format_results(all_results)
135
+ print(json.dumps(results, indent=2, ensure_ascii=False, cls=DocumentEncoder))
136
+
137
+ # Save results to file
138
+ output_file = Path("generated/results.json")
139
+ output_file.parent.mkdir(parents=True, exist_ok=True)
140
+ with open(output_file, "w", encoding="utf-8") as f:
141
+ json.dump(results, f, indent=2, ensure_ascii=False, cls=DocumentEncoder)
142
+ print(f"\n💾 Results saved to {output_file}")
143
  else:
144
  print("🔒 Production mode detected. Skipping document generation.")
145
 
pyproject.toml CHANGED
@@ -15,7 +15,8 @@ dependencies = [
15
  "openai",
16
  "tiktoken",
17
  "langchain-openai",
18
- "faiss-cpu",
 
19
  "streamlit"
20
  ]
21
 
 
15
  "openai",
16
  "tiktoken",
17
  "langchain-openai",
18
+ "faiss-cpu==1.7.4",
19
+ "numpy<2.0.0",
20
  "streamlit"
21
  ]
22
 
tests/graph/nodes/test_evolve.py CHANGED
@@ -1,12 +1,79 @@
1
  from graph.types import SDGState
2
  from graph.nodes.evolve import evolve_question
3
- from unittest.mock import MagicMock
4
 
5
- def test_evolve_question_modifies_state():
 
6
  state = SDGState(input="What were the top LLMs in 2023?")
7
  mock_llm = MagicMock()
8
  mock_llm.invoke.return_value = MagicMock(content="Evolved: What were the top LLMs in 2023?")
9
  updated_state = evolve_question(state, mock_llm)
10
 
11
- assert updated_state.evolved_question.startswith("Evolved:")
12
- assert updated_state.evolved_question.endswith("2023?")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from graph.types import SDGState
2
  from graph.nodes.evolve import evolve_question
3
+ from unittest.mock import MagicMock, call
4
 
5
+ def test_evolve_question_initial_state():
6
+ # Test evolution from initial state (should use input)
7
  state = SDGState(input="What were the top LLMs in 2023?")
8
  mock_llm = MagicMock()
9
  mock_llm.invoke.return_value = MagicMock(content="Evolved: What were the top LLMs in 2023?")
10
  updated_state = evolve_question(state, mock_llm)
11
 
12
+ # Should use challenging prompt first (even index)
13
+ mock_llm.invoke.assert_called_once_with(
14
+ "Rewrite or evolve the following question to be more challenging or insightful:\n\nWhat were the top LLMs in 2023?"
15
+ )
16
+ assert len(updated_state.evolved_questions) == 1
17
+ assert updated_state.evolved_questions[0] == "Evolved: What were the top LLMs in 2023?"
18
+ assert updated_state.evolved_question == "Evolved: What were the top LLMs in 2023?"
19
+
20
+ def test_evolve_question_with_one_evolution():
21
+ # Test evolution with one existing evolution (should use creative prompt)
22
+ state = SDGState(
23
+ input="Base question",
24
+ evolved_questions=["First evolution"]
25
+ )
26
+ mock_llm = MagicMock()
27
+ mock_llm.invoke.return_value = MagicMock(content="Creative evolution")
28
+ updated_state = evolve_question(state, mock_llm)
29
+
30
+ # Should use creative prompt (odd index)
31
+ mock_llm.invoke.assert_called_once_with(
32
+ "Rewrite or evolve the following question to be more creative or original:\n\nFirst evolution"
33
+ )
34
+ assert len(updated_state.evolved_questions) == 2
35
+ assert updated_state.evolved_questions == ["First evolution", "Creative evolution"]
36
+ assert updated_state.evolved_question == "Creative evolution"
37
+
38
+ def test_evolve_question_with_two_evolutions():
39
+ # Test evolution with two existing evolutions (should use challenging prompt)
40
+ state = SDGState(
41
+ input="Base question",
42
+ evolved_questions=["First evolution", "Second evolution"]
43
+ )
44
+ mock_llm = MagicMock()
45
+ mock_llm.invoke.return_value = MagicMock(content="Challenging evolution")
46
+ updated_state = evolve_question(state, mock_llm)
47
+
48
+ # Should use challenging prompt (even index)
49
+ mock_llm.invoke.assert_called_once_with(
50
+ "Rewrite or evolve the following question to be more challenging or insightful:\n\nSecond evolution"
51
+ )
52
+ assert len(updated_state.evolved_questions) == 3
53
+ assert updated_state.evolved_questions == ["First evolution", "Second evolution", "Challenging evolution"]
54
+ assert updated_state.evolved_question == "Challenging evolution"
55
+
56
+ def test_state_preservation():
57
+ # Test that other state fields are preserved
58
+ initial_state = SDGState(
59
+ input="Base question",
60
+ evolved_questions=["First evolution"],
61
+ documents=[],
62
+ context=["Some context"],
63
+ answer="Previous answer",
64
+ num_evolve_passes=5
65
+ )
66
+ mock_llm = MagicMock()
67
+ mock_llm.invoke.return_value = MagicMock(content="New evolution")
68
+ updated_state = evolve_question(initial_state, mock_llm)
69
+
70
+ # Check that all fields are preserved except evolved_questions
71
+ assert updated_state.input == initial_state.input
72
+ assert updated_state.documents == initial_state.documents
73
+ assert updated_state.context == initial_state.context
74
+ assert updated_state.answer == initial_state.answer
75
+ assert updated_state.num_evolve_passes == initial_state.num_evolve_passes
76
+ # Check that evolved_questions is updated correctly
77
+ assert len(updated_state.evolved_questions) == 2
78
+ assert updated_state.evolved_questions[0] == "First evolution"
79
+ assert updated_state.evolved_questions[1] == "New evolution"
tests/graph/test_build_graph.py CHANGED
@@ -17,6 +17,8 @@ def test_build_sdg_graph_runs():
17
  result = graph.invoke(state)
18
 
19
  assert isinstance(result, dict)
20
- assert "evolved_question" in result
 
 
21
  assert result["context"]
22
  assert "Relevant content" in result["context"][0]
 
17
  result = graph.invoke(state)
18
 
19
  assert isinstance(result, dict)
20
+ assert "evolved_questions" in result
21
+ if result["evolved_questions"]:
22
+ assert result["evolved_questions"][-1] == "Evolved test question"
23
  assert result["context"]
24
  assert "Relevant content" in result["context"][0]
uv.lock CHANGED
The diff for this file is too large to render. See raw diff