File size: 5,389 Bytes
d42dfa2 bb1aed1 cafb8d0 8a8728a d42dfa2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import base64
import logging
import os
import hashlib
import requests
import time
from io import BytesIO
from PIL import Image
from fair import FairClient
logger = logging.getLogger()
# SERVER_ADDRESS = "https://faircompute.com:8000"
# INFERENCE_NODE = "magnus"
# TUNNEL_NODE = "gcs-e2-micro"
SERVER_ADDRESS = "https://54.91.82.249:5000"
INFERENCE_NODE = "Nikhil-Macbook"
TUNNEL_NODE = "nikhil-tunneling-node"
INFERENCE_DOCKER_IMAGE = "faircompute/diffusers-api-dreamshaper-8"
TUNNEL_DOCKER_IMAGE = "rapiz1/rathole"
endpoint_client = None
fair_client = None
class EndpointClient:
def __init__(self, server_address, timeout):
self.endpoint_address = f'http://{server_address}:5000'
response = requests.get(os.path.join(self.endpoint_address, 'healthcheck'), timeout=timeout).json()
if response['state'] != 'healthy':
raise Exception("Server is not healthy")
def infer(self, prompt):
inputs = {
"modelInputs": {
"prompt": prompt,
"num_inference_steps": 25,
"width": 512,
"height": 512,
},
"callInputs": {
"MODEL_ID": "lykon/dreamshaper-8",
"PIPELINE": "AutoPipelineForText2Image",
"SCHEDULER": "DEISMultistepScheduler",
"PRECISION": "fp16",
"REVISION": "fp16",
},
}
response = requests.post(self.endpoint_address, json=inputs).json()
image_data = BytesIO(base64.b64decode(response["image_base64"]))
image = Image.open(image_data)
return image
class ServerNotReadyException(Exception):
pass
def create_fair_client():
return FairClient(server_address=SERVER_ADDRESS,
user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"),
user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd"))
def create_endpoint_client(fc, retries, timeout=1.0, delay=2.0):
nodes = fc.cluster().nodes.list()
server_address = next(info['host_address'] for info in nodes if info['name'] == TUNNEL_NODE)
for i in range(retries):
try:
return EndpointClient(server_address, timeout=timeout)
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e:
logging.exception(e)
time.sleep(delay)
raise ServerNotReadyException("Failed to start the server")
def start_tunnel(fc: FairClient):
# generate fixed random authentication token based off some secret
token = hashlib.sha256(os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd").encode()).hexdigest()
# start tunnel node
server_config = f"""
[server]
bind_addr = "0.0.0.0:2333" # port that rathole listens for clients
[server.services.inference_server]
token = "{token}" # token that is used to authenticate the client for the service
bind_addr = "0.0.0.0:5000" # port that exposes service to the Internet
"""
with open('server.toml', 'w') as file:
file.write(server_config)
fc.run(node_name=TUNNEL_NODE,
image=TUNNEL_DOCKER_IMAGE,
command=["--server", "/app/config.toml"],
volumes=[("./server.toml", "/app/config.toml")],
network="host",
detach=True)
nodes = fc.cluster().nodes.list()
server_address = next(info['host_address'] for info in nodes if info['name'] == TUNNEL_NODE)
client_config = f"""
[client]
remote_addr = "{server_address}:2333" # address of the rathole server
[client.services.inference_server]
token = "{token}" # token that is used to authenticate the client for the service
local_addr = "127.0.0.1:5001" # address of the service that needs to be forwarded
"""
with open('client.toml', 'w') as file:
file.write(client_config)
fc.run(node_name=INFERENCE_NODE,
image=TUNNEL_DOCKER_IMAGE,
command=["--client", "/app/config.toml"],
volumes=[("./client.toml", "/app/config.toml")],
network="host",
detach=True)
def start_inference_server(fc: FairClient):
fc.run(node_name=INFERENCE_NODE,
image=INFERENCE_DOCKER_IMAGE,
runtime="nvidia",
ports=[(5001, 8000)],
detach=True)
def text_to_image(text):
global endpoint_client
global fair_client
if fair_client is None:
fair_client = create_fair_client()
try:
# client is configured, try to do inference right away
if endpoint_client is not None:
return endpoint_client.infer(text)
# client is not configured, try connecting to the inference server, maybe it is running
else:
endpoint_client = create_endpoint_client(fair_client, 1)
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout, ServerNotReadyException):
# inference server is not ready, start inference server and open the tunnel
start_inference_server(fair_client)
start_tunnel(fair_client)
endpoint_client = create_endpoint_client(fair_client, retries=10)
# run inference
return endpoint_client.infer(text)
if __name__ == "__main__":
image = text_to_image(text="Robot dinosaur\n")
image.save("result.png")
|