File size: 3,822 Bytes
edb9ee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import subprocess
import signal
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
import gradio as gr
import tempfile

from huggingface_hub import HfApi, ModelCard, whoami
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from pathlib import Path
from textwrap import dedent


def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo, oauth_token: gr.OAuthToken | None):
    if oauth_token is None or oauth_token.token is None:
        raise gr.Error("You must be logged in")
    model_name = ft_model_id.split('/')[-1]

    if not os.path.exists("outputs"):
        os.makedirs("outputs")

    try:
        api = HfApi(token=oauth_token.token)
        
        with tempfile.TemporaryDirectory(dir="outputs") as outputdir:
            result = subprocess.run([
                "mergekit-extract-lora",
                ft_model_id,
                base_model_id,
                outputdir,
                f"--rank={rank}",
            ], shell=False, capture_output=True)
            print(result)
            if result.returncode != 0:
                raise Exception(f"Error converting to LoRA PEFT {q_method}: {result.stderr}")
            print("Model converted to LoRA PEFT successfully!")
            print(f"Converted model path: {outputdir}")

            # Check output dir
            if not os.listdir(outputdir):
                raise Exception("Output directory is empty!")

            # Create repo
            username = whoami(oauth_token.token)["name"]
            new_repo_url = api.create_repo(repo_id=f"{username}/LoRA-{model_name}", exist_ok=True, private=private_repo)
            new_repo_id = new_repo_url.repo_id
            print("Repo created successfully!", new_repo_url)

            # Upload files
            api.upload_file(
                folder_path=outputdir,
                path_in_repo="",
                repo_id=new_repo_id,
            )
            print("Uploaded", outputdir)

        return (
            f'<h1>βœ… DONE</h1><br/><br/>Find your repo here: <a href="{new_repo_url}" target="_blank" style="text-decoration:underline">{new_repo_id}</a>'
        )
    except Exception as e:
        return (f"<h1>❌ ERROR</h1><br/><br/>{e}")


css="""/* Custom CSS to allow scrolling */
.gradio-container {overflow-y: auto;}
"""
# Create Gradio interface
with gr.Blocks(css=css) as demo: 
    gr.Markdown("You must be logged in.")
    gr.LoginButton(min_width=250)

    ft_model_id = HuggingfaceHubSearch(
        label="Fine tuned model repository",
        placeholder="Search for repository on Huggingface",
        search_type="model",
    )

    base_model_id = HuggingfaceHubSearch(
        label="Base tuned model repository",
        placeholder="Search for repository on Huggingface",
        search_type="model",
    )

    rank = gr.Dropdown(
        ["16", "32", "64", "128"],
        label="LoRA rank",
        info="Higher the rank, better the result, but heavier the adapter",
        value="32",
        filterable=False,
        visible=True
    )

    private_repo = gr.Checkbox(
        value=False,
        label="Private Repo",
        info="Create a private repo under your username."
    )

    iface = gr.Interface(
        fn=process_model,
        inputs=[
            ft_model_id,
            base_model_id,
            rank,
            private_repo,
        ],
        outputs=[
            gr.Markdown(label="output"),
        ],
        title="Convert fine tuned model into LoRA with mergekit-extract-lora",
        description="The space takes a fine tuned model, a base model, then make a PEFT-compatible LoRA adapter based on the difference between 2 models.",
        api_name=False
    )

# Launch the interface
demo.queue(default_concurrency_limit=1, max_size=5).launch(debug=True, show_api=False)