Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
@@ -7,24 +7,20 @@ HF_USER = "fbaldassarri"
|
|
7 |
TEQ_KEYWORD = "TEQ"
|
8 |
|
9 |
def list_teq_models():
|
10 |
-
# List all models for the user, filter those with "TEQ" in the name
|
11 |
models = list_models(author=HF_USER)
|
12 |
return [model.modelId for model in models if TEQ_KEYWORD in model.modelId]
|
13 |
|
14 |
def list_model_files(model_id):
|
15 |
-
# List files in the repo that are likely to be weights/config
|
16 |
files = list_repo_files(model_id)
|
17 |
weights = [f for f in files if f.endswith('.pt')]
|
18 |
configs = [f for f in files if f.endswith('.json')]
|
19 |
return weights, configs
|
20 |
|
21 |
def run_teq_inference(model_id, weights_file, config_file, base_model, prompt, max_new_tokens, debug):
|
22 |
-
# Download files if not present
|
23 |
local_model_dir = f"./models/{model_id.replace('/', '_')}"
|
24 |
os.makedirs(local_model_dir, exist_ok=True)
|
25 |
-
|
26 |
-
|
27 |
-
# Call teq_inference.py as a subprocess for isolation
|
28 |
cmd = [
|
29 |
"python", "teq_inference.py",
|
30 |
"--model_dir", local_model_dir,
|
@@ -38,9 +34,7 @@ def run_teq_inference(model_id, weights_file, config_file, base_model, prompt, m
|
|
38 |
if debug:
|
39 |
cmd.append("--debug")
|
40 |
result = subprocess.run(cmd, capture_output=True, text=True)
|
41 |
-
# Extract generated text from logs
|
42 |
output = result.stdout + "\n" + result.stderr
|
43 |
-
# Try to find the generated text in logs
|
44 |
marker = "Generated text:"
|
45 |
if marker in output:
|
46 |
return output.split(marker)[-1].strip()
|
@@ -48,18 +42,21 @@ def run_teq_inference(model_id, weights_file, config_file, base_model, prompt, m
|
|
48 |
|
49 |
def update_files(model_id):
|
50 |
weights, configs = list_model_files(model_id)
|
|
|
51 |
weights_val = weights[0] if weights else ""
|
52 |
configs_val = configs[0] if configs else ""
|
53 |
-
return
|
|
|
|
|
|
|
54 |
|
55 |
def build_ui():
|
56 |
teq_models = list_teq_models()
|
57 |
with gr.Blocks() as demo:
|
58 |
gr.Markdown("# TEQ Quantized Model Inference Demo")
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
config_file = gr.Dropdown(choices=[], label="Config File (.json)")
|
63 |
base_model = gr.Textbox(label="Base Model Name", value="facebook/opt-350m")
|
64 |
prompt = gr.Textbox(label="Prompt", value="Once upon a time, a little girl")
|
65 |
max_new_tokens = gr.Slider(10, 512, value=100, label="Max New Tokens")
|
@@ -67,6 +64,7 @@ def build_ui():
|
|
67 |
output = gr.Textbox(label="Generated Text", lines=10)
|
68 |
run_btn = gr.Button("Run Inference")
|
69 |
|
|
|
70 |
model_id.change(
|
71 |
update_files,
|
72 |
inputs=model_id,
|
|
|
7 |
TEQ_KEYWORD = "TEQ"
|
8 |
|
9 |
def list_teq_models():
|
|
|
10 |
models = list_models(author=HF_USER)
|
11 |
return [model.modelId for model in models if TEQ_KEYWORD in model.modelId]
|
12 |
|
13 |
def list_model_files(model_id):
|
|
|
14 |
files = list_repo_files(model_id)
|
15 |
weights = [f for f in files if f.endswith('.pt')]
|
16 |
configs = [f for f in files if f.endswith('.json')]
|
17 |
return weights, configs
|
18 |
|
19 |
def run_teq_inference(model_id, weights_file, config_file, base_model, prompt, max_new_tokens, debug):
|
|
|
20 |
local_model_dir = f"./models/{model_id.replace('/', '_')}"
|
21 |
os.makedirs(local_model_dir, exist_ok=True)
|
22 |
+
hf_hub_download(model_id, weights_file, local_dir=local_model_dir)
|
23 |
+
hf_hub_download(model_id, config_file, local_dir=local_model_dir)
|
|
|
24 |
cmd = [
|
25 |
"python", "teq_inference.py",
|
26 |
"--model_dir", local_model_dir,
|
|
|
34 |
if debug:
|
35 |
cmd.append("--debug")
|
36 |
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
37 |
output = result.stdout + "\n" + result.stderr
|
|
|
38 |
marker = "Generated text:"
|
39 |
if marker in output:
|
40 |
return output.split(marker)[-1].strip()
|
|
|
42 |
|
43 |
def update_files(model_id):
|
44 |
weights, configs = list_model_files(model_id)
|
45 |
+
# Default to first file if available, else empty string
|
46 |
weights_val = weights[0] if weights else ""
|
47 |
configs_val = configs[0] if configs else ""
|
48 |
+
return (
|
49 |
+
gr.Dropdown.update(choices=weights, value=weights_val),
|
50 |
+
gr.Dropdown.update(choices=configs, value=configs_val)
|
51 |
+
)
|
52 |
|
53 |
def build_ui():
|
54 |
teq_models = list_teq_models()
|
55 |
with gr.Blocks() as demo:
|
56 |
gr.Markdown("# TEQ Quantized Model Inference Demo")
|
57 |
+
model_id = gr.Dropdown(teq_models, label="Select TEQ Model", interactive=True)
|
58 |
+
weights_file = gr.Dropdown(choices=[], label="Weights File (.pt)", interactive=True)
|
59 |
+
config_file = gr.Dropdown(choices=[], label="Config File (.json)", interactive=True)
|
|
|
60 |
base_model = gr.Textbox(label="Base Model Name", value="facebook/opt-350m")
|
61 |
prompt = gr.Textbox(label="Prompt", value="Once upon a time, a little girl")
|
62 |
max_new_tokens = gr.Slider(10, 512, value=100, label="Max New Tokens")
|
|
|
64 |
output = gr.Textbox(label="Generated Text", lines=10)
|
65 |
run_btn = gr.Button("Run Inference")
|
66 |
|
67 |
+
# When model_id changes, update weights_file and config_file dropdowns
|
68 |
model_id.change(
|
69 |
update_files,
|
70 |
inputs=model_id,
|