Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -1,17 +1,15 @@
|
|
| 1 |
import os
|
| 2 |
import gradio as gr
|
| 3 |
-
from huggingface_hub import list_repo_files, hf_hub_download
|
| 4 |
import subprocess
|
| 5 |
|
| 6 |
-
# Constants
|
| 7 |
HF_USER = "fbaldassarri"
|
| 8 |
TEQ_KEYWORD = "TEQ"
|
| 9 |
|
| 10 |
def list_teq_models():
|
| 11 |
-
# List all
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
return [repo.id for repo in repos if TEQ_KEYWORD in repo.id]
|
| 15 |
|
| 16 |
def list_model_files(model_id):
|
| 17 |
# List files in the repo that are likely to be weights/config
|
|
@@ -48,24 +46,32 @@ def run_teq_inference(model_id, weights_file, config_file, base_model, prompt, m
|
|
| 48 |
return output.split(marker)[-1].strip()
|
| 49 |
return output
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
teq_models = list_teq_models()
|
| 54 |
with gr.Blocks() as demo:
|
| 55 |
gr.Markdown("# TEQ Quantized Model Inference Demo")
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
| 59 |
base_model = gr.Textbox(label="Base Model Name", value="facebook/opt-350m")
|
| 60 |
prompt = gr.Textbox(label="Prompt", value="Once upon a time, a little girl")
|
| 61 |
max_new_tokens = gr.Slider(10, 512, value=100, label="Max New Tokens")
|
| 62 |
debug = gr.Checkbox(label="Debug Mode")
|
| 63 |
output = gr.Textbox(label="Generated Text", lines=10)
|
| 64 |
-
def update_files(model_id):
|
| 65 |
-
weights, configs = list_model_files(model_id)
|
| 66 |
-
return gr.update(choices=weights), gr.update(choices=configs)
|
| 67 |
-
model_id.change(update_files, inputs=model_id, outputs=[weights_file, config_file])
|
| 68 |
run_btn = gr.Button("Run Inference")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
run_btn.click(
|
| 70 |
run_teq_inference,
|
| 71 |
inputs=[model_id, weights_file, config_file, base_model, prompt, max_new_tokens, debug],
|
|
@@ -74,4 +80,4 @@ def ui():
|
|
| 74 |
return demo
|
| 75 |
|
| 76 |
if __name__ == "__main__":
|
| 77 |
-
|
|
|
|
| 1 |
import os
|
| 2 |
import gradio as gr
|
| 3 |
+
from huggingface_hub import list_models, list_repo_files, hf_hub_download
|
| 4 |
import subprocess
|
| 5 |
|
|
|
|
| 6 |
HF_USER = "fbaldassarri"
|
| 7 |
TEQ_KEYWORD = "TEQ"
|
| 8 |
|
| 9 |
def list_teq_models():
|
| 10 |
+
# List all models for the user, filter those with "TEQ" in the name
|
| 11 |
+
models = list_models(author=HF_USER)
|
| 12 |
+
return [model.modelId for model in models if TEQ_KEYWORD in model.modelId]
|
|
|
|
| 13 |
|
| 14 |
def list_model_files(model_id):
|
| 15 |
# List files in the repo that are likely to be weights/config
|
|
|
|
| 46 |
return output.split(marker)[-1].strip()
|
| 47 |
return output
|
| 48 |
|
| 49 |
+
def update_files(model_id):
|
| 50 |
+
weights, configs = list_model_files(model_id)
|
| 51 |
+
weights_val = weights[0] if weights else ""
|
| 52 |
+
configs_val = configs[0] if configs else ""
|
| 53 |
+
return gr.Dropdown.update(choices=weights, value=weights_val), gr.Dropdown.update(choices=configs, value=configs_val)
|
| 54 |
+
|
| 55 |
+
def build_ui():
|
| 56 |
teq_models = list_teq_models()
|
| 57 |
with gr.Blocks() as demo:
|
| 58 |
gr.Markdown("# TEQ Quantized Model Inference Demo")
|
| 59 |
+
with gr.Row():
|
| 60 |
+
model_id = gr.Dropdown(teq_models, label="Select TEQ Model")
|
| 61 |
+
weights_file = gr.Dropdown(choices=[], label="Weights File (.pt)")
|
| 62 |
+
config_file = gr.Dropdown(choices=[], label="Config File (.json)")
|
| 63 |
base_model = gr.Textbox(label="Base Model Name", value="facebook/opt-350m")
|
| 64 |
prompt = gr.Textbox(label="Prompt", value="Once upon a time, a little girl")
|
| 65 |
max_new_tokens = gr.Slider(10, 512, value=100, label="Max New Tokens")
|
| 66 |
debug = gr.Checkbox(label="Debug Mode")
|
| 67 |
output = gr.Textbox(label="Generated Text", lines=10)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
run_btn = gr.Button("Run Inference")
|
| 69 |
+
|
| 70 |
+
model_id.change(
|
| 71 |
+
update_files,
|
| 72 |
+
inputs=model_id,
|
| 73 |
+
outputs=[weights_file, config_file]
|
| 74 |
+
)
|
| 75 |
run_btn.click(
|
| 76 |
run_teq_inference,
|
| 77 |
inputs=[model_id, weights_file, config_file, base_model, prompt, max_new_tokens, debug],
|
|
|
|
| 80 |
return demo
|
| 81 |
|
| 82 |
if __name__ == "__main__":
|
| 83 |
+
build_ui().launch()
|