rayochoajr commited on
Commit
784e7de
·
verified ·
1 Parent(s): 707dce7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -0
app.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script demonstrates the use of the websockets API and the SaveImageWebsocket node to retrieve images directly
2
+ # without saving them to disk. It includes a Gradio interface for user interaction.
3
+
4
+ # JSON Comment Section
5
+ """
6
+ {
7
+ "inputs": [
8
+ {
9
+ "type": "text",
10
+ "label": "Text Prompt",
11
+ "name": "text_prompt",
12
+ "default": "masterpiece best quality man"
13
+ },
14
+ {
15
+ "type": "number",
16
+ "label": "Seed",
17
+ "name": "seed",
18
+ "default": 5
19
+ },
20
+ {
21
+ "type": "radio",
22
+ "label": "Server",
23
+ "name": "server",
24
+ "choices": ["AWS Server", "Home Server"],
25
+ "default": "Home Server"
26
+ }
27
+ ],
28
+ "outputs": [
29
+ {
30
+ "type": "image",
31
+ "label": "Generated Image",
32
+ "name": "generated_image"
33
+ }
34
+ ]
35
+ }
36
+ """
37
+
38
+ import gradio as gr
39
+ import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
40
+ import uuid
41
+ import json
42
+ import urllib.request
43
+ import urllib.parse
44
+ from PIL import Image
45
+ import io
46
+
47
+ # Generate a unique client ID for the session
48
+ client_id = str(uuid.uuid4())
49
+
50
+ # Function to queue a prompt to the server
51
+ def queue_prompt(prompt, server_address):
52
+ p = {"prompt": prompt, "client_id": client_id}
53
+ data = json.dumps(p).encode('utf-8')
54
+ req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
55
+ return json.loads(urllib.request.urlopen(req).read())
56
+
57
+ # Function to retrieve an image from the server
58
+ def get_image(filename, subfolder, folder_type, server_address):
59
+ data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
60
+ url_values = urllib.parse.urlencode(data)
61
+ with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
62
+ return response.read()
63
+
64
+ # Function to retrieve the history of a prompt from the server
65
+ def get_history(prompt_id, server_address):
66
+ with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
67
+ return json.loads(response.read())
68
+
69
+ # Function to get images from the websocket connection
70
+ def get_images(ws, prompt, server_address):
71
+ prompt_id = queue_prompt(prompt, server_address)['prompt_id']
72
+ output_images = {}
73
+ current_node = ""
74
+ while True:
75
+ out = ws.recv()
76
+ if isinstance(out, str):
77
+ message = json.loads(out)
78
+ if message['type'] == 'executing':
79
+ data = message['data']
80
+ if data['prompt_id'] == prompt_id:
81
+ if data['node'] is None:
82
+ break # Execution is done
83
+ else:
84
+ current_node = data['node']
85
+ else:
86
+ if current_node == 'save_image_websocket_node':
87
+ images_output = output_images.get(current_node, [])
88
+ images_output.append(out[8:])
89
+ output_images[current_node] = images_output
90
+
91
+ return output_images
92
+
93
+ # Function to generate an image based on the provided text prompt, seed, and server
94
+ def generate_image(text_prompt, seed, server):
95
+ prompt_text = """
96
+ {
97
+ "3": {
98
+ "class_type": "KSampler",
99
+ "inputs": {
100
+ "cfg": 8,
101
+ "denoise": 1,
102
+ "latent_image": [
103
+ "5",
104
+ 0
105
+ ],
106
+ "model": [
107
+ "4",
108
+ 0
109
+ ],
110
+ "negative": [
111
+ "7",
112
+ 0
113
+ ],
114
+ "positive": [
115
+ "6",
116
+ 0
117
+ ],
118
+ "sampler_name": "euler",
119
+ "scheduler": "normal",
120
+ "seed": 8566257,
121
+ "steps": 20
122
+ }
123
+ },
124
+ "4": {
125
+ "class_type": "CheckpointLoaderSimple",
126
+ "inputs": {
127
+ "ckpt_name": "v1-5-pruned-emaonly.ckpt"
128
+ }
129
+ },
130
+ "5": {
131
+ "class_type": "EmptyLatentImage",
132
+ "inputs": {
133
+ "batch_size": 1,
134
+ "height": 512,
135
+ "width": 768
136
+ }
137
+ },
138
+ "6": {
139
+ "class_type": "CLIPTextEncode",
140
+ "inputs": {
141
+ "clip": [
142
+ "4",
143
+ 1
144
+ ],
145
+ "text": "masterpiece best quality girl"
146
+ }
147
+ },
148
+ "7": {
149
+ "class_type": "CLIPTextEncode",
150
+ "inputs": {
151
+ "clip": [
152
+ "4",
153
+ 1
154
+ ],
155
+ "text": "bad hands"
156
+ }
157
+ },
158
+ "8": {
159
+ "class_type": "VAEDecode",
160
+ "inputs": {
161
+ "samples": [
162
+ "3",
163
+ 0
164
+ ],
165
+ "vae": [
166
+ "4",
167
+ 2
168
+ ]
169
+ }
170
+ },
171
+ "save_image_websocket_node": {
172
+ "class_type": "SaveImageWebsocket",
173
+ "inputs": {
174
+ "images": [
175
+ "8",
176
+ 0
177
+ ]
178
+ }
179
+ }
180
+ }
181
+ """
182
+
183
+ prompt = json.loads(prompt_text)
184
+ # Set the text prompt for our positive CLIPTextEncode
185
+ prompt["6"]["inputs"]["text"] = text_prompt
186
+
187
+ # Set the seed for our KSampler node
188
+ prompt["3"]["inputs"]["seed"] = seed
189
+
190
+ # Determine the server address based on the selected server
191
+ server_address = "3.14.144.23:8188" if server == "AWS Server" else "192.168.50.136:8188"
192
+
193
+ # Establish a websocket connection
194
+ ws = websocket.WebSocket()
195
+ ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
196
+ images = get_images(ws, prompt, server_address)
197
+
198
+ # Get the first image from the output
199
+ image = None
200
+
201
+ for node_id in images:
202
+ for image_data in images[node_id]:
203
+ image = Image.open(io.BytesIO(image_data))
204
+ break
205
+ if image:
206
+ break
207
+
208
+ return image
209
+
210
+ # Function to handle cancel request
211
+ def cancel_request():
212
+ return "Request Cancelled"
213
+
214
+ # Gradio Interface using Blocks
215
+ with gr.Blocks() as demo:
216
+ gr.Markdown("# Image Generation with Websockets API")
217
+ gr.Markdown("Generate images using a Websockets API and SaveImageWebsocket node.")
218
+
219
+ with gr.Row():
220
+ with gr.Column():
221
+ text_prompt = gr.Textbox(label="Text Prompt", value="masterpiece best quality man")
222
+ seed = gr.Number(label="Seed", value=5)
223
+ server = gr.Radio(label="Server", choices=["AWS Server", "Home Server"], value="Home Server")
224
+ generate_button = gr.Button("Generate Image")
225
+ cancel_button = gr.Button("Cancel Request")
226
+
227
+ with gr.Column():
228
+ output_image = gr.Image(label="Generated Image")
229
+
230
+ # Set up button click events
231
+ generate_button.click(fn=generate_image, inputs=[text_prompt, seed, server], outputs=output_image)
232
+ cancel_button.click(fn=cancel_request, inputs=[], outputs=[])
233
+
234
+ # Launch the Gradio interface
235
+ demo.launch()