ngxson HF staff commited on
Commit
a9bf2b7
·
verified ·
1 Parent(s): af5ee17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -3
app.py CHANGED
@@ -5,6 +5,7 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
5
  import gradio as gr
6
  import tempfile
7
  import torch
 
8
 
9
  from huggingface_hub import HfApi, ModelCard, whoami
10
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
@@ -74,17 +75,32 @@ def run_command(command):
74
 
75
  ###########
76
 
 
 
 
 
 
 
 
 
77
 
78
  def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo, oauth_token: gr.OAuthToken | None):
79
  if oauth_token is None or oauth_token.token is None:
80
  raise gr.Error("You must be logged in")
81
  model_name = ft_model_id.split('/')[-1]
82
 
 
 
 
83
  if not os.path.exists("outputs"):
84
  os.makedirs("outputs")
85
 
86
  try:
87
  api = HfApi(token=oauth_token.token)
 
 
 
 
88
 
89
  with tempfile.TemporaryDirectory(dir="outputs") as outputdir:
90
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -102,7 +118,7 @@ def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo,
102
  print("output_stdout", output_stdout)
103
  print("output_stderr", output_stderr)
104
  if returncode != 0:
105
- raise Exception(f"Error converting to LoRA PEFT {q_method}: {output_stderr}")
106
  print("Model converted to LoRA PEFT successfully!")
107
  print(f"Converted model path: {outputdir}")
108
 
@@ -146,8 +162,8 @@ with gr.Blocks(css=css) as demo:
146
  )
147
 
148
  base_model_id = HuggingfaceHubSearch(
149
- label="Base model repository",
150
- placeholder="Base model",
151
  search_type="model",
152
  )
153
 
 
5
  import gradio as gr
6
  import tempfile
7
  import torch
8
+ import requests
9
 
10
  from huggingface_hub import HfApi, ModelCard, whoami
11
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
 
75
 
76
  ###########
77
 
78
+ def guess_base_model(ft_model_id):
79
+ res = requests.get(f"https://huggingface.co/api/models/{ft_model_id}")
80
+ res = res.json()
81
+ for tag in res["tags"]:
82
+ if tag.startswith("base_model:"):
83
+ return tag.split(":")[-1]
84
+ raise Exception("Cannot guess the base model, please enter it manually")
85
+
86
 
87
  def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo, oauth_token: gr.OAuthToken | None):
88
  if oauth_token is None or oauth_token.token is None:
89
  raise gr.Error("You must be logged in")
90
  model_name = ft_model_id.split('/')[-1]
91
 
92
+ # validate the oauth token
93
+ whoami(oauth_token.token)
94
+
95
  if not os.path.exists("outputs"):
96
  os.makedirs("outputs")
97
 
98
  try:
99
  api = HfApi(token=oauth_token.token)
100
+
101
+ if not base_model_id:
102
+ base_model_id = guess_base_model(ft_model_id)
103
+ print("guess_base_model", base_model_id)
104
 
105
  with tempfile.TemporaryDirectory(dir="outputs") as outputdir:
106
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
118
  print("output_stdout", output_stdout)
119
  print("output_stderr", output_stderr)
120
  if returncode != 0:
121
+ raise Exception(f"Error converting to LoRA PEFT {output_stderr}")
122
  print("Model converted to LoRA PEFT successfully!")
123
  print(f"Converted model path: {outputdir}")
124
 
 
162
  )
163
 
164
  base_model_id = HuggingfaceHubSearch(
165
+ label="Base model repository (optional)",
166
+ placeholder="If empty, it will be guessed from repo tags",
167
  search_type="model",
168
  )
169