Spaces:
Runtime error
Runtime error
Create 1_Simple-Gemini.py
Browse files- 1_Simple-Gemini.py +187 -0
1_Simple-Gemini.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain import hub
|
2 |
+
from langchain.agents import Tool, create_react_agent
|
3 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
4 |
+
from langchain_community.utilities import GoogleSerperAPIWrapper
|
5 |
+
import os
|
6 |
+
from typing import TypedDict, Annotated, Union
|
7 |
+
from langchain_core.agents import AgentAction, AgentFinish
|
8 |
+
from langchain_core.messages import BaseMessage
|
9 |
+
import operator
|
10 |
+
from typing import TypedDict, Annotated
|
11 |
+
from langchain_core.agents import AgentFinish
|
12 |
+
from langgraph.prebuilt.tool_executor import ToolExecutor
|
13 |
+
from langgraph.prebuilt import ToolInvocation
|
14 |
+
from langgraph.graph import END, StateGraph
|
15 |
+
from langchain_core.agents import AgentActionMessageLog
|
16 |
+
import streamlit as st
|
17 |
+
|
18 |
+
st.set_page_config(page_title="LangChain Agent", layout="wide")
|
19 |
+
|
20 |
+
def main():
|
21 |
+
# Streamlit UI elements
|
22 |
+
st.title("LangGraph Agent + Gemini Pro + Custom Tool + Streamlit")
|
23 |
+
|
24 |
+
# Input from user
|
25 |
+
input_text = st.text_area("Enter your text:")
|
26 |
+
|
27 |
+
if st.button("Run Agent"):
|
28 |
+
|
29 |
+
os.environ["SERPER_API_KEY"] = "YOUR-KEY-API"
|
30 |
+
|
31 |
+
search = GoogleSerperAPIWrapper()
|
32 |
+
|
33 |
+
|
34 |
+
def toggle_case(word):
|
35 |
+
toggled_word = ""
|
36 |
+
for char in word:
|
37 |
+
if char.islower():
|
38 |
+
toggled_word += char.upper()
|
39 |
+
elif char.isupper():
|
40 |
+
toggled_word += char.lower()
|
41 |
+
else:
|
42 |
+
toggled_word += char
|
43 |
+
return toggled_word
|
44 |
+
|
45 |
+
def sort_string(string):
|
46 |
+
return ''.join(sorted(string))
|
47 |
+
|
48 |
+
|
49 |
+
tools = [
|
50 |
+
Tool(
|
51 |
+
name = "Search",
|
52 |
+
func=search.run,
|
53 |
+
description="useful for when you need to answer questions about current events",
|
54 |
+
),
|
55 |
+
Tool(
|
56 |
+
name = "Toogle_Case",
|
57 |
+
func = lambda word: toggle_case(word),
|
58 |
+
description = "use when you want covert the letter to uppercase or lowercase",
|
59 |
+
),
|
60 |
+
Tool(
|
61 |
+
name = "Sort String",
|
62 |
+
func = lambda string: sort_string(string),
|
63 |
+
description = "use when you want sort a string alphabetically",
|
64 |
+
),
|
65 |
+
]
|
66 |
+
|
67 |
+
prompt = hub.pull("hwchase17/react")
|
68 |
+
|
69 |
+
llm = ChatGoogleGenerativeAI(model="gemini-pro",
|
70 |
+
google_api_key="Your_API_KEY",
|
71 |
+
convert_system_message_to_human = True,
|
72 |
+
verbose = True,
|
73 |
+
)
|
74 |
+
|
75 |
+
# agent_runnable = create_react_agent(llm, tools, prompt)
|
76 |
+
|
77 |
+
class AgentState(TypedDict):
|
78 |
+
input: str
|
79 |
+
chat_history: list[BaseMessage]
|
80 |
+
agent_outcome: Union[AgentAction, AgentFinish, None]
|
81 |
+
return_direct: bool
|
82 |
+
intermediate_steps: Annotated[list[tuple[AgentAction, str]], operator.add]
|
83 |
+
|
84 |
+
tool_executor = ToolExecutor(tools)
|
85 |
+
|
86 |
+
|
87 |
+
def run_agent(state):
|
88 |
+
"""
|
89 |
+
#if you want to better manages intermediate steps
|
90 |
+
inputs = state.copy()
|
91 |
+
if len(inputs['intermediate_steps']) > 5:
|
92 |
+
inputs['intermediate_steps'] = inputs['intermediate_steps'][-5:]
|
93 |
+
"""
|
94 |
+
agent_outcome = agent_runnable.invoke(state)
|
95 |
+
return {"agent_outcome": agent_outcome}
|
96 |
+
|
97 |
+
def execute_tools(state):
|
98 |
+
|
99 |
+
messages = [state['agent_outcome'] ]
|
100 |
+
last_message = messages[-1]
|
101 |
+
######### human in the loop ###########
|
102 |
+
# human input y/n
|
103 |
+
# Get the most recent agent_outcome - this is the key added in the `agent` above
|
104 |
+
# state_action = state['agent_outcome']
|
105 |
+
# human_key = input(f"[y/n] continue with: {state_action}?")
|
106 |
+
# if human_key == "n":
|
107 |
+
# raise ValueError
|
108 |
+
|
109 |
+
tool_name = last_message.tool
|
110 |
+
arguments = last_message
|
111 |
+
if tool_name == "Search" or tool_name == "Sort" or tool_name == "Toggle_Case":
|
112 |
+
|
113 |
+
if "return_direct" in arguments:
|
114 |
+
del arguments["return_direct"]
|
115 |
+
action = ToolInvocation(
|
116 |
+
tool=tool_name,
|
117 |
+
tool_input= last_message.tool_input,
|
118 |
+
)
|
119 |
+
response = tool_executor.invoke(action)
|
120 |
+
return {"intermediate_steps": [(state['agent_outcome'],response)]}
|
121 |
+
|
122 |
+
def should_continue(state):
|
123 |
+
|
124 |
+
messages = [state['agent_outcome'] ]
|
125 |
+
last_message = messages[-1]
|
126 |
+
if "Action" not in last_message.log:
|
127 |
+
return "end"
|
128 |
+
else:
|
129 |
+
arguments = state["return_direct"]
|
130 |
+
if arguments is True:
|
131 |
+
return "final"
|
132 |
+
else:
|
133 |
+
return "continue"
|
134 |
+
|
135 |
+
|
136 |
+
def first_agent(inputs):
|
137 |
+
action = AgentActionMessageLog(
|
138 |
+
tool="Search",
|
139 |
+
tool_input=inputs["input"],
|
140 |
+
log="",
|
141 |
+
message_log=[]
|
142 |
+
)
|
143 |
+
return {"agent_outcome": action}
|
144 |
+
|
145 |
+
workflow = StateGraph(AgentState)
|
146 |
+
|
147 |
+
workflow.add_node("agent", run_agent)
|
148 |
+
workflow.add_node("action", execute_tools)
|
149 |
+
workflow.add_node("final", execute_tools)
|
150 |
+
# uncomment if you want to always calls a certain tool first
|
151 |
+
# workflow.add_node("first_agent", first_agent)
|
152 |
+
|
153 |
+
|
154 |
+
workflow.set_entry_point("agent")
|
155 |
+
# uncomment if you want to always calls a certain tool first
|
156 |
+
# workflow.set_entry_point("first_agent")
|
157 |
+
|
158 |
+
workflow.add_conditional_edges(
|
159 |
+
|
160 |
+
"agent",
|
161 |
+
should_continue,
|
162 |
+
|
163 |
+
{
|
164 |
+
"continue": "action",
|
165 |
+
"final": "final",
|
166 |
+
"end": END
|
167 |
+
}
|
168 |
+
)
|
169 |
+
|
170 |
+
|
171 |
+
workflow.add_edge('action', 'agent')
|
172 |
+
workflow.add_edge('final', END)
|
173 |
+
# uncomment if you want to always calls a certain tool first
|
174 |
+
# workflow.add_edge('first_agent', 'action')
|
175 |
+
app = workflow.compile()
|
176 |
+
|
177 |
+
inputs = {"input": input_text, "chat_history": [], "return_direct": False}
|
178 |
+
results = []
|
179 |
+
for s in app.stream(inputs):
|
180 |
+
result = list(s.values())[0]
|
181 |
+
results.append(result)
|
182 |
+
st.write(result) # Display each step's output
|
183 |
+
|
184 |
+
#result = app.invoke({"input": input_text, "chat_history": [], "return_direct": False})
|
185 |
+
|
186 |
+
#print(result["agent_outcome"].return_values["output"])
|
187 |
+
|