Spaces:
Sleeping
Sleeping
import base64 | |
import logging | |
import os | |
import requests | |
import time | |
from io import BytesIO | |
from PIL import Image | |
from fair import FairClient | |
logger = logging.getLogger() | |
SERVER_ADDRESS = "https://faircompute.com:8000" | |
ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000" | |
TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c" | |
# SERVER_ADDRESS = "http://localhost:8000" | |
# ENDPOINT_ADDRESS = "http://localhost:5000" | |
# TARGET_NODE = "ef09913249aa40ecba7d0097f7622855" | |
DOCKER_IMAGE = "faircompute/diffusers-api-dreamshaper-8" | |
class EndpointClient: | |
def __init__(self, timeout): | |
response = requests.get(os.path.join(ENDPOINT_ADDRESS, 'healthcheck'), timeout=timeout).json() | |
if response['state'] != 'healthy': | |
raise Exception("Server is not healthy") | |
def infer(self, prompt): | |
inputs = { | |
"modelInputs": { | |
"prompt": prompt, | |
"num_inference_steps": 25, | |
"width": 512, | |
"height": 512, | |
}, | |
"callInputs": { | |
"MODEL_ID": "lykon/dreamshaper-8", | |
"PIPELINE": "AutoPipelineForText2Image", | |
"SCHEDULER": "DEISMultistepScheduler", | |
"PRECISION": "fp16", | |
"REVISION": "fp16", | |
}, | |
} | |
response = requests.post(ENDPOINT_ADDRESS, json=inputs).json() | |
image_data = BytesIO(base64.b64decode(response["image_base64"])) | |
image = Image.open(image_data) | |
return image | |
class ServerNotReadyException(Exception): | |
pass | |
def wait_for_server(retries, timeout=1.0, delay=2.0): | |
for i in range(retries): | |
try: | |
return EndpointClient(timeout=timeout) | |
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e: | |
logging.exception(e) | |
time.sleep(delay) | |
raise ServerNotReadyException("Failed to start the server") | |
def start_server(): | |
# default credentials will work only for local server built in debug mode | |
client = FairClient(server_address=SERVER_ADDRESS, | |
user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"), | |
user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd")) | |
client.run(node=TARGET_NODE, | |
image=DOCKER_IMAGE, | |
runtime="nvidia", | |
ports=[(5000, 8000)], | |
detach=True) | |
def text_to_image(text): | |
try: | |
client = wait_for_server(retries=1) | |
except ServerNotReadyException: | |
start_server() | |
client = wait_for_server(retries=10) | |
return client.infer(text) | |
if __name__ == "__main__": | |
image = text_to_image(text="Robot dinosaur\n") | |
image.save("result.png") | |