Qwen3_Medical / app.py
lastmass's picture
Update app.py
798a275 verified
import os
import sys
import time
import urllib.request
from huggingface_hub import snapshot_download
# --- 配置(可通过环境变量覆盖) ---
MODEL_REPO = os.getenv("MODEL_REPO", "lastmass/Qwen3_Medical_GRPO")
MODEL_FILE = os.getenv("MODEL_FILE", "Qwen3_Medical_GRPO.i1-Q4_K_M.gguf")
MODEL_DIR = os.getenv("MODEL_DIR", "/models")
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_FILE)
HF_TOKEN = os.getenv("HF_TOKEN", None) # 如果模型是私有的,请在 Space Secrets 设置
N_THREADS = int(os.getenv("N_THREADS", "8"))
os.makedirs(MODEL_DIR, exist_ok=True)
def download_via_http(url, dest_path, token=None, chunk_size=4*1024*1024):
"""使用 urllib 分块下载,支持 token(Bearer)"""
print(f"Downloading via HTTP: {url} -> {dest_path}")
req = urllib.request.Request(url)
if token:
req.add_header("Authorization", f"Bearer {token}")
try:
with urllib.request.urlopen(req, timeout=120) as resp:
# 如果状态不是 200,抛错
if resp.status not in (200, 302, 301):
raise RuntimeError(f"HTTP download returned status {resp.status}")
# 写入临时文件,下载完成后重命名
tmp_dest = dest_path + ".part"
with open(tmp_dest, "wb") as fh:
while True:
chunk = resp.read(chunk_size)
if not chunk:
break
fh.write(chunk)
os.replace(tmp_dest, dest_path)
print("HTTP download finished.")
except Exception as e:
if os.path.exists(dest_path):
os.remove(dest_path)
raise
def download_model_if_missing():
if os.path.exists(MODEL_PATH):
print(f"Model already exists at {MODEL_PATH}")
return
print(f"Model not found at {MODEL_PATH}. Trying snapshot_download from {MODEL_REPO} ...")
# 先尝试使用 huggingface_hub.snapshot_download(优先)
try:
outdir = snapshot_download(
repo_id=MODEL_REPO,
repo_type="model",
local_dir=MODEL_DIR,
token=HF_TOKEN,
allow_patterns=[MODEL_FILE] # 仅抓取我们需要的文件
)
# snapshot_download 有时会返回 download 目录;搜索目标文件
found = None
for root, _, files in os.walk(outdir):
if MODEL_FILE in files:
found = os.path.join(root, MODEL_FILE)
break
if found:
# 如果找到了,把它移动到 MODEL_PATH(若已在正确位置则跳过)
if os.path.abspath(found) != os.path.abspath(MODEL_PATH):
print(f"Found model at {found}, moving to {MODEL_PATH}")
os.replace(found, MODEL_PATH)
print("snapshot_download succeeded.")
return
else:
print("snapshot_download did not find the file (0 files). Will try direct HTTP download as fallback.")
except Exception as e:
print("snapshot_download failed / returned nothing:", e, file=sys.stderr)
# 备用:直接构造 resolve URL 并下载
direct_url = f"https://huggingface.co/{MODEL_REPO}/resolve/main/{MODEL_FILE}"
try:
download_via_http(direct_url, MODEL_PATH, token=HF_TOKEN)
return
except Exception as e:
print("Direct HTTP download failed:", e, file=sys.stderr)
# 最后再尝试在 MODEL_DIR 下搜索一遍(保险)
for root, _, files in os.walk(MODEL_DIR):
if MODEL_FILE in files:
found = os.path.join(root, MODEL_FILE)
print(f"Found model at {found} after fallback search; moving to {MODEL_PATH}")
os.replace(found, MODEL_PATH)
return
raise RuntimeError(f"Model download finished but {MODEL_PATH} still not found. Check repo contents and network.")
# 执行下载
download_model_if_missing()
time.sleep(0.5)
# --- 导入并初始化 llama_cpp(确保模型存在) ---
try:
from llama_cpp import Llama
except Exception as e:
print("Failed to import llama_cpp. Ensure the wheel matches the runtime and required system libs are present.", file=sys.stderr)
raise
if not os.path.exists(MODEL_PATH):
raise RuntimeError(f"Model path does not exist after download: {MODEL_PATH}")
n_threads = max(1, N_THREADS)
llm = Llama(model_path=MODEL_PATH, n_ctx=4096, n_threads=n_threads)
# --- gradio 接口 ---
import gradio as gr
system_prompt = """You are given a problem.
Think about the problem and provide your working out.
Place it between <start_working_out> and <end_working_out>.
Then, provide your solution between <SOLUTION></SOLUTION>"""
def chat(user_input):
try:
prompt = system_prompt + "\n\nUser input: " + user_input + " <start_working_out>"
response = llm(prompt, max_tokens=1024, temperature=0.7) # 改为1024
return response["choices"][0]["text"]
except Exception as e:
err_msg = f"Error while generating: {e}"
print(err_msg, file=sys.stderr)
return err_msg
with gr.Blocks() as demo:
gr.Markdown("# 🏥 Qwen3 Medical GGUF Demo")
# 添加警告和说明信息
with gr.Row():
with gr.Column():
gr.HTML("""
<div style="background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 5px; padding: 15px; margin-bottom: 20px;">
<h3 style="color: #856404; margin-top: 0;">⚠️ Performance Notice</h3>
<p style="color: #856404; margin-bottom: 10px;">
This demo runs the <strong>lastmass/Qwen3_Medical_GRPO</strong> model (Q4_K_M quantized version)
on Hugging Face's free CPU hardware. Inference is <strong>very slow</strong>.
</p>
<p style="color: #856404; margin-bottom: 0;">
For better performance, we recommend running inference <strong>locally</strong> with GPU acceleration.
Please refer to the <a href="https://huggingface.co/lastmass/Qwen3_Medical_GRPO" target="_blank">model repository</a>
for usage instructions. For optimal performance, use <strong>vLLM</strong> for inference.
</p>
</div>
""")
with gr.Row():
with gr.Column():
input_box = gr.Textbox(label="Input your question", placeholder="Please enter your medical question...")
submit_btn = gr.Button("Generate Response")
with gr.Column():
output_box = gr.Textbox(label="Model Response", lines=10)
submit_btn.click(fn=chat, inputs=input_box, outputs=output_box)
demo.launch(server_name="0.0.0.0", server_port=7860)