File size: 2,769 Bytes
f060249
 
 
6c7a4d8
f060249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdc1919
f060249
 
 
6c7a4d8
 
 
 
f060249
6c7a4d8
 
 
 
bdc1919
6c7a4d8
 
 
 
bdc1919
6c7a4d8
bdc1919
 
 
6c7a4d8
 
 
 
 
f060249
 
 
 
 
 
 
 
 
6c7a4d8
f060249
 
6c7a4d8
f060249
6c7a4d8
 
 
 
f060249
 
 
 
 
 
 
 
 
 
386c5dd
6c7a4d8
f060249
 
 
 
 
6c7a4d8
f060249
 
6c7a4d8
f060249
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import base64
import logging
import os
import requests
import time
from io import BytesIO

from PIL import Image
from fair import FairClient

logger = logging.getLogger()

SERVER_ADDRESS = "https://faircompute.com:8000"
ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000"
TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c"
# SERVER_ADDRESS = "http://localhost:8000"
# ENDPOINT_ADDRESS = "http://localhost:5000"
# TARGET_NODE = "ef09913249aa40ecba7d0097f7622855"

DOCKER_IMAGE = "faircompute/diffusers-api-dreamshaper-8"


class EndpointClient:
    def __init__(self, timeout):
        response = requests.get(os.path.join(ENDPOINT_ADDRESS, 'healthcheck'), timeout=timeout).json()
        if response['state'] != 'healthy':
            raise Exception("Server is not healthy")

    def infer(self, prompt):
        inputs = {
            "modelInputs": {
                "prompt": prompt,
                "num_inference_steps": 25,
                "width": 512,
                "height": 512,
            },
            "callInputs": {
                "MODEL_ID": "lykon/dreamshaper-8",
                "PIPELINE": "AutoPipelineForText2Image",
                "SCHEDULER": "DEISMultistepScheduler",
                "PRECISION": "fp16",
                "REVISION": "fp16",
            },
        }

        response = requests.post(ENDPOINT_ADDRESS, json=inputs).json()
        image_data = BytesIO(base64.b64decode(response["image_base64"]))
        image = Image.open(image_data)

        return image


class ServerNotReadyException(Exception):
    pass


def wait_for_server(retries, timeout=1.0, delay=2.0):
    for i in range(retries):
        try:
            return EndpointClient(timeout=timeout)
        except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e:
            logging.exception(e)
            time.sleep(delay)

    raise ServerNotReadyException("Failed to start the server")


def start_server():
    # default credentials will work only for local server built in debug mode
    client = FairClient(server_address=SERVER_ADDRESS,
                        user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"),
                        user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd"))

    client.run(node=TARGET_NODE,
               image=DOCKER_IMAGE,
               runtime="nvidia",
               ports=[(5000, 8000)],
               detach=True)


def text_to_image(text):
    try:
        client = wait_for_server(retries=1)
    except ServerNotReadyException:
        start_server()
        client = wait_for_server(retries=10)

    return client.infer(text)


if __name__ == "__main__":
    image = text_to_image(text="Robot dinosaur\n")
    image.save("result.png")