import base64 import logging import os import requests import time from io import BytesIO from PIL import Image from fair import FairClient logger = logging.getLogger() SERVER_ADDRESS = "https://faircompute.com:8000" ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000" TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c" # SERVER_ADDRESS = "http://localhost:8000" # ENDPOINT_ADDRESS = "http://localhost:5000" # TARGET_NODE = "ef09913249aa40ecba7d0097f7622855" DOCKER_IMAGE = "faircompute/diffusers-api-dreamshaper-8" class EndpointClient: def __init__(self, timeout): response = requests.get(os.path.join(ENDPOINT_ADDRESS, 'healthcheck'), timeout=timeout).json() if response['state'] != 'healthy': raise Exception("Server is not healthy") def infer(self, prompt): inputs = { "modelInputs": { "prompt": prompt, "num_inference_steps": 25, "width": 512, "height": 512, }, "callInputs": { "MODEL_ID": "lykon/dreamshaper-8", "PIPELINE": "AutoPipelineForText2Image", "SCHEDULER": "DEISMultistepScheduler", "PRECISION": "fp16", "REVISION": "fp16", }, } response = requests.post(ENDPOINT_ADDRESS, json=inputs).json() image_data = BytesIO(base64.b64decode(response["image_base64"])) image = Image.open(image_data) return image class ServerNotReadyException(Exception): pass def wait_for_server(retries, timeout=1.0, delay=2.0): for i in range(retries): try: return EndpointClient(timeout=timeout) except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e: logging.exception(e) time.sleep(delay) raise ServerNotReadyException("Failed to start the server") def start_server(): # default credentials will work only for local server built in debug mode client = FairClient(server_address=SERVER_ADDRESS, user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"), user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd")) client.run(node=TARGET_NODE, image=DOCKER_IMAGE, runtime="nvidia", ports=[(5000, 8000)], detach=True) def text_to_image(text): try: client = wait_for_server(retries=1) except ServerNotReadyException: start_server() client = wait_for_server(retries=10) return client.infer(text) if __name__ == "__main__": image = text_to_image(text="Robot dinosaur\n") image.save("result.png")