Spaces:
Sleeping
Sleeping
Merge pull request #17 from mwalker-tmd/feature/evolve-instruct
Browse files- Dockerfile +10 -3
- README.md +49 -5
- app.py +37 -13
- graph/nodes/answer.py +3 -2
- graph/nodes/evolve.py +15 -8
- graph/nodes/retrieve.py +3 -2
- graph/types.py +7 -2
- main.py +65 -4
- pyproject.toml +2 -1
- tests/graph/nodes/test_evolve.py +71 -4
- tests/graph/test_build_graph.py +3 -1
- uv.lock +0 -0
Dockerfile
CHANGED
|
@@ -1,6 +1,10 @@
|
|
| 1 |
# Use Python 3.11 as base image
|
| 2 |
FROM python:3.11-slim
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
# Set working directory
|
| 5 |
WORKDIR /app
|
| 6 |
|
|
@@ -26,18 +30,21 @@ COPY data/ data/
|
|
| 26 |
|
| 27 |
# Create a shell script to run the application
|
| 28 |
RUN echo '#!/bin/bash\n\
|
|
|
|
| 29 |
source /app/.venv/bin/activate\n\
|
| 30 |
-
|
|
|
|
| 31 |
chmod +x /app/run.sh
|
| 32 |
|
| 33 |
-
# Expose the port Streamlit runs on
|
| 34 |
-
EXPOSE 8501
|
| 35 |
|
| 36 |
# Set environment variables
|
| 37 |
ENV PYTHONUNBUFFERED=1
|
| 38 |
ENV ENVIRONMENT=development
|
| 39 |
ENV LANGCHAIN_TRACING_V2=false
|
| 40 |
ENV PATH="/app/.venv/bin:$PATH"
|
|
|
|
| 41 |
|
| 42 |
# Command to run the application
|
| 43 |
CMD ["/app/run.sh"]
|
|
|
|
| 1 |
# Use Python 3.11 as base image
|
| 2 |
FROM python:3.11-slim
|
| 3 |
|
| 4 |
+
# Add build argument for version tracking
|
| 5 |
+
ARG BUILD_VERSION=1.0.0
|
| 6 |
+
ENV BUILD_VERSION=${BUILD_VERSION}
|
| 7 |
+
|
| 8 |
# Set working directory
|
| 9 |
WORKDIR /app
|
| 10 |
|
|
|
|
| 30 |
|
| 31 |
# Create a shell script to run the application
|
| 32 |
RUN echo '#!/bin/bash\n\
|
| 33 |
+
echo "Starting application version ${BUILD_VERSION}"\n\
|
| 34 |
source /app/.venv/bin/activate\n\
|
| 35 |
+
PORT=${PORT:-8501}\n\
|
| 36 |
+
exec /app/.venv/bin/streamlit run app.py --server.port=${PORT} --server.address=0.0.0.0' > /app/run.sh && \
|
| 37 |
chmod +x /app/run.sh
|
| 38 |
|
| 39 |
+
# Expose the default port Streamlit runs on
|
| 40 |
+
EXPOSE ${PORT:-8501}
|
| 41 |
|
| 42 |
# Set environment variables
|
| 43 |
ENV PYTHONUNBUFFERED=1
|
| 44 |
ENV ENVIRONMENT=development
|
| 45 |
ENV LANGCHAIN_TRACING_V2=false
|
| 46 |
ENV PATH="/app/.venv/bin:$PATH"
|
| 47 |
+
ENV PORT=8501
|
| 48 |
|
| 49 |
# Command to run the application
|
| 50 |
CMD ["/app/run.sh"]
|
README.md
CHANGED
|
@@ -14,11 +14,48 @@ This project reproduces the RAGAS Synthetic Data Generation steps using LangGrap
|
|
| 14 |
|
| 15 |
## Features
|
| 16 |
|
| 17 |
-
- Synthetic data generation using Evol Instruct
|
| 18 |
-
-
|
| 19 |
-
-
|
|
|
|
|
|
|
|
|
|
| 20 |
- Deployed as a Streamlit app on Hugging Face Spaces
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
## Quick Start
|
| 23 |
|
| 24 |
### Local Development
|
|
@@ -68,13 +105,20 @@ The following environment variables need to be set in your HuggingFace Space set
|
|
| 68 |
- `OPENAI_API_KEY`: Your OpenAI API key
|
| 69 |
- `LANGCHAIN_API_KEY`: Your LangChain API key (optional)
|
| 70 |
- `LANGCHAIN_PROJECT`: Your LangChain project name (optional)
|
|
|
|
| 71 |
- `ENVIRONMENT`: Set to "production" for production mode
|
|
|
|
|
|
|
| 72 |
|
| 73 |
## Project Structure
|
| 74 |
|
| 75 |
- `app.py`: Streamlit application for the Hugging Face deployment
|
|
|
|
| 76 |
- `preprocess/`: Code for preprocessing HTML files and creating embeddings
|
| 77 |
- `graph/`: LangGraph implementation for synthetic data generation
|
|
|
|
|
|
|
|
|
|
| 78 |
- `data/`: HTML files containing LLM evolution data
|
| 79 |
-
- `tests/`: Test files
|
| 80 |
-
- `generated/`: Generated documents and
|
|
|
|
| 14 |
|
| 15 |
## Features
|
| 16 |
|
| 17 |
+
- Synthetic data generation using Evol Instruct methodology
|
| 18 |
+
- Iterative question evolution with alternating prompts:
|
| 19 |
+
- Even iterations: More challenging and insightful questions
|
| 20 |
+
- Odd iterations: More creative and original questions
|
| 21 |
+
- Consistent state management across iterations
|
| 22 |
+
- Standardized JSON output format with linked questions, answers, and contexts
|
| 23 |
- Deployed as a Streamlit app on Hugging Face Spaces
|
| 24 |
|
| 25 |
+
## Evol Instruct Implementation
|
| 26 |
+
|
| 27 |
+
This project implements the Evol Instruct methodology for evolving questions through multiple iterations. The implementation has several key aspects that should be considered when modifying the code:
|
| 28 |
+
|
| 29 |
+
### Core Principles
|
| 30 |
+
|
| 31 |
+
1. **Single Evolution Per Pass**: Each graph invocation performs one evolution step, maintaining clarity and control over the evolution process.
|
| 32 |
+
2. **Alternating Prompts**: The system alternates between:
|
| 33 |
+
- Challenging/insightful prompts (even-numbered iterations)
|
| 34 |
+
- Creative/original prompts (odd-numbered iterations)
|
| 35 |
+
3. **State Management**: Evolution history is preserved between iterations of the evolving questions process. In addition, each node in the chain only processes the latest evolved question.
|
| 36 |
+
4. **Configurable Evolution Count**: The number of evolution passes can be controlled through UI or environment variables, allowing flexibility in the evolution process.
|
| 37 |
+
|
| 38 |
+
### Implementation Details
|
| 39 |
+
|
| 40 |
+
- The evolution logic is implemented in `graph/nodes/evolve.py`
|
| 41 |
+
- Prompt selection is based on the number of existing evolutions
|
| 42 |
+
- State management ensures each evolution builds upon previous results
|
| 43 |
+
- Results maintain consistent IDs (`q0`, `q1`, etc.) across questions, answers, and contexts
|
| 44 |
+
|
| 45 |
+
### Configuration
|
| 46 |
+
|
| 47 |
+
- Number of evolution passes can be controlled via:
|
| 48 |
+
- Streamlit UI slider (web interface)
|
| 49 |
+
- `NUM_EVOLVE_PASSES` environment variable (CLI)
|
| 50 |
+
|
| 51 |
+
### ⚠️ Important Considerations
|
| 52 |
+
|
| 53 |
+
When modifying this codebase, please keep in mind:
|
| 54 |
+
1. The evolution process is intentionally sequential and builds upon previous iterations
|
| 55 |
+
2. Maintaining the alternating prompt pattern is crucial for question diversity
|
| 56 |
+
3. State management between iterations must preserve the evolution history
|
| 57 |
+
4. The ID system (`q0`, `q1`, etc.) must remain consistent across all collections
|
| 58 |
+
|
| 59 |
## Quick Start
|
| 60 |
|
| 61 |
### Local Development
|
|
|
|
| 105 |
- `OPENAI_API_KEY`: Your OpenAI API key
|
| 106 |
- `LANGCHAIN_API_KEY`: Your LangChain API key (optional)
|
| 107 |
- `LANGCHAIN_PROJECT`: Your LangChain project name (optional)
|
| 108 |
+
- `LANGCHAIN_TRACING_V2`: Set to "true" to enable tracing
|
| 109 |
- `ENVIRONMENT`: Set to "production" for production mode
|
| 110 |
+
- `NUM_EVOLVE_PASSES`: Number of evolution iterations (default: 2)
|
| 111 |
+
- `VECTORSTORE_PATH`: Path to store vectors (default: /tmp/vectorstore)
|
| 112 |
|
| 113 |
## Project Structure
|
| 114 |
|
| 115 |
- `app.py`: Streamlit application for the Hugging Face deployment
|
| 116 |
+
- `main.py`: CLI interface with the same functionality as the web app
|
| 117 |
- `preprocess/`: Code for preprocessing HTML files and creating embeddings
|
| 118 |
- `graph/`: LangGraph implementation for synthetic data generation
|
| 119 |
+
- `nodes/`: Individual graph nodes (evolve, retrieve, answer)
|
| 120 |
+
- `types.py`: State management and data structures
|
| 121 |
+
- `build_graph.py`: Graph construction and configuration
|
| 122 |
- `data/`: HTML files containing LLM evolution data
|
| 123 |
+
- `tests/`: Test files ensuring correct implementation
|
| 124 |
+
- `generated/`: Generated documents, vectorstore, and results
|
app.py
CHANGED
|
@@ -52,24 +52,46 @@ def initialize_resources():
|
|
| 52 |
# Initialize resources
|
| 53 |
docs, vectorstore, graph = initialize_resources()
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
# Generate synthetic data button
|
| 56 |
if st.button("Generate Synthetic Data"):
|
| 57 |
with st.spinner("Generating synthetic data..."):
|
| 58 |
# Create initial state
|
| 59 |
-
|
| 60 |
input="Generate synthetic data about LLM evolution",
|
| 61 |
documents=[],
|
| 62 |
-
|
| 63 |
context=[],
|
| 64 |
-
answer=""
|
|
|
|
| 65 |
)
|
| 66 |
-
logger.debug(f"Initial state before invoke: {initial_state}")
|
| 67 |
|
| 68 |
-
#
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
result =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
# Display results
|
| 75 |
st.subheader("Generated Data")
|
|
@@ -77,22 +99,24 @@ if st.button("Generate Synthetic Data"):
|
|
| 77 |
# Display evolved questions
|
| 78 |
st.markdown("### Evolved Questions")
|
| 79 |
evolved_questions = [
|
| 80 |
-
{"id": f"q{i}", "question":
|
| 81 |
-
for i,
|
| 82 |
]
|
| 83 |
st.json(evolved_questions)
|
| 84 |
|
| 85 |
# Display answers
|
| 86 |
st.markdown("### Answers")
|
| 87 |
answers = [
|
| 88 |
-
{"id": "
|
|
|
|
| 89 |
]
|
| 90 |
st.json(answers)
|
| 91 |
|
| 92 |
# Display contexts
|
| 93 |
st.markdown("### Contexts")
|
| 94 |
contexts = [
|
| 95 |
-
{"id": "
|
|
|
|
| 96 |
]
|
| 97 |
st.json(contexts)
|
| 98 |
|
|
|
|
| 52 |
# Initialize resources
|
| 53 |
docs, vectorstore, graph = initialize_resources()
|
| 54 |
|
| 55 |
+
# Add a number input for evolution passes
|
| 56 |
+
num_evolve_passes = st.number_input(
|
| 57 |
+
label="Number of Evolution Passes",
|
| 58 |
+
min_value=1,
|
| 59 |
+
max_value=10,
|
| 60 |
+
value=2,
|
| 61 |
+
step=1,
|
| 62 |
+
help="How many times to evolve the question (alternates between challenging and creative prompts)."
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
# Generate synthetic data button
|
| 66 |
if st.button("Generate Synthetic Data"):
|
| 67 |
with st.spinner("Generating synthetic data..."):
|
| 68 |
# Create initial state
|
| 69 |
+
state = SDGState(
|
| 70 |
input="Generate synthetic data about LLM evolution",
|
| 71 |
documents=[],
|
| 72 |
+
evolved_questions=[],
|
| 73 |
context=[],
|
| 74 |
+
answer="",
|
| 75 |
+
num_evolve_passes=num_evolve_passes
|
| 76 |
)
|
|
|
|
| 77 |
|
| 78 |
+
# Run the graph for each evolution pass
|
| 79 |
+
all_results = []
|
| 80 |
+
for i in range(num_evolve_passes):
|
| 81 |
+
logger.debug(f"Running evolution pass {i+1}/{num_evolve_passes}")
|
| 82 |
+
result = graph.invoke(state)
|
| 83 |
+
if not isinstance(result, SDGState):
|
| 84 |
+
result = SDGState(**dict(result))
|
| 85 |
+
all_results.append(result)
|
| 86 |
+
# Update state for next iteration with evolved questions
|
| 87 |
+
state = SDGState(
|
| 88 |
+
input=state.input,
|
| 89 |
+
documents=state.documents,
|
| 90 |
+
evolved_questions=result.evolved_questions, # Pass forward all evolved questions
|
| 91 |
+
context=[], # Reset context for next iteration
|
| 92 |
+
answer="", # Reset answer for next iteration
|
| 93 |
+
num_evolve_passes=num_evolve_passes
|
| 94 |
+
)
|
| 95 |
|
| 96 |
# Display results
|
| 97 |
st.subheader("Generated Data")
|
|
|
|
| 99 |
# Display evolved questions
|
| 100 |
st.markdown("### Evolved Questions")
|
| 101 |
evolved_questions = [
|
| 102 |
+
{"id": f"q{i}", "question": result.evolved_questions[-1], "evolution_type": "simple"}
|
| 103 |
+
for i, result in enumerate(all_results)
|
| 104 |
]
|
| 105 |
st.json(evolved_questions)
|
| 106 |
|
| 107 |
# Display answers
|
| 108 |
st.markdown("### Answers")
|
| 109 |
answers = [
|
| 110 |
+
{"id": f"q{i}", "answer": result.answer}
|
| 111 |
+
for i, result in enumerate(all_results)
|
| 112 |
]
|
| 113 |
st.json(answers)
|
| 114 |
|
| 115 |
# Display contexts
|
| 116 |
st.markdown("### Contexts")
|
| 117 |
contexts = [
|
| 118 |
+
{"id": f"q{i}", "contexts": result.context}
|
| 119 |
+
for i, result in enumerate(all_results)
|
| 120 |
]
|
| 121 |
st.json(contexts)
|
| 122 |
|
graph/nodes/answer.py
CHANGED
|
@@ -17,9 +17,10 @@ def generate_answer(state: SDGState) -> SDGState:
|
|
| 17 |
new_state = SDGState(
|
| 18 |
input=state.input,
|
| 19 |
documents=state.documents,
|
| 20 |
-
|
| 21 |
context=state.context,
|
| 22 |
-
answer=f"Based on the retrieved context:\n{context_snippet}"
|
|
|
|
| 23 |
)
|
| 24 |
|
| 25 |
logger.debug(f"Answer node returning state: {new_state}")
|
|
|
|
| 17 |
new_state = SDGState(
|
| 18 |
input=state.input,
|
| 19 |
documents=state.documents,
|
| 20 |
+
evolved_questions=state.evolved_questions,
|
| 21 |
context=state.context,
|
| 22 |
+
answer=f"Based on the retrieved context:\n{context_snippet}",
|
| 23 |
+
num_evolve_passes=state.num_evolve_passes
|
| 24 |
)
|
| 25 |
|
| 26 |
logger.debug(f"Answer node returning state: {new_state}")
|
graph/nodes/evolve.py
CHANGED
|
@@ -5,20 +5,27 @@ import logging
|
|
| 5 |
logger = logging.getLogger(__name__)
|
| 6 |
|
| 7 |
def evolve_question(state: SDGState, llm) -> SDGState:
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
#
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
response = llm.invoke(prompt)
|
| 13 |
-
|
| 14 |
-
|
|
|
|
| 15 |
new_state = SDGState(
|
| 16 |
input=state.input,
|
| 17 |
documents=state.documents,
|
| 18 |
-
|
| 19 |
context=state.context,
|
| 20 |
-
answer=state.answer
|
|
|
|
| 21 |
)
|
| 22 |
-
|
| 23 |
logger.debug(f"Evolve node returning state: {new_state}")
|
| 24 |
return new_state
|
|
|
|
| 5 |
logger = logging.getLogger(__name__)
|
| 6 |
|
| 7 |
def evolve_question(state: SDGState, llm) -> SDGState:
|
| 8 |
+
prompts = [
|
| 9 |
+
"Rewrite or evolve the following question to be more challenging or insightful:\n\n{}",
|
| 10 |
+
"Rewrite or evolve the following question to be more creative or original:\n\n{}"
|
| 11 |
+
]
|
| 12 |
|
| 13 |
+
# Choose prompt based on number of existing evolutions (even/odd)
|
| 14 |
+
prompt_idx = len(state.evolved_questions) % len(prompts)
|
| 15 |
+
prompt = prompts[prompt_idx].format(state.evolved_question)
|
| 16 |
+
|
| 17 |
+
# Generate new evolution
|
| 18 |
response = llm.invoke(prompt)
|
| 19 |
+
evolved = response.content if hasattr(response, 'content') else str(response)
|
| 20 |
+
|
| 21 |
+
# Create new state with appended evolution
|
| 22 |
new_state = SDGState(
|
| 23 |
input=state.input,
|
| 24 |
documents=state.documents,
|
| 25 |
+
evolved_questions=state.evolved_questions + [evolved],
|
| 26 |
context=state.context,
|
| 27 |
+
answer=state.answer,
|
| 28 |
+
num_evolve_passes=state.num_evolve_passes
|
| 29 |
)
|
|
|
|
| 30 |
logger.debug(f"Evolve node returning state: {new_state}")
|
| 31 |
return new_state
|
graph/nodes/retrieve.py
CHANGED
|
@@ -14,9 +14,10 @@ def retrieve_relevant_context(state: SDGState, vectorstore) -> SDGState:
|
|
| 14 |
new_state = SDGState(
|
| 15 |
input=state.input,
|
| 16 |
documents=state.documents,
|
| 17 |
-
|
| 18 |
context=[doc.page_content for doc in retrieved_docs],
|
| 19 |
-
answer=state.answer
|
|
|
|
| 20 |
)
|
| 21 |
|
| 22 |
logger.debug(f"Retrieve node returning state: {new_state}")
|
|
|
|
| 14 |
new_state = SDGState(
|
| 15 |
input=state.input,
|
| 16 |
documents=state.documents,
|
| 17 |
+
evolved_questions=state.evolved_questions,
|
| 18 |
context=[doc.page_content for doc in retrieved_docs],
|
| 19 |
+
answer=state.answer,
|
| 20 |
+
num_evolve_passes=state.num_evolve_passes
|
| 21 |
)
|
| 22 |
|
| 23 |
logger.debug(f"Retrieve node returning state: {new_state}")
|
graph/types.py
CHANGED
|
@@ -5,6 +5,11 @@ from pydantic import BaseModel, Field
|
|
| 5 |
class SDGState(BaseModel):
|
| 6 |
input: str = Field(default="")
|
| 7 |
documents: List[Document] = Field(default_factory=list)
|
| 8 |
-
|
| 9 |
context: List[str] = Field(default_factory=list)
|
| 10 |
-
answer: str = Field(default="")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
class SDGState(BaseModel):
|
| 6 |
input: str = Field(default="")
|
| 7 |
documents: List[Document] = Field(default_factory=list)
|
| 8 |
+
evolved_questions: List[str] = Field(default_factory=list)
|
| 9 |
context: List[str] = Field(default_factory=list)
|
| 10 |
+
answer: str = Field(default="")
|
| 11 |
+
num_evolve_passes: int = Field(default=2)
|
| 12 |
+
|
| 13 |
+
@property
|
| 14 |
+
def evolved_question(self):
|
| 15 |
+
return self.evolved_questions[-1] if self.evolved_questions else self.input
|
main.py
CHANGED
|
@@ -20,6 +20,7 @@ class DocumentEncoder(json.JSONEncoder):
|
|
| 20 |
if isinstance(obj, SDGState):
|
| 21 |
return {
|
| 22 |
"input": obj.input,
|
|
|
|
| 23 |
"evolved_question": obj.evolved_question,
|
| 24 |
"context": obj.context,
|
| 25 |
"answer": obj.answer
|
|
@@ -63,6 +64,27 @@ def load_or_generate_documents() -> list[Document]:
|
|
| 63 |
return docs
|
| 64 |
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
def main():
|
| 67 |
if is_dev_mode():
|
| 68 |
print("🚧 Running in development mode...")
|
|
@@ -74,11 +96,50 @@ def main():
|
|
| 74 |
|
| 75 |
llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=None) # None will use env var
|
| 76 |
graph = build_sdg_graph(docs, vectorstore, llm)
|
| 77 |
-
initial_state = SDGState(input="How did LLMs evolve in 2023?")
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
else:
|
| 83 |
print("🔒 Production mode detected. Skipping document generation.")
|
| 84 |
|
|
|
|
| 20 |
if isinstance(obj, SDGState):
|
| 21 |
return {
|
| 22 |
"input": obj.input,
|
| 23 |
+
"evolved_questions": obj.evolved_questions,
|
| 24 |
"evolved_question": obj.evolved_question,
|
| 25 |
"context": obj.context,
|
| 26 |
"answer": obj.answer
|
|
|
|
| 64 |
return docs
|
| 65 |
|
| 66 |
|
| 67 |
+
def format_results(all_results):
|
| 68 |
+
"""Format results into the standard JSON structure."""
|
| 69 |
+
evolved_questions = [
|
| 70 |
+
{"id": f"q{i}", "question": result.evolved_questions[-1], "evolution_type": "simple"}
|
| 71 |
+
for i, result in enumerate(all_results)
|
| 72 |
+
]
|
| 73 |
+
answers = [
|
| 74 |
+
{"id": f"q{i}", "answer": result.answer}
|
| 75 |
+
for i, result in enumerate(all_results)
|
| 76 |
+
]
|
| 77 |
+
contexts = [
|
| 78 |
+
{"id": f"q{i}", "contexts": result.context}
|
| 79 |
+
for i, result in enumerate(all_results)
|
| 80 |
+
]
|
| 81 |
+
return {
|
| 82 |
+
"evolved_questions": evolved_questions,
|
| 83 |
+
"answers": answers,
|
| 84 |
+
"contexts": contexts
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
def main():
|
| 89 |
if is_dev_mode():
|
| 90 |
print("🚧 Running in development mode...")
|
|
|
|
| 96 |
|
| 97 |
llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=None) # None will use env var
|
| 98 |
graph = build_sdg_graph(docs, vectorstore, llm)
|
|
|
|
| 99 |
|
| 100 |
+
# Set up initial state with desired number of passes
|
| 101 |
+
num_evolve_passes = int(os.environ.get("NUM_EVOLVE_PASSES", "2"))
|
| 102 |
+
state = SDGState(
|
| 103 |
+
input="How did LLMs evolve in 2023?",
|
| 104 |
+
documents=[],
|
| 105 |
+
evolved_questions=[],
|
| 106 |
+
context=[],
|
| 107 |
+
answer="",
|
| 108 |
+
num_evolve_passes=num_evolve_passes
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Run the graph for each evolution pass
|
| 112 |
+
all_results = []
|
| 113 |
+
print(f"🔄 Running {num_evolve_passes} evolution passes...")
|
| 114 |
+
for i in range(num_evolve_passes):
|
| 115 |
+
print(f"\n📝 Evolution pass {i+1}/{num_evolve_passes}:")
|
| 116 |
+
result = graph.invoke(state)
|
| 117 |
+
if not isinstance(result, SDGState):
|
| 118 |
+
result = SDGState(**dict(result))
|
| 119 |
+
all_results.append(result)
|
| 120 |
+
# Update state for next iteration with evolved questions
|
| 121 |
+
state = SDGState(
|
| 122 |
+
input=state.input,
|
| 123 |
+
documents=state.documents,
|
| 124 |
+
evolved_questions=result.evolved_questions, # Pass forward all evolved questions
|
| 125 |
+
context=[], # Reset context for next iteration
|
| 126 |
+
answer="", # Reset answer for next iteration
|
| 127 |
+
num_evolve_passes=num_evolve_passes
|
| 128 |
+
)
|
| 129 |
+
print(f" Question: {result.evolved_questions[-1]}")
|
| 130 |
+
print(f" Answer: {result.answer[:100]}...")
|
| 131 |
+
|
| 132 |
+
# Format and output results
|
| 133 |
+
print("\n🧠 Final Output:")
|
| 134 |
+
results = format_results(all_results)
|
| 135 |
+
print(json.dumps(results, indent=2, ensure_ascii=False, cls=DocumentEncoder))
|
| 136 |
+
|
| 137 |
+
# Save results to file
|
| 138 |
+
output_file = Path("generated/results.json")
|
| 139 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 140 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 141 |
+
json.dump(results, f, indent=2, ensure_ascii=False, cls=DocumentEncoder)
|
| 142 |
+
print(f"\n💾 Results saved to {output_file}")
|
| 143 |
else:
|
| 144 |
print("🔒 Production mode detected. Skipping document generation.")
|
| 145 |
|
pyproject.toml
CHANGED
|
@@ -15,7 +15,8 @@ dependencies = [
|
|
| 15 |
"openai",
|
| 16 |
"tiktoken",
|
| 17 |
"langchain-openai",
|
| 18 |
-
"faiss-cpu",
|
|
|
|
| 19 |
"streamlit"
|
| 20 |
]
|
| 21 |
|
|
|
|
| 15 |
"openai",
|
| 16 |
"tiktoken",
|
| 17 |
"langchain-openai",
|
| 18 |
+
"faiss-cpu==1.7.4",
|
| 19 |
+
"numpy<2.0.0",
|
| 20 |
"streamlit"
|
| 21 |
]
|
| 22 |
|
tests/graph/nodes/test_evolve.py
CHANGED
|
@@ -1,12 +1,79 @@
|
|
| 1 |
from graph.types import SDGState
|
| 2 |
from graph.nodes.evolve import evolve_question
|
| 3 |
-
from unittest.mock import MagicMock
|
| 4 |
|
| 5 |
-
def
|
|
|
|
| 6 |
state = SDGState(input="What were the top LLMs in 2023?")
|
| 7 |
mock_llm = MagicMock()
|
| 8 |
mock_llm.invoke.return_value = MagicMock(content="Evolved: What were the top LLMs in 2023?")
|
| 9 |
updated_state = evolve_question(state, mock_llm)
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from graph.types import SDGState
|
| 2 |
from graph.nodes.evolve import evolve_question
|
| 3 |
+
from unittest.mock import MagicMock, call
|
| 4 |
|
| 5 |
+
def test_evolve_question_initial_state():
|
| 6 |
+
# Test evolution from initial state (should use input)
|
| 7 |
state = SDGState(input="What were the top LLMs in 2023?")
|
| 8 |
mock_llm = MagicMock()
|
| 9 |
mock_llm.invoke.return_value = MagicMock(content="Evolved: What were the top LLMs in 2023?")
|
| 10 |
updated_state = evolve_question(state, mock_llm)
|
| 11 |
|
| 12 |
+
# Should use challenging prompt first (even index)
|
| 13 |
+
mock_llm.invoke.assert_called_once_with(
|
| 14 |
+
"Rewrite or evolve the following question to be more challenging or insightful:\n\nWhat were the top LLMs in 2023?"
|
| 15 |
+
)
|
| 16 |
+
assert len(updated_state.evolved_questions) == 1
|
| 17 |
+
assert updated_state.evolved_questions[0] == "Evolved: What were the top LLMs in 2023?"
|
| 18 |
+
assert updated_state.evolved_question == "Evolved: What were the top LLMs in 2023?"
|
| 19 |
+
|
| 20 |
+
def test_evolve_question_with_one_evolution():
|
| 21 |
+
# Test evolution with one existing evolution (should use creative prompt)
|
| 22 |
+
state = SDGState(
|
| 23 |
+
input="Base question",
|
| 24 |
+
evolved_questions=["First evolution"]
|
| 25 |
+
)
|
| 26 |
+
mock_llm = MagicMock()
|
| 27 |
+
mock_llm.invoke.return_value = MagicMock(content="Creative evolution")
|
| 28 |
+
updated_state = evolve_question(state, mock_llm)
|
| 29 |
+
|
| 30 |
+
# Should use creative prompt (odd index)
|
| 31 |
+
mock_llm.invoke.assert_called_once_with(
|
| 32 |
+
"Rewrite or evolve the following question to be more creative or original:\n\nFirst evolution"
|
| 33 |
+
)
|
| 34 |
+
assert len(updated_state.evolved_questions) == 2
|
| 35 |
+
assert updated_state.evolved_questions == ["First evolution", "Creative evolution"]
|
| 36 |
+
assert updated_state.evolved_question == "Creative evolution"
|
| 37 |
+
|
| 38 |
+
def test_evolve_question_with_two_evolutions():
|
| 39 |
+
# Test evolution with two existing evolutions (should use challenging prompt)
|
| 40 |
+
state = SDGState(
|
| 41 |
+
input="Base question",
|
| 42 |
+
evolved_questions=["First evolution", "Second evolution"]
|
| 43 |
+
)
|
| 44 |
+
mock_llm = MagicMock()
|
| 45 |
+
mock_llm.invoke.return_value = MagicMock(content="Challenging evolution")
|
| 46 |
+
updated_state = evolve_question(state, mock_llm)
|
| 47 |
+
|
| 48 |
+
# Should use challenging prompt (even index)
|
| 49 |
+
mock_llm.invoke.assert_called_once_with(
|
| 50 |
+
"Rewrite or evolve the following question to be more challenging or insightful:\n\nSecond evolution"
|
| 51 |
+
)
|
| 52 |
+
assert len(updated_state.evolved_questions) == 3
|
| 53 |
+
assert updated_state.evolved_questions == ["First evolution", "Second evolution", "Challenging evolution"]
|
| 54 |
+
assert updated_state.evolved_question == "Challenging evolution"
|
| 55 |
+
|
| 56 |
+
def test_state_preservation():
|
| 57 |
+
# Test that other state fields are preserved
|
| 58 |
+
initial_state = SDGState(
|
| 59 |
+
input="Base question",
|
| 60 |
+
evolved_questions=["First evolution"],
|
| 61 |
+
documents=[],
|
| 62 |
+
context=["Some context"],
|
| 63 |
+
answer="Previous answer",
|
| 64 |
+
num_evolve_passes=5
|
| 65 |
+
)
|
| 66 |
+
mock_llm = MagicMock()
|
| 67 |
+
mock_llm.invoke.return_value = MagicMock(content="New evolution")
|
| 68 |
+
updated_state = evolve_question(initial_state, mock_llm)
|
| 69 |
+
|
| 70 |
+
# Check that all fields are preserved except evolved_questions
|
| 71 |
+
assert updated_state.input == initial_state.input
|
| 72 |
+
assert updated_state.documents == initial_state.documents
|
| 73 |
+
assert updated_state.context == initial_state.context
|
| 74 |
+
assert updated_state.answer == initial_state.answer
|
| 75 |
+
assert updated_state.num_evolve_passes == initial_state.num_evolve_passes
|
| 76 |
+
# Check that evolved_questions is updated correctly
|
| 77 |
+
assert len(updated_state.evolved_questions) == 2
|
| 78 |
+
assert updated_state.evolved_questions[0] == "First evolution"
|
| 79 |
+
assert updated_state.evolved_questions[1] == "New evolution"
|
tests/graph/test_build_graph.py
CHANGED
|
@@ -17,6 +17,8 @@ def test_build_sdg_graph_runs():
|
|
| 17 |
result = graph.invoke(state)
|
| 18 |
|
| 19 |
assert isinstance(result, dict)
|
| 20 |
-
assert "
|
|
|
|
|
|
|
| 21 |
assert result["context"]
|
| 22 |
assert "Relevant content" in result["context"][0]
|
|
|
|
| 17 |
result = graph.invoke(state)
|
| 18 |
|
| 19 |
assert isinstance(result, dict)
|
| 20 |
+
assert "evolved_questions" in result
|
| 21 |
+
if result["evolved_questions"]:
|
| 22 |
+
assert result["evolved_questions"][-1] == "Evolved test question"
|
| 23 |
assert result["context"]
|
| 24 |
assert "Relevant content" in result["context"][0]
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|