FairComputeStableDifussionDemo / text_to_image.py
naourpally
Update the server address
cafb8d0
import base64
import logging
import os
import hashlib
import requests
import time
from io import BytesIO
from PIL import Image
from fair import FairClient
logger = logging.getLogger()
# SERVER_ADDRESS = "https://faircompute.com:8000"
# INFERENCE_NODE = "magnus"
# TUNNEL_NODE = "gcs-e2-micro"
SERVER_ADDRESS = "https://54.91.82.249:5000"
INFERENCE_NODE = "Nikhil-Macbook"
TUNNEL_NODE = "nikhil-tunneling-node"
INFERENCE_DOCKER_IMAGE = "faircompute/diffusers-api-dreamshaper-8"
TUNNEL_DOCKER_IMAGE = "rapiz1/rathole"
endpoint_client = None
fair_client = None
class EndpointClient:
def __init__(self, server_address, timeout):
self.endpoint_address = f'http://{server_address}:5000'
response = requests.get(os.path.join(self.endpoint_address, 'healthcheck'), timeout=timeout).json()
if response['state'] != 'healthy':
raise Exception("Server is not healthy")
def infer(self, prompt):
inputs = {
"modelInputs": {
"prompt": prompt,
"num_inference_steps": 25,
"width": 512,
"height": 512,
},
"callInputs": {
"MODEL_ID": "lykon/dreamshaper-8",
"PIPELINE": "AutoPipelineForText2Image",
"SCHEDULER": "DEISMultistepScheduler",
"PRECISION": "fp16",
"REVISION": "fp16",
},
}
response = requests.post(self.endpoint_address, json=inputs).json()
image_data = BytesIO(base64.b64decode(response["image_base64"]))
image = Image.open(image_data)
return image
class ServerNotReadyException(Exception):
pass
def create_fair_client():
return FairClient(server_address=SERVER_ADDRESS,
user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"),
user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd"))
def create_endpoint_client(fc, retries, timeout=1.0, delay=2.0):
nodes = fc.cluster().nodes.list()
server_address = next(info['host_address'] for info in nodes if info['name'] == TUNNEL_NODE)
for i in range(retries):
try:
return EndpointClient(server_address, timeout=timeout)
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e:
logging.exception(e)
time.sleep(delay)
raise ServerNotReadyException("Failed to start the server")
def start_tunnel(fc: FairClient):
# generate fixed random authentication token based off some secret
token = hashlib.sha256(os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd").encode()).hexdigest()
# start tunnel node
server_config = f"""
[server]
bind_addr = "0.0.0.0:2333" # port that rathole listens for clients
[server.services.inference_server]
token = "{token}" # token that is used to authenticate the client for the service
bind_addr = "0.0.0.0:5000" # port that exposes service to the Internet
"""
with open('server.toml', 'w') as file:
file.write(server_config)
fc.run(node_name=TUNNEL_NODE,
image=TUNNEL_DOCKER_IMAGE,
command=["--server", "/app/config.toml"],
volumes=[("./server.toml", "/app/config.toml")],
network="host",
detach=True)
nodes = fc.cluster().nodes.list()
server_address = next(info['host_address'] for info in nodes if info['name'] == TUNNEL_NODE)
client_config = f"""
[client]
remote_addr = "{server_address}:2333" # address of the rathole server
[client.services.inference_server]
token = "{token}" # token that is used to authenticate the client for the service
local_addr = "127.0.0.1:5001" # address of the service that needs to be forwarded
"""
with open('client.toml', 'w') as file:
file.write(client_config)
fc.run(node_name=INFERENCE_NODE,
image=TUNNEL_DOCKER_IMAGE,
command=["--client", "/app/config.toml"],
volumes=[("./client.toml", "/app/config.toml")],
network="host",
detach=True)
def start_inference_server(fc: FairClient):
fc.run(node_name=INFERENCE_NODE,
image=INFERENCE_DOCKER_IMAGE,
runtime="nvidia",
ports=[(5001, 8000)],
detach=True)
def text_to_image(text):
global endpoint_client
global fair_client
if fair_client is None:
fair_client = create_fair_client()
try:
# client is configured, try to do inference right away
if endpoint_client is not None:
return endpoint_client.infer(text)
# client is not configured, try connecting to the inference server, maybe it is running
else:
endpoint_client = create_endpoint_client(fair_client, 1)
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout, ServerNotReadyException):
# inference server is not ready, start inference server and open the tunnel
start_inference_server(fair_client)
start_tunnel(fair_client)
endpoint_client = create_endpoint_client(fair_client, retries=10)
# run inference
return endpoint_client.infer(text)
if __name__ == "__main__":
image = text_to_image(text="Robot dinosaur\n")
image.save("result.png")