fbaldassarri commited on
Commit
2b10ab8
·
verified ·
1 Parent(s): 82c86d4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -7,24 +7,20 @@ HF_USER = "fbaldassarri"
7
  TEQ_KEYWORD = "TEQ"
8
 
9
  def list_teq_models():
10
- # List all models for the user, filter those with "TEQ" in the name
11
  models = list_models(author=HF_USER)
12
  return [model.modelId for model in models if TEQ_KEYWORD in model.modelId]
13
 
14
  def list_model_files(model_id):
15
- # List files in the repo that are likely to be weights/config
16
  files = list_repo_files(model_id)
17
  weights = [f for f in files if f.endswith('.pt')]
18
  configs = [f for f in files if f.endswith('.json')]
19
  return weights, configs
20
 
21
  def run_teq_inference(model_id, weights_file, config_file, base_model, prompt, max_new_tokens, debug):
22
- # Download files if not present
23
  local_model_dir = f"./models/{model_id.replace('/', '_')}"
24
  os.makedirs(local_model_dir, exist_ok=True)
25
- weights_path = hf_hub_download(model_id, weights_file, local_dir=local_model_dir)
26
- config_path = hf_hub_download(model_id, config_file, local_dir=local_model_dir)
27
- # Call teq_inference.py as a subprocess for isolation
28
  cmd = [
29
  "python", "teq_inference.py",
30
  "--model_dir", local_model_dir,
@@ -38,9 +34,7 @@ def run_teq_inference(model_id, weights_file, config_file, base_model, prompt, m
38
  if debug:
39
  cmd.append("--debug")
40
  result = subprocess.run(cmd, capture_output=True, text=True)
41
- # Extract generated text from logs
42
  output = result.stdout + "\n" + result.stderr
43
- # Try to find the generated text in logs
44
  marker = "Generated text:"
45
  if marker in output:
46
  return output.split(marker)[-1].strip()
@@ -48,18 +42,21 @@ def run_teq_inference(model_id, weights_file, config_file, base_model, prompt, m
48
 
49
  def update_files(model_id):
50
  weights, configs = list_model_files(model_id)
 
51
  weights_val = weights[0] if weights else ""
52
  configs_val = configs[0] if configs else ""
53
- return gr.Dropdown.update(choices=weights, value=weights_val), gr.Dropdown.update(choices=configs, value=configs_val)
 
 
 
54
 
55
  def build_ui():
56
  teq_models = list_teq_models()
57
  with gr.Blocks() as demo:
58
  gr.Markdown("# TEQ Quantized Model Inference Demo")
59
- with gr.Row():
60
- model_id = gr.Dropdown(teq_models, label="Select TEQ Model")
61
- weights_file = gr.Dropdown(choices=[], label="Weights File (.pt)")
62
- config_file = gr.Dropdown(choices=[], label="Config File (.json)")
63
  base_model = gr.Textbox(label="Base Model Name", value="facebook/opt-350m")
64
  prompt = gr.Textbox(label="Prompt", value="Once upon a time, a little girl")
65
  max_new_tokens = gr.Slider(10, 512, value=100, label="Max New Tokens")
@@ -67,6 +64,7 @@ def build_ui():
67
  output = gr.Textbox(label="Generated Text", lines=10)
68
  run_btn = gr.Button("Run Inference")
69
 
 
70
  model_id.change(
71
  update_files,
72
  inputs=model_id,
 
7
  TEQ_KEYWORD = "TEQ"
8
 
9
  def list_teq_models():
 
10
  models = list_models(author=HF_USER)
11
  return [model.modelId for model in models if TEQ_KEYWORD in model.modelId]
12
 
13
  def list_model_files(model_id):
 
14
  files = list_repo_files(model_id)
15
  weights = [f for f in files if f.endswith('.pt')]
16
  configs = [f for f in files if f.endswith('.json')]
17
  return weights, configs
18
 
19
  def run_teq_inference(model_id, weights_file, config_file, base_model, prompt, max_new_tokens, debug):
 
20
  local_model_dir = f"./models/{model_id.replace('/', '_')}"
21
  os.makedirs(local_model_dir, exist_ok=True)
22
+ hf_hub_download(model_id, weights_file, local_dir=local_model_dir)
23
+ hf_hub_download(model_id, config_file, local_dir=local_model_dir)
 
24
  cmd = [
25
  "python", "teq_inference.py",
26
  "--model_dir", local_model_dir,
 
34
  if debug:
35
  cmd.append("--debug")
36
  result = subprocess.run(cmd, capture_output=True, text=True)
 
37
  output = result.stdout + "\n" + result.stderr
 
38
  marker = "Generated text:"
39
  if marker in output:
40
  return output.split(marker)[-1].strip()
 
42
 
43
  def update_files(model_id):
44
  weights, configs = list_model_files(model_id)
45
+ # Default to first file if available, else empty string
46
  weights_val = weights[0] if weights else ""
47
  configs_val = configs[0] if configs else ""
48
+ return (
49
+ gr.Dropdown.update(choices=weights, value=weights_val),
50
+ gr.Dropdown.update(choices=configs, value=configs_val)
51
+ )
52
 
53
  def build_ui():
54
  teq_models = list_teq_models()
55
  with gr.Blocks() as demo:
56
  gr.Markdown("# TEQ Quantized Model Inference Demo")
57
+ model_id = gr.Dropdown(teq_models, label="Select TEQ Model", interactive=True)
58
+ weights_file = gr.Dropdown(choices=[], label="Weights File (.pt)", interactive=True)
59
+ config_file = gr.Dropdown(choices=[], label="Config File (.json)", interactive=True)
 
60
  base_model = gr.Textbox(label="Base Model Name", value="facebook/opt-350m")
61
  prompt = gr.Textbox(label="Prompt", value="Once upon a time, a little girl")
62
  max_new_tokens = gr.Slider(10, 512, value=100, label="Max New Tokens")
 
64
  output = gr.Textbox(label="Generated Text", lines=10)
65
  run_btn = gr.Button("Run Inference")
66
 
67
+ # When model_id changes, update weights_file and config_file dropdowns
68
  model_id.change(
69
  update_files,
70
  inputs=model_id,