File size: 6,738 Bytes
1fc7ac0
e915b53
 
8c137b5
e915b53
 
 
798a275
e915b53
 
 
8c137b5
e915b53
 
 
 
8c137b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e915b53
 
 
 
 
8c137b5
 
e915b53
8c137b5
e915b53
 
 
 
8c137b5
e915b53
8c137b5
e915b53
8c137b5
e915b53
 
 
 
8c137b5
 
 
 
 
 
e915b53
8c137b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fc7ac0
8c137b5
e915b53
8c137b5
e915b53
8c137b5
e915b53
 
 
8c137b5
e915b53
 
 
 
 
 
 
 
8c137b5
e915b53
1fc7ac0
 
 
 
 
 
 
e915b53
 
798a275
e915b53
 
 
 
 
1fc7ac0
 
798a275
 
 
1fc7ac0
 
798a275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fc7ac0
798a275
 
1fc7ac0
 
798a275
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import sys
import time
import urllib.request
from huggingface_hub import snapshot_download

# --- 配置(可通过环境变量覆盖) ---
MODEL_REPO = os.getenv("MODEL_REPO", "lastmass/Qwen3_Medical_GRPO")
MODEL_FILE = os.getenv("MODEL_FILE", "Qwen3_Medical_GRPO.i1-Q4_K_M.gguf")
MODEL_DIR = os.getenv("MODEL_DIR", "/models")
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_FILE)
HF_TOKEN = os.getenv("HF_TOKEN", None)  # 如果模型是私有的,请在 Space Secrets 设置
N_THREADS = int(os.getenv("N_THREADS", "8"))

os.makedirs(MODEL_DIR, exist_ok=True)

def download_via_http(url, dest_path, token=None, chunk_size=4*1024*1024):
    """使用 urllib 分块下载,支持 token(Bearer)"""
    print(f"Downloading via HTTP: {url} -> {dest_path}")
    req = urllib.request.Request(url)
    if token:
        req.add_header("Authorization", f"Bearer {token}")
    try:
        with urllib.request.urlopen(req, timeout=120) as resp:
            # 如果状态不是 200,抛错
            if resp.status not in (200, 302, 301):
                raise RuntimeError(f"HTTP download returned status {resp.status}")
            # 写入临时文件,下载完成后重命名
            tmp_dest = dest_path + ".part"
            with open(tmp_dest, "wb") as fh:
                while True:
                    chunk = resp.read(chunk_size)
                    if not chunk:
                        break
                    fh.write(chunk)
            os.replace(tmp_dest, dest_path)
            print("HTTP download finished.")
    except Exception as e:
        if os.path.exists(dest_path):
            os.remove(dest_path)
        raise

def download_model_if_missing():
    if os.path.exists(MODEL_PATH):
        print(f"Model already exists at {MODEL_PATH}")
        return

    print(f"Model not found at {MODEL_PATH}. Trying snapshot_download from {MODEL_REPO} ...")
    # 先尝试使用 huggingface_hub.snapshot_download(优先)
    try:
        outdir = snapshot_download(
            repo_id=MODEL_REPO,
            repo_type="model",
            local_dir=MODEL_DIR,
            token=HF_TOKEN,
            allow_patterns=[MODEL_FILE]  # 仅抓取我们需要的文件
        )
        # snapshot_download 有时会返回 download 目录;搜索目标文件
        found = None
        for root, _, files in os.walk(outdir):
            if MODEL_FILE in files:
                found = os.path.join(root, MODEL_FILE)
                break
        if found:
            # 如果找到了,把它移动到 MODEL_PATH(若已在正确位置则跳过)
            if os.path.abspath(found) != os.path.abspath(MODEL_PATH):
                print(f"Found model at {found}, moving to {MODEL_PATH}")
                os.replace(found, MODEL_PATH)
            print("snapshot_download succeeded.")
            return
        else:
            print("snapshot_download did not find the file (0 files). Will try direct HTTP download as fallback.")
    except Exception as e:
        print("snapshot_download failed / returned nothing:", e, file=sys.stderr)

    # 备用:直接构造 resolve URL 并下载
    direct_url = f"https://huggingface.co/{MODEL_REPO}/resolve/main/{MODEL_FILE}"
    try:
        download_via_http(direct_url, MODEL_PATH, token=HF_TOKEN)
        return
    except Exception as e:
        print("Direct HTTP download failed:", e, file=sys.stderr)
        # 最后再尝试在 MODEL_DIR 下搜索一遍(保险)
        for root, _, files in os.walk(MODEL_DIR):
            if MODEL_FILE in files:
                found = os.path.join(root, MODEL_FILE)
                print(f"Found model at {found} after fallback search; moving to {MODEL_PATH}")
                os.replace(found, MODEL_PATH)
                return
        raise RuntimeError(f"Model download finished but {MODEL_PATH} still not found. Check repo contents and network.")

# 执行下载
download_model_if_missing()
time.sleep(0.5)

# --- 导入并初始化 llama_cpp(确保模型存在) ---
try:
    from llama_cpp import Llama
except Exception as e:
    print("Failed to import llama_cpp. Ensure the wheel matches the runtime and required system libs are present.", file=sys.stderr)
    raise

if not os.path.exists(MODEL_PATH):
    raise RuntimeError(f"Model path does not exist after download: {MODEL_PATH}")

n_threads = max(1, N_THREADS)
llm = Llama(model_path=MODEL_PATH, n_ctx=4096, n_threads=n_threads)

# --- gradio 接口 ---
import gradio as gr

system_prompt = """You are given a problem.
Think about the problem and provide your working out.
Place it between <start_working_out> and <end_working_out>.
Then, provide your solution between <SOLUTION></SOLUTION>"""

def chat(user_input):
    try:
        prompt = system_prompt + "\n\nUser input: " + user_input + " <start_working_out>"
        response = llm(prompt, max_tokens=1024, temperature=0.7)  # 改为1024
        return response["choices"][0]["text"]
    except Exception as e:
        err_msg = f"Error while generating: {e}"
        print(err_msg, file=sys.stderr)
        return err_msg

with gr.Blocks() as demo:
    gr.Markdown("# 🏥 Qwen3 Medical GGUF Demo")
    
    # 添加警告和说明信息
    with gr.Row():
        with gr.Column():
            gr.HTML("""
            <div style="background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 5px; padding: 15px; margin-bottom: 20px;">
                <h3 style="color: #856404; margin-top: 0;">⚠️ Performance Notice</h3>
                <p style="color: #856404; margin-bottom: 10px;">
                    This demo runs the <strong>lastmass/Qwen3_Medical_GRPO</strong> model (Q4_K_M quantized version) 
                    on Hugging Face's free CPU hardware. Inference is <strong>very slow</strong>.
                </p>
                <p style="color: #856404; margin-bottom: 0;">
                    For better performance, we recommend running inference <strong>locally</strong> with GPU acceleration. 
                    Please refer to the <a href="https://huggingface.co/lastmass/Qwen3_Medical_GRPO" target="_blank">model repository</a> 
                    for usage instructions. For optimal performance, use <strong>vLLM</strong> for inference.
                </p>
            </div>
            """)
    
    with gr.Row():
        with gr.Column():
            input_box = gr.Textbox(label="Input your question", placeholder="Please enter your medical question...")
            submit_btn = gr.Button("Generate Response")
        with gr.Column():
            output_box = gr.Textbox(label="Model Response", lines=10)
    
    submit_btn.click(fn=chat, inputs=input_box, outputs=output_box)

demo.launch(server_name="0.0.0.0", server_port=7860)