Spaces:
Sleeping
Sleeping
File size: 2,769 Bytes
f060249 6c7a4d8 f060249 bdc1919 f060249 6c7a4d8 f060249 6c7a4d8 bdc1919 6c7a4d8 bdc1919 6c7a4d8 bdc1919 6c7a4d8 f060249 6c7a4d8 f060249 6c7a4d8 f060249 6c7a4d8 f060249 386c5dd 6c7a4d8 f060249 6c7a4d8 f060249 6c7a4d8 f060249 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import base64
import logging
import os
import requests
import time
from io import BytesIO
from PIL import Image
from fair import FairClient
logger = logging.getLogger()
SERVER_ADDRESS = "https://faircompute.com:8000"
ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000"
TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c"
# SERVER_ADDRESS = "http://localhost:8000"
# ENDPOINT_ADDRESS = "http://localhost:5000"
# TARGET_NODE = "ef09913249aa40ecba7d0097f7622855"
DOCKER_IMAGE = "faircompute/diffusers-api-dreamshaper-8"
class EndpointClient:
def __init__(self, timeout):
response = requests.get(os.path.join(ENDPOINT_ADDRESS, 'healthcheck'), timeout=timeout).json()
if response['state'] != 'healthy':
raise Exception("Server is not healthy")
def infer(self, prompt):
inputs = {
"modelInputs": {
"prompt": prompt,
"num_inference_steps": 25,
"width": 512,
"height": 512,
},
"callInputs": {
"MODEL_ID": "lykon/dreamshaper-8",
"PIPELINE": "AutoPipelineForText2Image",
"SCHEDULER": "DEISMultistepScheduler",
"PRECISION": "fp16",
"REVISION": "fp16",
},
}
response = requests.post(ENDPOINT_ADDRESS, json=inputs).json()
image_data = BytesIO(base64.b64decode(response["image_base64"]))
image = Image.open(image_data)
return image
class ServerNotReadyException(Exception):
pass
def wait_for_server(retries, timeout=1.0, delay=2.0):
for i in range(retries):
try:
return EndpointClient(timeout=timeout)
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e:
logging.exception(e)
time.sleep(delay)
raise ServerNotReadyException("Failed to start the server")
def start_server():
# default credentials will work only for local server built in debug mode
client = FairClient(server_address=SERVER_ADDRESS,
user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"),
user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd"))
client.run(node=TARGET_NODE,
image=DOCKER_IMAGE,
runtime="nvidia",
ports=[(5000, 8000)],
detach=True)
def text_to_image(text):
try:
client = wait_for_server(retries=1)
except ServerNotReadyException:
start_server()
client = wait_for_server(retries=10)
return client.infer(text)
if __name__ == "__main__":
image = text_to_image(text="Robot dinosaur\n")
image.save("result.png")
|