lastmass commited on
Commit
8c137b5
·
verified ·
1 Parent(s): 6f1355a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -27
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import sys
3
  import time
 
4
  from huggingface_hub import snapshot_download
5
 
6
  # --- 配置(可通过环境变量覆盖) ---
@@ -8,67 +9,104 @@ MODEL_REPO = os.getenv("MODEL_REPO", "mradermacher/Qwen3_Medical_GRPO-i1-GGUF")
8
  MODEL_FILE = os.getenv("MODEL_FILE", "Qwen3_Medical_GRPO.i1-Q4_K_M.gguf")
9
  MODEL_DIR = os.getenv("MODEL_DIR", "/models")
10
  MODEL_PATH = os.path.join(MODEL_DIR, MODEL_FILE)
11
- HF_TOKEN = os.getenv("HF_TOKEN", None) # 如果模型是私有的,需要在 Spaces Secret 中设置这个值
12
- # 可选线程设置(不设置则默认 8)
13
  N_THREADS = int(os.getenv("N_THREADS", "8"))
14
 
15
- # --- 确保模型文件存在:若不存在,则从 Hugging Face Hub 下载 ---
16
  os.makedirs(MODEL_DIR, exist_ok=True)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def download_model_if_missing():
19
  if os.path.exists(MODEL_PATH):
20
  print(f"Model already exists at {MODEL_PATH}")
21
  return
22
 
23
- print(f"Model not found at {MODEL_PATH}. Attempting to download from {MODEL_REPO} ...")
 
24
  try:
25
- # snapshot_download 会把仓库内容下载到 MODEL_DIR;allow_patterns 只抓我们需要的文件
26
- snapshot_download(
27
  repo_id=MODEL_REPO,
28
  repo_type="model",
29
  local_dir=MODEL_DIR,
30
  token=HF_TOKEN,
31
- allow_patterns=[MODEL_FILE],
32
- ignore_patterns=["*"] # 先默认忽略所有,allow_patterns 会覆盖需要的
33
  )
34
- except Exception as e:
35
- print("Error while trying to download the model:", e, file=sys.stderr)
36
- print("If the model is private, make sure HF_TOKEN is set in Space Secrets and has read access.", file=sys.stderr)
37
- raise
38
-
39
- # 等待短时间让文件系统稳定(可选)
40
- time.sleep(1)
41
-
42
- if not os.path.exists(MODEL_PATH):
43
- # 有时 snapshot_download 会把文件放在子目录,尝试在 MODEL_DIR 下搜索
44
  found = None
45
- for root, dirs, files in os.walk(MODEL_DIR):
46
  if MODEL_FILE in files:
47
  found = os.path.join(root, MODEL_FILE)
48
  break
49
  if found:
50
- print(f"Found model at {found}; moving to {MODEL_PATH}")
51
- os.replace(found, MODEL_PATH)
 
 
 
 
52
  else:
53
- raise RuntimeError(f"Model download finished but {MODEL_PATH} still not found. Check repo contents.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
 
55
  download_model_if_missing()
 
56
 
57
- # --- 现在再导入并初始化 llama_cpp(确保模型已存在) ---
58
  try:
59
  from llama_cpp import Llama
60
  except Exception as e:
61
- print("Failed to import llama_cpp. Ensure the wheel you installed matches the runtime (musl vs glibc) and required libs are present.", file=sys.stderr)
62
  raise
63
 
64
  if not os.path.exists(MODEL_PATH):
65
  raise RuntimeError(f"Model path does not exist after download: {MODEL_PATH}")
66
 
67
- # 初始化模型(给 N_THREADS 一个合理默认)
68
  n_threads = max(1, N_THREADS)
69
  llm = Llama(model_path=MODEL_PATH, n_ctx=4096, n_threads=n_threads)
70
 
71
- # --- system prompt 和 gradio 接口 ---
72
  import gradio as gr
73
 
74
  system_prompt = """You are given a problem.
@@ -82,7 +120,6 @@ def chat(user_input):
82
  response = llm(prompt, max_tokens=2048, temperature=0.7)
83
  return response["choices"][0]["text"]
84
  except Exception as e:
85
- # 捕获运行时错误并返回友好提示(也会打印到容器日志)
86
  err_msg = f"Error while generating: {e}"
87
  print(err_msg, file=sys.stderr)
88
  return err_msg
 
1
  import os
2
  import sys
3
  import time
4
+ import urllib.request
5
  from huggingface_hub import snapshot_download
6
 
7
  # --- 配置(可通过环境变量覆盖) ---
 
9
  MODEL_FILE = os.getenv("MODEL_FILE", "Qwen3_Medical_GRPO.i1-Q4_K_M.gguf")
10
  MODEL_DIR = os.getenv("MODEL_DIR", "/models")
11
  MODEL_PATH = os.path.join(MODEL_DIR, MODEL_FILE)
12
+ HF_TOKEN = os.getenv("HF_TOKEN", None) # 如果模型是私有的,请在 Space Secrets 设置
 
13
  N_THREADS = int(os.getenv("N_THREADS", "8"))
14
 
 
15
  os.makedirs(MODEL_DIR, exist_ok=True)
16
 
17
+ def download_via_http(url, dest_path, token=None, chunk_size=4*1024*1024):
18
+ """使用 urllib 分块下载,支持 token(Bearer)"""
19
+ print(f"Downloading via HTTP: {url} -> {dest_path}")
20
+ req = urllib.request.Request(url)
21
+ if token:
22
+ req.add_header("Authorization", f"Bearer {token}")
23
+ try:
24
+ with urllib.request.urlopen(req, timeout=120) as resp:
25
+ # 如果状态不是 200,抛错
26
+ if resp.status not in (200, 302, 301):
27
+ raise RuntimeError(f"HTTP download returned status {resp.status}")
28
+ # 写入临时文件,下载完成后重命名
29
+ tmp_dest = dest_path + ".part"
30
+ with open(tmp_dest, "wb") as fh:
31
+ while True:
32
+ chunk = resp.read(chunk_size)
33
+ if not chunk:
34
+ break
35
+ fh.write(chunk)
36
+ os.replace(tmp_dest, dest_path)
37
+ print("HTTP download finished.")
38
+ except Exception as e:
39
+ if os.path.exists(dest_path):
40
+ os.remove(dest_path)
41
+ raise
42
+
43
  def download_model_if_missing():
44
  if os.path.exists(MODEL_PATH):
45
  print(f"Model already exists at {MODEL_PATH}")
46
  return
47
 
48
+ print(f"Model not found at {MODEL_PATH}. Trying snapshot_download from {MODEL_REPO} ...")
49
+ # 先尝试使用 huggingface_hub.snapshot_download(优先)
50
  try:
51
+ outdir = snapshot_download(
 
52
  repo_id=MODEL_REPO,
53
  repo_type="model",
54
  local_dir=MODEL_DIR,
55
  token=HF_TOKEN,
56
+ allow_patterns=[MODEL_FILE] # 仅抓取我们需要的文件
 
57
  )
58
+ # snapshot_download 有时会返回 download 目录;搜索目标文件
 
 
 
 
 
 
 
 
 
59
  found = None
60
+ for root, _, files in os.walk(outdir):
61
  if MODEL_FILE in files:
62
  found = os.path.join(root, MODEL_FILE)
63
  break
64
  if found:
65
+ # 如果找到了,把它移动到 MODEL_PATH(若已在正确位置则跳过)
66
+ if os.path.abspath(found) != os.path.abspath(MODEL_PATH):
67
+ print(f"Found model at {found}, moving to {MODEL_PATH}")
68
+ os.replace(found, MODEL_PATH)
69
+ print("snapshot_download succeeded.")
70
+ return
71
  else:
72
+ print("snapshot_download did not find the file (0 files). Will try direct HTTP download as fallback.")
73
+ except Exception as e:
74
+ print("snapshot_download failed / returned nothing:", e, file=sys.stderr)
75
+
76
+ # 备用:直接构造 resolve URL 并下载
77
+ direct_url = f"https://huggingface.co/{MODEL_REPO}/resolve/main/{MODEL_FILE}"
78
+ try:
79
+ download_via_http(direct_url, MODEL_PATH, token=HF_TOKEN)
80
+ return
81
+ except Exception as e:
82
+ print("Direct HTTP download failed:", e, file=sys.stderr)
83
+ # 最后再尝试在 MODEL_DIR 下搜索一遍(保险)
84
+ for root, _, files in os.walk(MODEL_DIR):
85
+ if MODEL_FILE in files:
86
+ found = os.path.join(root, MODEL_FILE)
87
+ print(f"Found model at {found} after fallback search; moving to {MODEL_PATH}")
88
+ os.replace(found, MODEL_PATH)
89
+ return
90
+ raise RuntimeError(f"Model download finished but {MODEL_PATH} still not found. Check repo contents and network.")
91
 
92
+ # 执行下载
93
  download_model_if_missing()
94
+ time.sleep(0.5)
95
 
96
+ # --- 导入并初始化 llama_cpp(确保模型存在) ---
97
  try:
98
  from llama_cpp import Llama
99
  except Exception as e:
100
+ print("Failed to import llama_cpp. Ensure the wheel matches the runtime and required system libs are present.", file=sys.stderr)
101
  raise
102
 
103
  if not os.path.exists(MODEL_PATH):
104
  raise RuntimeError(f"Model path does not exist after download: {MODEL_PATH}")
105
 
 
106
  n_threads = max(1, N_THREADS)
107
  llm = Llama(model_path=MODEL_PATH, n_ctx=4096, n_threads=n_threads)
108
 
109
+ # --- gradio 接口 ---
110
  import gradio as gr
111
 
112
  system_prompt = """You are given a problem.
 
120
  response = llm(prompt, max_tokens=2048, temperature=0.7)
121
  return response["choices"][0]["text"]
122
  except Exception as e:
 
123
  err_msg = f"Error while generating: {e}"
124
  print(err_msg, file=sys.stderr)
125
  return err_msg