Dmitry Trifonov commited on
Commit
6c7a4d8
·
1 Parent(s): 2629a94

update demo to SDXL Turbo

Browse files
Files changed (3) hide show
  1. app.py +6 -12
  2. requirements.txt +0 -1
  3. text_to_image.py +32 -27
app.py CHANGED
@@ -150,13 +150,13 @@ with block:
150
  "
151
  >
152
  <h1 style="font-weight: 900; margin-bottom: 7px;">
153
- Fair Compute Demo
154
  </h1>
155
  </div>
156
  <p style="margin-bottom: 10px; font-size: 94%">
157
- A demo of a popular Stable Diffusion neural network that runs on regular computers.
158
- Run AI models on your computer and avoid paying egregious cloud fees
159
- This demo is powered by the <a href="https://faircompute.com/">FairCompute</a> platform.
160
  </p>
161
  </div>
162
  """
@@ -191,15 +191,9 @@ with block:
191
 
192
  gr.HTML(
193
  """
194
- <div class="footer">
195
- <p>Model by <a href="https://huggingface.co/CompVis" style="text-decoration: underline;" target="_blank">CompVis</a> and <a href="https://runwayml.com/" style="text-decoration: underline;" target="_blank">Runway</a> powered by <a href="https://faircompute.com/" style="text-decoration: underline;" target="_blank">FairCompute</a> platform
196
- </p>
197
- </div>
198
  <div class="acknowledgments">
199
- <p><h4>LICENSE</h4>
200
- The model is licensed with a <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" style="text-decoration: underline;" target="_blank">CreativeML Open RAIL-M</a> license. The authors claim no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in this license. The license forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm, spread misinformation and target vulnerable groups. For the full list of restrictions please <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" target="_blank" style="text-decoration: underline;" target="_blank">read the license</a></p>
201
- <p><h4>Biases and content acknowledgment</h4>
202
- Despite how impressive being able to turn text into image is, beware to the fact that this model may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography and violence. The model was trained on the <a href="https://laion.ai/blog/laion-5b/" style="text-decoration: underline;" target="_blank">LAION-5B dataset</a>, which scraped non-curated image-text-pairs from the internet (the exception being the removal of illegal content) and is meant for research purposes. You can read more in the <a href="https://huggingface.co/runwayml/stable-diffusion-v1-5" style="text-decoration: underline;" target="_blank">model card</a></p>
203
  </div>
204
  """
205
  )
 
150
  "
151
  >
152
  <h1 style="font-weight: 900; margin-bottom: 7px;">
153
+ SDXL Turbo by Fair Compute
154
  </h1>
155
  </div>
156
  <p style="margin-bottom: 10px; font-size: 94%">
157
+ SDXL Turbo model for generating high-quality images.<br/>
158
+ Model: <a href="https://huggingface.co/stabilityai/sdxl-turbo">https://huggingface.co/stabilityai/sdxl-turbo</a><br/>
159
+ <br/>
160
  </p>
161
  </div>
162
  """
 
191
 
192
  gr.HTML(
193
  """
 
 
 
 
194
  <div class="acknowledgments">
195
+ Run AI models on your computer and avoid paying egregious cloud fees.
196
+ This demo is powered by the <a href="https://faircompute.com/">FairCompute</a> platform.
 
 
197
  </div>
198
  """
199
  )
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
  gradio < 4
2
- octoai-sdk
3
  faircompute==0.20.0
4
  httpx==0.24.1
 
1
  gradio < 4
 
2
  faircompute==0.20.0
3
  httpx==0.24.1
text_to_image.py CHANGED
@@ -1,17 +1,15 @@
1
  import base64
2
  import logging
3
  import os
 
4
  import time
5
  from io import BytesIO
6
 
7
  from PIL import Image
8
- from octoai.client import Client as OctoAiClient
9
  from fair import FairClient
10
 
11
  logger = logging.getLogger()
12
 
13
- import requests
14
-
15
  SERVER_ADDRESS = "https://faircompute.com:8000"
16
  ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000"
17
  TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c"
@@ -19,19 +17,32 @@ TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c"
19
  # ENDPOINT_ADDRESS = "http://localhost:5000"
20
  # TARGET_NODE = "ef09913249aa40ecba7d0097f7622855"
21
 
22
- DOCKER_IMAGE = "faircompute/diffusion-octo:v1"
23
 
24
 
25
  class EndpointClient:
26
- def infer(self, prompt):
27
- client = OctoAiClient()
 
 
28
 
29
- inputs = {"prompt": {"text": prompt}}
30
- response = client.infer(endpoint_url=f"{ENDPOINT_ADDRESS}/infer", inputs=inputs)
31
-
32
- image_b64 = response["output"]["image_b64"]
33
- image_data = base64.b64decode(image_b64)
34
- image_data = BytesIO(image_data)
 
 
 
 
 
 
 
 
 
 
 
35
  image = Image.open(image_data)
36
 
37
  return image
@@ -41,18 +52,15 @@ class ServerNotReadyException(Exception):
41
  pass
42
 
43
 
44
- def wait_for_server(retries, timeout, delay=1.0):
45
  for i in range(retries):
46
  try:
47
- r = requests.get(ENDPOINT_ADDRESS, timeout=timeout)
48
- r.raise_for_status()
49
- return
50
  except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e:
51
- if i == retries - 1:
52
- raise ServerNotReadyException("Failed to start the server") from e
53
- else:
54
- logger.info("Server is not ready yet")
55
- time.sleep(delay)
56
 
57
 
58
  def start_server():
@@ -64,20 +72,17 @@ def start_server():
64
  client.run(node=TARGET_NODE,
65
  image=DOCKER_IMAGE,
66
  runtime="nvidia",
67
- ports=[(5000, 8080)],
68
  detach=True)
69
 
70
- # wait until the server is ready
71
- wait_for_server(retries=10, timeout=1.0)
72
-
73
 
74
  def text_to_image(text):
75
  try:
76
- wait_for_server(retries=1, timeout=1.0, delay=0.0)
77
  except ServerNotReadyException:
78
  start_server()
 
79
 
80
- client = EndpointClient()
81
  return client.infer(text)
82
 
83
 
 
1
  import base64
2
  import logging
3
  import os
4
+ import requests
5
  import time
6
  from io import BytesIO
7
 
8
  from PIL import Image
 
9
  from fair import FairClient
10
 
11
  logger = logging.getLogger()
12
 
 
 
13
  SERVER_ADDRESS = "https://faircompute.com:8000"
14
  ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000"
15
  TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c"
 
17
  # ENDPOINT_ADDRESS = "http://localhost:5000"
18
  # TARGET_NODE = "ef09913249aa40ecba7d0097f7622855"
19
 
20
+ DOCKER_IMAGE = "faircompute/diffusers-api-sdxl-turbo"
21
 
22
 
23
  class EndpointClient:
24
+ def __init__(self, timeout):
25
+ response = requests.get(os.path.join(ENDPOINT_ADDRESS, 'healthcheck'), timeout=timeout).json()
26
+ if response['state'] != 'healthy':
27
+ raise Exception("Server is not healthy")
28
 
29
+ def infer(self, prompt):
30
+ inputs = {
31
+ "modelInputs": {
32
+ "prompt": prompt,
33
+ "num_inference_steps": 4,
34
+ "guidance_scale": 0.0,
35
+ "width": 512,
36
+ "height": 512,
37
+ },
38
+ "callInputs": {
39
+ "MODEL_ID": "stabilityai/sdxl-turbo",
40
+ "PIPELINE": "AutoPipelineForText2Image",
41
+ },
42
+ }
43
+
44
+ response = requests.post(ENDPOINT_ADDRESS, json=inputs).json()
45
+ image_data = BytesIO(base64.b64decode(response["image_base64"]))
46
  image = Image.open(image_data)
47
 
48
  return image
 
52
  pass
53
 
54
 
55
+ def wait_for_server(retries, timeout=1.0, delay=2.0):
56
  for i in range(retries):
57
  try:
58
+ return EndpointClient(timeout=timeout)
 
 
59
  except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e:
60
+ logging.exception(e)
61
+ time.sleep(delay)
62
+
63
+ raise ServerNotReadyException("Failed to start the server")
 
64
 
65
 
66
  def start_server():
 
72
  client.run(node=TARGET_NODE,
73
  image=DOCKER_IMAGE,
74
  runtime="nvidia",
75
+ ports=[(5000, 8000)],
76
  detach=True)
77
 
 
 
 
78
 
79
  def text_to_image(text):
80
  try:
81
+ client = wait_for_server(retries=1)
82
  except ServerNotReadyException:
83
  start_server()
84
+ client = wait_for_server(retries=10)
85
 
 
86
  return client.infer(text)
87
 
88