File size: 2,458 Bytes
f060249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import base64
import logging
import os
import time
from io import BytesIO

from PIL import Image
from octoai.client import Client as OctoAiClient
from fair import FairClient

logger = logging.getLogger()

import requests

SERVER_ADDRESS = "https://faircompute.com:8000"
ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000"
TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c"
# SERVER_ADDRESS = "http://localhost:8000"
# ENDPOINT_ADDRESS = "http://localhost:5000"
# TARGET_NODE = "ef09913249aa40ecba7d0097f7622855"

DOCKER_IMAGE = "faircompute/diffusion-octo:v1"


class EndpointClient:
    def infer(self, prompt):
        client = OctoAiClient()

        inputs = {"prompt": {"text": prompt}}
        response = client.infer(endpoint_url=f"{ENDPOINT_ADDRESS}/infer", inputs=inputs)

        image_b64 = response["output"]["image_b64"]
        image_data = base64.b64decode(image_b64)
        image_data = BytesIO(image_data)
        image = Image.open(image_data)

        return image


class ServerNotReadyException(Exception):
    pass


def wait_for_server(retries, timeout, delay=1.0):
    for i in range(retries):
        try:
            r = requests.get(ENDPOINT_ADDRESS, timeout=timeout)
            r.raise_for_status()
            return
        except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e:
            if i == retries - 1:
                raise ServerNotReadyException("Failed to start the server") from e
            else:
                logger.info("Server is not ready yet")
                time.sleep(delay)


def start_server():
    # default credentials will work only for local server built in debug mode
    client = FairClient(server_address=SERVER_ADDRESS,
                        user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"),
                        user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd"))

    client.run(node=TARGET_NODE,
               image=DOCKER_IMAGE,
               ports=[(5000, 8080)],
               detach=True)

    # wait until the server is ready
    wait_for_server(retries=10, timeout=1.0)


def text_to_image(text):
    try:
        wait_for_server(retries=1, timeout=1.0, delay=0.0)
    except ServerNotReadyException:
        start_server()

    client = EndpointClient()
    return client.infer(text)


if __name__ == "__main__":
    image = text_to_image(text="Robot dinosaur\n")
    image.save("result.png")