Canstralian commited on
Commit
0dc11a4
·
verified ·
1 Parent(s): eb5600c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -13
app.py CHANGED
@@ -1,16 +1,44 @@
1
- import gradio as gr
2
- from datasets import load_dataset
 
3
 
4
- # Load datasets
5
- ds1 = load_dataset("b-mc2/sql-create-context")
6
- ds2 = load_dataset("TuneIt/o1-python")
7
- ds3 = load_dataset("HuggingFaceFW/fineweb-2", "aai_Latn")
8
- ds4 = load_dataset("HuggingFaceFW/fineweb-2", "aai_Latn_removed")
9
- ds5 = load_dataset("HuggingFaceFW/fineweb-2", "aak_Latn")
10
- ds6 = load_dataset("sentence-transformers/embedding-training-data")
11
 
12
- # Load the model and create a Gradio interface
13
- demo = gr.load("huggingface/distilbert-base-uncased")
 
 
14
 
15
- # Launch the Gradio app
16
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
 
5
+ # Load the tokenizer and models
6
+ tokenizer = AutoTokenizer.from_pretrained("cssupport/t5-small-awesome-text-to-sql")
7
+ original_model = AutoModelForSeq2SeqLM.from_pretrained("cssupport/t5-small-awesome-text-to-sql", torch_dtype=torch.bfloat16)
8
+ ft_model = AutoModelForSeq2SeqLM.from_pretrained("daljeetsingh/sql_ft_t5small_kag", torch_dtype=torch.bfloat16)
 
 
 
9
 
10
+ # Move models to GPU
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+ original_model.to(device)
13
+ ft_model.to(device)
14
 
15
+ # Streamlit app layout
16
+ st.title("SQL Generation with T5 Models")
17
+
18
+ # Input text box
19
+ input_text = st.text_area("Enter your query:", height=150)
20
+
21
+ # Generate button
22
+ if st.button("Generate SQL"):
23
+ if input_text:
24
+ # Tokenize input
25
+ inputs = tokenizer(input_text, return_tensors='pt').to(device)
26
+
27
+ # Generate SQL queries
28
+ with torch.no_grad():
29
+ original_sql = tokenizer.decode(
30
+ original_model.generate(inputs["input_ids"], max_new_tokens=200)[0],
31
+ skip_special_tokens=True
32
+ )
33
+ ft_sql = tokenizer.decode(
34
+ ft_model.generate(inputs["input_ids"], max_new_tokens=200)[0],
35
+ skip_special_tokens=True
36
+ )
37
+
38
+ # Display results
39
+ st.subheader("Original Model Output")
40
+ st.write(original_sql)
41
+ st.subheader("Fine-Tuned Model Output")
42
+ st.write(ft_sql)
43
+ else:
44
+ st.warning("Please enter a query to generate SQL.")