Spaces:
Sleeping
Sleeping
import base64 | |
import logging | |
import os | |
import time | |
from io import BytesIO | |
from PIL import Image | |
from octoai.client import Client as OctoAiClient | |
from fair import FairClient | |
logger = logging.getLogger() | |
import requests | |
SERVER_ADDRESS = "https://faircompute.com:8000" | |
ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000" | |
TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c" | |
# SERVER_ADDRESS = "http://localhost:8000" | |
# ENDPOINT_ADDRESS = "http://localhost:5000" | |
# TARGET_NODE = "ef09913249aa40ecba7d0097f7622855" | |
DOCKER_IMAGE = "faircompute/diffusion-octo:v1" | |
class EndpointClient: | |
def infer(self, prompt): | |
client = OctoAiClient() | |
inputs = {"prompt": {"text": prompt}} | |
response = client.infer(endpoint_url=f"{ENDPOINT_ADDRESS}/infer", inputs=inputs) | |
image_b64 = response["output"]["image_b64"] | |
image_data = base64.b64decode(image_b64) | |
image_data = BytesIO(image_data) | |
image = Image.open(image_data) | |
return image | |
class ServerNotReadyException(Exception): | |
pass | |
def wait_for_server(retries, timeout, delay=1.0): | |
for i in range(retries): | |
try: | |
r = requests.get(ENDPOINT_ADDRESS, timeout=timeout) | |
r.raise_for_status() | |
return | |
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e: | |
if i == retries - 1: | |
raise ServerNotReadyException("Failed to start the server") from e | |
else: | |
logger.info("Server is not ready yet") | |
time.sleep(delay) | |
def start_server(): | |
# default credentials will work only for local server built in debug mode | |
client = FairClient(server_address=SERVER_ADDRESS, | |
user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"), | |
user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd")) | |
client.run(node=TARGET_NODE, | |
image=DOCKER_IMAGE, | |
ports=[(5000, 8080)], | |
detach=True) | |
# wait until the server is ready | |
wait_for_server(retries=10, timeout=1.0) | |
def text_to_image(text): | |
try: | |
wait_for_server(retries=1, timeout=1.0, delay=0.0) | |
except ServerNotReadyException: | |
start_server() | |
client = EndpointClient() | |
return client.infer(text) | |
if __name__ == "__main__": | |
image = text_to_image(text="Robot dinosaur\n") | |
image.save("result.png") | |