cfpb-assistant / agent.py
ofermend's picture
updated
af169fc
raw
history blame
5.47 kB
import os
from typing import Optional
from pydantic import Field, BaseModel
from omegaconf import OmegaConf
from llama_index.core.utilities.sql_wrapper import SQLDatabase
from sqlalchemy import create_engine
from dotenv import load_dotenv
load_dotenv(override=True)
from vectara_agentic.agent import Agent
from vectara_agentic.tools import ToolsFactory, VectaraToolFactory
from vectara_agentic.types import ModelProvider, AgentType
from vectara_agentic.agent_config import AgentConfig
def create_assistant_tools(cfg):
class QueryCFPBComplaints(BaseModel):
company: Optional[str] = Field(
default=None,
description="The company that the complaint is about.",
examples=['CAPITAL ONE FINANCIAL CORPORATION', 'BANK OF AMERICA, NATIONAL ASSOCIATION', 'CITIBANK, N.A.', 'WELLS FARGO & COMPANY', 'JPMORGAN CHASE & CO.']
)
state: Optional[str] = Field(
default=None,
description="The two-character state code where the consumer lives.",
examples=['CA', 'FL', 'NY', 'TX', 'GA']
)
vec_factory = VectaraToolFactory(
vectara_api_key=cfg.api_keys,
vectara_corpus_key=cfg.corpus_keys
)
summarizer = 'vectara-summary-table-md-query-ext-jan-2025-gpt-4o'
ask_complaints = vec_factory.create_rag_tool(
tool_name = "ask_complaints",
tool_description = """
Given a user query,
returns a response to a user question about customer complaints for bank services.
""",
tool_args_schema = QueryCFPBComplaints,
reranker = "chain", rerank_k = 100,
rerank_chain = [
{
"type": "slingshot",
"cutoff": 0.2
},
{
"type": "mmr",
"diversity_bias": 0.2,
}
],
n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
summary_num_results = 10,
max_tokens = 4096, max_response_chars = 8192,
vectara_summarizer = summarizer,
include_citations = True,
verbose = True
)
tools_factory = ToolsFactory()
db_tools = tools_factory.database_tools(
tool_name_prefix = "cfpb",
content_description = "Customer complaints about five banks (Bank of America, Wells Fargo, Capital One, Chase, and CITI Bank) and geographic information (counties and zip codes)",
sql_database = SQLDatabase(create_engine('sqlite:///cfpb_database.db')),
)
return (tools_factory.standard_tools() +
db_tools +
[ask_complaints]
)
def initialize_agent(_cfg, agent_progress_callback=None):
cfpb_complaints_bot_instructions = """
- You are a helpful research assistant in conversation with a user.
- You are in expert in the domain of complaints recorded by the CFPB (Consumer Financial Protection Bureau).
- For informational questions about customer complaints, use the 'ask_complaints' tool.
- For analytical questions, use the database tools: cfpb_load_data, cfpb_load_sample_data, cfpb_list_tables, cfpb_describe_tables and cfpb_load_unique_values.
- Never discuss politics, and always respond politely.
"""
agent_config = AgentConfig(
agent_type = os.getenv("VECTARA_AGENTIC_AGENT_TYPE", AgentType.OPENAI.value),
main_llm_provider = os.getenv("VECTARA_AGENTIC_MAIN_LLM_PROVIDER", ModelProvider.OPENAI.value),
main_llm_model_name = os.getenv("VECTARA_AGENTIC_MAIN_MODEL_NAME", ""),
tool_llm_provider = os.getenv("VECTARA_AGENTIC_TOOL_LLM_PROVIDER", ModelProvider.OPENAI.value),
tool_llm_model_name = os.getenv("VECTARA_AGENTIC_TOOL_MODEL_NAME", ""),
observer = os.getenv("VECTARA_AGENTIC_OBSERVER_TYPE", "NO_OBSERVER")
)
fallback_agent_config = AgentConfig(
agent_type = os.getenv("VECTARA_AGENTIC_FALLBACK_AGENT_TYPE", AgentType.OPENAI.value),
main_llm_provider = os.getenv("VECTARA_AGENTIC_FALLBACK_MAIN_LLM_PROVIDER", ModelProvider.OPENAI.value),
main_llm_model_name = os.getenv("VECTARA_AGENTIC_FALLBACK_MAIN_MODEL_NAME", ""),
tool_llm_provider = os.getenv("VECTARA_AGENTIC_FALLBACK_TOOL_LLM_PROVIDER", ModelProvider.OPENAI.value),
tool_llm_model_name = os.getenv("VECTARA_AGENTIC_FALLBACK_TOOL_MODEL_NAME", ""),
observer = os.getenv("VECTARA_AGENTIC_OBSERVER_TYPE", "NO_OBSERVER")
)
agent = Agent(
tools=create_assistant_tools(_cfg),
topic="Customer complaints from the Consumer Financial Protection Bureau (CFPB)",
custom_instructions=cfpb_complaints_bot_instructions,
agent_progress_callback=agent_progress_callback,
validate_tools=True,
verbose=True,
agent_config=agent_config,
fallback_agent_config=fallback_agent_config,
)
agent.report(detailed=False)
return agent
def get_agent_config() -> OmegaConf:
cfg = OmegaConf.create({
'corpus_keys': str(os.environ['VECTARA_CORPUS_KEYS']),
'api_keys': str(os.environ['VECTARA_API_KEYS']),
'examples': os.environ.get('QUERY_EXAMPLES', None),
'demo_name': "cfpb-assistant",
'demo_welcome': "Welcome to the CFPB Customer Complaints demo.",
'demo_description': "This assistant can help you gain insights into customer complaints to banks recorded by the Consumer Financial Protection Bureau.",
})
return cfg