trifonova commited on
Commit
a01e50d
·
1 Parent(s): b60abf4

Add inference with fair compute api

Browse files
Files changed (3) hide show
  1. app.py +57 -58
  2. fair.py +254 -0
  3. requirements.txt +1 -1
app.py CHANGED
@@ -1,36 +1,35 @@
1
  import gradio as gr
2
- from datasets import load_dataset
3
  from PIL import Image
4
  import re
5
  import os
6
  import requests
 
 
 
7
 
8
  from share_btn import community_icon_html, loading_icon_html, share_js
9
 
10
  model_id = "runwayml/stable-diffusion-v1-5"
11
  device = "cuda"
12
 
13
- word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
14
- word_list = word_list_dataset["train"]['text']
15
 
16
- is_gpu_busy = False
17
  def infer(prompt):
18
- global is_gpu_busy
19
  samples = 4
20
  steps = 50
21
  scale = 7.5
22
- for filter in word_list:
23
- if re.search(rf"\b{filter}\b", prompt):
24
- raise gr.Error("Unsafe content found. Please try again with different prompts.")
25
-
26
  images = []
27
- url = os.getenv('JAX_BACKEND_URL')
28
- payload = {'prompt': prompt}
29
- images_request = requests.post(url, json = payload)
30
- for image in images_request.json()["images"]:
31
- image_b64 = (f"data:image/jpeg;base64,{image}")
32
- images.append(image_b64)
33
-
34
  return images
35
 
36
 
@@ -239,55 +238,55 @@ with block:
239
  rounded=(False, True, True, False),
240
  full_width=False,
241
  )
242
-
243
  gallery = gr.Gallery(
244
  label="Generated images", show_label=False, elem_id="gallery"
245
  ).style(grid=[2], height="auto")
246
 
247
- with gr.Group(elem_id="container-advanced-btns"):
248
- advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
249
- with gr.Group(elem_id="share-btn-container"):
250
- community_icon = gr.HTML(community_icon_html)
251
- loading_icon = gr.HTML(loading_icon_html)
252
- share_button = gr.Button("Share to community", elem_id="share-btn")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
- with gr.Row(elem_id="advanced-options"):
255
- gr.Markdown("Advanced settings are temporarily unavailable")
256
- samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
257
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
258
- scale = gr.Slider(
259
- label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
260
- )
261
- seed = gr.Slider(
262
- label="Seed",
263
- minimum=0,
264
- maximum=2147483647,
265
- step=1,
266
- randomize=True,
267
- )
268
 
269
- ex = gr.Examples(examples=examples, fn=infer, inputs=text, outputs=[gallery], cache_examples=True, postprocess=False)
270
- ex.dataset.headers = [""]
271
-
272
- text.submit(infer, inputs=text, outputs=[gallery], postprocess=False)
273
- btn.click(infer, inputs=text, outputs=[gallery], postprocess=False)
274
 
275
- advanced_button.click(
276
- None,
277
- [],
278
- text,
279
- _js="""
280
- () => {
281
- const options = document.querySelector("body > gradio-app").querySelector("#advanced-options");
282
- options.style.display = ["none", ""].includes(options.style.display) ? "flex" : "none";
283
- }""",
284
- )
285
- share_button.click(
286
- None,
287
- [],
288
- [],
289
- _js=share_js,
290
- )
291
  gr.HTML(
292
  """
293
  <div class="footer">
 
1
  import gradio as gr
2
+ #from datasets import load_dataset
3
  from PIL import Image
4
  import re
5
  import os
6
  import requests
7
+ import numpy as np
8
+
9
+ from fair import text_to_image
10
 
11
  from share_btn import community_icon_html, loading_icon_html, share_js
12
 
13
  model_id = "runwayml/stable-diffusion-v1-5"
14
  device = "cuda"
15
 
16
+ #word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
17
+ #word_list = word_list_dataset["train"]['text']
18
 
19
+ #is_gpu_busy = False
20
  def infer(prompt):
21
+ #global is_gpu_busy
22
  samples = 4
23
  steps = 50
24
  scale = 7.5
25
+ # for filter in word_list:
26
+ # if re.search(rf"\b{filter}\b", prompt):
27
+ # raise gr.Error("Unsafe content found. Please try again with different prompts.")
28
+ #
29
  images = []
30
+ image = text_to_image(prompt)
31
+ image = np.array(Image.open(image).convert('RGB'))
32
+ images = [image, image, image, image]
 
 
 
 
33
  return images
34
 
35
 
 
238
  rounded=(False, True, True, False),
239
  full_width=False,
240
  )
241
+ # gallery = gr.Image(type="filepath").style(grid=[2], height="auto")
242
  gallery = gr.Gallery(
243
  label="Generated images", show_label=False, elem_id="gallery"
244
  ).style(grid=[2], height="auto")
245
 
246
+ # with gr.Group(elem_id="container-advanced-btns"):
247
+ # advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
248
+ # with gr.Group(elem_id="share-btn-container"):
249
+ # community_icon = gr.HTML(community_icon_html)
250
+ # loading_icon = gr.HTML(loading_icon_html)
251
+ # share_button = gr.Button("Share to community", elem_id="share-btn")
252
+ #
253
+ # with gr.Row(elem_id="advanced-options"):
254
+ # gr.Markdown("Advanced settings are temporarily unavailable")
255
+ # samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
256
+ # steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
257
+ # scale = gr.Slider(
258
+ # label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
259
+ # )
260
+ # seed = gr.Slider(
261
+ # label="Seed",
262
+ # minimum=0,
263
+ # maximum=2147483647,
264
+ # step=1,
265
+ # randomize=True,
266
+ # )
267
 
268
+ #ex = gr.Examples(examples=examples, fn=infer, inputs=text, outputs=[gallery], cache_examples=True, postprocess=False)
269
+ #ex.dataset.headers = [""]
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
+ text.submit(infer, inputs=text, outputs=[gallery])
272
+ btn.click(infer, inputs=text, outputs=[gallery])
 
 
 
273
 
274
+ # advanced_button.click(
275
+ # None,
276
+ # [],
277
+ # text,
278
+ # _js="""
279
+ # () => {
280
+ # const options = document.querySelector("body > gradio-app").querySelector("#advanced-options");
281
+ # options.style.display = ["none", ""].includes(options.style.display) ? "flex" : "none";
282
+ # }""",
283
+ # )
284
+ # share_button.click(
285
+ # None,
286
+ # [],
287
+ # [],
288
+ # _js=share_js,
289
+ # )
290
  gr.HTML(
291
  """
292
  <div class="footer">
fair.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import time
4
+ from typing import List
5
+ import logging
6
+ logger = logging.getLogger()
7
+
8
+ import requests
9
+
10
+ #SERVER_ADRESS="https://faircompute.com:8000/api/v1"
11
+ SERVER_ADRESS="http://localhost:8000/api/v1"
12
+ DOCKER_IMAGE="faircompute/stable-diffusion:pytorch-1.13.1-cu116"
13
+ #DOCKER_IMAGE="sha256:e06453fe869556ea3e63572a935aed4261337b261fdf7bda370472b0587409a9"
14
+
15
+ def authenticate(email: str, password: str):
16
+ url = f'{SERVER_ADRESS}/auth/login'
17
+ json_obj = {"email": email, "password": password}
18
+ resp = requests.post(url, json=json_obj)
19
+ token = resp.json()["token"]
20
+ return token
21
+
22
+ def get(url, token, **kwargs):
23
+ headers = {
24
+ 'Authorization': f'Bearer {token}'
25
+ }
26
+ response = requests.get(url, headers=headers, **kwargs)
27
+
28
+ if not response.ok:
29
+ raise Exception(f"Error! status: {response.status_code}")
30
+ return response
31
+
32
+
33
+ def put(url, token, data):
34
+ headers = {
35
+ 'Content-Type': 'application/json',
36
+ 'Authorization': f'Bearer {token}'
37
+ }
38
+ if not isinstance(data, str):
39
+ data = json.dumps(data)
40
+ response = requests.put(url, headers=headers, data=data)
41
+
42
+ if not response.ok and response.status_code != 206:
43
+ raise Exception(f"Error! status: {response.status_code}")
44
+ return response
45
+
46
+
47
+ def put_program(token, launcher: str, image: str, runtime: str, command: List[str]):
48
+ url = f"{SERVER_ADRESS}/programs"
49
+ data = {
50
+ launcher: {
51
+ "image": image,
52
+ "command": command,
53
+ "runtime": runtime
54
+ }
55
+ }
56
+ response = put(url=url, token=token, data=data)
57
+
58
+ return int(response.text)
59
+
60
+
61
+ def put_job(token, program_id, input_files, output_files):
62
+ url = f"{SERVER_ADRESS}/jobs?program={program_id}"
63
+ data = {
64
+ 'input_files': input_files,
65
+ 'output_files': output_files
66
+ }
67
+
68
+ response = put(url=url, token=token, data=data)
69
+
70
+ return int(response.text)
71
+
72
+
73
+ def get_job_status(token, job_id):
74
+ url = f"{SERVER_ADRESS}/jobs/{job_id}/status"
75
+ response = get(url=url, token=token)
76
+ return response.text
77
+
78
+
79
+ def get_cluster_summary(token):
80
+ url = f"{SERVER_ADRESS}/nodes/summary"
81
+
82
+ response = get(token=token, url=url)
83
+
84
+ return response.json()
85
+
86
+
87
+ def put_job_stream_data(token, job_id, name, data):
88
+ url = f"{SERVER_ADRESS}/jobs/{job_id}/data/streams/{name}"
89
+ response = put(url=url, token=token, data=data)
90
+
91
+ return response.text
92
+
93
+
94
+ def put_job_stream_eof(token, job_id, name):
95
+ url = f"{SERVER_ADRESS}/jobs/{job_id}/data/streams/{name}/eof"
96
+
97
+ response = put(url=url, token=token, data=None)
98
+
99
+ return response.text
100
+
101
+
102
+ def wait_for_file(token, job_id, path, local_path, attempts=10):
103
+ headers = {
104
+ 'Authorization': f'Bearer {token}'
105
+ }
106
+ for i in range(attempts):
107
+ url = f"{SERVER_ADRESS}/jobs/{job_id}/data/files/{path}"
108
+ print(f"Waiting for file {path}...")
109
+ try:
110
+ with requests.get(url=url, headers=headers, stream=True) as r:
111
+ r.raise_for_status()
112
+ with open(local_path, 'wb') as f:
113
+ for chunk in r.iter_content(chunk_size=8192):
114
+ f.write(chunk)
115
+
116
+ print(f"File {local_path} ready")
117
+ return local_path
118
+ except Exception as e:
119
+ print(e)
120
+ time.sleep(0.5)
121
+
122
+ print(f"Failed to receive {local_path}")
123
+
124
+
125
+ def text_to_image(text):
126
+ email = os.getenv('FAIRCOMPUTE_EMAIL')
127
+ password = os.environ.get('FAIRCOMPUTE_PASSWORD')
128
+ token = authenticate(email=email, password=password)
129
+
130
+ logger.info(token)
131
+
132
+ summary = get_cluster_summary(token=token)
133
+ logger.info("Summary:")
134
+ logger.info(summary)
135
+ program_id = put_program(token=token,
136
+ launcher="Docker",
137
+ image=DOCKER_IMAGE,
138
+ runtime="nvidia",
139
+ command=[])
140
+ logger.info(program_id)
141
+
142
+ job_id = put_job(token=token,
143
+ program_id=program_id,
144
+ input_files=[],
145
+ output_files=["/workspace/result.png"])
146
+
147
+ logger.info(job_id)
148
+
149
+ status = get_job_status(token=token,
150
+ job_id=job_id)
151
+ logger.info(status)
152
+
153
+ while status != "Processing" and status != "Completed":
154
+ status = get_job_status(token=token,
155
+ job_id=job_id)
156
+ logger.info(status)
157
+ time.sleep(0.5)
158
+
159
+ res = put_job_stream_data(token=token,
160
+ job_id=job_id,
161
+ name="stdin",
162
+ data=text + "\n")
163
+ logger.info(res)
164
+
165
+ res = put_job_stream_eof(token=token,
166
+ job_id=job_id,
167
+ name="stdin")
168
+ logger.info(res)
169
+
170
+ status = get_job_status(token=token,
171
+ job_id=job_id)
172
+ logger.info(status)
173
+
174
+ while status == "Processing":
175
+ status = get_job_status(token=token,
176
+ job_id=job_id)
177
+ logger.info(status)
178
+ time.sleep(0.5)
179
+ if status == "Completed":
180
+ logger.info("Done!")
181
+ else:
182
+ logger.info("Job Failed")
183
+ resp = wait_for_file(token=token,
184
+ job_id=job_id,
185
+ path="%2Fworkspace%2Fresult.png",
186
+ local_path="result.png")
187
+ logger.info(resp)
188
+ return resp
189
+
190
+
191
+ if __name__=="__main__":
192
+ email = os.getenv('FAIRCOMPUTE_EMAIL')
193
+ password = os.environ.get('FAIRCOMPUTE_PASSWORD')
194
+ token = authenticate(email=email, password=password)
195
+
196
+ print(token)
197
+
198
+ summary = get_cluster_summary(token=token)
199
+ print("Summary:")
200
+ print(summary)
201
+ program_id = put_program(token=token,
202
+ launcher="Docker",
203
+ image=DOCKER_IMAGE,
204
+ runtime="nvidia",
205
+ command=[])
206
+ print(program_id)
207
+
208
+ job_id = put_job(token=token,
209
+ program_id=program_id,
210
+ input_files=[],
211
+ output_files=["/workspace/result.png"])
212
+
213
+ print(job_id)
214
+
215
+ status = get_job_status(token=token,
216
+ job_id=job_id)
217
+ print(status)
218
+
219
+ while status != "Processing" and status != "Completed":
220
+ status = get_job_status(token=token,
221
+ job_id=job_id)
222
+ print(status)
223
+ time.sleep(0.5)
224
+
225
+ res = put_job_stream_data(token=token,
226
+ job_id=job_id,
227
+ name="stdin",
228
+ data="Robot dinozaur\n")
229
+ print(res)
230
+
231
+ res = put_job_stream_eof(token=token,
232
+ job_id=job_id,
233
+ name="stdin")
234
+ print(res)
235
+
236
+ status = get_job_status(token=token,
237
+ job_id=job_id)
238
+ print(status)
239
+
240
+ while status == "Processing":
241
+ status = get_job_status(token=token,
242
+ job_id=job_id)
243
+ print(status)
244
+ time.sleep(0.5)
245
+ if status == "Completed":
246
+ print("Done!")
247
+ else:
248
+ print("Job Failed")
249
+ resp = wait_for_file(token=token,
250
+ job_id=job_id,
251
+ path="%2Fworkspace%2Fresult.png",
252
+ local_path="result.png")
253
+ print(resp)
254
+
requirements.txt CHANGED
@@ -1 +1 @@
1
- python-dotenv
 
1
+ gradio < 4