Dmitry Trifonov commited on
Commit
f060249
·
1 Parent(s): 784c8c5

update to use faircompute python library

Browse files
Files changed (4) hide show
  1. app.py +1 -1
  2. fair.py +0 -200
  3. requirements.txt +1 -0
  4. text_to_image.py +85 -0
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import numpy as np
3
 
4
- from fair import text_to_image
5
 
6
  model_id = "runwayml/stable-diffusion-v1-5"
7
  device = "cuda"
 
1
  import gradio as gr
2
  import numpy as np
3
 
4
+ from text_to_image import text_to_image
5
 
6
  model_id = "runwayml/stable-diffusion-v1-5"
7
  device = "cuda"
fair.py DELETED
@@ -1,200 +0,0 @@
1
- import base64
2
- import json
3
- import logging
4
- import os
5
- import time
6
- from io import BytesIO
7
- from typing import List, BinaryIO
8
-
9
- from PIL import Image
10
- from octoai.client import Client as OctoAiClient
11
-
12
- logger = logging.getLogger()
13
-
14
- import requests
15
- import tempfile
16
-
17
- SERVER_ADDRESS = "https://faircompute.com:8000/api/v0"
18
- ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000"
19
- TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c"
20
- # SERVER_ADDRESS = "http://localhost:8000/api/v0"
21
- # ENDPOINT_ADDRESS = "http://localhost:5000"
22
- # TARGET_NODE = None
23
-
24
- DOCKER_IMAGE = "faircompute/diffusion-octo:v1"
25
-
26
-
27
- class FairApiClient:
28
- def __init__(self, server_address: str):
29
- self.server_address = server_address
30
- self.token = None
31
-
32
- def authenticate(self, email: str, password: str):
33
- url = f'{self.server_address}/auth/login'
34
- json_obj = {"email": email, "password": password, "version": "V018"}
35
- resp = requests.post(url, json=json_obj)
36
- self.token = resp.json()["token"]
37
-
38
- def get(self, url, **kwargs):
39
- headers = {
40
- 'Authorization': f'Bearer {self.token}'
41
- }
42
- response = requests.get(url, headers=headers, **kwargs)
43
-
44
- if not response.ok:
45
- raise Exception(f"Error! status: {response.status_code}")
46
-
47
- return response
48
-
49
- def put(self, url, data):
50
- headers = {
51
- 'Content-Type': 'application/json',
52
- 'Authorization': f'Bearer {self.token}'
53
- }
54
- if not isinstance(data, str):
55
- data = json.dumps(data)
56
- response = requests.put(url, headers=headers, data=data)
57
-
58
- if not response.ok and response.status_code != 206:
59
- raise Exception(f"Error! status: {response.status_code}")
60
-
61
- return response
62
-
63
- def put_job(self, image: str, command: List[str], ports: List[tuple[int, int]], input_files, output_files):
64
- url = f"{self.server_address}/jobs"
65
- data = {
66
- 'version': 'V018',
67
- 'container_desc': {
68
- 'version': 'V018',
69
- 'image': image,
70
- 'runtime': 'nvidia',
71
- 'ports': [[{"port": host_port, "ip": 'null'}, {"port": container_port, "protocol": "Tcp"}] for (host_port, container_port) in ports],
72
- 'command': command,
73
- },
74
- 'input_files': input_files,
75
- 'output_files': output_files,
76
- 'target_node': TARGET_NODE,
77
- }
78
- response = self.put(url=url, data=data).json()
79
-
80
- return response['id'], response['pid']
81
-
82
- def get_job_info(self, job_id):
83
- url = f"{self.server_address}/jobs/{job_id}/stat"
84
- response = self.get(url=url).json()
85
- return response
86
-
87
- def get_cluster_summary(self):
88
- url = f"{self.server_address}/nodes/summary"
89
- response = self.get(url=url)
90
-
91
- return response.json()
92
-
93
- def put_job_stream_data(self, job_id, name, data):
94
- url = f"{self.server_address}/jobs/{job_id}/data/streams/{name}"
95
- response = self.put(url=url, data=data)
96
- return response.text
97
-
98
- def put_job_stream_eof(self, job_id, name):
99
- url = f"{self.server_address}/jobs/{job_id}/data/streams/{name}/eof"
100
- response = self.put(url=url, data=None)
101
- return response.text
102
-
103
- def wait_for_file(self, job_id, path, attempts=10) -> BinaryIO:
104
- headers = {
105
- 'Authorization': f'Bearer {self.token}'
106
- }
107
- for i in range(attempts):
108
- url = f"{self.server_address}/jobs/{job_id}/data/files/{path}"
109
- print(f"Waiting for file {path}...")
110
- try:
111
- with requests.get(url=url, headers=headers, stream=True) as r:
112
- r.raise_for_status()
113
- f = tempfile.TemporaryFile()
114
- for chunk in r.iter_content(chunk_size=8192):
115
- f.write(chunk)
116
-
117
- print(f"File {path} ready")
118
- f.seek(0, 0)
119
- return f
120
- except Exception as e:
121
- print(e)
122
- time.sleep(0.5)
123
-
124
- print(f"Failed to receive {path}")
125
-
126
-
127
- class EndpointClient:
128
- def infer(self, prompt):
129
- client = OctoAiClient()
130
-
131
- inputs = {"prompt": {"text": prompt}}
132
- response = client.infer(endpoint_url=f"{ENDPOINT_ADDRESS}/infer", inputs=inputs)
133
-
134
- image_b64 = response["output"]["image_b64"]
135
- image_data = base64.b64decode(image_b64)
136
- image_data = BytesIO(image_data)
137
- image = Image.open(image_data)
138
-
139
- return image
140
-
141
-
142
- class ServerNotReadyException(Exception):
143
- pass
144
-
145
-
146
- def wait_for_server(retries, timeout, delay=1.0):
147
- for i in range(retries):
148
- try:
149
- r = requests.get(ENDPOINT_ADDRESS, timeout=timeout)
150
- r.raise_for_status()
151
- return
152
- except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e:
153
- if i == retries - 1:
154
- raise ServerNotReadyException("Failed to start the server") from e
155
- else:
156
- logger.info("Server is not ready yet")
157
- time.sleep(delay)
158
-
159
-
160
- def start_server():
161
- # default credentials will work only for local server built in debug mode
162
- email = os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr")
163
- password = os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd")
164
- client = FairApiClient(SERVER_ADDRESS)
165
- client.authenticate(email=email, password=password)
166
-
167
- job_id, job_pid = client.put_job(
168
- image=DOCKER_IMAGE,
169
- command=[],
170
- ports=[(5000, 8080)],
171
- input_files=[],
172
- output_files=[])
173
-
174
- logger.info(f"Job id: {job_id}, pid: {job_pid}")
175
-
176
- info = client.get_job_info(job_id=job_id)
177
- logger.info(info)
178
-
179
- while info["status"] != "Processing":
180
- info = client.get_job_info(job_id=job_id)
181
- logger.info(info)
182
- time.sleep(0.5)
183
-
184
- # wait until the server is ready
185
- wait_for_server(retries=10, timeout=1.0)
186
-
187
-
188
- def text_to_image(text):
189
- # try:
190
- # wait_for_server(retries=1, timeout=1.0, delay=0.0)
191
- # except ServerNotReadyException:
192
- # start_server()
193
-
194
- client = EndpointClient()
195
- return client.infer(text)
196
-
197
-
198
- if __name__ == "__main__":
199
- image = text_to_image(text="Robot dinosaur\n")
200
- image.save("result.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  gradio < 4
2
  octoai-sdk
 
 
1
  gradio < 4
2
  octoai-sdk
3
+ faircompute==0.19.0
text_to_image.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import logging
3
+ import os
4
+ import time
5
+ from io import BytesIO
6
+
7
+ from PIL import Image
8
+ from octoai.client import Client as OctoAiClient
9
+ from fair import FairClient
10
+
11
+ logger = logging.getLogger()
12
+
13
+ import requests
14
+
15
+ SERVER_ADDRESS = "https://faircompute.com:8000"
16
+ ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000"
17
+ TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c"
18
+ # SERVER_ADDRESS = "http://localhost:8000"
19
+ # ENDPOINT_ADDRESS = "http://localhost:5000"
20
+ # TARGET_NODE = "ef09913249aa40ecba7d0097f7622855"
21
+
22
+ DOCKER_IMAGE = "faircompute/diffusion-octo:v1"
23
+
24
+
25
+ class EndpointClient:
26
+ def infer(self, prompt):
27
+ client = OctoAiClient()
28
+
29
+ inputs = {"prompt": {"text": prompt}}
30
+ response = client.infer(endpoint_url=f"{ENDPOINT_ADDRESS}/infer", inputs=inputs)
31
+
32
+ image_b64 = response["output"]["image_b64"]
33
+ image_data = base64.b64decode(image_b64)
34
+ image_data = BytesIO(image_data)
35
+ image = Image.open(image_data)
36
+
37
+ return image
38
+
39
+
40
+ class ServerNotReadyException(Exception):
41
+ pass
42
+
43
+
44
+ def wait_for_server(retries, timeout, delay=1.0):
45
+ for i in range(retries):
46
+ try:
47
+ r = requests.get(ENDPOINT_ADDRESS, timeout=timeout)
48
+ r.raise_for_status()
49
+ return
50
+ except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e:
51
+ if i == retries - 1:
52
+ raise ServerNotReadyException("Failed to start the server") from e
53
+ else:
54
+ logger.info("Server is not ready yet")
55
+ time.sleep(delay)
56
+
57
+
58
+ def start_server():
59
+ # default credentials will work only for local server built in debug mode
60
+ client = FairClient(server_address=SERVER_ADDRESS,
61
+ user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"),
62
+ user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd"))
63
+
64
+ client.run(node=TARGET_NODE,
65
+ image=DOCKER_IMAGE,
66
+ ports=[(5000, 8080)],
67
+ detach=True)
68
+
69
+ # wait until the server is ready
70
+ wait_for_server(retries=10, timeout=1.0)
71
+
72
+
73
+ def text_to_image(text):
74
+ try:
75
+ wait_for_server(retries=1, timeout=1.0, delay=0.0)
76
+ except ServerNotReadyException:
77
+ start_server()
78
+
79
+ client = EndpointClient()
80
+ return client.infer(text)
81
+
82
+
83
+ if __name__ == "__main__":
84
+ image = text_to_image(text="Robot dinosaur\n")
85
+ image.save("result.png")