Dmitry Trifonov commited on
Commit
aeb505f
·
1 Parent(s): 5d69e43

add tunneling support

Browse files
Files changed (1) hide show
  1. text_to_image.py +94 -27
text_to_image.py CHANGED
@@ -1,6 +1,8 @@
1
  import base64
2
  import logging
3
  import os
 
 
4
  import requests
5
  import time
6
  from io import BytesIO
@@ -11,18 +13,23 @@ from fair import FairClient
11
  logger = logging.getLogger()
12
 
13
  SERVER_ADDRESS = "https://faircompute.com:8000"
14
- ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000"
15
- TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c"
16
  # SERVER_ADDRESS = "http://localhost:8000"
17
- # ENDPOINT_ADDRESS = "http://localhost:5000"
18
- # TARGET_NODE = "ef09913249aa40ecba7d0097f7622855"
 
 
 
19
 
20
- DOCKER_IMAGE = "faircompute/diffusers-api-dreamshaper-8"
 
21
 
22
 
23
  class EndpointClient:
24
- def __init__(self, timeout):
25
- response = requests.get(os.path.join(ENDPOINT_ADDRESS, 'healthcheck'), timeout=timeout).json()
 
26
  if response['state'] != 'healthy':
27
  raise Exception("Server is not healthy")
28
 
@@ -43,7 +50,7 @@ class EndpointClient:
43
  },
44
  }
45
 
46
- response = requests.post(ENDPOINT_ADDRESS, json=inputs).json()
47
  image_data = BytesIO(base64.b64decode(response["image_base64"]))
48
  image = Image.open(image_data)
49
 
@@ -54,10 +61,17 @@ class ServerNotReadyException(Exception):
54
  pass
55
 
56
 
57
- def wait_for_server(retries, timeout=1.0, delay=2.0):
 
 
 
 
 
 
 
58
  for i in range(retries):
59
  try:
60
- return EndpointClient(timeout=timeout)
61
  except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e:
62
  logging.exception(e)
63
  time.sleep(delay)
@@ -65,27 +79,80 @@ def wait_for_server(retries, timeout=1.0, delay=2.0):
65
  raise ServerNotReadyException("Failed to start the server")
66
 
67
 
68
- def start_server():
69
- # default credentials will work only for local server built in debug mode
70
- client = FairClient(server_address=SERVER_ADDRESS,
71
- user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"),
72
- user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd"))
73
-
74
- client.run(node=TARGET_NODE,
75
- image=DOCKER_IMAGE,
76
- runtime="nvidia",
77
- ports=[(5000, 8000)],
78
- detach=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
 
81
  def text_to_image(text):
82
- try:
83
- client = wait_for_server(retries=1)
84
- except ServerNotReadyException:
85
- start_server()
86
- client = wait_for_server(retries=10)
87
 
88
- return client.infer(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
 
91
  if __name__ == "__main__":
 
1
  import base64
2
  import logging
3
  import os
4
+ import hashlib
5
+
6
  import requests
7
  import time
8
  from io import BytesIO
 
13
  logger = logging.getLogger()
14
 
15
  SERVER_ADDRESS = "https://faircompute.com:8000"
16
+ INFERENCE_NODE = "119eccba-2388-43c1-bdb9-02133049604c"
17
+ TUNNEL_NODE = "c312e6c4788b00c73c287ab0445d3655"
18
  # SERVER_ADDRESS = "http://localhost:8000"
19
+ # INFERENCE_NODE = "ef09913249aa40ecba7d0097f7622855"
20
+ # TUNNEL_NODE = "c312e6c4788b00c73c287ab0445d3655"
21
+
22
+ INFERENCE_DOCKER_IMAGE = "faircompute/diffusers-api-dreamshaper-8"
23
+ TUNNEL_DOCKER_IMAGE = "rapiz1/rathole"
24
 
25
+ endpoint_client = None
26
+ fair_client = None
27
 
28
 
29
  class EndpointClient:
30
+ def __init__(self, server_address, timeout):
31
+ self.endpoint_address = f'http://{server_address}:5000'
32
+ response = requests.get(os.path.join(self.endpoint_address, 'healthcheck'), timeout=timeout).json()
33
  if response['state'] != 'healthy':
34
  raise Exception("Server is not healthy")
35
 
 
50
  },
51
  }
52
 
53
+ response = requests.post(self.endpoint_address, json=inputs).json()
54
  image_data = BytesIO(base64.b64decode(response["image_base64"]))
55
  image = Image.open(image_data)
56
 
 
61
  pass
62
 
63
 
64
+ def create_fair_client():
65
+ return FairClient(server_address=SERVER_ADDRESS,
66
+ user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"),
67
+ user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd"))
68
+
69
+
70
+ def create_endpoint_client(fc, retries, timeout=1.0, delay=2.0):
71
+ server_address = next(info['host_address'] for info in fc.get_nodes() if info['node_id'] == TUNNEL_NODE)
72
  for i in range(retries):
73
  try:
74
+ return EndpointClient(server_address, timeout=timeout)
75
  except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e:
76
  logging.exception(e)
77
  time.sleep(delay)
 
79
  raise ServerNotReadyException("Failed to start the server")
80
 
81
 
82
+ def start_tunnel(fc: FairClient):
83
+ # generate fixed random authentication token based off some secret
84
+ token = hashlib.sha256(os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd").encode()).hexdigest()
85
+
86
+ # start tunnel node
87
+ server_config = f"""
88
+ [server]
89
+ bind_addr = "0.0.0.0:2333" # port that rathole listens for clients
90
+
91
+ [server.services.inference_server]
92
+ token = "{token}" # token that is used to authenticate the client for the service
93
+ bind_addr = "0.0.0.0:5000" # port that exposes service to the Internet
94
+ """
95
+ with open('server.toml', 'w') as file:
96
+ file.write(server_config)
97
+ fc.run(node=TUNNEL_NODE,
98
+ image=TUNNEL_DOCKER_IMAGE,
99
+ command=["--server", "/app/config.toml"],
100
+ volumes=[("./server.toml", "/app/config.toml")],
101
+ network="host",
102
+ detach=True)
103
+
104
+ server_address = next(info['host_address'] for info in fc.get_nodes() if info['node_id'] == TUNNEL_NODE)
105
+ client_config = f"""
106
+ [client]
107
+ remote_addr = "{server_address}:2333" # address of the rathole server
108
+
109
+ [client.services.inference_server]
110
+ token = "{token}" # token that is used to authenticate the client for the service
111
+ local_addr = "127.0.0.1:5001" # address of the service that needs to be forwarded
112
+ """
113
+ with open('client.toml', 'w') as file:
114
+ file.write(client_config)
115
+ fc.run(node=INFERENCE_NODE,
116
+ image=TUNNEL_DOCKER_IMAGE,
117
+ command=["--client", "/app/config.toml"],
118
+ volumes=[("./client.toml", "/app/config.toml")],
119
+ network="host",
120
+ detach=True)
121
+
122
+
123
+ def start_inference_server(fc: FairClient):
124
+ fc.run(node=INFERENCE_NODE,
125
+ image=INFERENCE_DOCKER_IMAGE,
126
+ runtime="nvidia",
127
+ ports=[(5001, 8000)],
128
+ detach=True)
129
+
130
+
131
+ def start_services(fc):
132
+ start_tunnel(fc)
133
+ start_inference_server(fc)
134
 
135
 
136
  def text_to_image(text):
137
+ global endpoint_client
138
+ global fair_client
139
+ if fair_client is None:
140
+ fair_client = create_fair_client()
 
141
 
142
+ try:
143
+ if endpoint_client is None: # try connecting to the server
144
+ endpoint_client = create_endpoint_client(fair_client, 1)
145
+ else: # try inference
146
+ return endpoint_client.infer(text)
147
+ except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout):
148
+ endpoint_client = None
149
+
150
+ # start all services
151
+ if endpoint_client is None:
152
+ start_services(fair_client)
153
+ endpoint_client = create_endpoint_client(fair_client, retries=10)
154
+
155
+ return endpoint_client.infer(text)
156
 
157
 
158
  if __name__ == "__main__":