simonraj commited on
Commit
e4e42e1
1 Parent(s): eebea6d

Update app.py

Browse files

Langsmith trace

Files changed (1) hide show
  1. app.py +179 -168
app.py CHANGED
@@ -1,169 +1,180 @@
1
- import os
2
- import re
3
- import gradio as gr
4
- from dotenv import load_dotenv
5
- from langchain_community.utilities import SQLDatabase
6
- from langchain_openai import ChatOpenAI
7
- from langchain.chains import create_sql_query_chain
8
- from langchain_core.output_parsers import StrOutputParser
9
- from langchain_core.prompts import ChatPromptTemplate
10
- from langchain_core.runnables import RunnablePassthrough
11
- from langchain_core.output_parsers.openai_tools import PydanticToolsParser
12
- from langchain_core.pydantic_v1 import BaseModel, Field
13
- from typing import List
14
- import sqlite3
15
-
16
- # Load environment variables from .env file
17
- load_dotenv()
18
-
19
- # Set up the database connection
20
- db_path = os.path.join(os.path.dirname(__file__), "chinook.db")
21
- db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
22
-
23
- # Function to get table info
24
- def get_table_info(db_path):
25
- conn = sqlite3.connect(db_path)
26
- cursor = conn.cursor()
27
-
28
- # Get all table names
29
- cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
30
- tables = cursor.fetchall()
31
-
32
- table_info = {}
33
- for table in tables:
34
- table_name = table[0]
35
- cursor.execute(f"PRAGMA table_info({table_name})")
36
- columns = cursor.fetchall()
37
- column_names = [column[1] for column in columns]
38
- table_info[table_name] = column_names
39
-
40
- conn.close()
41
- return table_info
42
-
43
- # Get table info
44
- table_info = get_table_info(db_path)
45
-
46
- # Format table info for display
47
- def format_table_info(table_info):
48
- info_str = f"Total number of tables: {len(table_info)}\n\n"
49
- info_str += "Tables and their columns:\n\n"
50
- for table, columns in table_info.items():
51
- info_str += f"{table}:\n"
52
- for column in columns:
53
- info_str += f" - {column}\n"
54
- info_str += "\n"
55
- return info_str
56
-
57
- # Initialize the language model
58
- llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
59
-
60
- class Table(BaseModel):
61
- """Table in SQL database."""
62
- name: str = Field(description="Name of table in SQL database.")
63
-
64
- # Create the table selection prompt
65
- table_names = "\n".join(db.get_usable_table_names())
66
- system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
67
- The tables are:
68
-
69
- {table_names}
70
-
71
- Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""
72
-
73
- table_prompt = ChatPromptTemplate.from_messages([
74
- ("system", system),
75
- ("human", "{input}"),
76
- ])
77
-
78
- llm_with_tools = llm.bind_tools([Table])
79
- output_parser = PydanticToolsParser(tools=[Table])
80
-
81
- table_chain = table_prompt | llm_with_tools | output_parser
82
-
83
- # Function to get table names from the output
84
- def get_table_names(output: List[Table]) -> List[str]:
85
- return [table.name for table in output]
86
-
87
- # Create the SQL query chain
88
- query_chain = create_sql_query_chain(llm, db)
89
-
90
- # Combine table selection and query generation
91
- full_chain = (
92
- RunnablePassthrough.assign(
93
- table_names_to_use=lambda x: get_table_names(table_chain.invoke({"input": x["question"]}))
94
- )
95
- | query_chain
96
- )
97
-
98
- # Function to strip markdown formatting from SQL query
99
- def strip_markdown(text):
100
- # Remove code block formatting
101
- text = re.sub(r'```sql\s*|\s*```', '', text)
102
- # Remove any leading/trailing whitespace
103
- return text.strip()
104
-
105
- # Function to execute SQL query
106
- def execute_query(query: str) -> str:
107
- try:
108
- # Strip markdown formatting before executing
109
- clean_query = strip_markdown(query)
110
- result = db.run(clean_query)
111
- return str(result)
112
- except Exception as e:
113
- return f"Error executing query: {str(e)}"
114
-
115
- # Create the answer generation prompt
116
- answer_prompt = ChatPromptTemplate.from_messages([
117
- ("system", """Given the following user question, corresponding SQL query, and SQL result, answer the user question.
118
- If there was an error in executing the SQL query, please explain the error and suggest a correction.
119
- Do not include any SQL code formatting or markdown in your response.
120
-
121
- Here is the database schema for reference:
122
- {table_info}"""),
123
- ("human", "Question: {question}\nSQL Query: {query}\nSQL Result: {result}\nAnswer:")
124
- ])
125
-
126
- # Assemble the final chain
127
- chain = (
128
- RunnablePassthrough.assign(query=lambda x: full_chain.invoke(x))
129
- .assign(result=lambda x: execute_query(x["query"]))
130
- | answer_prompt
131
- | llm
132
- | StrOutputParser()
133
- )
134
-
135
- # Function to process user input and generate response
136
- def process_input(message, history, table_info_str):
137
- response = chain.invoke({"question": message, "table_info": table_info_str})
138
- return response
139
-
140
- # Formatted table info
141
- formatted_table_info = format_table_info(table_info)
142
-
143
- # Create Gradio interface
144
- iface = gr.ChatInterface(
145
- fn=process_input,
146
- title="SQL Q&A Chatbot for Chinook Database",
147
- description="Ask questions about the Chinook music store database and get answers!",
148
- examples=[
149
- ["Who are the top 5 artists with the most albums in the database?"],
150
- ["What is the total sales amount for each country?"],
151
- ["Which employee has made the highest total sales, and what is the amount?"],
152
- ["What are the top 10 longest tracks in the database, and who are their artists?"],
153
- ["How many customers are there in each country, and what is the total sales for each?"]
154
- ],
155
- additional_inputs=[
156
- gr.Textbox(
157
- label="Database Schema",
158
- value=formatted_table_info,
159
- lines=10,
160
- max_lines=20,
161
- interactive=False
162
- )
163
- ],
164
- theme="soft"
165
- )
166
-
167
- # Launch the interface
168
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
169
  iface.launch()
 
1
+ import os
2
+ import re
3
+ import gradio as gr
4
+ from dotenv import load_dotenv
5
+ from langchain_community.utilities import SQLDatabase
6
+ from langchain_openai import ChatOpenAI
7
+ from langchain.chains import create_sql_query_chain
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+ from langchain_core.runnables import RunnablePassthrough
11
+ from langchain_core.output_parsers.openai_tools import PydanticToolsParser
12
+ from langchain_core.pydantic_v1 import BaseModel, Field
13
+ from typing import List
14
+ import sqlite3
15
+ from langsmith import traceable
16
+ from openai import OpenAI
17
+
18
+ # Load environment variables from .env file
19
+ load_dotenv()
20
+
21
+ # Set up LangSmith
22
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
23
+ os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY")
24
+ os.environ["LANGCHAIN_PROJECT"] = "SQLq&a"
25
+
26
+ # Initialize OpenAI client
27
+ openai_client = OpenAI()
28
+
29
+ # Set up the database connection
30
+ db_path = os.path.join(os.path.dirname(__file__), "chinook.db")
31
+ db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
32
+
33
+ # Function to get table info
34
+ def get_table_info(db_path):
35
+ conn = sqlite3.connect(db_path)
36
+ cursor = conn.cursor()
37
+
38
+ # Get all table names
39
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
40
+ tables = cursor.fetchall()
41
+
42
+ table_info = {}
43
+ for table in tables:
44
+ table_name = table[0]
45
+ cursor.execute(f"PRAGMA table_info({table_name})")
46
+ columns = cursor.fetchall()
47
+ column_names = [column[1] for column in columns]
48
+ table_info[table_name] = column_names
49
+
50
+ conn.close()
51
+ return table_info
52
+
53
+ # Get table info
54
+ table_info = get_table_info(db_path)
55
+
56
+ # Format table info for display
57
+ def format_table_info(table_info):
58
+ info_str = f"Total number of tables: {len(table_info)}\n\n"
59
+ info_str += "Tables and their columns:\n\n"
60
+ for table, columns in table_info.items():
61
+ info_str += f"{table}:\n"
62
+ for column in columns:
63
+ info_str += f" - {column}\n"
64
+ info_str += "\n"
65
+ return info_str
66
+
67
+ # Initialize the language model
68
+ llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
69
+
70
+ class Table(BaseModel):
71
+ """Table in SQL database."""
72
+ name: str = Field(description="Name of table in SQL database.")
73
+
74
+ # Create the table selection prompt
75
+ table_names = "\n".join(db.get_usable_table_names())
76
+ system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
77
+ The tables are:
78
+
79
+ {table_names}
80
+
81
+ Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""
82
+
83
+ table_prompt = ChatPromptTemplate.from_messages([
84
+ ("system", system),
85
+ ("human", "{input}"),
86
+ ])
87
+
88
+ llm_with_tools = llm.bind_tools([Table])
89
+ output_parser = PydanticToolsParser(tools=[Table])
90
+
91
+ table_chain = table_prompt | llm_with_tools | output_parser
92
+
93
+ # Function to get table names from the output
94
+ def get_table_names(output: List[Table]) -> List[str]:
95
+ return [table.name for table in output]
96
+
97
+ # Create the SQL query chain
98
+ query_chain = create_sql_query_chain(llm, db)
99
+
100
+ # Combine table selection and query generation
101
+ full_chain = (
102
+ RunnablePassthrough.assign(
103
+ table_names_to_use=lambda x: get_table_names(table_chain.invoke({"input": x["question"]}))
104
+ )
105
+ | query_chain
106
+ )
107
+
108
+ # Function to strip markdown formatting from SQL query
109
+ def strip_markdown(text):
110
+ # Remove code block formatting
111
+ text = re.sub(r'```sql\s*|\s*```', '', text)
112
+ # Remove any leading/trailing whitespace
113
+ return text.strip()
114
+
115
+ # Function to execute SQL query
116
+ def execute_query(query: str) -> str:
117
+ try:
118
+ # Strip markdown formatting before executing
119
+ clean_query = strip_markdown(query)
120
+ result = db.run(clean_query)
121
+ return str(result)
122
+ except Exception as e:
123
+ return f"Error executing query: {str(e)}"
124
+
125
+ # Create the answer generation prompt
126
+ answer_prompt = ChatPromptTemplate.from_messages([
127
+ ("system", """Given the following user question, corresponding SQL query, and SQL result, answer the user question.
128
+ If there was an error in executing the SQL query, please explain the error and suggest a correction.
129
+ Do not include any SQL code formatting or markdown in your response.
130
+
131
+ Here is the database schema for reference:
132
+ {table_info}"""),
133
+ ("human", "Question: {question}\nSQL Query: {query}\nSQL Result: {result}\nAnswer:")
134
+ ])
135
+
136
+ # Assemble the final chain
137
+ chain = (
138
+ RunnablePassthrough.assign(query=lambda x: full_chain.invoke(x))
139
+ .assign(result=lambda x: execute_query(x["query"]))
140
+ | answer_prompt
141
+ | llm
142
+ | StrOutputParser()
143
+ )
144
+
145
+ # Function to process user input and generate response
146
+ @traceable
147
+ def process_input(message, history, table_info_str):
148
+ response = chain.invoke({"question": message, "table_info": table_info_str})
149
+ return response
150
+
151
+ # Formatted table info
152
+ formatted_table_info = format_table_info(table_info)
153
+
154
+ # Create Gradio interface
155
+ iface = gr.ChatInterface(
156
+ fn=process_input,
157
+ title="SQL Q&A Chatbot for Chinook Database",
158
+ description="Ask questions about the Chinook music store database and get answers!",
159
+ examples=[
160
+ ["Who are the top 5 artists with the most albums in the database?"],
161
+ ["What is the total sales amount for each country?"],
162
+ ["Which employee has made the highest total sales, and what is the amount?"],
163
+ ["What are the top 10 longest tracks in the database, and who are their artists?"],
164
+ ["How many customers are there in each country, and what is the total sales for each?"]
165
+ ],
166
+ additional_inputs=[
167
+ gr.Textbox(
168
+ label="Database Schema",
169
+ value=formatted_table_info,
170
+ lines=10,
171
+ max_lines=20,
172
+ interactive=False
173
+ )
174
+ ],
175
+ theme="soft"
176
+ )
177
+
178
+ # Launch the interface
179
+ if __name__ == "__main__":
180
  iface.launch()