Spaces:
Running
Running
import os | |
import sys | |
import time | |
import urllib.request | |
from huggingface_hub import snapshot_download | |
# --- 配置(可通过环境变量覆盖) --- | |
MODEL_REPO = os.getenv("MODEL_REPO", "lastmass/Qwen3_Medical_GRPO") | |
MODEL_FILE = os.getenv("MODEL_FILE", "Qwen3_Medical_GRPO.i1-Q4_K_M.gguf") | |
MODEL_DIR = os.getenv("MODEL_DIR", "/models") | |
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_FILE) | |
HF_TOKEN = os.getenv("HF_TOKEN", None) # 如果模型是私有的,请在 Space Secrets 设置 | |
N_THREADS = int(os.getenv("N_THREADS", "8")) | |
os.makedirs(MODEL_DIR, exist_ok=True) | |
def download_via_http(url, dest_path, token=None, chunk_size=4*1024*1024): | |
"""使用 urllib 分块下载,支持 token(Bearer)""" | |
print(f"Downloading via HTTP: {url} -> {dest_path}") | |
req = urllib.request.Request(url) | |
if token: | |
req.add_header("Authorization", f"Bearer {token}") | |
try: | |
with urllib.request.urlopen(req, timeout=120) as resp: | |
# 如果状态不是 200,抛错 | |
if resp.status not in (200, 302, 301): | |
raise RuntimeError(f"HTTP download returned status {resp.status}") | |
# 写入临时文件,下载完成后重命名 | |
tmp_dest = dest_path + ".part" | |
with open(tmp_dest, "wb") as fh: | |
while True: | |
chunk = resp.read(chunk_size) | |
if not chunk: | |
break | |
fh.write(chunk) | |
os.replace(tmp_dest, dest_path) | |
print("HTTP download finished.") | |
except Exception as e: | |
if os.path.exists(dest_path): | |
os.remove(dest_path) | |
raise | |
def download_model_if_missing(): | |
if os.path.exists(MODEL_PATH): | |
print(f"Model already exists at {MODEL_PATH}") | |
return | |
print(f"Model not found at {MODEL_PATH}. Trying snapshot_download from {MODEL_REPO} ...") | |
# 先尝试使用 huggingface_hub.snapshot_download(优先) | |
try: | |
outdir = snapshot_download( | |
repo_id=MODEL_REPO, | |
repo_type="model", | |
local_dir=MODEL_DIR, | |
token=HF_TOKEN, | |
allow_patterns=[MODEL_FILE] # 仅抓取我们需要的文件 | |
) | |
# snapshot_download 有时会返回 download 目录;搜索目标文件 | |
found = None | |
for root, _, files in os.walk(outdir): | |
if MODEL_FILE in files: | |
found = os.path.join(root, MODEL_FILE) | |
break | |
if found: | |
# 如果找到了,把它移动到 MODEL_PATH(若已在正确位置则跳过) | |
if os.path.abspath(found) != os.path.abspath(MODEL_PATH): | |
print(f"Found model at {found}, moving to {MODEL_PATH}") | |
os.replace(found, MODEL_PATH) | |
print("snapshot_download succeeded.") | |
return | |
else: | |
print("snapshot_download did not find the file (0 files). Will try direct HTTP download as fallback.") | |
except Exception as e: | |
print("snapshot_download failed / returned nothing:", e, file=sys.stderr) | |
# 备用:直接构造 resolve URL 并下载 | |
direct_url = f"https://huggingface.co/{MODEL_REPO}/resolve/main/{MODEL_FILE}" | |
try: | |
download_via_http(direct_url, MODEL_PATH, token=HF_TOKEN) | |
return | |
except Exception as e: | |
print("Direct HTTP download failed:", e, file=sys.stderr) | |
# 最后再尝试在 MODEL_DIR 下搜索一遍(保险) | |
for root, _, files in os.walk(MODEL_DIR): | |
if MODEL_FILE in files: | |
found = os.path.join(root, MODEL_FILE) | |
print(f"Found model at {found} after fallback search; moving to {MODEL_PATH}") | |
os.replace(found, MODEL_PATH) | |
return | |
raise RuntimeError(f"Model download finished but {MODEL_PATH} still not found. Check repo contents and network.") | |
# 执行下载 | |
download_model_if_missing() | |
time.sleep(0.5) | |
# --- 导入并初始化 llama_cpp(确保模型存在) --- | |
try: | |
from llama_cpp import Llama | |
except Exception as e: | |
print("Failed to import llama_cpp. Ensure the wheel matches the runtime and required system libs are present.", file=sys.stderr) | |
raise | |
if not os.path.exists(MODEL_PATH): | |
raise RuntimeError(f"Model path does not exist after download: {MODEL_PATH}") | |
n_threads = max(1, N_THREADS) | |
llm = Llama(model_path=MODEL_PATH, n_ctx=4096, n_threads=n_threads) | |
# --- gradio 接口 --- | |
import gradio as gr | |
system_prompt = """You are given a problem. | |
Think about the problem and provide your working out. | |
Place it between <start_working_out> and <end_working_out>. | |
Then, provide your solution between <SOLUTION></SOLUTION>""" | |
def chat(user_input): | |
try: | |
prompt = system_prompt + "\n\nUser input: " + user_input + " <start_working_out>" | |
response = llm(prompt, max_tokens=1024, temperature=0.7) # 改为1024 | |
return response["choices"][0]["text"] | |
except Exception as e: | |
err_msg = f"Error while generating: {e}" | |
print(err_msg, file=sys.stderr) | |
return err_msg | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🏥 Qwen3 Medical GGUF Demo") | |
# 添加警告和说明信息 | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML(""" | |
<div style="background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 5px; padding: 15px; margin-bottom: 20px;"> | |
<h3 style="color: #856404; margin-top: 0;">⚠️ Performance Notice</h3> | |
<p style="color: #856404; margin-bottom: 10px;"> | |
This demo runs the <strong>lastmass/Qwen3_Medical_GRPO</strong> model (Q4_K_M quantized version) | |
on Hugging Face's free CPU hardware. Inference is <strong>very slow</strong>. | |
</p> | |
<p style="color: #856404; margin-bottom: 0;"> | |
For better performance, we recommend running inference <strong>locally</strong> with GPU acceleration. | |
Please refer to the <a href="https://huggingface.co/lastmass/Qwen3_Medical_GRPO" target="_blank">model repository</a> | |
for usage instructions. For optimal performance, use <strong>vLLM</strong> for inference. | |
</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
input_box = gr.Textbox(label="Input your question", placeholder="Please enter your medical question...") | |
submit_btn = gr.Button("Generate Response") | |
with gr.Column(): | |
output_box = gr.Textbox(label="Model Response", lines=10) | |
submit_btn.click(fn=chat, inputs=input_box, outputs=output_box) | |
demo.launch(server_name="0.0.0.0", server_port=7860) |