stable-diffusion / text_to_image.py
Dmitry Trifonov
use dreamshaper v8
bdc1919
raw
history blame
2.77 kB
import base64
import logging
import os
import requests
import time
from io import BytesIO
from PIL import Image
from fair import FairClient
logger = logging.getLogger()
SERVER_ADDRESS = "https://faircompute.com:8000"
ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000"
TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c"
# SERVER_ADDRESS = "http://localhost:8000"
# ENDPOINT_ADDRESS = "http://localhost:5000"
# TARGET_NODE = "ef09913249aa40ecba7d0097f7622855"
DOCKER_IMAGE = "faircompute/diffusers-api-dreamshaper-8"
class EndpointClient:
def __init__(self, timeout):
response = requests.get(os.path.join(ENDPOINT_ADDRESS, 'healthcheck'), timeout=timeout).json()
if response['state'] != 'healthy':
raise Exception("Server is not healthy")
def infer(self, prompt):
inputs = {
"modelInputs": {
"prompt": prompt,
"num_inference_steps": 25,
"width": 512,
"height": 512,
},
"callInputs": {
"MODEL_ID": "lykon/dreamshaper-8",
"PIPELINE": "AutoPipelineForText2Image",
"SCHEDULER": "DEISMultistepScheduler",
"PRECISION": "fp16",
"REVISION": "fp16",
},
}
response = requests.post(ENDPOINT_ADDRESS, json=inputs).json()
image_data = BytesIO(base64.b64decode(response["image_base64"]))
image = Image.open(image_data)
return image
class ServerNotReadyException(Exception):
pass
def wait_for_server(retries, timeout=1.0, delay=2.0):
for i in range(retries):
try:
return EndpointClient(timeout=timeout)
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e:
logging.exception(e)
time.sleep(delay)
raise ServerNotReadyException("Failed to start the server")
def start_server():
# default credentials will work only for local server built in debug mode
client = FairClient(server_address=SERVER_ADDRESS,
user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"),
user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd"))
client.run(node=TARGET_NODE,
image=DOCKER_IMAGE,
runtime="nvidia",
ports=[(5000, 8000)],
detach=True)
def text_to_image(text):
try:
client = wait_for_server(retries=1)
except ServerNotReadyException:
start_server()
client = wait_for_server(retries=10)
return client.infer(text)
if __name__ == "__main__":
image = text_to_image(text="Robot dinosaur\n")
image.save("result.png")