fbaldassarri commited on
Commit
82c86d4
·
verified ·
1 Parent(s): dff2fe0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -16
app.py CHANGED
@@ -1,17 +1,15 @@
1
  import os
2
  import gradio as gr
3
- from huggingface_hub import list_repo_files, hf_hub_download
4
  import subprocess
5
 
6
- # Constants
7
  HF_USER = "fbaldassarri"
8
  TEQ_KEYWORD = "TEQ"
9
 
10
  def list_teq_models():
11
- # List all repos with TEQ in their name
12
- from huggingface_hub import list_repos
13
- repos = list_repos(HF_USER)
14
- return [repo.id for repo in repos if TEQ_KEYWORD in repo.id]
15
 
16
  def list_model_files(model_id):
17
  # List files in the repo that are likely to be weights/config
@@ -48,24 +46,32 @@ def run_teq_inference(model_id, weights_file, config_file, base_model, prompt, m
48
  return output.split(marker)[-1].strip()
49
  return output
50
 
51
- # Gradio UI
52
- def ui():
 
 
 
 
 
53
  teq_models = list_teq_models()
54
  with gr.Blocks() as demo:
55
  gr.Markdown("# TEQ Quantized Model Inference Demo")
56
- model_id = gr.Dropdown(teq_models, label="Select TEQ Model")
57
- weights_file = gr.Textbox(label="Weights File (.pt)")
58
- config_file = gr.Textbox(label="Config File (.json)")
 
59
  base_model = gr.Textbox(label="Base Model Name", value="facebook/opt-350m")
60
  prompt = gr.Textbox(label="Prompt", value="Once upon a time, a little girl")
61
  max_new_tokens = gr.Slider(10, 512, value=100, label="Max New Tokens")
62
  debug = gr.Checkbox(label="Debug Mode")
63
  output = gr.Textbox(label="Generated Text", lines=10)
64
- def update_files(model_id):
65
- weights, configs = list_model_files(model_id)
66
- return gr.update(choices=weights), gr.update(choices=configs)
67
- model_id.change(update_files, inputs=model_id, outputs=[weights_file, config_file])
68
  run_btn = gr.Button("Run Inference")
 
 
 
 
 
 
69
  run_btn.click(
70
  run_teq_inference,
71
  inputs=[model_id, weights_file, config_file, base_model, prompt, max_new_tokens, debug],
@@ -74,4 +80,4 @@ def ui():
74
  return demo
75
 
76
  if __name__ == "__main__":
77
- ui().launch()
 
1
  import os
2
  import gradio as gr
3
+ from huggingface_hub import list_models, list_repo_files, hf_hub_download
4
  import subprocess
5
 
 
6
  HF_USER = "fbaldassarri"
7
  TEQ_KEYWORD = "TEQ"
8
 
9
  def list_teq_models():
10
+ # List all models for the user, filter those with "TEQ" in the name
11
+ models = list_models(author=HF_USER)
12
+ return [model.modelId for model in models if TEQ_KEYWORD in model.modelId]
 
13
 
14
  def list_model_files(model_id):
15
  # List files in the repo that are likely to be weights/config
 
46
  return output.split(marker)[-1].strip()
47
  return output
48
 
49
+ def update_files(model_id):
50
+ weights, configs = list_model_files(model_id)
51
+ weights_val = weights[0] if weights else ""
52
+ configs_val = configs[0] if configs else ""
53
+ return gr.Dropdown.update(choices=weights, value=weights_val), gr.Dropdown.update(choices=configs, value=configs_val)
54
+
55
+ def build_ui():
56
  teq_models = list_teq_models()
57
  with gr.Blocks() as demo:
58
  gr.Markdown("# TEQ Quantized Model Inference Demo")
59
+ with gr.Row():
60
+ model_id = gr.Dropdown(teq_models, label="Select TEQ Model")
61
+ weights_file = gr.Dropdown(choices=[], label="Weights File (.pt)")
62
+ config_file = gr.Dropdown(choices=[], label="Config File (.json)")
63
  base_model = gr.Textbox(label="Base Model Name", value="facebook/opt-350m")
64
  prompt = gr.Textbox(label="Prompt", value="Once upon a time, a little girl")
65
  max_new_tokens = gr.Slider(10, 512, value=100, label="Max New Tokens")
66
  debug = gr.Checkbox(label="Debug Mode")
67
  output = gr.Textbox(label="Generated Text", lines=10)
 
 
 
 
68
  run_btn = gr.Button("Run Inference")
69
+
70
+ model_id.change(
71
+ update_files,
72
+ inputs=model_id,
73
+ outputs=[weights_file, config_file]
74
+ )
75
  run_btn.click(
76
  run_teq_inference,
77
  inputs=[model_id, weights_file, config_file, base_model, prompt, max_new_tokens, debug],
 
80
  return demo
81
 
82
  if __name__ == "__main__":
83
+ build_ui().launch()