Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,7 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
|
5 |
import gradio as gr
|
6 |
import tempfile
|
7 |
import torch
|
|
|
8 |
|
9 |
from huggingface_hub import HfApi, ModelCard, whoami
|
10 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
@@ -74,17 +75,32 @@ def run_command(command):
|
|
74 |
|
75 |
###########
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo, oauth_token: gr.OAuthToken | None):
|
79 |
if oauth_token is None or oauth_token.token is None:
|
80 |
raise gr.Error("You must be logged in")
|
81 |
model_name = ft_model_id.split('/')[-1]
|
82 |
|
|
|
|
|
|
|
83 |
if not os.path.exists("outputs"):
|
84 |
os.makedirs("outputs")
|
85 |
|
86 |
try:
|
87 |
api = HfApi(token=oauth_token.token)
|
|
|
|
|
|
|
|
|
88 |
|
89 |
with tempfile.TemporaryDirectory(dir="outputs") as outputdir:
|
90 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -102,7 +118,7 @@ def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo,
|
|
102 |
print("output_stdout", output_stdout)
|
103 |
print("output_stderr", output_stderr)
|
104 |
if returncode != 0:
|
105 |
-
raise Exception(f"Error converting to LoRA PEFT {
|
106 |
print("Model converted to LoRA PEFT successfully!")
|
107 |
print(f"Converted model path: {outputdir}")
|
108 |
|
@@ -146,8 +162,8 @@ with gr.Blocks(css=css) as demo:
|
|
146 |
)
|
147 |
|
148 |
base_model_id = HuggingfaceHubSearch(
|
149 |
-
label="Base model repository",
|
150 |
-
placeholder="
|
151 |
search_type="model",
|
152 |
)
|
153 |
|
|
|
5 |
import gradio as gr
|
6 |
import tempfile
|
7 |
import torch
|
8 |
+
import requests
|
9 |
|
10 |
from huggingface_hub import HfApi, ModelCard, whoami
|
11 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
|
|
75 |
|
76 |
###########
|
77 |
|
78 |
+
def guess_base_model(ft_model_id):
|
79 |
+
res = requests.get(f"https://huggingface.co/api/models/{ft_model_id}")
|
80 |
+
res = res.json()
|
81 |
+
for tag in res["tags"]:
|
82 |
+
if tag.startswith("base_model:"):
|
83 |
+
return tag.split(":")[-1]
|
84 |
+
raise Exception("Cannot guess the base model, please enter it manually")
|
85 |
+
|
86 |
|
87 |
def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo, oauth_token: gr.OAuthToken | None):
|
88 |
if oauth_token is None or oauth_token.token is None:
|
89 |
raise gr.Error("You must be logged in")
|
90 |
model_name = ft_model_id.split('/')[-1]
|
91 |
|
92 |
+
# validate the oauth token
|
93 |
+
whoami(oauth_token.token)
|
94 |
+
|
95 |
if not os.path.exists("outputs"):
|
96 |
os.makedirs("outputs")
|
97 |
|
98 |
try:
|
99 |
api = HfApi(token=oauth_token.token)
|
100 |
+
|
101 |
+
if not base_model_id:
|
102 |
+
base_model_id = guess_base_model(ft_model_id)
|
103 |
+
print("guess_base_model", base_model_id)
|
104 |
|
105 |
with tempfile.TemporaryDirectory(dir="outputs") as outputdir:
|
106 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
118 |
print("output_stdout", output_stdout)
|
119 |
print("output_stderr", output_stderr)
|
120 |
if returncode != 0:
|
121 |
+
raise Exception(f"Error converting to LoRA PEFT {output_stderr}")
|
122 |
print("Model converted to LoRA PEFT successfully!")
|
123 |
print(f"Converted model path: {outputdir}")
|
124 |
|
|
|
162 |
)
|
163 |
|
164 |
base_model_id = HuggingfaceHubSearch(
|
165 |
+
label="Base model repository (optional)",
|
166 |
+
placeholder="If empty, it will be guessed from repo tags",
|
167 |
search_type="model",
|
168 |
)
|
169 |
|