ABSA_Test_Space / supabase.py
Futuresony's picture
Upload supabase.py
029dfaa verified
import psycopg2
import os
import pickle # Still needed for general cache
import traceback
import numpy as np
import json
import base64 # Still needed for Google Sheets auth if that part of the code is kept elsewhere
import time # Still needed for general cache
# Assuming gspread and SentenceTransformer are installed
try:
import gspread
from oauth2client.service_account import ServiceAccountCredentials
from sentence_transformers import SentenceTransformer
print("gspread and SentenceTransformer imported successfully.")
except ImportError:
print("Error: Required libraries (gspread, oauth2client, sentence_transformers) not found.")
print("Please install them: pip install psycopg2-binary gspread oauth2client sentence-transformers numpy")
pass # Allow execution to continue with a warning
# Define environment variables for Supabase database connection
# These should be set in the environment where you run this script
# Replace with your actual Supabase database credentials
SUPABASE_DB_HOST = os.getenv("SUPABASE_DB_HOST", "wziqfkzaqorzthpoxhjh.supabase.co")
SUPABASE_DB_NAME = os.getenv("SUPABASE_DB_NAME", "postgres")
SUPABASE_DB_USER = os.getenv("SUPABASE_DB_USER", "postgres")
SUPABASE_DB_PASSWORD = os.getenv("SUPABASE_DB_PASSWORD", "Me21322972..........") # Replace with your actual password
SUPABASE_DB_PORT = os.getenv("SUPABASE_DB_PORT", "5432")
# Define environment variables for Google Sheets authentication (kept for reference if needed elsewhere)
GOOGLE_BASE64_CREDENTIALS = os.getenv("GOOGLE_BASE64_CREDENTIALS")
SHEET_ID = "19ipxC2vHYhpXCefpxpIkpeYdI43a1Ku2kYwecgUULIw" # Replace with your actual Sheet ID
# Define table names - Updated to use the user's specified table name 'manual' for business data
BUSINESS_DATA_TABLE = "manual" # Updated table name
CONVERSATION_HISTORY_TABLE = "conversation_history" # Assuming this table name remains the same
# Define Embedding Dimension (must match your chosen Sentence Transformer model)
EMBEDDING_DIM = 384 # Dimension for paraphrase-MiniLM-L6-v2 or all-MiniLM-L6-v2
# --- Database Functions ---
def connect_to_supabase():
conn = None
print("Attempting to connect to Supabase database...")
# Add checks for environment variables
if not all([SUPABASE_DB_HOST, SUPABASE_DB_NAME, SUPABASE_DB_USER, SUPABASE_DB_PASSWORD]):
print("Error: Supabase database credentials (SUPABASE_DB_HOST, SUPABASE_DB_NAME, SUPABASE_DB_USER, SUPABASE_DB_PASSWORD) are not fully set as environment variables or defined in the script.")
return None
try:
conn = psycopg2.connect(
host=SUPABASE_DB_HOST,
database=SUPABASE_DB_NAME,
user=SUPABASE_DB_USER,
password=SUPABASE_DB_PASSWORD,
port=SUPABASE_DB_PORT
)
print("Connected to Supabase database successfully!")
except psycopg2.OperationalError as e:
print(f"Database connection failed: {e}")
print(traceback.format_exc())
return conn
def setup_db_schema(conn):
"""Sets up the necessary tables and pgvector extension."""
print("Setting up database schema...")
try:
with conn.cursor() as cur:
# Enable pgvector extension
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
print("pgvector extension enabled (if not already).")
# Create the 'manual' table if it doesn't exist, matching the user's specified schema
# Note: The embedding column is added here for RAG purposes, assuming it's needed in the 'manual' table.
# If embeddings should be in a separate table, this schema needs adjustment.
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {BUSINESS_DATA_TABLE} (
id SERIAL PRIMARY KEY,
"Service" TEXT NOT NULL, -- Use double quotes for capitalized column names
"Description" TEXT NOT NULL, -- Use double quotes for capitalized column names
"Price" TEXT, -- Added Price column
"Available" TEXT, -- Added Available column
embedding vector({EMBEDDING_DIM}) -- Added embedding column for RAG
);
""")
print(f"Table '{BUSINESS_DATA_TABLE}' created (if not already) with columns: id, Service, Description, Price, Available, embedding.")
# Create conversation_history table (assuming this is still needed)
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {CONVERSATION_HISTORY_TABLE} (
id SERIAL PRIMARY KEY,
timestamp TIMESTAMP WITH TIME ZONE NOT NULL,
user_id TEXT,
user_query TEXT,
model_response TEXT,
tool_details JSONB,
model_used TEXT
);
""")
print(f"Table '{CONVERSATION_HISTORY_TABLE}' created (if not already).")
conn.commit()
print("Database schema setup complete.")
return True
except Exception as e:
print(f"Error setting up database schema: {e}")
print(traceback.format_exc())
conn.rollback()
return False
# --- Manual Data Definition (kept for the migration script, but not used by the main app load) ---
# Define the business data manually based on the user's example
business_data_manual = [
{"Service": "Savings Account", "Price": "Free", "Description": "A basic savings account with interest", "Available": "Yes"},
# Add more data rows here in the same dictionary format
]
# --- Data Insertion Function (using manual data) ---
def insert_manual_data_to_supabase(conn, embedder_model):
"""Inserts manual business data into the Supabase database."""
print("Inserting manual business data into database...")
if embedder_model is None:
print("Skipping data insertion: Embedder not available.")
return False
if EMBEDDING_DIM is None:
print("Skipping data insertion: EMBEDDING_DIM not defined.")
return False
if not business_data_manual:
print("No manual data defined for insertion.")
return False
try:
# Check if business_data table is already populated (based on 'manual' table)
with conn.cursor() as cur:
cur.execute(f"SELECT COUNT(*) FROM {BUSINESS_DATA_TABLE};")
count = cur.fetchone()[0]
if count > 0:
print(f"Table '{BUSINESS_DATA_TABLE}' already contains {count} records. Skipping insertion of manual data.")
return True # Indicate success because data is already there
print(f"Processing {len(business_data_manual)} manual records for insertion.")
insert_count = 0
with conn.cursor() as cur:
for row in business_data_manual:
service = row.get('Service', '').strip()
description = row.get('Description', '').strip()
price = row.get('Price', '').strip() # Get Price
available = row.get('Available', '').strip() # Get Available
# The description used for embedding can include other fields if desired for RAG context
description_for_embedding = f"Service: {service}. Description: {description}. Price: {price}. Available: {available}."
if not service or not description:
print(f"Skipping row due to missing Service or Description: {row}")
continue
# Generate embedding for the description
try:
# Assuming embedder_model is a SentenceTransformer instance
embedding = embedder_model.encode(description_for_embedding, convert_to_tensor=False) # Encode single sentence
if embedding is not None:
embedding_list = embedding.tolist() # Convert numpy array to list
# SQL query to insert data into the 'manual' table with all columns
# Use double quotes for capitalized column names
sql = f"""
INSERT INTO {BUSINESS_DATA_TABLE} ("Service", "Description", "Price", "Available", embedding)
VALUES (%s, %s, %s, %s, %s::vector)
ON CONFLICT ("Service") DO NOTHING; -- Prevent duplicate inserts based on Service name
"""
# Note: Using ON CONFLICT ("Service") assumes Service names are unique and you want to avoid inserting duplicates based on Service.
# If Service names are not unique or you need different conflict resolution, adjust the ON CONFLICT clause.
cur.execute(sql, (service, description, price, available, embedding_list))
insert_count += 1
# print(f"Processed Service: {service[:50]}...") # Keep for debugging
else:
print(f"Skipping insertion for Service '{service[:50]}...' due to embedding generation failure.")
except Exception as embed_e:
print(f"Error generating embedding for Service '{service[:50]}...': {embed_e}")
print(traceback.format_exc())
print("Skipping insertion for this row.")
conn.commit()
print(f"Data insertion process completed. Inserted {insert_count} records.")
return True
except Exception as e:
conn.rollback()
print(f"Error during data insertion: {e}")
print(traceback.format_exc())
return False
finally:
if cur:
cur.close()
# --- Main Execution Flow for Migration Script ---
# This block is intended to be run separately to perform the initial data migration.
# The main application startup logic will be in a different __main__ block.
# if __name__ == "__main__":
# print("Starting RAG data insertion script from manual data...")
# # 1. Initialize Embedder Model
# try:
# print(f"Loading Sentence Transformer model for embeddings (dimension: {EMBEDDING_DIM})...")
# embedder = SentenceTransformer("paraphrase-MiniLM-L6-v2")
# if embedder.get_sentence_embedding_dimension() != EMBEDDING_DIM:
# print(f"Error: Loaded embedder dimension ({embedder.get_sentence_embedding_dimension()}) does not match expected EMBEDDING_DIM ({EMBEDDING_DIM}).")
# print("Please check the model or update EMBEDDING_DIM.")
# embedder = None
# else:
# print("Embedder model loaded successfully.")
# except Exception as e:
# print(f"Error loading Sentence Transformer model: {e}")
# print(traceback.format_exc())
# embedder = None
# if embedder is None:
# print("Embedder model not available. Cannot generate embeddings for data insertion.")
# pass
# # 2. Connect to Database and Setup Schema
# db_conn = connect_to_supabase()
# if db_conn is None:
# print("Database connection failed. Cannot setup schema or insert data.")
# pass
# else:
# try:
# if setup_db_schema(db_conn):
# print("\nDatabase schema setup successful.")
# # 3. Insert Manual Data
# if embedder is not None:
# if insert_manual_data_to_supabase(db_conn, embedder):
# print("\nManual RAG Data Insertion to PostgreSQL completed.")
# else:
# print("\nManual RAG Data Insertion to PostgreSQL failed.")
# else:
# print("\nEmbedder not available. Skipping manual data insertion.")
# else:
# print("\nDatabase schema setup failed.")
# finally:
# # 4. Close Database Connection
# if db_conn:
# db_conn.close()
# print("Database connection closed.")
# print("Manual data insertion script finished.")
# --- Update load_business_info to load from PostgreSQL 'manual' table ---
def load_business_info():
"""Loads business information from PostgreSQL 'manual' table and creates embeddings and FAISS index in memory."""
global data, descriptions_for_embedding, business_info_available
global rag_faiss_index, rag_metadata
# Assuming embedder and EMBEDDING_DIM are defined globally and initialized on app startup
business_info_available = False
rag_faiss_index = None
rag_metadata = []
data = []
descriptions_for_embedding = []
print("Attempting to load RAG data from PostgreSQL 'manual' table...")
db_conn = connect_to_supabase()
if db_conn is None:
print("Failed to connect to database. RAG will be unavailable.")
return
# Ensure embedder is initialized before proceeding
# Assuming embedder is initialized globally in the main application startup
if 'embedder' not in globals() or embedder is None:
print("Embedder not initialized. Cannot load RAG data embeddings.")
if db_conn: db_conn.close()
return
try:
with db_conn.cursor() as cur:
# Ensure pgvector extension is enabled (important if not done manually during setup)
# This is a good practice to ensure the session can use vector types
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
db_conn.commit() # Commit the extension command
# Retrieve data from the 'manual' table, including embedding
# Use double quotes for capitalized column names
cur.execute(f"""
SELECT "Service", "Description", "Price", "Available", embedding
FROM {BUSINESS_DATA_TABLE};
""")
db_records = cur.fetchall()
if not db_records:
print(f"Warning: No data found in table '{BUSINESS_DATA_TABLE}'. RAG will be unavailable.")
business_info_available = False
else:
print(f"Loaded {len(db_records)} records from '{BUSINESS_DATA_TABLE}'.")
# Process the retrieved data
data = []
descriptions_for_embedding = []
embeddings_list = []
# Assuming the columns are returned in the order of the SELECT statement
for service, description, price, available, embedding in db_records:
# Store the original data row as a dictionary
data.append({'Service': service, 'Description': description, 'Price': price, 'Available': available})
# Store a combined description for potential re-ranking or context
descriptions_for_embedding.append(f"Service: {service.strip()}. Description: {description.strip()}. Price: {price.strip() if price else ''}. Available: {available.strip() if available else ''}.")
# Store the embedding (psycopg2 fetches vector as a list)
embeddings_list.append(embedding)
if data and embeddings_list:
print("Building in-memory FAISS index...")
try:
# Convert list of lists to numpy array for FAISS
embeddings_np = np.array(embeddings_list).astype('float32')
# Ensure EMBEDDING_DIM is correct
if embeddings_np.shape[1] != EMBEDDING_DIM:
print(f"Error: Embedding dimension mismatch. Expected {EMBEDDING_DIM}, got {embeddings_np.shape[1]}.")
print("This might happen if the embeddings in the database were generated with a different model or dimension.")
print("RAG will be unavailable.")
business_info_available = False
rag_faiss_index = None
rag_metadata = []
else:
# Use L2 distance (Euclidean) for FAISS Flat index
rag_faiss_index = faiss.IndexFlatL2(EMBEDDING_DIM)
rag_faiss_index.add(embeddings_np)
# rag_metadata maps FAISS index back to index in our 'data' list
rag_metadata = list(range(len(data)))
print(f"In-memory FAISS index built. Index size: {rag_faiss_index.ntotal}")
business_info_available = True
except Exception as e:
print(f"Error during FAISS index building: {e}")
print(traceback.format_exc())
rag_faiss_index = None
rag_metadata = []
business_info_available = False
else:
print("No valid data or embeddings to build FAISS index. RAG will be unavailable.")
business_info_available = False
if not business_info_available:
print("Business information retrieval (RAG) is NOT available.")
else:
print("Business information retrieval (RAG) is available using in-memory FAISS index from DB data.")
except Exception as e:
print(f"An error occurred while accessing the database for RAG data: {e}")
print(traceback.format_exc())
business_info_available = False
rag_faiss_index = None
rag_metadata = []
finally:
if db_conn:
db_conn.close()
# --- Update retrieve_business_info to use data structure from 'manual' table ---
# The core logic of retrieve_business_info using FAISS search on in-memory data remains the same.
# However, the structure of the 'data' list it accesses now comes from the 'manual' table columns.
# The retrieval function already handles accessing 'Service' and 'Description' from the dictionary.
# If you need to return Price or Available, you can adjust the return format.
# For now, assuming it returns the dictionary as loaded into the 'data' list.
def retrieve_business_info(query: str, top_n: int = 3) -> list:
"""
Retrieves relevant business information from loaded data (from 'manual' table)
based on a query using in-memory FAISS index.
"""
global data, rag_faiss_index, rag_metadata, descriptions_for_embedding
# Assuming embedder and reranker are defined globally and initialized on app startup
if not business_info_available or embedder is None or rag_faiss_index is None or rag_faiss_index.ntotal == 0 or not data or not rag_metadata or len(rag_metadata) != len(data):
print("Business information retrieval is not available, RAG index is empty, or data/metadata mismatch.")
return []
try:
# Use the global embedder initialized on startup
query_embedding = embedder.encode(query, convert_to_tensor=False)
# Perform FAISS search on the in-memory index
D, I = rag_faiss_index.search(np.array([query_embedding]).astype('float32'), min(top_n, rag_faiss_index.ntotal))
# Map FAISS results back to original data using rag_metadata
# Ensure indices are valid
original_indices = [rag_metadata[i] for i in I[0] if i != -1 and i < len(rag_metadata)]
# Get the actual data records based on indices
top_results = [data[i] for i in original_indices]
# Get corresponding descriptions for re-ranking
descriptions_for_reranking = [descriptions_for_embedding[i] for i in original_indices]
# Re-rank results using the global reranker
# Assuming reranker is initialized globally on app startup
if 'reranker' in globals() and reranker is not None and top_results:
print("Re-ranking top results...")
rerank_pairs = [(query, descriptions_for_reranking[i]) for i in range(len(top_results))]
rerank_scores = reranker.predict(rerank_pairs)
# Sort results based on re-ranker scores
reranked_indices = sorted(range(len(rerank_scores)), key=lambda i: rerank_scores[i], reverse=True)
reranked_results = [top_results[i] for i in reranked_indices]
print("Re-ranking complete.")
return reranked_results
else:
# If no reranker or no results, return the raw FAISS results (mapped to data)
print("Skipping re-ranking: Reranker not available or no results.")
return top_results
except Exception as e:
print(f"Error during business information retrieval (FAISS search/re-ranking): {e}")
print(traceback.format_exc())
return []
# --- Update log_conversation to log to PostgreSQL conversation_history table ---
# This function was already updated in a previous step to log to the DB.
# Ensure the table name used here matches CONVERSATION_HISTORY_TABLE.
# Assuming CONVERSATION_HISTORY_TABLE is defined globally.
# def log_conversation(user_query: str, model_response: str, tool_details: dict = None, user_id: str = None, model_used: str = None):
# """
# Logs conversation data (query, response, timestamp, optional details) to the PostgreSQL database.
# """
# print("\n--- Attempting to log conversation to PostgreSQL Database ---")
# db_conn = connect_to_supabase() # Use the Supabase connection function
# if db_conn is None:
# print("Warning: Failed to connect to database. Skipping conversation logging.")
# return
# try:
# timestamp = datetime.now().astimezone().isoformat() # Use astimezone() for timezone-aware timestamp
# tool_details_json = json.dumps(tool_details) if tool_details is not None else None
# user_id_val = user_id if user_id is not None else "anonymous"
# model_used_val = model_used if model_used is not None else "unknown"
# with db_conn.cursor() as cur:
# cur.execute(f"""
# INSERT INTO {CONVERSATION_HISTORY_TABLE} (timestamp, user_id, user_query, model_response, tool_details, model_used)
# VALUES (%s, %s, %s, %s, %s, %s);
# """, (timestamp, user_id_val, user_query, model_response, tool_details_json, model_used_val))
# db_conn.commit()
# print("Conversation data successfully logged to PostgreSQL.")
# except Exception as e:
# print(f"An unexpected error occurred during database conversation logging: {e}")
# print(traceback.format_exc())
# if db_conn:
# db_conn.rollback()
# finally:
# if db_conn:
# db_conn.close()
# --- Update load_conversation_history to load from PostgreSQL conversation_history table ---
# This function was already updated in a previous step to load from the DB.
# Ensure the table name used here matches CONVERSATION_HISTORY_TABLE.
# Assuming CONVERSATION_HISTORY_TABLE is defined globally.
# def load_conversation_history(api_key: str) -> list[dict]:
# """Loads conversation history for a given API key from the PostgreSQL database."""
# user_id_to_load = api_key if api_key is not None else "anonymous"
# print(f"Attempting to load conversation history for user '{user_id_to_load}' from PostgreSQL...")
# history = []
# db_conn = connect_to_supabase() # Use the Supabase connection function
# if db_conn is None:
# print("Warning: Failed to connect to database. Cannot load conversation history.")
# return history # Return empty history on failure
# try:
# with db_conn.cursor() as cur:
# # Retrieve history ordered by timestamp for a specific user
# cur.execute(f"""
# SELECT user_query, model_response
# FROM {CONVERSATION_HISTORY_TABLE}
# WHERE user_id = %s
# ORDER BY timestamp;
# """, (user_id_to_load,))
# db_records = cur.fetchall()
# # Format the history as a list of dictionaries for compatibility with chat function
# for user_query, model_response in db_records:
# # Add user query role
# if user_query:
# history.append({"role": "user", "content": user_query})
# # Add assistant response role
# if model_response:
# history.append({"role": "assistant", "content": model_response})
# print(f"Loaded {len(history)} turns of conversation history for user '{user_id_to_load}' from PostgreSQL.")
# except Exception as e:
# print(f"Error loading conversation history from database: {e}")
# print(traceback.format_exc())
# history = [] # Ensure empty history is returned on error
# finally:
# if db_conn:
# db_conn.close()
# return history
# --- Main Application Startup Block (__main__) ---
# This block assumes it's part of the larger application script in the Hugging Face Space
# It needs to initialize global resources and then potentially launch a Gradio interface.
# Remove the separate data insertion script execution from this block.
# The data insertion is a one-time or separate process.
# if __name__ == "__main__":
# print("Starting main application startup...")
# # 1. Load/Create Hugging Face Dataset (still used for other logging if needed)
# # ... (existing code for HF dataset loading remains)
# # 2. Authenticate and Load Business Info from PostgreSQL (updated function)
# # This function now handles connecting to DB and loading data/embeddings into memory
# load_business_info()
# # 3. Initialize other necessary global variables/clients
# # (e.g., nlp, embedder, reranker, primary_client, fallback_client)
# # These need to be initialized after load_business_info if embedder/reranker are used by it
# # Assuming embedder and reranker are initialized here or earlier in the full script:
# # try:
# # embedder = SentenceTransformer("paraphrase-MiniLM-L6-v2")
# # print("Sentence Transformer (embedder) initialized.")
# # except Exception as e:
# # print(f"Error initializing embedder: {e}")
# # embedder = None
# # try:
# # reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L6-v2')
# # print("Cross-Encoder (reranker) initialized.")
# # except Exception as e:
# # print(f"Error initializing reranker: {e}")
# # reranker = None
# # try:
# # nlp = spacy.load("en_core_web_sm") # Assuming spacy is imported
# # print("SpaCy model initialized.")
# # except Exception as e:
# # print(f"Error initializing SpaCy model: {e}")
# # nlp = None
# # try:
# # primary_client = InferenceClient("meta-llama/Llama-3.3-70B-Instruct", token=HF_TOKEN) # Assuming InferenceClient and HF_TOKEN
# # print("Primary LLM client initialized.")
# # except Exception as e:
# # print(f"Error initializing primary client: {e}")
# # primary_client = None
# # try:
# # fallback_client = InferenceClient("meta-llama/Llama-3.3-70B-Instruct", token=HF_TOKEN) # Assuming InferenceClient and HF_TOKEN
# # print("Fallback LLM client initialized.")
# # except Exception as e:
# # print(f"Error initializing fallback client: {e}")
# # fallback_client = None
# # 4. Check RAG availability (based on load_business_info results)
# # Check business_info_available and rag_faiss_index which are set by load_business_info
# if not business_info_available or rag_faiss_index is None:
# print("Warning: Business information (PostgreSQL data) not loaded successfully or RAG index not built. RAG will not be available.")
# # 5. Initialize the general query cache (still uses local files)
# # Assuming initialize_general_cache is defined globally
# # initialize_general_cache()
# # 6. Launch Gradio Interface (assuming gr and chat are defined globally)
# # ... (Gradio interface setup and launch code)
# Note: The provided code block contains the updated function definitions.
# These need to be integrated into the complete application script in your Hugging Face Space.
# The __main__ block structure is commented out as a guide for integration.