import base64 import logging import os import time from io import BytesIO from PIL import Image from octoai.client import Client as OctoAiClient from fair import FairClient logger = logging.getLogger() import requests SERVER_ADDRESS = "https://faircompute.com:8000" ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000" TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c" # SERVER_ADDRESS = "http://localhost:8000" # ENDPOINT_ADDRESS = "http://localhost:5000" # TARGET_NODE = "ef09913249aa40ecba7d0097f7622855" DOCKER_IMAGE = "faircompute/diffusion-octo:v1" class EndpointClient: def infer(self, prompt): client = OctoAiClient() inputs = {"prompt": {"text": prompt}} response = client.infer(endpoint_url=f"{ENDPOINT_ADDRESS}/infer", inputs=inputs) image_b64 = response["output"]["image_b64"] image_data = base64.b64decode(image_b64) image_data = BytesIO(image_data) image = Image.open(image_data) return image class ServerNotReadyException(Exception): pass def wait_for_server(retries, timeout, delay=1.0): for i in range(retries): try: r = requests.get(ENDPOINT_ADDRESS, timeout=timeout) r.raise_for_status() return except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e: if i == retries - 1: raise ServerNotReadyException("Failed to start the server") from e else: logger.info("Server is not ready yet") time.sleep(delay) def start_server(): # default credentials will work only for local server built in debug mode client = FairClient(server_address=SERVER_ADDRESS, user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"), user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd")) client.run(node=TARGET_NODE, image=DOCKER_IMAGE, ports=[(5000, 8080)], detach=True) # wait until the server is ready wait_for_server(retries=10, timeout=1.0) def text_to_image(text): try: wait_for_server(retries=1, timeout=1.0, delay=0.0) except ServerNotReadyException: start_server() client = EndpointClient() return client.infer(text) if __name__ == "__main__": image = text_to_image(text="Robot dinosaur\n") image.save("result.png")