Dmitry Trifonov commited on
Commit
5048eb4
·
1 Parent(s): 04f1dbd

use new OctoAI architecture based on OctoAI server

Browse files
Files changed (3) hide show
  1. app.py +5 -6
  2. fair.py +63 -38
  3. requirements.txt +2 -1
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
- from PIL import Image
3
  import numpy as np
4
 
5
  from fair import text_to_image
@@ -7,12 +6,12 @@ from fair import text_to_image
7
  model_id = "runwayml/stable-diffusion-v1-5"
8
  device = "cuda"
9
 
 
10
  def infer(prompt):
11
- image_file = text_to_image(prompt)
12
- image = np.array(Image.open(image_file).convert('RGB'))
13
- images = [image]
14
- return images
15
-
16
 
17
  css = """
18
  .gradio-container {
 
1
  import gradio as gr
 
2
  import numpy as np
3
 
4
  from fair import text_to_image
 
6
  model_id = "runwayml/stable-diffusion-v1-5"
7
  device = "cuda"
8
 
9
+
10
  def infer(prompt):
11
+ image = text_to_image(prompt)
12
+ image = np.array(image.convert('RGB'))
13
+ return [image]
14
+
 
15
 
16
  css = """
17
  .gradio-container {
fair.py CHANGED
@@ -1,16 +1,27 @@
 
1
  import json
 
2
  import os
3
  import time
 
4
  from typing import List, BinaryIO
5
- import logging
 
 
 
6
  logger = logging.getLogger()
7
 
8
  import requests
9
  import tempfile
10
 
11
  SERVER_ADDRESS = "https://faircompute.com:8000/api/v1"
 
 
12
  # SERVER_ADDRESS = "http://localhost:8000/api/v1"
13
- DOCKER_IMAGE = "faircompute/stable-diffusion:pytorch-1.13.1-cu116"
 
 
 
14
 
15
 
16
  class FairApiClient:
@@ -49,7 +60,7 @@ class FairApiClient:
49
 
50
  return response
51
 
52
- def put_job(self, image: str, command: List[str], input_files, output_files):
53
  url = f"{self.server_address}/jobs"
54
  data = {
55
  'type': 'V016',
@@ -57,12 +68,12 @@ class FairApiClient:
57
  'type': 'V016',
58
  'image': image,
59
  'runtime': 'nvidia',
60
- 'ports': [],
61
  'command': command,
62
  },
63
  'input_files': input_files,
64
  'output_files': output_files,
65
- 'target_node': None,
66
  }
67
  response = self.put(url=url, data=data)
68
 
@@ -113,7 +124,39 @@ class FairApiClient:
113
  print(f"Failed to receive {path}")
114
 
115
 
116
- def text_to_image(text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  # default credentials will work only for local server built in debug mode
118
  email = os.getenv('FAIRCOMPUTE_EMAIL', "debug-email")
119
  password = os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd")
@@ -123,52 +166,34 @@ def text_to_image(text):
123
  job_id = client.put_job(
124
  image=DOCKER_IMAGE,
125
  command=[],
 
126
  input_files=[],
127
- output_files=["/workspace/result.png"])
128
 
129
  logger.info(job_id)
130
 
131
  status = client.get_job_status(job_id=job_id)
132
  logger.info(status)
133
 
134
- while status != "Processing" and status != "Completed":
135
  status = client.get_job_status(job_id=job_id)
136
  logger.info(status)
137
  time.sleep(0.5)
138
 
139
- res = client.put_job_stream_data(
140
- job_id=job_id,
141
- name="stdin",
142
- data=text + "\n")
143
- logger.info(res)
144
 
145
- res = client.put_job_stream_eof(
146
- job_id=job_id,
147
- name="stdin")
148
- logger.info(res)
149
 
150
- status = client.get_job_status(job_id=job_id)
151
- logger.info(status)
 
 
 
152
 
153
- while status == "Processing":
154
- status = client.get_job_status(
155
- job_id=job_id)
156
- logger.info(status)
157
- time.sleep(0.5)
158
- if status == "Completed":
159
- logger.info("Done!")
160
- else:
161
- logger.info("Job Failed")
162
- file = client.wait_for_file(
163
- job_id=job_id,
164
- path="%2Fworkspace%2Fresult.png")
165
- return file
166
 
167
 
168
- if __name__=="__main__":
169
- from PIL import Image
170
- file = text_to_image(text="Robot dinozaur\n")
171
- image = Image.open(file)
172
  image.save("result.png")
173
-
174
-
 
1
+ import base64
2
  import json
3
+ import logging
4
  import os
5
  import time
6
+ from io import BytesIO
7
  from typing import List, BinaryIO
8
+
9
+ from PIL import Image
10
+ from octoai.client import Client as OctoAiClient
11
+
12
  logger = logging.getLogger()
13
 
14
  import requests
15
  import tempfile
16
 
17
  SERVER_ADDRESS = "https://faircompute.com:8000/api/v1"
18
+ ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000"
19
+ TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c"
20
  # SERVER_ADDRESS = "http://localhost:8000/api/v1"
21
+ # ENDPOINT_ADDRESS = "http://localhost:5000"
22
+ # TARGET_NODE = None
23
+
24
+ DOCKER_IMAGE = "faircompute/diffusion-octo:latest"
25
 
26
 
27
  class FairApiClient:
 
60
 
61
  return response
62
 
63
+ def put_job(self, image: str, command: List[str], ports: List[tuple[int, int]], input_files, output_files):
64
  url = f"{self.server_address}/jobs"
65
  data = {
66
  'type': 'V016',
 
68
  'type': 'V016',
69
  'image': image,
70
  'runtime': 'nvidia',
71
+ 'ports': [[{"port": host_port, "ip": 'null'}, {"port": container_port, "protocol": "Tcp"}] for (host_port, container_port) in ports],
72
  'command': command,
73
  },
74
  'input_files': input_files,
75
  'output_files': output_files,
76
+ 'target_node': TARGET_NODE,
77
  }
78
  response = self.put(url=url, data=data)
79
 
 
124
  print(f"Failed to receive {path}")
125
 
126
 
127
+ class EndpointClient:
128
+ def infer(self, prompt):
129
+ client = OctoAiClient()
130
+
131
+ inputs = {"prompt": {"text": prompt}}
132
+ response = client.infer(endpoint_url=f"{ENDPOINT_ADDRESS}/infer", inputs=inputs)
133
+
134
+ image_b64 = response["output"]["image_b64"]
135
+ image_data = base64.b64decode(image_b64)
136
+ image_data = BytesIO(image_data)
137
+ image = Image.open(image_data)
138
+
139
+ return image
140
+
141
+
142
+ class ServerNotReadyException(Exception):
143
+ pass
144
+
145
+
146
+ def wait_for_server(retries, timeout):
147
+ for i in range(retries):
148
+ try:
149
+ r = requests.get(ENDPOINT_ADDRESS)
150
+ r.raise_for_status()
151
+ return
152
+ except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout):
153
+ logger.info("Server is not ready yet")
154
+ time.sleep(timeout)
155
+ else:
156
+ raise ServerNotReadyException("Failed to start the server")
157
+
158
+
159
+ def start_server():
160
  # default credentials will work only for local server built in debug mode
161
  email = os.getenv('FAIRCOMPUTE_EMAIL', "debug-email")
162
  password = os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd")
 
166
  job_id = client.put_job(
167
  image=DOCKER_IMAGE,
168
  command=[],
169
+ ports=[(5000, 8080)],
170
  input_files=[],
171
+ output_files=[])
172
 
173
  logger.info(job_id)
174
 
175
  status = client.get_job_status(job_id=job_id)
176
  logger.info(status)
177
 
178
+ while status != "Processing":
179
  status = client.get_job_status(job_id=job_id)
180
  logger.info(status)
181
  time.sleep(0.5)
182
 
183
+ # wait until the server is ready
184
+ wait_for_server(retries=10, timeout=1.0)
 
 
 
185
 
 
 
 
 
186
 
187
+ def text_to_image(text):
188
+ try:
189
+ wait_for_server(retries=1, timeout=0.0)
190
+ except ServerNotReadyException:
191
+ start_server()
192
 
193
+ client = EndpointClient()
194
+ return client.infer(text)
 
 
 
 
 
 
 
 
 
 
 
195
 
196
 
197
+ if __name__ == "__main__":
198
+ image = text_to_image(text="Robot dinozaur\n")
 
 
199
  image.save("result.png")
 
 
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- gradio < 4
 
 
1
+ gradio < 4
2
+ octoai-sdk