|
import os |
|
import subprocess |
|
import signal |
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" |
|
import gradio as gr |
|
import tempfile |
|
import torch |
|
import requests |
|
|
|
from huggingface_hub import HfApi, ModelCard, whoami |
|
from gradio_huggingfacehub_search import HuggingfaceHubSearch |
|
from pathlib import Path |
|
from textwrap import dedent |
|
|
|
|
|
|
|
|
|
import subprocess |
|
import threading |
|
from queue import Queue, Empty |
|
|
|
def stream_output(pipe, queue): |
|
"""Read output from pipe and put it in the queue.""" |
|
for line in iter(pipe.readline, b''): |
|
queue.put(line.decode('utf-8').rstrip()) |
|
pipe.close() |
|
|
|
def run_command(command, env_vars): |
|
|
|
process = subprocess.Popen( |
|
command, |
|
stdout=subprocess.PIPE, |
|
stderr=subprocess.PIPE, |
|
|
|
universal_newlines=False, |
|
env=env_vars, |
|
) |
|
|
|
|
|
stdout_queue = Queue() |
|
stderr_queue = Queue() |
|
|
|
|
|
stdout_thread = threading.Thread(target=stream_output, args=(process.stdout, stdout_queue)) |
|
stderr_thread = threading.Thread(target=stream_output, args=(process.stderr, stderr_queue)) |
|
stdout_thread.daemon = True |
|
stderr_thread.daemon = True |
|
stdout_thread.start() |
|
stderr_thread.start() |
|
|
|
output_stdout = "" |
|
output_stderr = "" |
|
|
|
while process.poll() is None: |
|
|
|
try: |
|
stdout_line = stdout_queue.get_nowait() |
|
print(f"STDOUT: {stdout_line}") |
|
output_stdout += stdout_line + "\n" |
|
except Empty: |
|
pass |
|
|
|
|
|
try: |
|
stderr_line = stderr_queue.get_nowait() |
|
print(f"STDERR: {stderr_line}") |
|
output_stderr += stderr_line + "\n" |
|
except Empty: |
|
pass |
|
|
|
|
|
stdout_thread.join() |
|
stderr_thread.join() |
|
|
|
return (process.returncode, output_stdout, output_stderr) |
|
|
|
|
|
|
|
def guess_base_model(ft_model_id): |
|
res = requests.get(f"https://huggingface.co/api/models/{ft_model_id}") |
|
res = res.json() |
|
for tag in res["tags"]: |
|
if tag.startswith("base_model:"): |
|
return tag.split(":")[-1] |
|
raise Exception("Cannot guess the base model, please enter it manually") |
|
|
|
|
|
def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo, oauth_token: gr.OAuthToken | None): |
|
|
|
try: |
|
whoami(oauth_token.token) |
|
except Exception as e: |
|
raise gr.Error("You must be logged in") |
|
|
|
model_name = ft_model_id.split('/')[-1] |
|
|
|
|
|
whoami(oauth_token.token) |
|
|
|
if not os.path.exists("outputs"): |
|
os.makedirs("outputs") |
|
|
|
try: |
|
api = HfApi(token=oauth_token.token) |
|
|
|
if not base_model_id: |
|
base_model_id = guess_base_model(ft_model_id) |
|
print("guess_base_model", base_model_id) |
|
|
|
with tempfile.TemporaryDirectory(dir="outputs") as outputdir: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
cmd = [ |
|
"mergekit-extract-lora", |
|
ft_model_id, |
|
base_model_id, |
|
outputdir, |
|
f"--rank={rank}", |
|
f"--device={device}" |
|
] |
|
print("cmd", cmd) |
|
env_vars = dict(os.environ, HF_TOKEN=oauth_token.token) |
|
returncode, output_stdout, output_stderr = run_command(cmd, env_vars) |
|
print("returncode", returncode) |
|
print("output_stdout", output_stdout) |
|
print("output_stderr", output_stderr) |
|
if returncode != 0: |
|
raise Exception(f"Error converting to LoRA PEFT {output_stderr}") |
|
print("Model converted to LoRA PEFT successfully!") |
|
print(f"Converted model path: {outputdir}") |
|
|
|
|
|
if not os.listdir(outputdir): |
|
raise Exception("Output directory is empty!") |
|
|
|
|
|
username = whoami(oauth_token.token)["name"] |
|
new_repo_url = api.create_repo(repo_id=f"{username}/LoRA-{model_name}", exist_ok=True, private=private_repo) |
|
new_repo_id = new_repo_url.repo_id |
|
print("Repo created successfully!", new_repo_url) |
|
|
|
|
|
api.upload_folder( |
|
folder_path=outputdir, |
|
path_in_repo="", |
|
repo_id=new_repo_id, |
|
) |
|
print("Uploaded", outputdir) |
|
|
|
return ( |
|
f'<h1>β
DONE</h1><br/><br/>Find your repo here: <a href="{new_repo_url}" target="_blank" style="text-decoration:underline">{new_repo_id}</a>' |
|
) |
|
except Exception as e: |
|
return (f"<h1>β ERROR</h1><br/><br/>{e}") |
|
|
|
|
|
css="""/* Custom CSS to allow scrolling */ |
|
.gradio-container {overflow-y: auto;} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown("You must be logged in.") |
|
gr.LoginButton(min_width=250) |
|
|
|
ft_model_id = HuggingfaceHubSearch( |
|
label="Fine tuned model repository", |
|
placeholder="Fine tuned model", |
|
search_type="model", |
|
) |
|
|
|
base_model_id = HuggingfaceHubSearch( |
|
label="Base model repository (optional)", |
|
placeholder="If empty, it will be guessed from repo tags", |
|
search_type="model", |
|
) |
|
|
|
rank = gr.Dropdown( |
|
["16", "32", "64", "128"], |
|
label="LoRA rank", |
|
info="Higher the rank, better the result, but heavier the adapter", |
|
value="32", |
|
filterable=False, |
|
visible=True |
|
) |
|
|
|
private_repo = gr.Checkbox( |
|
value=False, |
|
label="Private Repo", |
|
info="Create a private repo under your username." |
|
) |
|
|
|
iface = gr.Interface( |
|
fn=process_model, |
|
inputs=[ |
|
ft_model_id, |
|
base_model_id, |
|
rank, |
|
private_repo, |
|
], |
|
outputs=[ |
|
gr.Markdown(label="output"), |
|
], |
|
title="Convert fine tuned model into LoRA with mergekit-extract-lora", |
|
description="The space takes a fine tuned model, a base model, then make a PEFT-compatible LoRA adapter based on the difference between 2 models.<br/><br/>NOTE: Each conversion takes about <b>5 to 20 minutes</b>, depending on how big the model is.", |
|
api_name=False |
|
) |
|
|
|
|
|
demo.queue(default_concurrency_limit=1, max_size=5).launch(debug=True, show_api=False) |