sheikhDipta003 commited on
Commit
2097562
·
1 Parent(s): 464856a

add all files

Browse files
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import asyncio
4
+ import time
5
+ from typing import Any, Dict
6
+ from src.enrichment_agent import graph
7
+ # from dotenv import load_dotenv
8
+ # load_dotenv()
9
+
10
+ import os
11
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
12
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
13
+
14
+ if TAVILY_API_KEY:
15
+ print("TAVILY_API_KEY found!")
16
+ else:
17
+ print("TAVILY_API_KEY not found. Please check your Secrets configuration.")
18
+
19
+ if OPENAI_API_KEY:
20
+ print("OPENAI_API_KEY found!")
21
+ else:
22
+ print("OPENAI_API_KEY not found. Please check your Secrets configuration.")
23
+
24
+ def extract_leaf_nodes(data, parent_key=''):
25
+ """Extract only the leaf nodes (keys without nested key-value pairs)."""
26
+ leaf_nodes = {}
27
+ for key, value in data.items():
28
+ new_key = f"{parent_key}.{key}" if parent_key else key
29
+ if isinstance(value, dict):
30
+ leaf_nodes.update(extract_leaf_nodes(value, new_key))
31
+ elif isinstance(value, list) and all(isinstance(item, dict) for item in value):
32
+ for idx, item in enumerate(value):
33
+ leaf_nodes.update(extract_leaf_nodes(item, f"{new_key}[{idx}]"))
34
+ else:
35
+ leaf_nodes[new_key] = value
36
+ return leaf_nodes
37
+
38
+ def agent_response(schema_json: str, topic: str):
39
+ try:
40
+ # Parse the schema JSON string
41
+ schema = json.loads(schema_json)
42
+ except json.JSONDecodeError:
43
+ return "Invalid JSON schema.", 0.0
44
+
45
+ async def fetch_data(schema: Dict[str, Any], topic: str) -> Dict[str, Any]:
46
+ return await graph.ainvoke({
47
+ "topic": topic,
48
+ "extraction_schema": schema,
49
+ })
50
+
51
+ # Measure processing time
52
+ start_time = time.time()
53
+ result = asyncio.run(fetch_data(schema, topic))
54
+ processing_time = time.time() - start_time
55
+
56
+ # Extract the 'info' dictionary from the result
57
+ info = result.get('info', {})
58
+
59
+ # Extract only the leaf nodes for display
60
+ leaf_nodes = extract_leaf_nodes(info)
61
+
62
+ # Format the key-value pairs as Markdown with newlines
63
+ display_data = "\n\n".join(f"**{key}**: {value}" for key, value in leaf_nodes.items())
64
+
65
+ return display_data, processing_time
66
+
67
+ # Define the Gradio interface
68
+ with gr.Blocks() as demo:
69
+ gr.Markdown(
70
+ """
71
+ <div style="text-align: center;">
72
+ <h1 style="color: #4CAF50;">🌟 Enrichment Agent Interface 🌟</h1>
73
+ <p style="font-size: 1.2em; color: #555;">
74
+ Dynamically extract and display information in a visually appealing format.
75
+ </p>
76
+ </div>
77
+ """
78
+ )
79
+
80
+ with gr.Row():
81
+ with gr.Column(scale=1):
82
+ gr.Markdown("### 🛠 Input")
83
+ schema_input = gr.Textbox(
84
+ label="Extraction Schema (JSON)",
85
+ value=json.dumps({
86
+ "type": "object",
87
+ "properties": {
88
+ "founder": {"type": "string", "description": "Name of the founder"},
89
+ "websiteUrl": {"type": "string", "description": "Website URL"},
90
+ "products_sold": {"type": "array", "items": {"type": "string"}}
91
+ },
92
+ "required": ["founder", "websiteUrl", "products_sold"]
93
+ }, indent=2),
94
+ lines=10,
95
+ placeholder="Enter the extraction schema in JSON format."
96
+ )
97
+ topic_input = gr.Textbox(label="Topic", placeholder="Enter the research topic, e.g., 'Google'")
98
+ submit_button = gr.Button("Submit 🚀")
99
+
100
+ with gr.Column(scale=2):
101
+ gr.Markdown("### 📊 Output")
102
+ output_display = gr.Markdown(label="Extracted Information")
103
+ time_display = gr.Textbox(label="Processing Time (seconds)", interactive=False)
104
+
105
+ def on_submit(schema, topic):
106
+ data, time_taken = agent_response(schema, topic)
107
+ return data, f"{time_taken:.2f}"
108
+
109
+ submit_button.click(on_submit, inputs=[schema_input, topic_input], outputs=[output_display, time_display])
110
+
111
+ # Launch the interface
112
+ demo.launch(share=True)
requirements.txt ADDED
Binary file (3.5 kB). View file
 
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (163 Bytes). View file
 
src/enrichment_agent/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Enrichment for a pre-defined schema."""
2
+
3
+ from .graph import graph
4
+
5
+ __all__ = ["graph"]
src/enrichment_agent/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (304 Bytes). View file
 
src/enrichment_agent/__pycache__/configuration.cpython-311.pyc ADDED
Binary file (3.44 kB). View file
 
src/enrichment_agent/__pycache__/graph.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
src/enrichment_agent/__pycache__/prompts.cpython-311.pyc ADDED
Binary file (756 Bytes). View file
 
src/enrichment_agent/__pycache__/state.cpython-311.pyc ADDED
Binary file (2.99 kB). View file
 
src/enrichment_agent/__pycache__/tools.cpython-311.pyc ADDED
Binary file (4.23 kB). View file
 
src/enrichment_agent/__pycache__/utils.cpython-311.pyc ADDED
Binary file (2.21 kB). View file
 
src/enrichment_agent/configuration.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Define the configurable parameters for the agent."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field, fields
6
+ from typing import Annotated, Optional
7
+
8
+ from langchain_core.runnables import RunnableConfig, ensure_config
9
+
10
+ from . import prompts
11
+
12
+
13
+ @dataclass(kw_only=True)
14
+ class Configuration:
15
+ """The configuration for the agent."""
16
+
17
+ model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
18
+ default="openai/gpt-3.5-turbo",
19
+ metadata={
20
+ "description": "The name of the language model to use for the agent. "
21
+ "Should be in the form: provider/model-name."
22
+ },
23
+ )
24
+
25
+ prompt: str = field(
26
+ default=prompts.MAIN_PROMPT,
27
+ metadata={
28
+ "description": "The main prompt template to use for the agent's interactions. "
29
+ "Expects two f-string arguments: {info} and {topic}."
30
+ },
31
+ )
32
+
33
+ max_search_results: int = field(
34
+ default=10,
35
+ metadata={
36
+ "description": "The maximum number of search results to return for each search query."
37
+ },
38
+ )
39
+
40
+ max_info_tool_calls: int = field(
41
+ default=3,
42
+ metadata={
43
+ "description": "The maximum number of times the Info tool can be called during a single interaction."
44
+ },
45
+ )
46
+
47
+ max_loops: int = field(
48
+ default=6,
49
+ metadata={
50
+ "description": "The maximum number of interaction loops allowed before the agent terminates."
51
+ },
52
+ )
53
+
54
+ @classmethod
55
+ def from_runnable_config(
56
+ cls, config: Optional[RunnableConfig] = None
57
+ ) -> Configuration:
58
+ """Load configuration w/ defaults for the given invocation."""
59
+ config = ensure_config(config)
60
+ configurable = config.get("configurable") or {}
61
+ _fields = {f.name for f in fields(cls) if f.init}
62
+ return cls(**{k: v for k, v in configurable.items() if k in _fields})
src/enrichment_agent/graph.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Define a data enrichment agent.
2
+
3
+ Works with a chat model with tool calling support.
4
+ """
5
+
6
+ import json
7
+ from typing import Any, Dict, List, Literal, Optional, cast
8
+
9
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
10
+ from langchain_core.runnables import RunnableConfig
11
+ from langgraph.graph import StateGraph
12
+ from langgraph.prebuilt import ToolNode
13
+ from pydantic import BaseModel, Field
14
+
15
+ from . import prompts
16
+ from .configuration import Configuration
17
+ from .state import InputState, OutputState, State
18
+ from .tools import scrape_website, search
19
+ from .utils import init_model
20
+
21
+
22
+ async def call_agent_model(
23
+ state: State, *, config: Optional[RunnableConfig] = None
24
+ ) -> Dict[str, Any]:
25
+ """Call the primary Language Model (LLM) to decide on the next research action.
26
+
27
+ This asynchronous function performs the following steps:
28
+ 1. Initializes configuration and sets up the 'Info' tool, which is the user-defined extraction schema.
29
+ 2. Prepares the prompt and message history for the LLM.
30
+ 3. Initializes and configures the LLM with available tools.
31
+ 4. Invokes the LLM and processes its response.
32
+ 5. Handles the LLM's decision to either continue research or submit final info.
33
+ """
34
+ # Load configuration from the provided RunnableConfig
35
+ configuration = Configuration.from_runnable_config(config)
36
+
37
+ # Define the 'Info' tool, which is the user-defined extraction schema
38
+ info_tool = {
39
+ "name": "Info",
40
+ "description": "Call this when you have gathered all the relevant info",
41
+ "parameters": state.extraction_schema,
42
+ }
43
+
44
+ # Format the prompt defined in prompts.py with the extraction schema and topic
45
+ p = configuration.prompt.format(
46
+ info=json.dumps(state.extraction_schema, indent=2), topic=state.topic
47
+ )
48
+
49
+ # Create the messages list with the formatted prompt and the previous messages
50
+ messages = [HumanMessage(content=p)] + state.messages
51
+
52
+ # Initialize the raw model with the provided configuration and bind the tools
53
+ raw_model = init_model(config)
54
+ model = raw_model.bind_tools([scrape_website, search, info_tool], tool_choice="any")
55
+ response = cast(AIMessage, await model.ainvoke(messages))
56
+
57
+ # Initialize info to None
58
+ info = None
59
+
60
+ # Check if the response has tool calls
61
+ if response.tool_calls:
62
+ for tool_call in response.tool_calls:
63
+ if tool_call["name"] == "Info":
64
+ info = tool_call["args"]
65
+ break
66
+ if info is not None:
67
+ # The agent is submitting their answer;
68
+ # ensure it isn't erroneously attempting to simultaneously perform research
69
+ response.tool_calls = [
70
+ next(tc for tc in response.tool_calls if tc["name"] == "Info")
71
+ ]
72
+ response_messages: List[BaseMessage] = [response]
73
+ if not response.tool_calls: # If LLM didn't respect the tool_choice
74
+ response_messages.append(
75
+ HumanMessage(content="Please respond by calling one of the provided tools.")
76
+ )
77
+ return {
78
+ "messages": response_messages,
79
+ "info": info,
80
+ # Add 1 to the step count
81
+ "loop_step": 1,
82
+ }
83
+
84
+
85
+ class InfoIsSatisfactory(BaseModel):
86
+ """Validate whether the current extracted info is satisfactory and complete."""
87
+
88
+ reason: List[str] = Field(
89
+ description="First, provide reasoning for why this is either good or bad as a final result. Must include at least 3 reasons."
90
+ )
91
+ is_satisfactory: bool = Field(
92
+ description="After providing your reasoning, provide a value indicating whether the result is satisfactory. If not, you will continue researching."
93
+ )
94
+ improvement_instructions: Optional[str] = Field(
95
+ description="If the result is not satisfactory, provide clear and specific instructions on what needs to be improved or added to make the information satisfactory."
96
+ " This should include details on missing information, areas that need more depth, or specific aspects to focus on in further research.",
97
+ default=None,
98
+ )
99
+
100
+
101
+ async def reflect(
102
+ state: State, *, config: Optional[RunnableConfig] = None
103
+ ) -> Dict[str, Any]:
104
+ """Validate the quality of the data enrichment agent's output.
105
+
106
+ This asynchronous function performs the following steps:
107
+ 1. Prepares the initial prompt using the main prompt template.
108
+ 2. Constructs a message history for the model.
109
+ 3. Prepares a checker prompt to evaluate the presumed info.
110
+ 4. Initializes and configures a language model with structured output.
111
+ 5. Invokes the model to assess the quality of the gathered information.
112
+ 6. Processes the model's response and determines if the info is satisfactory.
113
+ """
114
+ p = prompts.MAIN_PROMPT.format(
115
+ info=json.dumps(state.extraction_schema, indent=2), topic=state.topic
116
+ )
117
+ last_message = state.messages[-1]
118
+ if not isinstance(last_message, AIMessage):
119
+ raise ValueError(
120
+ f"{reflect.__name__} expects the last message in the state to be an AI message with tool calls."
121
+ f" Got: {type(last_message)}"
122
+ )
123
+ messages = [HumanMessage(content=p)] + state.messages[:-1]
124
+ presumed_info = state.info
125
+ checker_prompt = """I am thinking of calling the info tool with the info below. \
126
+ Is this good? Give your reasoning as well. \
127
+ You can encourage the Assistant to look at specific URLs if that seems relevant, or do more searches.
128
+ If you don't think it is good, you should be very specific about what could be improved.
129
+
130
+ {presumed_info}"""
131
+ p1 = checker_prompt.format(presumed_info=json.dumps(presumed_info or {}, indent=2))
132
+ messages.append(HumanMessage(content=p1))
133
+ raw_model = init_model(config)
134
+ bound_model = raw_model.with_structured_output(InfoIsSatisfactory)
135
+ response = cast(InfoIsSatisfactory, await bound_model.ainvoke(messages))
136
+ if response.is_satisfactory and presumed_info:
137
+ return {
138
+ "info": presumed_info,
139
+ "messages": [
140
+ ToolMessage(
141
+ tool_call_id=last_message.tool_calls[0]["id"],
142
+ content="\n".join(response.reason),
143
+ name="Info",
144
+ additional_kwargs={"artifact": response.model_dump()},
145
+ status="success",
146
+ )
147
+ ],
148
+ }
149
+ else:
150
+ return {
151
+ "messages": [
152
+ ToolMessage(
153
+ tool_call_id=last_message.tool_calls[0]["id"],
154
+ content=f"Unsatisfactory response:\n{response.improvement_instructions}",
155
+ name="Info",
156
+ additional_kwargs={"artifact": response.model_dump()},
157
+ status="error",
158
+ )
159
+ ]
160
+ }
161
+
162
+
163
+ def route_after_agent(
164
+ state: State,
165
+ ) -> Literal["reflect", "tools", "call_agent_model", "__end__"]:
166
+ """Schedule the next node after the agent's action.
167
+
168
+ This function determines the next step in the research process based on the
169
+ last message in the state. It handles three main scenarios:
170
+
171
+ 1. Error recovery: If the last message is unexpectedly not an AIMessage.
172
+ 2. Info submission: If the agent has called the "Info" tool to submit findings.
173
+ 3. Continued research: If the agent has called any other tool.
174
+ """
175
+ last_message = state.messages[-1]
176
+
177
+ # "If for some reason the last message is not an AIMessage (due to a bug or unexpected behavior elsewhere in the code),
178
+ # it ensures the system doesn't crash but instead tries to recover by calling the agent model again.
179
+ if not isinstance(last_message, AIMessage):
180
+ return "call_agent_model"
181
+ # If the "Into" tool was called, then the model provided its extraction output. Reflect on the result
182
+ if last_message.tool_calls and last_message.tool_calls[0]["name"] == "Info":
183
+ return "reflect"
184
+ # The last message is a tool call that is not "Info" (extraction output)
185
+ else:
186
+ return "tools"
187
+
188
+
189
+ def route_after_checker(
190
+ state: State, config: RunnableConfig
191
+ ) -> Literal["__end__", "call_agent_model"]:
192
+ """Schedule the next node after the checker's evaluation.
193
+
194
+ This function determines whether to continue the research process or end it
195
+ based on the checker's evaluation and the current state of the research.
196
+ """
197
+ configurable = Configuration.from_runnable_config(config)
198
+ last_message = state.messages[-1]
199
+
200
+ if state.loop_step < configurable.max_loops:
201
+ if not state.info:
202
+ return "call_agent_model"
203
+ if not isinstance(last_message, ToolMessage):
204
+ raise ValueError(
205
+ f"{route_after_checker.__name__} expected a tool messages. Received: {type(last_message)}."
206
+ )
207
+ if last_message.status == "error":
208
+ # Research deemed unsatisfactory
209
+ return "call_agent_model"
210
+ # It's great!
211
+ return "__end__"
212
+ else:
213
+ return "__end__"
214
+
215
+
216
+ # Create the graph
217
+ workflow = StateGraph(
218
+ State, input=InputState, output=OutputState, config_schema=Configuration
219
+ )
220
+ workflow.add_node(call_agent_model)
221
+ workflow.add_node(reflect)
222
+ workflow.add_node("tools", ToolNode([search, scrape_website]))
223
+ workflow.add_edge("__start__", "call_agent_model")
224
+ workflow.add_conditional_edges("call_agent_model", route_after_agent)
225
+ workflow.add_edge("tools", "call_agent_model")
226
+ workflow.add_conditional_edges("reflect", route_after_checker)
227
+
228
+ graph = workflow.compile()
229
+ graph.name = "ResearchTopic"
src/enrichment_agent/prompts.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Default prompts used in this project."""
2
+
3
+ MAIN_PROMPT = """You are doing web research on behalf of a user. You are trying to figure out this information:
4
+
5
+ <info>
6
+ {info}
7
+ </info>
8
+
9
+ You have access to the following tools:
10
+
11
+ - `Search`: call a search tool and get back some results
12
+ - `ScrapeWebsite`: scrape a website and get relevant notes about the given request. This will update the notes above.
13
+ - `Info`: call this when you are done and have gathered all the relevant info
14
+
15
+ Here is the information you have about the topic you are researching:
16
+
17
+ Topic: {topic}"""
src/enrichment_agent/state.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """State definitions.
2
+
3
+ State is the interface between the graph and end user as well as the
4
+ data model used internally by the graph.
5
+ """
6
+
7
+ import operator
8
+ from dataclasses import dataclass, field
9
+ from typing import Annotated, Any, List, Optional
10
+
11
+ from langchain_core.messages import BaseMessage
12
+ from langgraph.graph import add_messages
13
+
14
+
15
+ @dataclass(kw_only=True)
16
+ class InputState:
17
+ """Input state defines the interface between the graph and the user (external API)."""
18
+
19
+ topic: str
20
+ "The topic for which the agent is tasked to gather information."
21
+
22
+ extraction_schema: dict[str, Any]
23
+ "The json schema defines the information the agent is tasked with filling out."
24
+
25
+ info: Optional[dict[str, Any]] = field(default=None)
26
+ "The info state tracks the current extracted data for the given topic, conforming to the provided schema. This is primarily populated by the agent."
27
+
28
+
29
+ @dataclass(kw_only=True)
30
+ class State(InputState):
31
+ """A graph's State defines three main things.
32
+
33
+ 1. The structure of the data to be passed between nodes (which "channels" to read from/write to and their types)
34
+ 2. Default values for each field
35
+ 3. Reducers for the state's fields. Reducers are functions that determine how to apply updates to the state.
36
+ See [Reducers](https://langchain-ai.github.io/langgraph/concepts/low_level/#reducers) for more information.
37
+ """
38
+
39
+ messages: Annotated[List[BaseMessage], add_messages] = field(default_factory=list)
40
+ """
41
+ Messages track the primary execution state of the agent.
42
+
43
+ Typically accumulates a pattern of:
44
+
45
+ 1. HumanMessage - user input
46
+ 2. AIMessage with .tool_calls - agent picking tool(s) to use to collect
47
+ information
48
+ 3. ToolMessage(s) - the responses (or errors) from the executed tools
49
+
50
+ (... repeat steps 2 and 3 as needed ...)
51
+ 4. AIMessage without .tool_calls - agent responding in unstructured
52
+ format to the user.
53
+
54
+ 5. HumanMessage - user responds with the next conversational turn.
55
+
56
+ (... repeat steps 2-5 as needed ... )
57
+
58
+ Merges two lists of messages, updating existing messages by ID.
59
+
60
+ By default, this ensures the state is "append-only", unless the
61
+ new message has the same ID as an existing message.
62
+
63
+ Returns:
64
+ A new list of messages with the messages from `right` merged into `left`.
65
+ If a message in `right` has the same ID as a message in `left`, the
66
+ message from `right` will replace the message from `left`.
67
+ """
68
+
69
+ loop_step: Annotated[int, operator.add] = field(default=0)
70
+
71
+ # Feel free to add additional attributes to your state as needed.
72
+ # Common examples include retrieved documents, extracted entities, API connections, etc.
73
+
74
+
75
+ @dataclass(kw_only=True)
76
+ class OutputState:
77
+ """The response object for the end user.
78
+
79
+ This class defines the structure of the output that will be provided
80
+ to the user after the graph's execution is complete.
81
+ """
82
+
83
+ info: dict[str, Any]
84
+ """
85
+ A dictionary containing the extracted and processed information
86
+ based on the user's query and the graph's execution.
87
+ This is the primary output of the enrichment process.
88
+ """
src/enrichment_agent/tools.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tools for data enrichment.
2
+
3
+ This module contains functions that are directly exposed to the LLM as tools.
4
+ These tools can be used for tasks such as web searching and scraping.
5
+ Users can edit and extend these tools as needed.
6
+ """
7
+
8
+ import json
9
+ from typing import Any, Optional, cast
10
+
11
+ import aiohttp
12
+ from langchain_community.tools.tavily_search import TavilySearchResults
13
+ from langchain_core.runnables import RunnableConfig
14
+ from langchain_core.tools import InjectedToolArg
15
+ from langgraph.prebuilt import InjectedState
16
+ from typing_extensions import Annotated
17
+
18
+ from .configuration import Configuration
19
+ from .state import State
20
+ from .utils import init_model
21
+
22
+
23
+ async def search(
24
+ query: str, *, config: Annotated[RunnableConfig, InjectedToolArg]
25
+ ) -> Optional[list[dict[str, Any]]]:
26
+ """Query a search engine.
27
+
28
+ This function queries the web to fetch comprehensive, accurate, and trusted results. It's particularly useful
29
+ for answering questions about current events. Provide as much context in the query as needed to ensure high recall.
30
+ """
31
+ configuration = Configuration.from_runnable_config(config)
32
+ wrapped = TavilySearchResults(max_results=configuration.max_search_results)
33
+ result = await wrapped.ainvoke({"query": query})
34
+ return cast(list[dict[str, Any]], result)
35
+
36
+
37
+ _INFO_PROMPT = """You are doing web research on behalf of a user. You are trying to find out this information:
38
+
39
+ <info>
40
+ {info}
41
+ </info>
42
+
43
+ You just scraped the following website: {url}
44
+
45
+ Based on the website content below, jot down some notes about the website.
46
+
47
+ <Website content>
48
+ {content}
49
+ </Website content>"""
50
+
51
+
52
+ async def scrape_website(
53
+ url: str,
54
+ *,
55
+ state: Annotated[State, InjectedState],
56
+ config: Annotated[RunnableConfig, InjectedToolArg],
57
+ ) -> str:
58
+ """Scrape and summarize content from a given URL.
59
+
60
+ Returns:
61
+ str: A summary of the scraped content, tailored to the extraction schema.
62
+ """
63
+ async with aiohttp.ClientSession() as session:
64
+ async with session.get(url) as response:
65
+ content = await response.text()
66
+
67
+ p = _INFO_PROMPT.format(
68
+ info=json.dumps(state.extraction_schema, indent=2),
69
+ url=url,
70
+ content=content[:40_000],
71
+ )
72
+ raw_model = init_model(config)
73
+ result = await raw_model.ainvoke(p)
74
+ return str(result.content)
src/enrichment_agent/utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions used in our graph."""
2
+
3
+ from typing import Optional
4
+
5
+ from langchain.chat_models import init_chat_model
6
+ from langchain_core.language_models import BaseChatModel
7
+ from langchain_core.messages import AnyMessage
8
+ from langchain_core.runnables import RunnableConfig
9
+
10
+ from .configuration import Configuration
11
+
12
+
13
+ def get_message_text(msg: AnyMessage) -> str:
14
+ """Get the text content of a message."""
15
+ content = msg.content
16
+ if isinstance(content, str):
17
+ return content
18
+ elif isinstance(content, dict):
19
+ return content.get("text", "")
20
+ else:
21
+ txts = [c if isinstance(c, str) else (c.get("text") or "") for c in content]
22
+ return "".join(txts).strip()
23
+
24
+
25
+ def init_model(config: Optional[RunnableConfig] = None) -> BaseChatModel:
26
+ """Initialize the configured chat model."""
27
+ configuration = Configuration.from_runnable_config(config)
28
+ fully_specified_name = configuration.model
29
+ if "/" in fully_specified_name:
30
+ provider, model = fully_specified_name.split("/", maxsplit=1)
31
+ else:
32
+ provider = None
33
+ model = fully_specified_name
34
+ return init_chat_model(model, model_provider=provider)