Spaces:
Sleeping
Sleeping
Dmitry Trifonov
commited on
Commit
·
5048eb4
1
Parent(s):
04f1dbd
use new OctoAI architecture based on OctoAI server
Browse files- app.py +5 -6
- fair.py +63 -38
- requirements.txt +2 -1
app.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import gradio as gr
|
2 |
-
from PIL import Image
|
3 |
import numpy as np
|
4 |
|
5 |
from fair import text_to_image
|
@@ -7,12 +6,12 @@ from fair import text_to_image
|
|
7 |
model_id = "runwayml/stable-diffusion-v1-5"
|
8 |
device = "cuda"
|
9 |
|
|
|
10 |
def infer(prompt):
|
11 |
-
|
12 |
-
image = np.array(
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
|
17 |
css = """
|
18 |
.gradio-container {
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import numpy as np
|
3 |
|
4 |
from fair import text_to_image
|
|
|
6 |
model_id = "runwayml/stable-diffusion-v1-5"
|
7 |
device = "cuda"
|
8 |
|
9 |
+
|
10 |
def infer(prompt):
|
11 |
+
image = text_to_image(prompt)
|
12 |
+
image = np.array(image.convert('RGB'))
|
13 |
+
return [image]
|
14 |
+
|
|
|
15 |
|
16 |
css = """
|
17 |
.gradio-container {
|
fair.py
CHANGED
@@ -1,16 +1,27 @@
|
|
|
|
1 |
import json
|
|
|
2 |
import os
|
3 |
import time
|
|
|
4 |
from typing import List, BinaryIO
|
5 |
-
|
|
|
|
|
|
|
6 |
logger = logging.getLogger()
|
7 |
|
8 |
import requests
|
9 |
import tempfile
|
10 |
|
11 |
SERVER_ADDRESS = "https://faircompute.com:8000/api/v1"
|
|
|
|
|
12 |
# SERVER_ADDRESS = "http://localhost:8000/api/v1"
|
13 |
-
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
class FairApiClient:
|
@@ -49,7 +60,7 @@ class FairApiClient:
|
|
49 |
|
50 |
return response
|
51 |
|
52 |
-
def put_job(self, image: str, command: List[str], input_files, output_files):
|
53 |
url = f"{self.server_address}/jobs"
|
54 |
data = {
|
55 |
'type': 'V016',
|
@@ -57,12 +68,12 @@ class FairApiClient:
|
|
57 |
'type': 'V016',
|
58 |
'image': image,
|
59 |
'runtime': 'nvidia',
|
60 |
-
'ports': [],
|
61 |
'command': command,
|
62 |
},
|
63 |
'input_files': input_files,
|
64 |
'output_files': output_files,
|
65 |
-
'target_node':
|
66 |
}
|
67 |
response = self.put(url=url, data=data)
|
68 |
|
@@ -113,7 +124,39 @@ class FairApiClient:
|
|
113 |
print(f"Failed to receive {path}")
|
114 |
|
115 |
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
# default credentials will work only for local server built in debug mode
|
118 |
email = os.getenv('FAIRCOMPUTE_EMAIL', "debug-email")
|
119 |
password = os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd")
|
@@ -123,52 +166,34 @@ def text_to_image(text):
|
|
123 |
job_id = client.put_job(
|
124 |
image=DOCKER_IMAGE,
|
125 |
command=[],
|
|
|
126 |
input_files=[],
|
127 |
-
output_files=[
|
128 |
|
129 |
logger.info(job_id)
|
130 |
|
131 |
status = client.get_job_status(job_id=job_id)
|
132 |
logger.info(status)
|
133 |
|
134 |
-
while status != "Processing"
|
135 |
status = client.get_job_status(job_id=job_id)
|
136 |
logger.info(status)
|
137 |
time.sleep(0.5)
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
name="stdin",
|
142 |
-
data=text + "\n")
|
143 |
-
logger.info(res)
|
144 |
|
145 |
-
res = client.put_job_stream_eof(
|
146 |
-
job_id=job_id,
|
147 |
-
name="stdin")
|
148 |
-
logger.info(res)
|
149 |
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
152 |
|
153 |
-
|
154 |
-
|
155 |
-
job_id=job_id)
|
156 |
-
logger.info(status)
|
157 |
-
time.sleep(0.5)
|
158 |
-
if status == "Completed":
|
159 |
-
logger.info("Done!")
|
160 |
-
else:
|
161 |
-
logger.info("Job Failed")
|
162 |
-
file = client.wait_for_file(
|
163 |
-
job_id=job_id,
|
164 |
-
path="%2Fworkspace%2Fresult.png")
|
165 |
-
return file
|
166 |
|
167 |
|
168 |
-
if __name__=="__main__":
|
169 |
-
|
170 |
-
file = text_to_image(text="Robot dinozaur\n")
|
171 |
-
image = Image.open(file)
|
172 |
image.save("result.png")
|
173 |
-
|
174 |
-
|
|
|
1 |
+
import base64
|
2 |
import json
|
3 |
+
import logging
|
4 |
import os
|
5 |
import time
|
6 |
+
from io import BytesIO
|
7 |
from typing import List, BinaryIO
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
from octoai.client import Client as OctoAiClient
|
11 |
+
|
12 |
logger = logging.getLogger()
|
13 |
|
14 |
import requests
|
15 |
import tempfile
|
16 |
|
17 |
SERVER_ADDRESS = "https://faircompute.com:8000/api/v1"
|
18 |
+
ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000"
|
19 |
+
TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c"
|
20 |
# SERVER_ADDRESS = "http://localhost:8000/api/v1"
|
21 |
+
# ENDPOINT_ADDRESS = "http://localhost:5000"
|
22 |
+
# TARGET_NODE = None
|
23 |
+
|
24 |
+
DOCKER_IMAGE = "faircompute/diffusion-octo:latest"
|
25 |
|
26 |
|
27 |
class FairApiClient:
|
|
|
60 |
|
61 |
return response
|
62 |
|
63 |
+
def put_job(self, image: str, command: List[str], ports: List[tuple[int, int]], input_files, output_files):
|
64 |
url = f"{self.server_address}/jobs"
|
65 |
data = {
|
66 |
'type': 'V016',
|
|
|
68 |
'type': 'V016',
|
69 |
'image': image,
|
70 |
'runtime': 'nvidia',
|
71 |
+
'ports': [[{"port": host_port, "ip": 'null'}, {"port": container_port, "protocol": "Tcp"}] for (host_port, container_port) in ports],
|
72 |
'command': command,
|
73 |
},
|
74 |
'input_files': input_files,
|
75 |
'output_files': output_files,
|
76 |
+
'target_node': TARGET_NODE,
|
77 |
}
|
78 |
response = self.put(url=url, data=data)
|
79 |
|
|
|
124 |
print(f"Failed to receive {path}")
|
125 |
|
126 |
|
127 |
+
class EndpointClient:
|
128 |
+
def infer(self, prompt):
|
129 |
+
client = OctoAiClient()
|
130 |
+
|
131 |
+
inputs = {"prompt": {"text": prompt}}
|
132 |
+
response = client.infer(endpoint_url=f"{ENDPOINT_ADDRESS}/infer", inputs=inputs)
|
133 |
+
|
134 |
+
image_b64 = response["output"]["image_b64"]
|
135 |
+
image_data = base64.b64decode(image_b64)
|
136 |
+
image_data = BytesIO(image_data)
|
137 |
+
image = Image.open(image_data)
|
138 |
+
|
139 |
+
return image
|
140 |
+
|
141 |
+
|
142 |
+
class ServerNotReadyException(Exception):
|
143 |
+
pass
|
144 |
+
|
145 |
+
|
146 |
+
def wait_for_server(retries, timeout):
|
147 |
+
for i in range(retries):
|
148 |
+
try:
|
149 |
+
r = requests.get(ENDPOINT_ADDRESS)
|
150 |
+
r.raise_for_status()
|
151 |
+
return
|
152 |
+
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout):
|
153 |
+
logger.info("Server is not ready yet")
|
154 |
+
time.sleep(timeout)
|
155 |
+
else:
|
156 |
+
raise ServerNotReadyException("Failed to start the server")
|
157 |
+
|
158 |
+
|
159 |
+
def start_server():
|
160 |
# default credentials will work only for local server built in debug mode
|
161 |
email = os.getenv('FAIRCOMPUTE_EMAIL', "debug-email")
|
162 |
password = os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd")
|
|
|
166 |
job_id = client.put_job(
|
167 |
image=DOCKER_IMAGE,
|
168 |
command=[],
|
169 |
+
ports=[(5000, 8080)],
|
170 |
input_files=[],
|
171 |
+
output_files=[])
|
172 |
|
173 |
logger.info(job_id)
|
174 |
|
175 |
status = client.get_job_status(job_id=job_id)
|
176 |
logger.info(status)
|
177 |
|
178 |
+
while status != "Processing":
|
179 |
status = client.get_job_status(job_id=job_id)
|
180 |
logger.info(status)
|
181 |
time.sleep(0.5)
|
182 |
|
183 |
+
# wait until the server is ready
|
184 |
+
wait_for_server(retries=10, timeout=1.0)
|
|
|
|
|
|
|
185 |
|
|
|
|
|
|
|
|
|
186 |
|
187 |
+
def text_to_image(text):
|
188 |
+
try:
|
189 |
+
wait_for_server(retries=1, timeout=0.0)
|
190 |
+
except ServerNotReadyException:
|
191 |
+
start_server()
|
192 |
|
193 |
+
client = EndpointClient()
|
194 |
+
return client.infer(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
|
197 |
+
if __name__ == "__main__":
|
198 |
+
image = text_to_image(text="Robot dinozaur\n")
|
|
|
|
|
199 |
image.save("result.png")
|
|
|
|
requirements.txt
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
gradio < 4
|
|
|
|
1 |
+
gradio < 4
|
2 |
+
octoai-sdk
|