Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
@@ -1,17 +1,15 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
-
from huggingface_hub import list_repo_files, hf_hub_download
|
4 |
import subprocess
|
5 |
|
6 |
-
# Constants
|
7 |
HF_USER = "fbaldassarri"
|
8 |
TEQ_KEYWORD = "TEQ"
|
9 |
|
10 |
def list_teq_models():
|
11 |
-
# List all
|
12 |
-
|
13 |
-
|
14 |
-
return [repo.id for repo in repos if TEQ_KEYWORD in repo.id]
|
15 |
|
16 |
def list_model_files(model_id):
|
17 |
# List files in the repo that are likely to be weights/config
|
@@ -48,24 +46,32 @@ def run_teq_inference(model_id, weights_file, config_file, base_model, prompt, m
|
|
48 |
return output.split(marker)[-1].strip()
|
49 |
return output
|
50 |
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
53 |
teq_models = list_teq_models()
|
54 |
with gr.Blocks() as demo:
|
55 |
gr.Markdown("# TEQ Quantized Model Inference Demo")
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
59 |
base_model = gr.Textbox(label="Base Model Name", value="facebook/opt-350m")
|
60 |
prompt = gr.Textbox(label="Prompt", value="Once upon a time, a little girl")
|
61 |
max_new_tokens = gr.Slider(10, 512, value=100, label="Max New Tokens")
|
62 |
debug = gr.Checkbox(label="Debug Mode")
|
63 |
output = gr.Textbox(label="Generated Text", lines=10)
|
64 |
-
def update_files(model_id):
|
65 |
-
weights, configs = list_model_files(model_id)
|
66 |
-
return gr.update(choices=weights), gr.update(choices=configs)
|
67 |
-
model_id.change(update_files, inputs=model_id, outputs=[weights_file, config_file])
|
68 |
run_btn = gr.Button("Run Inference")
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
run_btn.click(
|
70 |
run_teq_inference,
|
71 |
inputs=[model_id, weights_file, config_file, base_model, prompt, max_new_tokens, debug],
|
@@ -74,4 +80,4 @@ def ui():
|
|
74 |
return demo
|
75 |
|
76 |
if __name__ == "__main__":
|
77 |
-
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
+
from huggingface_hub import list_models, list_repo_files, hf_hub_download
|
4 |
import subprocess
|
5 |
|
|
|
6 |
HF_USER = "fbaldassarri"
|
7 |
TEQ_KEYWORD = "TEQ"
|
8 |
|
9 |
def list_teq_models():
|
10 |
+
# List all models for the user, filter those with "TEQ" in the name
|
11 |
+
models = list_models(author=HF_USER)
|
12 |
+
return [model.modelId for model in models if TEQ_KEYWORD in model.modelId]
|
|
|
13 |
|
14 |
def list_model_files(model_id):
|
15 |
# List files in the repo that are likely to be weights/config
|
|
|
46 |
return output.split(marker)[-1].strip()
|
47 |
return output
|
48 |
|
49 |
+
def update_files(model_id):
|
50 |
+
weights, configs = list_model_files(model_id)
|
51 |
+
weights_val = weights[0] if weights else ""
|
52 |
+
configs_val = configs[0] if configs else ""
|
53 |
+
return gr.Dropdown.update(choices=weights, value=weights_val), gr.Dropdown.update(choices=configs, value=configs_val)
|
54 |
+
|
55 |
+
def build_ui():
|
56 |
teq_models = list_teq_models()
|
57 |
with gr.Blocks() as demo:
|
58 |
gr.Markdown("# TEQ Quantized Model Inference Demo")
|
59 |
+
with gr.Row():
|
60 |
+
model_id = gr.Dropdown(teq_models, label="Select TEQ Model")
|
61 |
+
weights_file = gr.Dropdown(choices=[], label="Weights File (.pt)")
|
62 |
+
config_file = gr.Dropdown(choices=[], label="Config File (.json)")
|
63 |
base_model = gr.Textbox(label="Base Model Name", value="facebook/opt-350m")
|
64 |
prompt = gr.Textbox(label="Prompt", value="Once upon a time, a little girl")
|
65 |
max_new_tokens = gr.Slider(10, 512, value=100, label="Max New Tokens")
|
66 |
debug = gr.Checkbox(label="Debug Mode")
|
67 |
output = gr.Textbox(label="Generated Text", lines=10)
|
|
|
|
|
|
|
|
|
68 |
run_btn = gr.Button("Run Inference")
|
69 |
+
|
70 |
+
model_id.change(
|
71 |
+
update_files,
|
72 |
+
inputs=model_id,
|
73 |
+
outputs=[weights_file, config_file]
|
74 |
+
)
|
75 |
run_btn.click(
|
76 |
run_teq_inference,
|
77 |
inputs=[model_id, weights_file, config_file, base_model, prompt, max_new_tokens, debug],
|
|
|
80 |
return demo
|
81 |
|
82 |
if __name__ == "__main__":
|
83 |
+
build_ui().launch()
|