Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
63a6e70
1
Parent(s):
38d05ac
hybrid-backend (#256)
Browse files- Swap to hybrid backend (53066e3cf8ddbe2b02bee8a96c1feee01d99a01d)
Co-authored-by: Apolinario <[email protected]>
app.py
CHANGED
|
@@ -5,7 +5,11 @@ from torch import autocast
|
|
| 5 |
from diffusers import StableDiffusionPipeline
|
| 6 |
from datasets import load_dataset
|
| 7 |
from PIL import Image
|
|
|
|
|
|
|
| 8 |
import re
|
|
|
|
|
|
|
| 9 |
|
| 10 |
from share_btn import community_icon_html, loading_icon_html, share_js
|
| 11 |
|
|
@@ -21,27 +25,44 @@ torch.backends.cudnn.benchmark = True
|
|
| 21 |
word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
|
| 22 |
word_list = word_list_dataset["train"]['text']
|
| 23 |
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
#When running locally you can also remove this filter
|
| 26 |
for filter in word_list:
|
| 27 |
if re.search(rf"\b{filter}\b", prompt):
|
| 28 |
raise gr.Error("Unsafe content found. Please try again with different prompts.")
|
| 29 |
|
| 30 |
-
generator = torch.Generator(device=device).manual_seed(seed)
|
| 31 |
-
|
| 32 |
-
images_list = pipe(
|
| 33 |
-
[prompt] * samples,
|
| 34 |
-
num_inference_steps=steps,
|
| 35 |
-
guidance_scale=scale,
|
| 36 |
-
generator=generator,
|
| 37 |
-
)
|
| 38 |
images = []
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
return images, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
| 46 |
|
| 47 |
|
|
@@ -298,6 +319,7 @@ with block:
|
|
| 298 |
share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
|
| 299 |
|
| 300 |
with gr.Row(elem_id="advanced-options"):
|
|
|
|
| 301 |
samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
|
| 302 |
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
|
| 303 |
scale = gr.Slider(
|
|
@@ -311,13 +333,13 @@ with block:
|
|
| 311 |
randomize=True,
|
| 312 |
)
|
| 313 |
|
| 314 |
-
ex = gr.Examples(examples=examples, fn=infer, inputs=
|
| 315 |
ex.dataset.headers = [""]
|
| 316 |
|
| 317 |
|
| 318 |
-
text.submit(infer, inputs=
|
| 319 |
|
| 320 |
-
btn.click(infer, inputs=
|
| 321 |
|
| 322 |
advanced_button.click(
|
| 323 |
None,
|
|
@@ -350,4 +372,4 @@ Despite how impressive being able to turn text into image is, beware to the fact
|
|
| 350 |
"""
|
| 351 |
)
|
| 352 |
|
| 353 |
-
block.queue(max_size=25).launch()
|
|
|
|
| 5 |
from diffusers import StableDiffusionPipeline
|
| 6 |
from datasets import load_dataset
|
| 7 |
from PIL import Image
|
| 8 |
+
from io import BytesIO
|
| 9 |
+
import base64
|
| 10 |
import re
|
| 11 |
+
import os
|
| 12 |
+
import requests
|
| 13 |
|
| 14 |
from share_btn import community_icon_html, loading_icon_html, share_js
|
| 15 |
|
|
|
|
| 25 |
word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
|
| 26 |
word_list = word_list_dataset["train"]['text']
|
| 27 |
|
| 28 |
+
is_gpu_busy = False
|
| 29 |
+
def infer(prompt):
|
| 30 |
+
global is_gpu_busy
|
| 31 |
+
samples = 4
|
| 32 |
+
steps = 50
|
| 33 |
+
scale = 7.5
|
| 34 |
#When running locally you can also remove this filter
|
| 35 |
for filter in word_list:
|
| 36 |
if re.search(rf"\b{filter}\b", prompt):
|
| 37 |
raise gr.Error("Unsafe content found. Please try again with different prompts.")
|
| 38 |
|
| 39 |
+
#generator = torch.Generator(device=device).manual_seed(seed)
|
| 40 |
+
print("Is GPU busy? ", is_gpu_busy)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
images = []
|
| 42 |
+
if(not is_gpu_busy):
|
| 43 |
+
is_gpu_busy = True
|
| 44 |
+
images_list = pipe(
|
| 45 |
+
[prompt] * samples,
|
| 46 |
+
num_inference_steps=steps,
|
| 47 |
+
guidance_scale=scale,
|
| 48 |
+
#generator=generator,
|
| 49 |
+
)
|
| 50 |
+
is_gpu_busy = False
|
| 51 |
+
safe_image = Image.open(r"unsafe.png")
|
| 52 |
+
for i, image in enumerate(images_list["sample"]):
|
| 53 |
+
if(images_list["nsfw_content_detected"][i]):
|
| 54 |
+
images.append(safe_image)
|
| 55 |
+
else:
|
| 56 |
+
images.append(image)
|
| 57 |
+
else:
|
| 58 |
+
url = os.getenv('JAX_BACKEND_URL')
|
| 59 |
+
payload = {'prompt': prompt}
|
| 60 |
+
images_request = requests.post(url, json = payload)
|
| 61 |
+
for image in images_request.json()["images"]:
|
| 62 |
+
image_decoded = Image.open(BytesIO(base64.b64decode(image)))
|
| 63 |
+
images.append(image_decoded)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
return images, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
| 67 |
|
| 68 |
|
|
|
|
| 319 |
share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
|
| 320 |
|
| 321 |
with gr.Row(elem_id="advanced-options"):
|
| 322 |
+
gr.Markdown("Advanced settings are temporarily unavailable")
|
| 323 |
samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
|
| 324 |
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
|
| 325 |
scale = gr.Slider(
|
|
|
|
| 333 |
randomize=True,
|
| 334 |
)
|
| 335 |
|
| 336 |
+
ex = gr.Examples(examples=examples, fn=infer, inputs=text, outputs=[gallery, community_icon, loading_icon, share_button], cache_examples=True)
|
| 337 |
ex.dataset.headers = [""]
|
| 338 |
|
| 339 |
|
| 340 |
+
text.submit(infer, inputs=text, outputs=[gallery, community_icon, loading_icon, share_button])
|
| 341 |
|
| 342 |
+
btn.click(infer, inputs=text, outputs=[gallery, community_icon, loading_icon, share_button])
|
| 343 |
|
| 344 |
advanced_button.click(
|
| 345 |
None,
|
|
|
|
| 372 |
"""
|
| 373 |
)
|
| 374 |
|
| 375 |
+
block.queue(max_size=25, concurrency_count=2).launch()
|