File size: 4,893 Bytes
4a11192
a63a8f0
 
 
 
e13c481
eecb24e
a63a8f0
8c7d6fb
a63a8f0
 
df2d67a
 
62c7d61
569e3d3
a63a8f0
 
 
 
 
16dcabb
 
 
 
 
 
 
 
 
 
 
a63a8f0
e13c481
43408ea
 
 
e13c481
 
 
 
 
 
a232e0c
a63a8f0
08800f1
875acc3
f183eb6
4e83cae
739fa1b
dc1cf71
4da242e
 
e13c481
 
dc1cf71
 
 
4da242e
 
e13c481
 
dc1cf71
 
a63a8f0
e13c481
a63a8f0
 
 
 
 
 
 
 
 
 
 
 
 
16dcabb
 
 
 
 
 
 
 
 
 
a63a8f0
 
a9eb656
a274f3e
90f5249
 
a63a8f0
f6a1585
9a32b6c
4c3f09a
fb3682d
 
 
f6a1585
fb3682d
3f44b7d
 
f6a1585
e8f35b2
 
c08838c
 
e8f35b2
c08838c
a9b3d15
82b64ca
e13c481
a9b3d15
326497b
c19881c
a9b3d15
067b45f
a9b3d15
c08838c
a63a8f0
a232e0c
a63a8f0
 
aacdbea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

import gradio as gr
import os
import requests
import json
import base64
from io import BytesIO
from huggingface_hub import login
from PIL import Image


# myip = os.environ["0.0.0.0"]
# myport = os.environ["80"]
myip = "34.219.98.113"
myport=8000

is_spaces = True if "SPACE_ID" in os.environ else False

is_shared_ui = False

from css_html_js import custom_css

from about import (
    CITATION_BUTTON_LABEL,
    CITATION_BUTTON_TEXT,
    EVALUATION_QUEUE_TEXT,
    INTRODUCTION_TEXT,
    LLM_BENCHMARKS_TEXT,
    TITLE,
)


def process_image_from_binary(img_stream):
    if img_stream is None:
        print("no image binary")
        return
    image_data = base64.b64decode(img_stream)
    image_bytes = BytesIO(image_data)
    img = Image.open(image_bytes)
    
    return img

def excute_udiff(diffusion_model_id, concept, steps, attack_id):
    print(f"my IP is {myip}, my port is {myport}")
    print(f"my input is diffusion_model_id: {diffusion_model_id}, concept: {concept}, steps: {steps}")
    response = requests.post('http://{}:{}/udiff'.format(myip, myport), 
                             json={"diffusion_model_id": diffusion_model_id, "concept": concept, "steps": steps, "attack_id": attack_id},
                             timeout=(10, 1200))
    print(f"result: {response}")
    # result = result.text[1:-1]
    prompt1 = ""
    prompt2 = ""
    img1 = None
    img2 = None
    if response.status_code == 200:
        response_json = response.json()
        print(response_json)
        prompt1 = response_json['input_prompt']
        prompt2 = response_json['output_prompt']
        img1 = process_image_from_binary(response_json['no_attack_img'])
        img2 = process_image_from_binary(response_json['attack_img'])
    else:
        print(f"Request failed with status code {response.status_code}")
    
    return prompt1, prompt2, img1, img2


css = '''
    .instruction{position: absolute; top: 0;right: 0;margin-top: 0px !important}
    .arrow{position: absolute;top: 0;right: -110px;margin-top: -8px !important}
    #component-4, #component-3, #component-10{min-height: 0}
    .duplicate-button img{margin: 0}
    #img_1, #img_2, #img_3, #img_4{height:15rem}
    #mdStyle{font-size: 0.7rem}
    #titleCenter {text-align:center}
'''


with gr.Blocks(css=custom_css) as demo:
    gr.HTML(TITLE)
    gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")

#     gr.Markdown("# Demo of UnlearnDiffAtk.")
#     gr.Markdown("### UnlearnDiffAtk is an effective and efficient adversarial prompt generation approach for unlearned diffusion models(DMs).")
# #     gr.Markdown("####For more details, please visit the [project](https://www.optml-group.com/posts/mu_attack), 
# # check the [code](https://github.com/OPTML-Group/Diffusion-MU-Attack), and read the [paper](https://arxiv.org/abs/2310.11868).")
#     gr.Markdown("### Please notice that the process may take a long time, but the results will be saved. You can try it later if it waits for too long.")
    

    with gr.Row() as udiff:
        with gr.Row():
            drop = gr.Dropdown(["Object-Church", "Object-Parachute", "Object-Garbage_Truck","Style-VanGogh",
                               "Nudity"], 
                               label="Unlearning undesirable concepts")
        with gr.Column():
            # gr.Markdown("Please upload your model id.")
            drop_model = gr.Dropdown(["ESD", "FMN", "SPM"], 
                               label="Unlearned DMs")
            # diffusion_model_T = gr.Textbox(label='diffusion_model_id')
            # concept = gr.Textbox(label='concept')
            # attacker = gr.Textbox(label='attacker')

            # start_button = gr.Button("Attack!")
        with gr.Column():
            atk_idx = gr.Textbox(label="attack index")

        with gr.Column():
             shown_columns_step = gr.Slider(
                            0, 100, value=40, 
                            step=1, label="Attack Steps", info="Choose between 0 and 100",
                            interactive=True,)
    with gr.Row() as attack:
        with gr.Column(min_width=256):
            text_input = gr.Textbox(label="Input Prompt")
            
            orig_img = gr.Image(label="Image Generated by Input Prompt",width=256,show_share_button=False,show_download_button=False)
        with gr.Column():
            start_button = gr.Button("UnlearnDiffAtk!",size='lg')
        with gr.Column(min_width=256):
            text_ouput = gr.Textbox(label="Prompt Genetated by UnlearnDiffAtk")
            result_img = gr.Image(label="Image Gnerated by Prompt of UnlearnDiffAtk",width=256,show_share_button=False,show_download_button=False)
            

    start_button.click(fn=excute_udiff, inputs=[drop_model, drop, shown_columns_step, atk_idx], outputs=[text_input, text_ouput, orig_img, result_img], api_name="udiff")


demo.queue().launch(server_name='0.0.0.0')