Add trust remote code option and improve HF Token handling in model conversion
Browse files
app.py
CHANGED
@@ -2,9 +2,12 @@ import logging
|
|
2 |
import os
|
3 |
import subprocess
|
4 |
import sys
|
|
|
|
|
|
|
5 |
from dataclasses import dataclass
|
6 |
from pathlib import Path
|
7 |
-
from typing import Optional, Tuple
|
8 |
from urllib.request import urlopen, urlretrieve
|
9 |
|
10 |
import streamlit as st
|
@@ -20,6 +23,7 @@ class Config:
|
|
20 |
|
21 |
hf_token: str
|
22 |
hf_username: str
|
|
|
23 |
transformers_version: str = "3.5.0"
|
24 |
hf_base_url: str = "https://huggingface.co"
|
25 |
transformers_base_url: str = (
|
@@ -32,18 +36,32 @@ class Config:
|
|
32 |
"""Create config from environment variables and secrets."""
|
33 |
system_token = st.secrets.get("HF_TOKEN")
|
34 |
user_token = st.session_state.get("user_hf_token")
|
|
|
35 |
if user_token:
|
36 |
hf_username = whoami(token=user_token)["name"]
|
37 |
else:
|
38 |
hf_username = (
|
39 |
os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"]
|
40 |
)
|
|
|
41 |
hf_token = user_token or system_token
|
42 |
|
43 |
if not hf_token:
|
44 |
-
raise ValueError(
|
|
|
|
|
45 |
|
46 |
-
return cls(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
|
49 |
class ModelConverter:
|
@@ -82,9 +100,6 @@ class ModelConverter:
|
|
82 |
|
83 |
def _extract_archive(self, archive_path: Path) -> None:
|
84 |
"""Extract the downloaded archive."""
|
85 |
-
import tarfile
|
86 |
-
import tempfile
|
87 |
-
|
88 |
with tempfile.TemporaryDirectory() as tmp_dir:
|
89 |
with tarfile.open(archive_path, "r:gz") as tar:
|
90 |
tar.extractall(tmp_dir)
|
@@ -92,43 +107,46 @@ class ModelConverter:
|
|
92 |
extracted_folder = next(Path(tmp_dir).iterdir())
|
93 |
extracted_folder.rename(self.config.repo_path)
|
94 |
|
95 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
"""Convert the model to ONNX format."""
|
97 |
try:
|
98 |
if trust_remote_code:
|
99 |
-
if
|
100 |
-
|
101 |
-
|
102 |
-
sys.executable,
|
103 |
-
"-m",
|
104 |
-
"scripts.convert",
|
105 |
-
"--quantize",
|
106 |
-
"--trust_remote_code",
|
107 |
-
"--model_id",
|
108 |
-
input_model_id,
|
109 |
-
],
|
110 |
-
cwd=self.config.repo_path,
|
111 |
-
capture_output=True,
|
112 |
-
text=True,
|
113 |
-
env={},
|
114 |
)
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
result = subprocess.run(
|
119 |
-
[
|
120 |
-
sys.executable,
|
121 |
-
"-m",
|
122 |
-
"scripts.convert",
|
123 |
-
"--quantize",
|
124 |
-
"--model_id",
|
125 |
-
input_model_id,
|
126 |
-
],
|
127 |
-
cwd=self.config.repo_path,
|
128 |
-
capture_output=True,
|
129 |
-
text=True,
|
130 |
-
env={},
|
131 |
)
|
|
|
|
|
132 |
|
133 |
if result.returncode != 0:
|
134 |
return False, result.stderr
|
@@ -158,8 +176,6 @@ class ModelConverter:
|
|
158 |
except Exception as e:
|
159 |
return str(e)
|
160 |
finally:
|
161 |
-
import shutil
|
162 |
-
|
163 |
shutil.rmtree(model_folder_path, ignore_errors=True)
|
164 |
|
165 |
def generate_readme(self, imi: str):
|
@@ -197,9 +213,11 @@ def main():
|
|
197 |
type="password",
|
198 |
key="user_hf_token",
|
199 |
)
|
200 |
-
trust_remote_code = st.toggle("Trust Remote Code
|
201 |
if trust_remote_code:
|
202 |
-
st.warning(
|
|
|
|
|
203 |
|
204 |
if config.hf_username == input_model_id.split("/")[0]:
|
205 |
same_repo = st.checkbox(
|
@@ -229,7 +247,9 @@ def main():
|
|
229 |
return
|
230 |
|
231 |
with st.spinner("Converting model..."):
|
232 |
-
success, stderr = converter.convert_model(
|
|
|
|
|
233 |
if not success:
|
234 |
st.error(f"Conversion failed: {stderr}")
|
235 |
return
|
@@ -253,4 +273,4 @@ def main():
|
|
253 |
|
254 |
|
255 |
if __name__ == "__main__":
|
256 |
-
main()
|
|
|
2 |
import os
|
3 |
import subprocess
|
4 |
import sys
|
5 |
+
import tempfile
|
6 |
+
import tarfile
|
7 |
+
import shutil
|
8 |
from dataclasses import dataclass
|
9 |
from pathlib import Path
|
10 |
+
from typing import Dict, List, Optional, Tuple
|
11 |
from urllib.request import urlopen, urlretrieve
|
12 |
|
13 |
import streamlit as st
|
|
|
23 |
|
24 |
hf_token: str
|
25 |
hf_username: str
|
26 |
+
is_using_user_token: bool
|
27 |
transformers_version: str = "3.5.0"
|
28 |
hf_base_url: str = "https://huggingface.co"
|
29 |
transformers_base_url: str = (
|
|
|
36 |
"""Create config from environment variables and secrets."""
|
37 |
system_token = st.secrets.get("HF_TOKEN")
|
38 |
user_token = st.session_state.get("user_hf_token")
|
39 |
+
|
40 |
if user_token:
|
41 |
hf_username = whoami(token=user_token)["name"]
|
42 |
else:
|
43 |
hf_username = (
|
44 |
os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"]
|
45 |
)
|
46 |
+
|
47 |
hf_token = user_token or system_token
|
48 |
|
49 |
if not hf_token:
|
50 |
+
raise ValueError(
|
51 |
+
"When the user token is not provided, the system token must be set."
|
52 |
+
)
|
53 |
|
54 |
+
return cls(
|
55 |
+
hf_token=hf_token,
|
56 |
+
hf_username=hf_username,
|
57 |
+
is_using_user_token=bool(user_token),
|
58 |
+
)
|
59 |
+
|
60 |
+
def get_env_vars(self) -> Dict[str, str]:
|
61 |
+
"""Get environment variables with HF_TOKEN set appropriately."""
|
62 |
+
env_vars = os.environ.copy()
|
63 |
+
env_vars["HF_TOKEN"] = self.hf_token
|
64 |
+
return env_vars
|
65 |
|
66 |
|
67 |
class ModelConverter:
|
|
|
100 |
|
101 |
def _extract_archive(self, archive_path: Path) -> None:
|
102 |
"""Extract the downloaded archive."""
|
|
|
|
|
|
|
103 |
with tempfile.TemporaryDirectory() as tmp_dir:
|
104 |
with tarfile.open(archive_path, "r:gz") as tar:
|
105 |
tar.extractall(tmp_dir)
|
|
|
107 |
extracted_folder = next(Path(tmp_dir).iterdir())
|
108 |
extracted_folder.rename(self.config.repo_path)
|
109 |
|
110 |
+
def _run_conversion_subprocess(
|
111 |
+
self, input_model_id: str, extra_args: List[str] = None
|
112 |
+
) -> subprocess.CompletedProcess:
|
113 |
+
"""Run the conversion subprocess with the given arguments."""
|
114 |
+
cmd = [
|
115 |
+
sys.executable,
|
116 |
+
"-m",
|
117 |
+
"scripts.convert",
|
118 |
+
"--quantize",
|
119 |
+
"--model_id",
|
120 |
+
input_model_id,
|
121 |
+
]
|
122 |
+
|
123 |
+
if extra_args:
|
124 |
+
cmd.extend(extra_args)
|
125 |
+
|
126 |
+
return subprocess.run(
|
127 |
+
cmd,
|
128 |
+
cwd=self.config.repo_path,
|
129 |
+
capture_output=True,
|
130 |
+
text=True,
|
131 |
+
env=self.config.get_env_vars(),
|
132 |
+
)
|
133 |
+
|
134 |
+
def convert_model(
|
135 |
+
self, input_model_id: str, trust_remote_code=False
|
136 |
+
) -> Tuple[bool, Optional[str]]:
|
137 |
"""Convert the model to ONNX format."""
|
138 |
try:
|
139 |
if trust_remote_code:
|
140 |
+
if not self.config.is_using_user_token:
|
141 |
+
raise Exception(
|
142 |
+
"Trust Remote Code requires your own HuggingFace token."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
)
|
144 |
+
|
145 |
+
result = self._run_conversion_subprocess(
|
146 |
+
input_model_id, extra_args=["--trust_remote_code"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
)
|
148 |
+
else:
|
149 |
+
result = self._run_conversion_subprocess(input_model_id)
|
150 |
|
151 |
if result.returncode != 0:
|
152 |
return False, result.stderr
|
|
|
176 |
except Exception as e:
|
177 |
return str(e)
|
178 |
finally:
|
|
|
|
|
179 |
shutil.rmtree(model_folder_path, ignore_errors=True)
|
180 |
|
181 |
def generate_readme(self, imi: str):
|
|
|
213 |
type="password",
|
214 |
key="user_hf_token",
|
215 |
)
|
216 |
+
trust_remote_code = st.toggle("Optional: Trust Remote Code.")
|
217 |
if trust_remote_code:
|
218 |
+
st.warning(
|
219 |
+
"This option should only be enabled for repositories you trust and in which you have read the code, as it will execute arbitrary code present in the model repository. When this option is enabled, you must use your own Hugging Face write token."
|
220 |
+
)
|
221 |
|
222 |
if config.hf_username == input_model_id.split("/")[0]:
|
223 |
same_repo = st.checkbox(
|
|
|
247 |
return
|
248 |
|
249 |
with st.spinner("Converting model..."):
|
250 |
+
success, stderr = converter.convert_model(
|
251 |
+
input_model_id, trust_remote_code=trust_remote_code
|
252 |
+
)
|
253 |
if not success:
|
254 |
st.error(f"Conversion failed: {stderr}")
|
255 |
return
|
|
|
273 |
|
274 |
|
275 |
if __name__ == "__main__":
|
276 |
+
main()
|