Dmitry Trifonov commited on
Commit
b6fedf9
·
1 Parent(s): 0075cb7

use just endpoint in text to image demo

Browse files
Files changed (1) hide show
  1. text_to_image.py +23 -145
text_to_image.py CHANGED
@@ -1,158 +1,36 @@
1
  import base64
2
- import logging
3
- import os
4
- import hashlib
5
-
6
- import requests
7
- import time
8
  from io import BytesIO
9
 
 
10
  from PIL import Image
11
- from fair import FairClient
12
-
13
- logger = logging.getLogger()
14
-
15
- SERVER_ADDRESS = "https://faircompute.com:8000"
16
- INFERENCE_NODE = "magnus"
17
- TUNNEL_NODE = "gcs-e2-micro"
18
- # SERVER_ADDRESS = "http://localhost:8000"
19
- # INFERENCE_NODE = "ef09913249aa40ecba7d0097f7622855"
20
- # TUNNEL_NODE = "c312e6c4788b00c73c287ab0445d3655"
21
-
22
- INFERENCE_DOCKER_IMAGE = "faircompute/diffusers-api-dreamshaper-8"
23
- TUNNEL_DOCKER_IMAGE = "rapiz1/rathole"
24
-
25
- endpoint_client = None
26
- fair_client = None
27
-
28
-
29
- class EndpointClient:
30
- def __init__(self, server_address, timeout):
31
- self.endpoint_address = f'http://{server_address}:5000'
32
- response = requests.get(os.path.join(self.endpoint_address, 'healthcheck'), timeout=timeout).json()
33
- if response['state'] != 'healthy':
34
- raise Exception("Server is not healthy")
35
-
36
- def infer(self, prompt):
37
- inputs = {
38
- "modelInputs": {
39
- "prompt": prompt,
40
- "num_inference_steps": 25,
41
- "width": 512,
42
- "height": 512,
43
- },
44
- "callInputs": {
45
- "MODEL_ID": "lykon/dreamshaper-8",
46
- "PIPELINE": "AutoPipelineForText2Image",
47
- "SCHEDULER": "DEISMultistepScheduler",
48
- "PRECISION": "fp16",
49
- "REVISION": "fp16",
50
- },
51
- }
52
-
53
- response = requests.post(self.endpoint_address, json=inputs).json()
54
- image_data = BytesIO(base64.b64decode(response["image_base64"]))
55
- image = Image.open(image_data)
56
-
57
- return image
58
-
59
-
60
- class ServerNotReadyException(Exception):
61
- pass
62
-
63
-
64
- def create_fair_client():
65
- return FairClient(server_address=SERVER_ADDRESS,
66
- user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"),
67
- user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd"))
68
-
69
-
70
- def create_endpoint_client(fc, retries, timeout=1.0, delay=2.0):
71
- nodes = fc.cluster().nodes.list()
72
- server_address = next(info['host_address'] for info in nodes if info['name'] == TUNNEL_NODE)
73
- for i in range(retries):
74
- try:
75
- return EndpointClient(server_address, timeout=timeout)
76
- except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e:
77
- logging.exception(e)
78
- time.sleep(delay)
79
-
80
- raise ServerNotReadyException("Failed to start the server")
81
-
82
-
83
- def start_tunnel(fc: FairClient):
84
- # generate fixed random authentication token based off some secret
85
- token = hashlib.sha256(os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd").encode()).hexdigest()
86
-
87
- # start tunnel node
88
- server_config = f"""
89
- [server]
90
- bind_addr = "0.0.0.0:2333" # port that rathole listens for clients
91
-
92
- [server.services.inference_server]
93
- token = "{token}" # token that is used to authenticate the client for the service
94
- bind_addr = "0.0.0.0:5000" # port that exposes service to the Internet
95
- """
96
- with open('server.toml', 'w') as file:
97
- file.write(server_config)
98
- fc.run(node_name=TUNNEL_NODE,
99
- image=TUNNEL_DOCKER_IMAGE,
100
- command=["--server", "/app/config.toml"],
101
- volumes=[("./server.toml", "/app/config.toml")],
102
- network="host",
103
- detach=True)
104
-
105
- nodes = fc.cluster().nodes.list()
106
- server_address = next(info['host_address'] for info in nodes if info['name'] == TUNNEL_NODE)
107
- client_config = f"""
108
- [client]
109
- remote_addr = "{server_address}:2333" # address of the rathole server
110
-
111
- [client.services.inference_server]
112
- token = "{token}" # token that is used to authenticate the client for the service
113
- local_addr = "127.0.0.1:5001" # address of the service that needs to be forwarded
114
- """
115
- with open('client.toml', 'w') as file:
116
- file.write(client_config)
117
- fc.run(node_name=INFERENCE_NODE,
118
- image=TUNNEL_DOCKER_IMAGE,
119
- command=["--client", "/app/config.toml"],
120
- volumes=[("./client.toml", "/app/config.toml")],
121
- network="host",
122
- detach=True)
123
-
124
 
125
- def start_inference_server(fc: FairClient):
126
- fc.run(node_name=INFERENCE_NODE,
127
- image=INFERENCE_DOCKER_IMAGE,
128
- runtime="nvidia",
129
- ports=[(5001, 8000)],
130
- detach=True)
131
 
132
 
133
- def text_to_image(text):
134
- global endpoint_client
135
- global fair_client
136
- if fair_client is None:
137
- fair_client = create_fair_client()
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- try:
140
- # client is configured, try to do inference right away
141
- if endpoint_client is not None:
142
- return endpoint_client.infer(text)
143
- # client is not configured, try connecting to the inference server, maybe it is running
144
- else:
145
- endpoint_client = create_endpoint_client(fair_client, 1)
146
- except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout, ServerNotReadyException):
147
- # inference server is not ready, start inference server and open the tunnel
148
- start_inference_server(fair_client)
149
- start_tunnel(fair_client)
150
- endpoint_client = create_endpoint_client(fair_client, retries=10)
151
 
152
- # run inference
153
- return endpoint_client.infer(text)
154
 
155
 
156
  if __name__ == "__main__":
157
- image = text_to_image(text="Robot dinosaur\n")
158
  image.save("result.png")
 
1
  import base64
 
 
 
 
 
 
2
  from io import BytesIO
3
 
4
+ import requests
5
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ ENDPOINT_ADDRESS = "http://35.233.231.20:5000"
 
 
 
 
 
8
 
9
 
10
+ def text_to_image(prompt):
11
+ inputs = {
12
+ "modelInputs": {
13
+ "prompt": prompt,
14
+ "num_inference_steps": 25,
15
+ "width": 512,
16
+ "height": 512,
17
+ },
18
+ "callInputs": {
19
+ "MODEL_ID": "lykon/dreamshaper-8",
20
+ "PIPELINE": "AutoPipelineForText2Image",
21
+ "SCHEDULER": "DEISMultistepScheduler",
22
+ "PRECISION": "fp16",
23
+ "REVISION": "fp16",
24
+ },
25
+ }
26
 
27
+ response = requests.post(ENDPOINT_ADDRESS, json=inputs).json()
28
+ image_data = BytesIO(base64.b64decode(response["image_base64"]))
29
+ image = Image.open(image_data)
 
 
 
 
 
 
 
 
 
30
 
31
+ return image
 
32
 
33
 
34
  if __name__ == "__main__":
35
+ image = text_to_image(prompt="Robot dinosaur")
36
  image.save("result.png")