File size: 5,389 Bytes
d42dfa2
 
 
 
 
 
 
 
 
 
 
 
 
 
bb1aed1
 
 
cafb8d0
8a8728a
 
d42dfa2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import base64
import logging
import os
import hashlib

import requests
import time
from io import BytesIO

from PIL import Image
from fair import FairClient

logger = logging.getLogger()

# SERVER_ADDRESS = "https://faircompute.com:8000"
# INFERENCE_NODE = "magnus"
# TUNNEL_NODE = "gcs-e2-micro"
SERVER_ADDRESS = "https://54.91.82.249:5000"
INFERENCE_NODE = "Nikhil-Macbook"
TUNNEL_NODE = "nikhil-tunneling-node"

INFERENCE_DOCKER_IMAGE = "faircompute/diffusers-api-dreamshaper-8"
TUNNEL_DOCKER_IMAGE = "rapiz1/rathole"

endpoint_client = None
fair_client = None

class EndpointClient:
    def __init__(self, server_address, timeout):
        self.endpoint_address = f'http://{server_address}:5000'
        response = requests.get(os.path.join(self.endpoint_address, 'healthcheck'), timeout=timeout).json()
        if response['state'] != 'healthy':
            raise Exception("Server is not healthy")

    def infer(self, prompt):
        inputs = {
            "modelInputs": {
                "prompt": prompt,
                "num_inference_steps": 25,
                "width": 512,
                "height": 512,
            },
            "callInputs": {
                "MODEL_ID": "lykon/dreamshaper-8",
                "PIPELINE": "AutoPipelineForText2Image",
                "SCHEDULER": "DEISMultistepScheduler",
                "PRECISION": "fp16",
                "REVISION": "fp16",
            },
        }

        response = requests.post(self.endpoint_address, json=inputs).json()
        image_data = BytesIO(base64.b64decode(response["image_base64"]))
        image = Image.open(image_data)

        return image


class ServerNotReadyException(Exception):
    pass


def create_fair_client():
    return FairClient(server_address=SERVER_ADDRESS,
                      user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"),
                      user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd"))


def create_endpoint_client(fc, retries, timeout=1.0, delay=2.0):
    nodes = fc.cluster().nodes.list()
    server_address = next(info['host_address'] for info in nodes if info['name'] == TUNNEL_NODE)
    for i in range(retries):
        try:
            return EndpointClient(server_address, timeout=timeout)
        except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e:
            logging.exception(e)
            time.sleep(delay)

    raise ServerNotReadyException("Failed to start the server")


def start_tunnel(fc: FairClient):
    # generate fixed random authentication token based off some secret
    token = hashlib.sha256(os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd").encode()).hexdigest()

    # start tunnel node
    server_config = f"""
[server]
bind_addr = "0.0.0.0:2333"  # port that rathole listens for clients

[server.services.inference_server]
token = "{token}"           # token that is used to authenticate the client for the service
bind_addr = "0.0.0.0:5000"  # port that exposes service to the Internet
"""
    with open('server.toml', 'w') as file:
        file.write(server_config)
    fc.run(node_name=TUNNEL_NODE,
           image=TUNNEL_DOCKER_IMAGE,
           command=["--server", "/app/config.toml"],
           volumes=[("./server.toml", "/app/config.toml")],
           network="host",
           detach=True)

    nodes = fc.cluster().nodes.list()
    server_address = next(info['host_address'] for info in nodes if info['name'] == TUNNEL_NODE)
    client_config = f"""
[client]
remote_addr = "{server_address}:2333"       # address of the rathole server

[client.services.inference_server]
token = "{token}"                           # token that is used to authenticate the client for the service
local_addr = "127.0.0.1:5001"               # address of the service that needs to be forwarded
"""
    with open('client.toml', 'w') as file:
        file.write(client_config)
    fc.run(node_name=INFERENCE_NODE,
           image=TUNNEL_DOCKER_IMAGE,
           command=["--client", "/app/config.toml"],
           volumes=[("./client.toml", "/app/config.toml")],
           network="host",
           detach=True)


def start_inference_server(fc: FairClient):
    fc.run(node_name=INFERENCE_NODE,
           image=INFERENCE_DOCKER_IMAGE,
           runtime="nvidia",
           ports=[(5001, 8000)],
           detach=True)


def text_to_image(text):
    global endpoint_client
    global fair_client
    if fair_client is None:
        fair_client = create_fair_client()

    try:
        # client is configured, try to do inference right away
        if endpoint_client is not None:
            return endpoint_client.infer(text)
        # client is not configured, try connecting to the inference server, maybe it is running
        else:
            endpoint_client = create_endpoint_client(fair_client, 1)
    except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout, ServerNotReadyException):
        # inference server is not ready, start inference server and open the tunnel
        start_inference_server(fair_client)
        start_tunnel(fair_client)
        endpoint_client = create_endpoint_client(fair_client, retries=10)

    # run inference
    return endpoint_client.infer(text)


if __name__ == "__main__":
    image = text_to_image(text="Robot dinosaur\n")
    image.save("result.png")