Spaces:
Sleeping
Sleeping
import streamlit as st | |
import requests | |
import subprocess | |
import re | |
import sys | |
PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}\n### Question:\n{question}\n\n### Response (use duckdb shorthand if possible):\n""" | |
INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}""" # noqa: E501 | |
ERROR_MESSAGE = ":red[ Quack! Much to our regret, SQL generation has gone a tad duck-side-down.\nThe model is currently not able to craft a correct SQL query for this request. \nSorry my duck friend. ]\n\n:red[If the question is about your own database, make sure to set the correct schema. Otherwise, try to rephrase your request. ]\n\n```sql\n{sql_query}\n```\n\n```sql\n{error_msg}\n```" | |
STOP_TOKENS = ["###", ";", "--", "```"] | |
def generate_prompt(question, schema): | |
input = "" | |
if schema: | |
# Lowercase types inside each CREATE TABLE (...) statement | |
for create_table in re.findall( | |
r"CREATE TABLE [^(]+\((.*?)\);", schema, flags=re.DOTALL | re.MULTILINE | |
): | |
for create_col in re.findall(r"(\w+) (\w+)", create_table): | |
schema = schema.replace( | |
f"{create_col[0]} {create_col[1]}", | |
f"{create_col[0]} {create_col[1].lower()}", | |
) | |
input = """Here is the database schema that the SQL query will run on:\n{schema}\n""".format( # noqa: E501 | |
schema=schema | |
) | |
prompt = PROMPT_TEMPLATE.format( | |
instruction=INSTRUCTION_TEMPLATE.format( | |
has_schema="." if schema == "" else ", given a duckdb database schema." | |
), | |
input=input, | |
question=question, | |
) | |
return prompt | |
def generate_sql(question, schema): | |
prompt = generate_prompt(question, schema) | |
s = requests.Session() | |
api_base = "https://text-motherduck-sql-fp16-4vycuix6qcp2.octoai.run" | |
url = f"{api_base}/v1/completions" | |
body = { | |
"model": "motherduck-sql-fp16", | |
"prompt": prompt, | |
"temperature": 0.1, | |
"max_tokens": 200, | |
"stop": "<s>", | |
"n": 1, | |
} | |
headers = {"Authorization": f"Bearer {st.secrets['octoml_token']}"} | |
with s.post(url, json=body, headers=headers) as resp: | |
sql_query = resp.json()["choices"][0]["text"] | |
return sql_query | |
def validate_sql(query, schema): | |
try: | |
# Define subprocess | |
process = subprocess.Popen( | |
[sys.executable, './validate_sql.py', query, schema], | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE | |
) | |
# Get output and potential parser, and binder error message | |
stdout, stderr = process.communicate(timeout=0.5) | |
if stderr: | |
error_message = stderr.decode('utf8').split("\n") | |
# skip traceback | |
if len(error_message) > 3: | |
error_message = "\n".join(error_message[3:]) | |
return False, error_message | |
return True, "" | |
except subprocess.TimeoutExpired: | |
process.kill() | |
# timeout reached, so parsing and binding was very likely successful | |
return True, "" | |
st.title("DuckDB-NSQL-7B Demo") | |
expander = st.expander("Customize Schema (Optional)") | |
expander.markdown( | |
"If you DuckDB database is `database.duckdb`, execute this query in your terminal to get your current schema:" | |
) | |
expander.markdown( | |
"""```bash\necho ".schema" | duckdb database.duckdb | sed 's/(/(\\n /g' | sed 's/, /,\\n /g' | sed 's/);/\\n);\\n/g'\n```""", | |
) | |
# Input field for text prompt | |
default_schema = """CREATE TABLE rideshare( | |
hvfhs_license_num VARCHAR, | |
dispatching_base_num VARCHAR, | |
originating_base_num VARCHAR, | |
request_datetime TIMESTAMP, | |
on_scene_datetime TIMESTAMP, | |
pickup_datetime TIMESTAMP, | |
dropoff_datetime TIMESTAMP, | |
PULocationID BIGINT, | |
DOLocationID BIGINT, | |
trip_miles DOUBLE, | |
trip_time BIGINT, | |
base_passenger_fare DOUBLE, | |
tolls DOUBLE, | |
bcf DOUBLE, | |
sales_tax DOUBLE, | |
congestion_surcharge DOUBLE, | |
airport_fee DOUBLE, | |
tips DOUBLE, | |
driver_pay DOUBLE, | |
shared_request_flag VARCHAR, | |
shared_match_flag VARCHAR, | |
access_a_ride_flag VARCHAR, | |
wav_request_flag VARCHAR, | |
wav_match_flag VARCHAR | |
); | |
CREATE TABLE service_requests( | |
unique_key BIGINT, | |
created_date TIMESTAMP, | |
closed_date TIMESTAMP, | |
agency VARCHAR, | |
agency_name VARCHAR, | |
complaint_type VARCHAR, | |
descriptor VARCHAR, | |
location_type VARCHAR, | |
incident_zip VARCHAR, | |
incident_address VARCHAR, | |
street_name VARCHAR, | |
cross_street_1 VARCHAR, | |
cross_street_2 VARCHAR, | |
intersection_street_1 VARCHAR, | |
intersection_street_2 VARCHAR, | |
address_type VARCHAR, | |
city VARCHAR, | |
landmark VARCHAR, | |
facility_type VARCHAR, | |
status VARCHAR, | |
due_date TIMESTAMP, | |
resolution_description VARCHAR, | |
resolution_action_updated_date TIMESTAMP, | |
community_board VARCHAR, | |
bbl VARCHAR, | |
borough VARCHAR, | |
x_coordinate_state_plane VARCHAR, | |
y_coordinate_state_plane VARCHAR, | |
open_data_channel_type VARCHAR, | |
park_facility_name VARCHAR, | |
park_borough VARCHAR, | |
vehicle_type VARCHAR, | |
taxi_company_borough VARCHAR, | |
taxi_pick_up_location VARCHAR, | |
bridge_highway_name VARCHAR, | |
bridge_highway_direction VARCHAR, | |
road_ramp VARCHAR, | |
bridge_highway_segment VARCHAR, | |
latitude DOUBLE, | |
longitude DOUBLE | |
); | |
CREATE TABLE taxi( | |
VendorID BIGINT, | |
tpep_pickup_datetime TIMESTAMP, | |
tpep_dropoff_datetime TIMESTAMP, | |
passenger_count DOUBLE, | |
trip_distance DOUBLE, | |
RatecodeID DOUBLE, | |
store_and_fwd_flag VARCHAR, | |
PULocationID BIGINT, | |
DOLocationID BIGINT, | |
payment_type BIGINT, | |
fare_amount DOUBLE, | |
extra DOUBLE, | |
mta_tax DOUBLE, | |
tip_amount DOUBLE, | |
tolls_amount DOUBLE, | |
improvement_surcharge DOUBLE, | |
total_amount DOUBLE, | |
congestion_surcharge DOUBLE, | |
airport_fee DOUBLE, | |
drivers VARCHAR[], | |
speeding_tickets STRUCT(date TIMESTAMP, speed VARCHAR)[], | |
other_violations JSON | |
);""" | |
schema = expander.text_area("Current schema:", value=default_schema, height=500) | |
# Input field for text prompt | |
text_prompt = st.text_input( | |
"What DuckDB SQL query can I write for you?", value="Read a CSV file from test.csv" | |
) | |
if text_prompt: | |
sql_query = generate_sql(text_prompt, schema) | |
valid, msg = validate_sql(sql_query, schema) | |
if not valid: | |
st.markdown(ERROR_MESSAGE.format(sql_query=sql_query, error_msg=msg)) | |
else: | |
st.markdown(f"""```sql\n{sql_query}\n```""") | |