Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Optional | |
| from pydantic import Field, BaseModel | |
| from omegaconf import OmegaConf | |
| from llama_index.core.utilities.sql_wrapper import SQLDatabase | |
| from sqlalchemy import create_engine | |
| from dotenv import load_dotenv | |
| load_dotenv(override=True) | |
| from vectara_agentic.agent import Agent | |
| from vectara_agentic.tools import ToolsFactory, VectaraToolFactory | |
| from vectara_agentic.types import ModelProvider, AgentType | |
| from vectara_agentic.agent_config import AgentConfig | |
| def create_assistant_tools(cfg): | |
| class QueryCFPBComplaints(BaseModel): | |
| company: Optional[str] = Field( | |
| default=None, | |
| description="The company that the complaint is about.", | |
| examples=['CAPITAL ONE FINANCIAL CORPORATION', 'BANK OF AMERICA, NATIONAL ASSOCIATION', 'CITIBANK, N.A.', 'WELLS FARGO & COMPANY', 'JPMORGAN CHASE & CO.'] | |
| ) | |
| state: Optional[str] = Field( | |
| default=None, | |
| description="The two-character state code where the consumer lives.", | |
| examples=['CA', 'FL', 'NY', 'TX', 'GA'] | |
| ) | |
| vec_factory = VectaraToolFactory( | |
| vectara_api_key=cfg.api_keys, | |
| vectara_corpus_key=cfg.corpus_keys | |
| ) | |
| summarizer = 'vectara-summary-table-md-query-ext-jan-2025-gpt-4o' | |
| ask_complaints = vec_factory.create_rag_tool( | |
| tool_name = "ask_complaints", | |
| tool_description = """ | |
| Given a user query, | |
| returns a response to a user question about customer complaints for bank services. | |
| """, | |
| tool_args_schema = QueryCFPBComplaints, | |
| reranker = "chain", rerank_k = 100, | |
| rerank_chain = [ | |
| { | |
| "type": "slingshot", | |
| "cutoff": 0.2 | |
| }, | |
| { | |
| "type": "mmr", | |
| "diversity_bias": 0.2, | |
| } | |
| ], | |
| n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005, | |
| summary_num_results = 10, | |
| max_tokens = 4096, max_response_chars = 8192, | |
| vectara_summarizer = summarizer, | |
| include_citations = True, | |
| verbose = True | |
| ) | |
| tools_factory = ToolsFactory() | |
| db_tools = tools_factory.database_tools( | |
| tool_name_prefix = "cfpb", | |
| content_description = "Customer complaints about five banks (Bank of America, Wells Fargo, Capital One, Chase, and CITI Bank) and geographic information (counties and zip codes)", | |
| sql_database = SQLDatabase(create_engine('sqlite:///cfpb_database.db')), | |
| ) | |
| return (tools_factory.standard_tools() + | |
| db_tools + | |
| [ask_complaints] | |
| ) | |
| def initialize_agent(_cfg, agent_progress_callback=None): | |
| cfpb_complaints_bot_instructions = """ | |
| - You are a helpful research assistant in conversation with a user. | |
| - You are in expert in the domain of complaints recorded by the CFPB (Consumer Financial Protection Bureau). | |
| - For informational questions about customer complaints, use the 'ask_complaints' tool. | |
| - For analytical questions, use the database tools: cfpb_load_data, cfpb_load_sample_data, cfpb_list_tables, cfpb_describe_tables and cfpb_load_unique_values. | |
| - Never discuss politics, and always respond politely. | |
| """ | |
| agent_config = AgentConfig( | |
| agent_type = os.getenv("VECTARA_AGENTIC_AGENT_TYPE", AgentType.OPENAI.value), | |
| main_llm_provider = os.getenv("VECTARA_AGENTIC_MAIN_LLM_PROVIDER", ModelProvider.OPENAI.value), | |
| main_llm_model_name = os.getenv("VECTARA_AGENTIC_MAIN_MODEL_NAME", ""), | |
| tool_llm_provider = os.getenv("VECTARA_AGENTIC_TOOL_LLM_PROVIDER", ModelProvider.OPENAI.value), | |
| tool_llm_model_name = os.getenv("VECTARA_AGENTIC_TOOL_MODEL_NAME", ""), | |
| observer = os.getenv("VECTARA_AGENTIC_OBSERVER_TYPE", "NO_OBSERVER") | |
| ) | |
| fallback_agent_config = AgentConfig( | |
| agent_type = os.getenv("VECTARA_AGENTIC_FALLBACK_AGENT_TYPE", AgentType.OPENAI.value), | |
| main_llm_provider = os.getenv("VECTARA_AGENTIC_FALLBACK_MAIN_LLM_PROVIDER", ModelProvider.OPENAI.value), | |
| main_llm_model_name = os.getenv("VECTARA_AGENTIC_FALLBACK_MAIN_MODEL_NAME", ""), | |
| tool_llm_provider = os.getenv("VECTARA_AGENTIC_FALLBACK_TOOL_LLM_PROVIDER", ModelProvider.OPENAI.value), | |
| tool_llm_model_name = os.getenv("VECTARA_AGENTIC_FALLBACK_TOOL_MODEL_NAME", ""), | |
| observer = os.getenv("VECTARA_AGENTIC_OBSERVER_TYPE", "NO_OBSERVER") | |
| ) | |
| agent = Agent( | |
| tools=create_assistant_tools(_cfg), | |
| topic="Customer complaints from the Consumer Financial Protection Bureau (CFPB)", | |
| custom_instructions=cfpb_complaints_bot_instructions, | |
| agent_progress_callback=agent_progress_callback, | |
| validate_tools=True, | |
| verbose=True, | |
| agent_config=agent_config, | |
| fallback_agent_config=fallback_agent_config, | |
| ) | |
| agent.report(detailed=False) | |
| return agent | |
| def get_agent_config() -> OmegaConf: | |
| cfg = OmegaConf.create({ | |
| 'corpus_keys': str(os.environ['VECTARA_CORPUS_KEYS']), | |
| 'api_keys': str(os.environ['VECTARA_API_KEYS']), | |
| 'examples': os.environ.get('QUERY_EXAMPLES', None), | |
| 'demo_name': "cfpb-assistant", | |
| 'demo_welcome': "Welcome to the CFPB Customer Complaints demo.", | |
| 'demo_description': "This assistant can help you gain insights into customer complaints to banks recorded by the Consumer Financial Protection Bureau.", | |
| }) | |
| return cfg | |