Spaces:
Running
Running
File size: 6,738 Bytes
1fc7ac0 e915b53 8c137b5 e915b53 798a275 e915b53 8c137b5 e915b53 8c137b5 e915b53 8c137b5 e915b53 8c137b5 e915b53 8c137b5 e915b53 8c137b5 e915b53 8c137b5 e915b53 8c137b5 e915b53 8c137b5 1fc7ac0 8c137b5 e915b53 8c137b5 e915b53 8c137b5 e915b53 8c137b5 e915b53 8c137b5 e915b53 1fc7ac0 e915b53 798a275 e915b53 1fc7ac0 798a275 1fc7ac0 798a275 1fc7ac0 798a275 1fc7ac0 798a275 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os
import sys
import time
import urllib.request
from huggingface_hub import snapshot_download
# --- 配置(可通过环境变量覆盖) ---
MODEL_REPO = os.getenv("MODEL_REPO", "lastmass/Qwen3_Medical_GRPO")
MODEL_FILE = os.getenv("MODEL_FILE", "Qwen3_Medical_GRPO.i1-Q4_K_M.gguf")
MODEL_DIR = os.getenv("MODEL_DIR", "/models")
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_FILE)
HF_TOKEN = os.getenv("HF_TOKEN", None) # 如果模型是私有的,请在 Space Secrets 设置
N_THREADS = int(os.getenv("N_THREADS", "8"))
os.makedirs(MODEL_DIR, exist_ok=True)
def download_via_http(url, dest_path, token=None, chunk_size=4*1024*1024):
"""使用 urllib 分块下载,支持 token(Bearer)"""
print(f"Downloading via HTTP: {url} -> {dest_path}")
req = urllib.request.Request(url)
if token:
req.add_header("Authorization", f"Bearer {token}")
try:
with urllib.request.urlopen(req, timeout=120) as resp:
# 如果状态不是 200,抛错
if resp.status not in (200, 302, 301):
raise RuntimeError(f"HTTP download returned status {resp.status}")
# 写入临时文件,下载完成后重命名
tmp_dest = dest_path + ".part"
with open(tmp_dest, "wb") as fh:
while True:
chunk = resp.read(chunk_size)
if not chunk:
break
fh.write(chunk)
os.replace(tmp_dest, dest_path)
print("HTTP download finished.")
except Exception as e:
if os.path.exists(dest_path):
os.remove(dest_path)
raise
def download_model_if_missing():
if os.path.exists(MODEL_PATH):
print(f"Model already exists at {MODEL_PATH}")
return
print(f"Model not found at {MODEL_PATH}. Trying snapshot_download from {MODEL_REPO} ...")
# 先尝试使用 huggingface_hub.snapshot_download(优先)
try:
outdir = snapshot_download(
repo_id=MODEL_REPO,
repo_type="model",
local_dir=MODEL_DIR,
token=HF_TOKEN,
allow_patterns=[MODEL_FILE] # 仅抓取我们需要的文件
)
# snapshot_download 有时会返回 download 目录;搜索目标文件
found = None
for root, _, files in os.walk(outdir):
if MODEL_FILE in files:
found = os.path.join(root, MODEL_FILE)
break
if found:
# 如果找到了,把它移动到 MODEL_PATH(若已在正确位置则跳过)
if os.path.abspath(found) != os.path.abspath(MODEL_PATH):
print(f"Found model at {found}, moving to {MODEL_PATH}")
os.replace(found, MODEL_PATH)
print("snapshot_download succeeded.")
return
else:
print("snapshot_download did not find the file (0 files). Will try direct HTTP download as fallback.")
except Exception as e:
print("snapshot_download failed / returned nothing:", e, file=sys.stderr)
# 备用:直接构造 resolve URL 并下载
direct_url = f"https://huggingface.co/{MODEL_REPO}/resolve/main/{MODEL_FILE}"
try:
download_via_http(direct_url, MODEL_PATH, token=HF_TOKEN)
return
except Exception as e:
print("Direct HTTP download failed:", e, file=sys.stderr)
# 最后再尝试在 MODEL_DIR 下搜索一遍(保险)
for root, _, files in os.walk(MODEL_DIR):
if MODEL_FILE in files:
found = os.path.join(root, MODEL_FILE)
print(f"Found model at {found} after fallback search; moving to {MODEL_PATH}")
os.replace(found, MODEL_PATH)
return
raise RuntimeError(f"Model download finished but {MODEL_PATH} still not found. Check repo contents and network.")
# 执行下载
download_model_if_missing()
time.sleep(0.5)
# --- 导入并初始化 llama_cpp(确保模型存在) ---
try:
from llama_cpp import Llama
except Exception as e:
print("Failed to import llama_cpp. Ensure the wheel matches the runtime and required system libs are present.", file=sys.stderr)
raise
if not os.path.exists(MODEL_PATH):
raise RuntimeError(f"Model path does not exist after download: {MODEL_PATH}")
n_threads = max(1, N_THREADS)
llm = Llama(model_path=MODEL_PATH, n_ctx=4096, n_threads=n_threads)
# --- gradio 接口 ---
import gradio as gr
system_prompt = """You are given a problem.
Think about the problem and provide your working out.
Place it between <start_working_out> and <end_working_out>.
Then, provide your solution between <SOLUTION></SOLUTION>"""
def chat(user_input):
try:
prompt = system_prompt + "\n\nUser input: " + user_input + " <start_working_out>"
response = llm(prompt, max_tokens=1024, temperature=0.7) # 改为1024
return response["choices"][0]["text"]
except Exception as e:
err_msg = f"Error while generating: {e}"
print(err_msg, file=sys.stderr)
return err_msg
with gr.Blocks() as demo:
gr.Markdown("# 🏥 Qwen3 Medical GGUF Demo")
# 添加警告和说明信息
with gr.Row():
with gr.Column():
gr.HTML("""
<div style="background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 5px; padding: 15px; margin-bottom: 20px;">
<h3 style="color: #856404; margin-top: 0;">⚠️ Performance Notice</h3>
<p style="color: #856404; margin-bottom: 10px;">
This demo runs the <strong>lastmass/Qwen3_Medical_GRPO</strong> model (Q4_K_M quantized version)
on Hugging Face's free CPU hardware. Inference is <strong>very slow</strong>.
</p>
<p style="color: #856404; margin-bottom: 0;">
For better performance, we recommend running inference <strong>locally</strong> with GPU acceleration.
Please refer to the <a href="https://huggingface.co/lastmass/Qwen3_Medical_GRPO" target="_blank">model repository</a>
for usage instructions. For optimal performance, use <strong>vLLM</strong> for inference.
</p>
</div>
""")
with gr.Row():
with gr.Column():
input_box = gr.Textbox(label="Input your question", placeholder="Please enter your medical question...")
submit_btn = gr.Button("Generate Response")
with gr.Column():
output_box = gr.Textbox(label="Model Response", lines=10)
submit_btn.click(fn=chat, inputs=input_box, outputs=output_box)
demo.launch(server_name="0.0.0.0", server_port=7860) |