text2sql-demo / app.py
trminhnam20082002's picture
feat: add main menu options and scripts
8621811
import os
import sys
import time
# insert current directory to sys.path
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
import re
import sqlite3
import numpy as np
import pandas as pd
import streamlit as st
import requests
from googletrans import Translator
from langdetect import detect
from sql_formatter.core import format_sql
translator = Translator()
st.set_page_config(
layout="wide",
page_title="Text To SQL",
page_icon="📊",
)
# TEXT_2_SQL_API = "http://83.219.197.235:40172/api/text2sql/ask"
TEXT_2_SQL_API = os.environ.get(
"TEXT_2_SQL_API", "http://213.181.122.2:40057/api/text2sql/ask"
)
@st.cache_resource
def load_database():
db_conn = sqlite3.connect("resources/ai_app.db")
with open("resources/schema.sql", "r") as f:
db_conn.executescript(f.read())
return db_conn
db_conn = load_database()
def execute_sql(sql_query):
try:
cursor = db_conn.cursor()
cursor.execute(sql_query)
st.success("SQL query executed successfully!")
return cursor.fetchall()
except Exception as e:
st.info("Database is not supported")
return None
# @st.cache_data
def ask_text2sql(question, context):
if detect(question) != "en":
question = translate_question(question)
# st.write("The question is translated to Vietnamese:")
# st.code(question, language="en")
r = requests.post(
TEXT_2_SQL_API,
json={
"context": context,
"question": question,
},
)
return r.json()["answers"][0]
@st.cache_data
def translate_question(question):
return translator.translate(question, dest="en").text
@st.cache_data
def load_example_df():
example_df = pd.read_csv("resources/examples.csv")
return example_df
def introduction():
st.title("📊 Introduction")
st.write("👋 Welcome to the Text to SQL app!")
st.write(
"🔍 This app allows you to explore the ability of Text to SQL model. The model is CodeLlama-13b finetuned using QLoRA on NSText2SQL dataset."
)
st.write(
"📈 The NSText2SQL dataset contains more than 290.000 training samples. Then, the model is evaluated on Spider and vMLP datasets."
)
st.write("📑 The other pages in this app include:")
st.write(
" - 📊 EDA Page: This page includes several visualizations to help you understand the two dataset: Spider and vMLP."
)
st.write(
" - 💰 Text2SQL Page: This page allows you to generate SQL query from a given question and context."
)
st.write(
" - 🧑‍💻 About Page: This page provides information about the app and its creators."
)
st.write(
" - 📚 Reference Page: This page lists the references used in building this app."
)
# Define a function for the EDA page
def eda():
st.title("📊 Dataset Exploration")
# st.subheader("Candlestick Chart")
# fig = go.Figure(
# data=[
# go.Candlestick(
# x=df["date"],
# open=df["open"],
# high=df["high"],
# low=df["low"],
# close=df["close"],
# increasing_line_color="green",
# decreasing_line_color="red",
# )
# ],
# layout=go.Layout(
# title="Tesla Stock Price",
# xaxis_title="Date",
# yaxis_title="Price (USD)",
# xaxis_rangeslider_visible=True,
# ),
# )
# st.plotly_chart(fig)
# st.subheader("Line Chart")
# # Plot the closing price over time
# plot_column = st.selectbox(
# "Select a column", ["open", "close", "low", "high"], index=0
# )
# fig = px.line(
# df, x="date", y=plot_column, title=f"Tesla {plot_column} Price Over Time"
# )
# st.plotly_chart(fig)
# st.subheader("Distribution of Closing Price")
# # Plot the distribution of the closing price
# closing_price_hist = px.histogram(
# df, x="close", nbins=30, title="Distribution of Tesla Closing Price"
# )
# st.plotly_chart(closing_price_hist)
# st.subheader("Raw Data")
# st.write("You can see the raw data below.")
# # Display the dataset
# st.dataframe(df)
def preprocess_context(context):
context = context.replace("\n", " ").replace("\t", " ").replace("\r", " ")
# Remove multiple spaces
context = re.sub(" +", " ", context)
return context
def examples():
st.title("Examples")
st.write(
"This page uses CodeLlama-13b finetuned using QLoRA on NSText2SQL dataset to generate SQL query from a given question and context.\nThe examples are listed below"
)
example_df = load_example_df()
example_tabs = st.tabs([f"Example {i+1}" for i in range(len(example_df))])
example_btns = []
with st.sidebar:
# create a blank space
st.write("")
st.write("")
st.write("")
execute_sql_query = st.checkbox(
"Execute SQL query",
)
num_tries = st.number_input(
"Number of tries",
value=3,
min_value=1,
max_value=10,
step=1,
)
for idx, row in example_df.iterrows():
with example_tabs[idx]:
st.markdown("##### Context:")
st.code(row["context"], language="sql")
st.markdown("##### Question:")
st.text(row["question"])
example_btns.append(st.button("Generate SQL query", key=f"exp-btn-{idx}"))
if example_btns[idx]:
st.markdown("##### SQL query:")
tries = num_tries
with st.spinner("Generating SQL query..."):
if execute_sql_query:
while tries > 0:
start_time = time.time()
query = ask_text2sql(row["question"], row["context"])
end_time = time.time()
st.write(
"The SQL query generated by the model in **{:.2f}s** is:".format(
end_time - start_time
)
)
st.code(format_sql(query), language="sql")
result = execute_sql(query)
st.write(
"Executing the SQL query yields the following result:"
)
st.dataframe(pd.DataFrame(result), hide_index=True)
if result is not None:
break
else:
tries -= 1
else:
start_time = time.time()
query = ask_text2sql(row["question"], row["context"])
end_time = time.time()
st.markdown(
"The SQL query generated by the model in **{:.2f}s** is:".format(
end_time - start_time
)
)
st.code(format_sql(query), language="sql")
# Define a function for the Stock Prediction page
def interactive_demo():
st.title("Text to SQL using CodeLlama-13b")
st.write(
"This page uses CodeLlama-13b finetuned using QLoRA on NSText2SQL dataset to generate SQL query from a given question and context."
)
st.subheader("Input")
context_placeholder = st.empty()
question_placeholder = st.empty()
context = context_placeholder.text_area(
"##### Context",
"""CREATE TABLE customer (id number, name text, gender text, age number, district_id number;
CREATE TABLE registration (customer_id number, product_id number);
CREATE TABLE district (id number, name text, prefix text, province_id number);
CREATE TABLE province (id number, name text, code text)
CREATE TABLE product (id number, category text, name text, description text, price number, duration number, data_amount number, voice_amount number, sms_amount number);""",
key="context",
height=150,
)
question = question_placeholder.text_input(
"##### Question",
"Số lượng khách hàng có độ tuổi từ 30 đến 45 tuổi?",
key="question",
)
get_sql_button = st.button("Generate SQL query")
with st.sidebar:
# create a blank space
st.write("")
st.write("")
st.write("")
execute_sql_query = st.checkbox(
"Execute SQL query",
)
num_tries = st.number_input(
"Number of tries",
value=3,
min_value=1,
max_value=10,
step=1,
)
if get_sql_button:
st.markdown("##### Output")
tries = num_tries
if execute_sql_query:
while tries > 0:
start_time = time.time()
query = ask_text2sql(question, context)
end_time = time.time()
st.write(
"The SQL query generated by the model in **{:.2f}s** is:".format(
end_time - start_time
)
)
# Display the SQL query in a code block
st.code(format_sql(query), language="sql")
result = execute_sql(query)
st.write("Executing the SQL query yields the following result:")
st.dataframe(pd.DataFrame(result), hide_index=True)
if result is not None:
break
else:
tries -= 1
else:
start_time = time.time()
query = ask_text2sql(question, context)
end_time = time.time()
st.markdown(
"The SQL query generated by the model in **{:.2f}s** is:".format(
end_time - start_time
)
)
# Display the SQL query in a code block
st.code(format_sql(query), language="sql")
def database_query_page():
query = st.text_input("Enter SQL query")
if st.button("Execute"):
result = execute_sql(query)
st.dataframe(pd.DataFrame(result), hide_index=True)
# Define a function for the About page
def about():
st.title("🧑‍💻 About")
st.write(
"This Streamlit app allows you to explore stock prices and make predictions using an LSTM model."
)
st.header("Author")
st.write(
"This app was developed by Minh Nam. You can contact the author at [email protected]."
)
st.header("Data Sources")
st.markdown(
"The Spider dataset was sourced from [Spider](https://yale-lily.github.io/spider)."
)
st.markdown("The vMLP dataset is a private dataset from Viettel.")
st.header("Acknowledgments")
st.write(
"The author would like to thank Dr. Nguyen Van Nam for his proper guidance, Mr. Nguyen Chi Dong for his support."
)
st.header("License")
st.write(
# "This app is licensed under the MIT License. See LICENSE.txt for more information."
"N/A"
)
def references():
st.title("📚 References")
st.header(
"References for Text to SQL project using foundation model - CodeLlama-13b"
)
st.subheader("1. 'Project for time-series data' by AI VIET NAM, et al.")
st.write(
"This organization provides a tutorial on how to build a stock price prediction model using LSTM in the AIO2022 course."
)
st.write("Link: https://www.facebook.com/aivietnam.edu.vn")
st.subheader(
"2. 'PyTorch LSTMs for time series forecasting of Indian Stocks' by Vinayak Nayak"
)
st.write(
"This blog post describes how to build a stock price prediction model using LSTM, RNN and CNN-sliding window model."
)
st.write(
"Link: https://medium.com/analytics-vidhya/pytorch-lstms-for-time-series-forecasting-of-indian-stocks-8a49157da8b9#b052"
)
st.header("References for Streamlit")
st.subheader("1. Streamlit Documentation")
st.write(
"The official documentation for Streamlit provides detailed information about how to use the library and build Streamlit apps."
)
st.write("Link: https://docs.streamlit.io/")
st.subheader("2. Streamlit Community")
st.write(
"The Streamlit community includes a forum and a GitHub repository with examples and resources for building Streamlit apps."
)
st.write(
"Link: https://discuss.streamlit.io/ and https://github.com/streamlit/streamlit/"
)
# Create the sidebar
st.sidebar.title("Menu")
pages = [
"Introduction",
# "Datasets",
"Examples",
"Interactive Demo",
"About",
"References",
"Database Query",
]
selected_page = st.sidebar.radio("Go to", pages)
if st.sidebar.button("Clear All"):
# Clears all st.cache_resource caches:
st.cache_resource.clear()
st.cache_data.clear()
st.rerun()
# Show the appropriate page based on the selection
if selected_page == "Introduction":
introduction()
elif selected_page == "EDA":
eda()
elif selected_page == "Examples":
examples()
elif selected_page == "Interactive Demo":
interactive_demo()
elif selected_page == "About":
about()
elif selected_page == "References":
references()
elif selected_page == "Database Query":
database_query_page()