|
import base64 |
|
import logging |
|
import os |
|
import hashlib |
|
|
|
import requests |
|
import time |
|
from io import BytesIO |
|
|
|
from PIL import Image |
|
from fair import FairClient |
|
|
|
logger = logging.getLogger() |
|
|
|
|
|
|
|
|
|
SERVER_ADDRESS = "https://54.91.82.249:5000" |
|
INFERENCE_NODE = "Nikhil-Macbook" |
|
TUNNEL_NODE = "nikhil-tunneling-node" |
|
|
|
INFERENCE_DOCKER_IMAGE = "faircompute/diffusers-api-dreamshaper-8" |
|
TUNNEL_DOCKER_IMAGE = "rapiz1/rathole" |
|
|
|
endpoint_client = None |
|
fair_client = None |
|
|
|
class EndpointClient: |
|
def __init__(self, server_address, timeout): |
|
self.endpoint_address = f'http://{server_address}:5000' |
|
response = requests.get(os.path.join(self.endpoint_address, 'healthcheck'), timeout=timeout).json() |
|
if response['state'] != 'healthy': |
|
raise Exception("Server is not healthy") |
|
|
|
def infer(self, prompt): |
|
inputs = { |
|
"modelInputs": { |
|
"prompt": prompt, |
|
"num_inference_steps": 25, |
|
"width": 512, |
|
"height": 512, |
|
}, |
|
"callInputs": { |
|
"MODEL_ID": "lykon/dreamshaper-8", |
|
"PIPELINE": "AutoPipelineForText2Image", |
|
"SCHEDULER": "DEISMultistepScheduler", |
|
"PRECISION": "fp16", |
|
"REVISION": "fp16", |
|
}, |
|
} |
|
|
|
response = requests.post(self.endpoint_address, json=inputs).json() |
|
image_data = BytesIO(base64.b64decode(response["image_base64"])) |
|
image = Image.open(image_data) |
|
|
|
return image |
|
|
|
|
|
class ServerNotReadyException(Exception): |
|
pass |
|
|
|
|
|
def create_fair_client(): |
|
return FairClient(server_address=SERVER_ADDRESS, |
|
user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"), |
|
user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd")) |
|
|
|
|
|
def create_endpoint_client(fc, retries, timeout=1.0, delay=2.0): |
|
nodes = fc.cluster().nodes.list() |
|
server_address = next(info['host_address'] for info in nodes if info['name'] == TUNNEL_NODE) |
|
for i in range(retries): |
|
try: |
|
return EndpointClient(server_address, timeout=timeout) |
|
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e: |
|
logging.exception(e) |
|
time.sleep(delay) |
|
|
|
raise ServerNotReadyException("Failed to start the server") |
|
|
|
|
|
def start_tunnel(fc: FairClient): |
|
|
|
token = hashlib.sha256(os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd").encode()).hexdigest() |
|
|
|
|
|
server_config = f""" |
|
[server] |
|
bind_addr = "0.0.0.0:2333" # port that rathole listens for clients |
|
|
|
[server.services.inference_server] |
|
token = "{token}" # token that is used to authenticate the client for the service |
|
bind_addr = "0.0.0.0:5000" # port that exposes service to the Internet |
|
""" |
|
with open('server.toml', 'w') as file: |
|
file.write(server_config) |
|
fc.run(node_name=TUNNEL_NODE, |
|
image=TUNNEL_DOCKER_IMAGE, |
|
command=["--server", "/app/config.toml"], |
|
volumes=[("./server.toml", "/app/config.toml")], |
|
network="host", |
|
detach=True) |
|
|
|
nodes = fc.cluster().nodes.list() |
|
server_address = next(info['host_address'] for info in nodes if info['name'] == TUNNEL_NODE) |
|
client_config = f""" |
|
[client] |
|
remote_addr = "{server_address}:2333" # address of the rathole server |
|
|
|
[client.services.inference_server] |
|
token = "{token}" # token that is used to authenticate the client for the service |
|
local_addr = "127.0.0.1:5001" # address of the service that needs to be forwarded |
|
""" |
|
with open('client.toml', 'w') as file: |
|
file.write(client_config) |
|
fc.run(node_name=INFERENCE_NODE, |
|
image=TUNNEL_DOCKER_IMAGE, |
|
command=["--client", "/app/config.toml"], |
|
volumes=[("./client.toml", "/app/config.toml")], |
|
network="host", |
|
detach=True) |
|
|
|
|
|
def start_inference_server(fc: FairClient): |
|
fc.run(node_name=INFERENCE_NODE, |
|
image=INFERENCE_DOCKER_IMAGE, |
|
runtime="nvidia", |
|
ports=[(5001, 8000)], |
|
detach=True) |
|
|
|
|
|
def text_to_image(text): |
|
global endpoint_client |
|
global fair_client |
|
if fair_client is None: |
|
fair_client = create_fair_client() |
|
|
|
try: |
|
|
|
if endpoint_client is not None: |
|
return endpoint_client.infer(text) |
|
|
|
else: |
|
endpoint_client = create_endpoint_client(fair_client, 1) |
|
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout, ServerNotReadyException): |
|
|
|
start_inference_server(fair_client) |
|
start_tunnel(fair_client) |
|
endpoint_client = create_endpoint_client(fair_client, retries=10) |
|
|
|
|
|
return endpoint_client.infer(text) |
|
|
|
|
|
if __name__ == "__main__": |
|
image = text_to_image(text="Robot dinosaur\n") |
|
image.save("result.png") |
|
|