Felladrin commited on
Commit
db5b70b
·
verified ·
1 Parent(s): d8ee264

Add trust remote code option and improve HF Token handling in model conversion

Browse files
Files changed (1) hide show
  1. app.py +64 -44
app.py CHANGED
@@ -2,9 +2,12 @@ import logging
2
  import os
3
  import subprocess
4
  import sys
 
 
 
5
  from dataclasses import dataclass
6
  from pathlib import Path
7
- from typing import Optional, Tuple
8
  from urllib.request import urlopen, urlretrieve
9
 
10
  import streamlit as st
@@ -20,6 +23,7 @@ class Config:
20
 
21
  hf_token: str
22
  hf_username: str
 
23
  transformers_version: str = "3.5.0"
24
  hf_base_url: str = "https://huggingface.co"
25
  transformers_base_url: str = (
@@ -32,18 +36,32 @@ class Config:
32
  """Create config from environment variables and secrets."""
33
  system_token = st.secrets.get("HF_TOKEN")
34
  user_token = st.session_state.get("user_hf_token")
 
35
  if user_token:
36
  hf_username = whoami(token=user_token)["name"]
37
  else:
38
  hf_username = (
39
  os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"]
40
  )
 
41
  hf_token = user_token or system_token
42
 
43
  if not hf_token:
44
- raise ValueError("HF_TOKEN must be set")
 
 
45
 
46
- return cls(hf_token=hf_token, hf_username=hf_username)
 
 
 
 
 
 
 
 
 
 
47
 
48
 
49
  class ModelConverter:
@@ -82,9 +100,6 @@ class ModelConverter:
82
 
83
  def _extract_archive(self, archive_path: Path) -> None:
84
  """Extract the downloaded archive."""
85
- import tarfile
86
- import tempfile
87
-
88
  with tempfile.TemporaryDirectory() as tmp_dir:
89
  with tarfile.open(archive_path, "r:gz") as tar:
90
  tar.extractall(tmp_dir)
@@ -92,43 +107,46 @@ class ModelConverter:
92
  extracted_folder = next(Path(tmp_dir).iterdir())
93
  extracted_folder.rename(self.config.repo_path)
94
 
95
- def convert_model(self, input_model_id: str, trust_remote_code=False) -> Tuple[bool, Optional[str]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  """Convert the model to ONNX format."""
97
  try:
98
  if trust_remote_code:
99
- if st.session_state.get("user_hf_token") != "":
100
- result = subprocess.run(
101
- [
102
- sys.executable,
103
- "-m",
104
- "scripts.convert",
105
- "--quantize",
106
- "--trust_remote_code",
107
- "--model_id",
108
- input_model_id,
109
- ],
110
- cwd=self.config.repo_path,
111
- capture_output=True,
112
- text=True,
113
- env={},
114
  )
115
- else:
116
- raise Exception("Trust Remote Code requires your own HuggingFace token.")
117
- else:
118
- result = subprocess.run(
119
- [
120
- sys.executable,
121
- "-m",
122
- "scripts.convert",
123
- "--quantize",
124
- "--model_id",
125
- input_model_id,
126
- ],
127
- cwd=self.config.repo_path,
128
- capture_output=True,
129
- text=True,
130
- env={},
131
  )
 
 
132
 
133
  if result.returncode != 0:
134
  return False, result.stderr
@@ -158,8 +176,6 @@ class ModelConverter:
158
  except Exception as e:
159
  return str(e)
160
  finally:
161
- import shutil
162
-
163
  shutil.rmtree(model_folder_path, ignore_errors=True)
164
 
165
  def generate_readme(self, imi: str):
@@ -197,9 +213,11 @@ def main():
197
  type="password",
198
  key="user_hf_token",
199
  )
200
- trust_remote_code = st.toggle("Trust Remote Code?")
201
  if trust_remote_code:
202
- st.warning("Remote code could be used for malicious purposes. Make sure you trust the code fully. You must use your own Hugging Face write token.")
 
 
203
 
204
  if config.hf_username == input_model_id.split("/")[0]:
205
  same_repo = st.checkbox(
@@ -229,7 +247,9 @@ def main():
229
  return
230
 
231
  with st.spinner("Converting model..."):
232
- success, stderr = converter.convert_model(input_model_id, trust_remote_code=trust_remote_code)
 
 
233
  if not success:
234
  st.error(f"Conversion failed: {stderr}")
235
  return
@@ -253,4 +273,4 @@ def main():
253
 
254
 
255
  if __name__ == "__main__":
256
- main()
 
2
  import os
3
  import subprocess
4
  import sys
5
+ import tempfile
6
+ import tarfile
7
+ import shutil
8
  from dataclasses import dataclass
9
  from pathlib import Path
10
+ from typing import Dict, List, Optional, Tuple
11
  from urllib.request import urlopen, urlretrieve
12
 
13
  import streamlit as st
 
23
 
24
  hf_token: str
25
  hf_username: str
26
+ is_using_user_token: bool
27
  transformers_version: str = "3.5.0"
28
  hf_base_url: str = "https://huggingface.co"
29
  transformers_base_url: str = (
 
36
  """Create config from environment variables and secrets."""
37
  system_token = st.secrets.get("HF_TOKEN")
38
  user_token = st.session_state.get("user_hf_token")
39
+
40
  if user_token:
41
  hf_username = whoami(token=user_token)["name"]
42
  else:
43
  hf_username = (
44
  os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"]
45
  )
46
+
47
  hf_token = user_token or system_token
48
 
49
  if not hf_token:
50
+ raise ValueError(
51
+ "When the user token is not provided, the system token must be set."
52
+ )
53
 
54
+ return cls(
55
+ hf_token=hf_token,
56
+ hf_username=hf_username,
57
+ is_using_user_token=bool(user_token),
58
+ )
59
+
60
+ def get_env_vars(self) -> Dict[str, str]:
61
+ """Get environment variables with HF_TOKEN set appropriately."""
62
+ env_vars = os.environ.copy()
63
+ env_vars["HF_TOKEN"] = self.hf_token
64
+ return env_vars
65
 
66
 
67
  class ModelConverter:
 
100
 
101
  def _extract_archive(self, archive_path: Path) -> None:
102
  """Extract the downloaded archive."""
 
 
 
103
  with tempfile.TemporaryDirectory() as tmp_dir:
104
  with tarfile.open(archive_path, "r:gz") as tar:
105
  tar.extractall(tmp_dir)
 
107
  extracted_folder = next(Path(tmp_dir).iterdir())
108
  extracted_folder.rename(self.config.repo_path)
109
 
110
+ def _run_conversion_subprocess(
111
+ self, input_model_id: str, extra_args: List[str] = None
112
+ ) -> subprocess.CompletedProcess:
113
+ """Run the conversion subprocess with the given arguments."""
114
+ cmd = [
115
+ sys.executable,
116
+ "-m",
117
+ "scripts.convert",
118
+ "--quantize",
119
+ "--model_id",
120
+ input_model_id,
121
+ ]
122
+
123
+ if extra_args:
124
+ cmd.extend(extra_args)
125
+
126
+ return subprocess.run(
127
+ cmd,
128
+ cwd=self.config.repo_path,
129
+ capture_output=True,
130
+ text=True,
131
+ env=self.config.get_env_vars(),
132
+ )
133
+
134
+ def convert_model(
135
+ self, input_model_id: str, trust_remote_code=False
136
+ ) -> Tuple[bool, Optional[str]]:
137
  """Convert the model to ONNX format."""
138
  try:
139
  if trust_remote_code:
140
+ if not self.config.is_using_user_token:
141
+ raise Exception(
142
+ "Trust Remote Code requires your own HuggingFace token."
 
 
 
 
 
 
 
 
 
 
 
 
143
  )
144
+
145
+ result = self._run_conversion_subprocess(
146
+ input_model_id, extra_args=["--trust_remote_code"]
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  )
148
+ else:
149
+ result = self._run_conversion_subprocess(input_model_id)
150
 
151
  if result.returncode != 0:
152
  return False, result.stderr
 
176
  except Exception as e:
177
  return str(e)
178
  finally:
 
 
179
  shutil.rmtree(model_folder_path, ignore_errors=True)
180
 
181
  def generate_readme(self, imi: str):
 
213
  type="password",
214
  key="user_hf_token",
215
  )
216
+ trust_remote_code = st.toggle("Optional: Trust Remote Code.")
217
  if trust_remote_code:
218
+ st.warning(
219
+ "This option should only be enabled for repositories you trust and in which you have read the code, as it will execute arbitrary code present in the model repository. When this option is enabled, you must use your own Hugging Face write token."
220
+ )
221
 
222
  if config.hf_username == input_model_id.split("/")[0]:
223
  same_repo = st.checkbox(
 
247
  return
248
 
249
  with st.spinner("Converting model..."):
250
+ success, stderr = converter.convert_model(
251
+ input_model_id, trust_remote_code=trust_remote_code
252
+ )
253
  if not success:
254
  st.error(f"Conversion failed: {stderr}")
255
  return
 
273
 
274
 
275
  if __name__ == "__main__":
276
+ main()