stable-diffusion / text_to_image.py
Dmitry Trifonov
update to use faircompute python library
f060249
raw
history blame
2.46 kB
import base64
import logging
import os
import time
from io import BytesIO
from PIL import Image
from octoai.client import Client as OctoAiClient
from fair import FairClient
logger = logging.getLogger()
import requests
SERVER_ADDRESS = "https://faircompute.com:8000"
ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000"
TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c"
# SERVER_ADDRESS = "http://localhost:8000"
# ENDPOINT_ADDRESS = "http://localhost:5000"
# TARGET_NODE = "ef09913249aa40ecba7d0097f7622855"
DOCKER_IMAGE = "faircompute/diffusion-octo:v1"
class EndpointClient:
def infer(self, prompt):
client = OctoAiClient()
inputs = {"prompt": {"text": prompt}}
response = client.infer(endpoint_url=f"{ENDPOINT_ADDRESS}/infer", inputs=inputs)
image_b64 = response["output"]["image_b64"]
image_data = base64.b64decode(image_b64)
image_data = BytesIO(image_data)
image = Image.open(image_data)
return image
class ServerNotReadyException(Exception):
pass
def wait_for_server(retries, timeout, delay=1.0):
for i in range(retries):
try:
r = requests.get(ENDPOINT_ADDRESS, timeout=timeout)
r.raise_for_status()
return
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e:
if i == retries - 1:
raise ServerNotReadyException("Failed to start the server") from e
else:
logger.info("Server is not ready yet")
time.sleep(delay)
def start_server():
# default credentials will work only for local server built in debug mode
client = FairClient(server_address=SERVER_ADDRESS,
user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"),
user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd"))
client.run(node=TARGET_NODE,
image=DOCKER_IMAGE,
ports=[(5000, 8080)],
detach=True)
# wait until the server is ready
wait_for_server(retries=10, timeout=1.0)
def text_to_image(text):
try:
wait_for_server(retries=1, timeout=1.0, delay=0.0)
except ServerNotReadyException:
start_server()
client = EndpointClient()
return client.infer(text)
if __name__ == "__main__":
image = text_to_image(text="Robot dinosaur\n")
image.save("result.png")