Spaces:
Sleeping
Sleeping
import base64 | |
import logging | |
import os | |
import hashlib | |
import requests | |
import time | |
from io import BytesIO | |
from PIL import Image | |
from fair import FairClient | |
logger = logging.getLogger() | |
SERVER_ADDRESS = "https://faircompute.com:8000" | |
INFERENCE_NODE = "magnus" | |
TUNNEL_NODE = "gcs-e2-micro" | |
# SERVER_ADDRESS = "http://localhost:8000" | |
# INFERENCE_NODE = "ef09913249aa40ecba7d0097f7622855" | |
# TUNNEL_NODE = "c312e6c4788b00c73c287ab0445d3655" | |
INFERENCE_DOCKER_IMAGE = "faircompute/diffusers-api-dreamshaper-8" | |
TUNNEL_DOCKER_IMAGE = "rapiz1/rathole" | |
endpoint_client = None | |
fair_client = None | |
class EndpointClient: | |
def __init__(self, server_address, timeout): | |
self.endpoint_address = f'http://{server_address}:5000' | |
response = requests.get(os.path.join(self.endpoint_address, 'healthcheck'), timeout=timeout).json() | |
if response['state'] != 'healthy': | |
raise Exception("Server is not healthy") | |
def infer(self, prompt): | |
inputs = { | |
"modelInputs": { | |
"prompt": prompt, | |
"num_inference_steps": 25, | |
"width": 512, | |
"height": 512, | |
}, | |
"callInputs": { | |
"MODEL_ID": "lykon/dreamshaper-8", | |
"PIPELINE": "AutoPipelineForText2Image", | |
"SCHEDULER": "DEISMultistepScheduler", | |
"PRECISION": "fp16", | |
"REVISION": "fp16", | |
}, | |
} | |
response = requests.post(self.endpoint_address, json=inputs).json() | |
image_data = BytesIO(base64.b64decode(response["image_base64"])) | |
image = Image.open(image_data) | |
return image | |
class ServerNotReadyException(Exception): | |
pass | |
def create_fair_client(): | |
return FairClient(server_address=SERVER_ADDRESS, | |
user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"), | |
user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd")) | |
def create_endpoint_client(fc, retries, timeout=1.0, delay=2.0): | |
nodes = fc.cluster().nodes.list() | |
server_address = next(info['host_address'] for info in nodes if info['name'] == TUNNEL_NODE) | |
for i in range(retries): | |
try: | |
return EndpointClient(server_address, timeout=timeout) | |
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e: | |
logging.exception(e) | |
time.sleep(delay) | |
raise ServerNotReadyException("Failed to start the server") | |
def start_tunnel(fc: FairClient): | |
# generate fixed random authentication token based off some secret | |
token = hashlib.sha256(os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd").encode()).hexdigest() | |
# start tunnel node | |
server_config = f""" | |
[server] | |
bind_addr = "0.0.0.0:2333" # port that rathole listens for clients | |
[server.services.inference_server] | |
token = "{token}" # token that is used to authenticate the client for the service | |
bind_addr = "0.0.0.0:5000" # port that exposes service to the Internet | |
""" | |
with open('server.toml', 'w') as file: | |
file.write(server_config) | |
fc.run(node_name=TUNNEL_NODE, | |
image=TUNNEL_DOCKER_IMAGE, | |
command=["--server", "/app/config.toml"], | |
volumes=[("./server.toml", "/app/config.toml")], | |
network="host", | |
detach=True) | |
nodes = fc.cluster().nodes.list() | |
server_address = next(info['host_address'] for info in nodes if info['name'] == TUNNEL_NODE) | |
client_config = f""" | |
[client] | |
remote_addr = "{server_address}:2333" # address of the rathole server | |
[client.services.inference_server] | |
token = "{token}" # token that is used to authenticate the client for the service | |
local_addr = "127.0.0.1:5001" # address of the service that needs to be forwarded | |
""" | |
with open('client.toml', 'w') as file: | |
file.write(client_config) | |
fc.run(node_name=INFERENCE_NODE, | |
image=TUNNEL_DOCKER_IMAGE, | |
command=["--client", "/app/config.toml"], | |
volumes=[("./client.toml", "/app/config.toml")], | |
network="host", | |
detach=True) | |
def start_inference_server(fc: FairClient): | |
fc.run(node_name=INFERENCE_NODE, | |
image=INFERENCE_DOCKER_IMAGE, | |
runtime="nvidia", | |
ports=[(5001, 8000)], | |
detach=True) | |
def text_to_image(text): | |
global endpoint_client | |
global fair_client | |
if fair_client is None: | |
fair_client = create_fair_client() | |
try: | |
# client is configured, try to do inference right away | |
if endpoint_client is not None: | |
return endpoint_client.infer(text) | |
# client is not configured, try connecting to the inference server, maybe it is running | |
else: | |
endpoint_client = create_endpoint_client(fair_client, 1) | |
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout, ServerNotReadyException): | |
# inference server is not ready, start inference server and open the tunnel | |
start_inference_server(fair_client) | |
start_tunnel(fair_client) | |
endpoint_client = create_endpoint_client(fair_client, retries=10) | |
# run inference | |
return endpoint_client.infer(text) | |
if __name__ == "__main__": | |
image = text_to_image(text="Robot dinosaur\n") | |
image.save("result.png") | |