File size: 5,415 Bytes
f060249
 
 
aeb505f
 
6c7a4d8
f060249
 
 
 
 
 
 
 
 
1e7a4d3
 
f060249
aeb505f
 
 
 
 
f060249
aeb505f
 
f060249
 
 
aeb505f
 
 
6c7a4d8
 
f060249
6c7a4d8
 
 
 
bdc1919
6c7a4d8
 
 
 
bdc1919
6c7a4d8
bdc1919
 
 
6c7a4d8
 
 
aeb505f
6c7a4d8
f060249
 
 
 
 
 
 
 
 
aeb505f
 
 
 
 
 
 
1e7a4d3
 
f060249
 
aeb505f
f060249
6c7a4d8
 
 
 
f060249
 
aeb505f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e7a4d3
aeb505f
 
 
 
 
 
1e7a4d3
 
aeb505f
 
 
 
 
 
 
 
 
 
1e7a4d3
aeb505f
 
 
 
 
 
 
 
1e7a4d3
aeb505f
 
 
 
 
 
f060249
aeb505f
 
 
 
f060249
aeb505f
bf413b8
 
aeb505f
bf413b8
 
 
1e7a4d3
bf413b8
 
 
aeb505f
 
bf413b8
aeb505f
f060249
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import base64
import logging
import os
import hashlib

import requests
import time
from io import BytesIO

from PIL import Image
from fair import FairClient

logger = logging.getLogger()

SERVER_ADDRESS = "https://faircompute.com:8000"
INFERENCE_NODE = "magnus"
TUNNEL_NODE = "gcs-e2-micro"
# SERVER_ADDRESS = "http://localhost:8000"
# INFERENCE_NODE = "ef09913249aa40ecba7d0097f7622855"
# TUNNEL_NODE = "c312e6c4788b00c73c287ab0445d3655"

INFERENCE_DOCKER_IMAGE = "faircompute/diffusers-api-dreamshaper-8"
TUNNEL_DOCKER_IMAGE = "rapiz1/rathole"

endpoint_client = None
fair_client = None


class EndpointClient:
    def __init__(self, server_address, timeout):
        self.endpoint_address = f'http://{server_address}:5000'
        response = requests.get(os.path.join(self.endpoint_address, 'healthcheck'), timeout=timeout).json()
        if response['state'] != 'healthy':
            raise Exception("Server is not healthy")

    def infer(self, prompt):
        inputs = {
            "modelInputs": {
                "prompt": prompt,
                "num_inference_steps": 25,
                "width": 512,
                "height": 512,
            },
            "callInputs": {
                "MODEL_ID": "lykon/dreamshaper-8",
                "PIPELINE": "AutoPipelineForText2Image",
                "SCHEDULER": "DEISMultistepScheduler",
                "PRECISION": "fp16",
                "REVISION": "fp16",
            },
        }

        response = requests.post(self.endpoint_address, json=inputs).json()
        image_data = BytesIO(base64.b64decode(response["image_base64"]))
        image = Image.open(image_data)

        return image


class ServerNotReadyException(Exception):
    pass


def create_fair_client():
    return FairClient(server_address=SERVER_ADDRESS,
                      user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"),
                      user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd"))


def create_endpoint_client(fc, retries, timeout=1.0, delay=2.0):
    nodes = fc.cluster().nodes.list()
    server_address = next(info['host_address'] for info in nodes if info['name'] == TUNNEL_NODE)
    for i in range(retries):
        try:
            return EndpointClient(server_address, timeout=timeout)
        except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e:
            logging.exception(e)
            time.sleep(delay)

    raise ServerNotReadyException("Failed to start the server")


def start_tunnel(fc: FairClient):
    # generate fixed random authentication token based off some secret
    token = hashlib.sha256(os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd").encode()).hexdigest()

    # start tunnel node
    server_config = f"""
[server]
bind_addr = "0.0.0.0:2333"  # port that rathole listens for clients

[server.services.inference_server]
token = "{token}"           # token that is used to authenticate the client for the service
bind_addr = "0.0.0.0:5000"  # port that exposes service to the Internet
"""
    with open('server.toml', 'w') as file:
        file.write(server_config)
    fc.run(node_name=TUNNEL_NODE,
           image=TUNNEL_DOCKER_IMAGE,
           command=["--server", "/app/config.toml"],
           volumes=[("./server.toml", "/app/config.toml")],
           network="host",
           detach=True)

    nodes = fc.cluster().nodes.list()
    server_address = next(info['host_address'] for info in nodes if info['name'] == TUNNEL_NODE)
    client_config = f"""
[client]
remote_addr = "{server_address}:2333"       # address of the rathole server

[client.services.inference_server]
token = "{token}"                           # token that is used to authenticate the client for the service
local_addr = "127.0.0.1:5001"               # address of the service that needs to be forwarded
"""
    with open('client.toml', 'w') as file:
        file.write(client_config)
    fc.run(node_name=INFERENCE_NODE,
           image=TUNNEL_DOCKER_IMAGE,
           command=["--client", "/app/config.toml"],
           volumes=[("./client.toml", "/app/config.toml")],
           network="host",
           detach=True)


def start_inference_server(fc: FairClient):
    fc.run(node_name=INFERENCE_NODE,
           image=INFERENCE_DOCKER_IMAGE,
           runtime="nvidia",
           ports=[(5001, 8000)],
           detach=True)


def text_to_image(text):
    global endpoint_client
    global fair_client
    if fair_client is None:
        fair_client = create_fair_client()

    try:
        # client is configured, try to do inference right away
        if endpoint_client is not None:
            return endpoint_client.infer(text)
        # client is not configured, try connecting to the inference server, maybe it is running
        else:
            endpoint_client = create_endpoint_client(fair_client, 1)
    except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout, ServerNotReadyException):
        # inference server is not ready, start inference server and open the tunnel
        start_inference_server(fair_client)
        start_tunnel(fair_client)
        endpoint_client = create_endpoint_client(fair_client, retries=10)

    # run inference
    return endpoint_client.infer(text)


if __name__ == "__main__":
    image = text_to_image(text="Robot dinosaur\n")
    image.save("result.png")