Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import time | |
# insert current directory to sys.path | |
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) | |
import re | |
import sqlite3 | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
import requests | |
from googletrans import Translator | |
from langdetect import detect | |
from sql_formatter.core import format_sql | |
translator = Translator() | |
st.set_page_config( | |
layout="wide", | |
page_title="Text To SQL", | |
page_icon="📊", | |
) | |
# TEXT_2_SQL_API = "http://83.219.197.235:40172/api/text2sql/ask" | |
TEXT_2_SQL_API = os.environ.get( | |
"TEXT_2_SQL_API", "http://213.181.122.2:40057/api/text2sql/ask" | |
) | |
def load_database(): | |
db_conn = sqlite3.connect("resources/ai_app.db") | |
with open("resources/schema.sql", "r") as f: | |
db_conn.executescript(f.read()) | |
return db_conn | |
db_conn = load_database() | |
def execute_sql(sql_query): | |
try: | |
cursor = db_conn.cursor() | |
cursor.execute(sql_query) | |
st.success("SQL query executed successfully!") | |
return cursor.fetchall() | |
except Exception as e: | |
st.info("Database is not supported") | |
return None | |
# @st.cache_data | |
def ask_text2sql(question, context): | |
if detect(question) != "en": | |
question = translate_question(question) | |
# st.write("The question is translated to Vietnamese:") | |
# st.code(question, language="en") | |
r = requests.post( | |
TEXT_2_SQL_API, | |
json={ | |
"context": context, | |
"question": question, | |
}, | |
) | |
return r.json()["answers"][0] | |
def translate_question(question): | |
return translator.translate(question, dest="en").text | |
def load_example_df(): | |
example_df = pd.read_csv("resources/examples.csv") | |
return example_df | |
def introduction(): | |
st.title("📊 Introduction") | |
st.write("👋 Welcome to the Text to SQL app!") | |
st.write( | |
"🔍 This app allows you to explore the ability of Text to SQL model. The model is CodeLlama-13b finetuned using QLoRA on NSText2SQL dataset." | |
) | |
st.write( | |
"📈 The NSText2SQL dataset contains more than 290.000 training samples. Then, the model is evaluated on Spider and vMLP datasets." | |
) | |
st.write("📑 The other pages in this app include:") | |
st.write( | |
" - 📊 EDA Page: This page includes several visualizations to help you understand the two dataset: Spider and vMLP." | |
) | |
st.write( | |
" - 💰 Text2SQL Page: This page allows you to generate SQL query from a given question and context." | |
) | |
st.write( | |
" - 🧑💻 About Page: This page provides information about the app and its creators." | |
) | |
st.write( | |
" - 📚 Reference Page: This page lists the references used in building this app." | |
) | |
# Define a function for the EDA page | |
def eda(): | |
st.title("📊 Dataset Exploration") | |
# st.subheader("Candlestick Chart") | |
# fig = go.Figure( | |
# data=[ | |
# go.Candlestick( | |
# x=df["date"], | |
# open=df["open"], | |
# high=df["high"], | |
# low=df["low"], | |
# close=df["close"], | |
# increasing_line_color="green", | |
# decreasing_line_color="red", | |
# ) | |
# ], | |
# layout=go.Layout( | |
# title="Tesla Stock Price", | |
# xaxis_title="Date", | |
# yaxis_title="Price (USD)", | |
# xaxis_rangeslider_visible=True, | |
# ), | |
# ) | |
# st.plotly_chart(fig) | |
# st.subheader("Line Chart") | |
# # Plot the closing price over time | |
# plot_column = st.selectbox( | |
# "Select a column", ["open", "close", "low", "high"], index=0 | |
# ) | |
# fig = px.line( | |
# df, x="date", y=plot_column, title=f"Tesla {plot_column} Price Over Time" | |
# ) | |
# st.plotly_chart(fig) | |
# st.subheader("Distribution of Closing Price") | |
# # Plot the distribution of the closing price | |
# closing_price_hist = px.histogram( | |
# df, x="close", nbins=30, title="Distribution of Tesla Closing Price" | |
# ) | |
# st.plotly_chart(closing_price_hist) | |
# st.subheader("Raw Data") | |
# st.write("You can see the raw data below.") | |
# # Display the dataset | |
# st.dataframe(df) | |
def preprocess_context(context): | |
context = context.replace("\n", " ").replace("\t", " ").replace("\r", " ") | |
# Remove multiple spaces | |
context = re.sub(" +", " ", context) | |
return context | |
def examples(): | |
st.title("Examples") | |
st.write( | |
"This page uses CodeLlama-13b finetuned using QLoRA on NSText2SQL dataset to generate SQL query from a given question and context.\nThe examples are listed below" | |
) | |
example_df = load_example_df() | |
example_tabs = st.tabs([f"Example {i+1}" for i in range(len(example_df))]) | |
example_btns = [] | |
with st.sidebar: | |
# create a blank space | |
st.write("") | |
st.write("") | |
st.write("") | |
execute_sql_query = st.checkbox( | |
"Execute SQL query", | |
) | |
num_tries = st.number_input( | |
"Number of tries", | |
value=3, | |
min_value=1, | |
max_value=10, | |
step=1, | |
) | |
for idx, row in example_df.iterrows(): | |
with example_tabs[idx]: | |
st.markdown("##### Context:") | |
st.code(row["context"], language="sql") | |
st.markdown("##### Question:") | |
st.text(row["question"]) | |
example_btns.append(st.button("Generate SQL query", key=f"exp-btn-{idx}")) | |
if example_btns[idx]: | |
st.markdown("##### SQL query:") | |
tries = num_tries | |
with st.spinner("Generating SQL query..."): | |
if execute_sql_query: | |
while tries > 0: | |
start_time = time.time() | |
query = ask_text2sql(row["question"], row["context"]) | |
end_time = time.time() | |
st.write( | |
"The SQL query generated by the model in **{:.2f}s** is:".format( | |
end_time - start_time | |
) | |
) | |
st.code(format_sql(query), language="sql") | |
result = execute_sql(query) | |
st.write( | |
"Executing the SQL query yields the following result:" | |
) | |
st.dataframe(pd.DataFrame(result), hide_index=True) | |
if result is not None: | |
break | |
else: | |
tries -= 1 | |
else: | |
start_time = time.time() | |
query = ask_text2sql(row["question"], row["context"]) | |
end_time = time.time() | |
st.markdown( | |
"The SQL query generated by the model in **{:.2f}s** is:".format( | |
end_time - start_time | |
) | |
) | |
st.code(format_sql(query), language="sql") | |
# Define a function for the Stock Prediction page | |
def interactive_demo(): | |
st.title("Text to SQL using CodeLlama-13b") | |
st.write( | |
"This page uses CodeLlama-13b finetuned using QLoRA on NSText2SQL dataset to generate SQL query from a given question and context." | |
) | |
st.subheader("Input") | |
context_placeholder = st.empty() | |
question_placeholder = st.empty() | |
context = context_placeholder.text_area( | |
"##### Context", | |
"""CREATE TABLE customer (id number, name text, gender text, age number, district_id number; | |
CREATE TABLE registration (customer_id number, product_id number); | |
CREATE TABLE district (id number, name text, prefix text, province_id number); | |
CREATE TABLE province (id number, name text, code text) | |
CREATE TABLE product (id number, category text, name text, description text, price number, duration number, data_amount number, voice_amount number, sms_amount number);""", | |
key="context", | |
height=150, | |
) | |
question = question_placeholder.text_input( | |
"##### Question", | |
"Số lượng khách hàng có độ tuổi từ 30 đến 45 tuổi?", | |
key="question", | |
) | |
get_sql_button = st.button("Generate SQL query") | |
with st.sidebar: | |
# create a blank space | |
st.write("") | |
st.write("") | |
st.write("") | |
execute_sql_query = st.checkbox( | |
"Execute SQL query", | |
) | |
num_tries = st.number_input( | |
"Number of tries", | |
value=3, | |
min_value=1, | |
max_value=10, | |
step=1, | |
) | |
if get_sql_button: | |
st.markdown("##### Output") | |
tries = num_tries | |
if execute_sql_query: | |
while tries > 0: | |
start_time = time.time() | |
query = ask_text2sql(question, context) | |
end_time = time.time() | |
st.write( | |
"The SQL query generated by the model in **{:.2f}s** is:".format( | |
end_time - start_time | |
) | |
) | |
# Display the SQL query in a code block | |
st.code(format_sql(query), language="sql") | |
result = execute_sql(query) | |
st.write("Executing the SQL query yields the following result:") | |
st.dataframe(pd.DataFrame(result), hide_index=True) | |
if result is not None: | |
break | |
else: | |
tries -= 1 | |
else: | |
start_time = time.time() | |
query = ask_text2sql(question, context) | |
end_time = time.time() | |
st.markdown( | |
"The SQL query generated by the model in **{:.2f}s** is:".format( | |
end_time - start_time | |
) | |
) | |
# Display the SQL query in a code block | |
st.code(format_sql(query), language="sql") | |
def database_query_page(): | |
query = st.text_input("Enter SQL query") | |
if st.button("Execute"): | |
result = execute_sql(query) | |
st.dataframe(pd.DataFrame(result), hide_index=True) | |
# Define a function for the About page | |
def about(): | |
st.title("🧑💻 About") | |
st.write( | |
"This Streamlit app allows you to explore stock prices and make predictions using an LSTM model." | |
) | |
st.header("Author") | |
st.write( | |
"This app was developed by Minh Nam. You can contact the author at [email protected]." | |
) | |
st.header("Data Sources") | |
st.markdown( | |
"The Spider dataset was sourced from [Spider](https://yale-lily.github.io/spider)." | |
) | |
st.markdown("The vMLP dataset is a private dataset from Viettel.") | |
st.header("Acknowledgments") | |
st.write( | |
"The author would like to thank Dr. Nguyen Van Nam for his proper guidance, Mr. Nguyen Chi Dong for his support." | |
) | |
st.header("License") | |
st.write( | |
# "This app is licensed under the MIT License. See LICENSE.txt for more information." | |
"N/A" | |
) | |
def references(): | |
st.title("📚 References") | |
st.header( | |
"References for Text to SQL project using foundation model - CodeLlama-13b" | |
) | |
st.subheader("1. 'Project for time-series data' by AI VIET NAM, et al.") | |
st.write( | |
"This organization provides a tutorial on how to build a stock price prediction model using LSTM in the AIO2022 course." | |
) | |
st.write("Link: https://www.facebook.com/aivietnam.edu.vn") | |
st.subheader( | |
"2. 'PyTorch LSTMs for time series forecasting of Indian Stocks' by Vinayak Nayak" | |
) | |
st.write( | |
"This blog post describes how to build a stock price prediction model using LSTM, RNN and CNN-sliding window model." | |
) | |
st.write( | |
"Link: https://medium.com/analytics-vidhya/pytorch-lstms-for-time-series-forecasting-of-indian-stocks-8a49157da8b9#b052" | |
) | |
st.header("References for Streamlit") | |
st.subheader("1. Streamlit Documentation") | |
st.write( | |
"The official documentation for Streamlit provides detailed information about how to use the library and build Streamlit apps." | |
) | |
st.write("Link: https://docs.streamlit.io/") | |
st.subheader("2. Streamlit Community") | |
st.write( | |
"The Streamlit community includes a forum and a GitHub repository with examples and resources for building Streamlit apps." | |
) | |
st.write( | |
"Link: https://discuss.streamlit.io/ and https://github.com/streamlit/streamlit/" | |
) | |
# Create the sidebar | |
st.sidebar.title("Menu") | |
pages = [ | |
"Introduction", | |
# "Datasets", | |
"Examples", | |
"Interactive Demo", | |
"About", | |
"References", | |
"Database Query", | |
] | |
selected_page = st.sidebar.radio("Go to", pages) | |
if st.sidebar.button("Clear All"): | |
# Clears all st.cache_resource caches: | |
st.cache_resource.clear() | |
st.cache_data.clear() | |
st.rerun() | |
# Show the appropriate page based on the selection | |
if selected_page == "Introduction": | |
introduction() | |
elif selected_page == "EDA": | |
eda() | |
elif selected_page == "Examples": | |
examples() | |
elif selected_page == "Interactive Demo": | |
interactive_demo() | |
elif selected_page == "About": | |
about() | |
elif selected_page == "References": | |
references() | |
elif selected_page == "Database Query": | |
database_query_page() | |