Spaces:
Running
Running
| import json | |
| # Import relevant functionality | |
| from langchain_core.messages import HumanMessage, AIMessage, SystemMessage | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langgraph.prebuilt import create_react_agent | |
| from langchain_anthropic import ChatAnthropic | |
| import sys | |
| sys.path.append('.') | |
| from toolformers.base import Tool as AgoraTool | |
| from langchain_core.tools import tool as function_to_tool | |
| from toolformers.base import StringParameter, Toolformer, Conversation | |
| from utils import register_cost | |
| COSTS = { | |
| 'claude-3-5-sonnet-latest': { | |
| 'input_tokens': 3e-6, | |
| 'output_tokens': 15e-6 | |
| }, | |
| 'claude-3-5-haiku-latest': { | |
| 'input_tokens': 1e-6, | |
| 'output_tokens': 5e-6 | |
| } | |
| } | |
| class LangChainConversation(Conversation): | |
| def __init__(self, model_name, agent, messages, category=None): | |
| self.model_name = model_name | |
| self.agent = agent | |
| self.messages = messages | |
| self.category = category | |
| def chat(self, message, role='user', print_output=True) -> str: | |
| self.messages.append(HumanMessage(content=message)) | |
| final_message = '' | |
| aggregate = None | |
| for chunk in self.agent.stream({"messages": self.messages}, stream_mode="values"): | |
| print(chunk) | |
| print("----") | |
| for message in chunk['messages']: | |
| if isinstance(message, AIMessage): | |
| content = message.content | |
| if isinstance(content, str): | |
| final_message += content | |
| else: | |
| for content_chunk in content: | |
| if isinstance(content_chunk, str): | |
| final_message += content_chunk | |
| aggregate = chunk if aggregate is None else (aggregate + chunk) | |
| #final_message += chunk['agent']['messages'].content | |
| total_cost = 0 | |
| for message in aggregate['messages']: | |
| if isinstance(message, AIMessage): | |
| for cost_name in ['input_tokens', 'output_tokens']: | |
| total_cost += COSTS[self.model_name][cost_name] * message.usage_metadata[cost_name] | |
| register_cost(self.category, total_cost) | |
| self.messages.append(AIMessage(content=final_message)) | |
| #print(final_message) | |
| return final_message | |
| class LangChainAnthropicToolformer(Toolformer): | |
| def __init__(self, model_name, api_key): | |
| self.model_name = model_name | |
| self.api_key = api_key | |
| def new_conversation(self, prompt, tools, category=None): | |
| tools = [function_to_tool(tool.as_annotated_function()) for tool in tools] | |
| model = ChatAnthropic(model_name=self.model_name, api_key=self.api_key) | |
| agent_executor = create_react_agent(model, tools) | |
| return LangChainConversation(self.model_name, agent_executor, [SystemMessage(prompt)], category) | |
| #weather_tool = AgoraTool("WeatherForecastAPI", "A simple tool that returns the weather", [StringParameter( | |
| # name="location", | |
| # description="The name of the location for which the weather forecast is requested.", | |
| # required=True | |
| #)], lambda location: 'Sunny', { | |
| # "type": "string" | |
| #}) | |
| # | |
| #tools = [agora_tool_to_langchain(weather_tool)] | |
| #toolformer = LangChainToolformer("claude-3-sonnet-20240229", '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') | |
| #conversation = toolformer.new_conversation('You are a weather bot', [weather_tool]) | |
| # | |
| #print(conversation.chat('What is the weather in San Francisco?')) | |