Spaces:
Sleeping
Sleeping
import base64 | |
import json | |
import logging | |
import os | |
import time | |
from io import BytesIO | |
from typing import List, BinaryIO | |
from PIL import Image | |
from octoai.client import Client as OctoAiClient | |
logger = logging.getLogger() | |
import requests | |
import tempfile | |
SERVER_ADDRESS = "https://faircompute.com:8000/api/v0" | |
ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000" | |
TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c" | |
# SERVER_ADDRESS = "http://localhost:8000/api/v0" | |
# ENDPOINT_ADDRESS = "http://localhost:5000" | |
# TARGET_NODE = None | |
DOCKER_IMAGE = "faircompute/diffusion-octo:v1" | |
class FairApiClient: | |
def __init__(self, server_address: str): | |
self.server_address = server_address | |
self.token = None | |
def authenticate(self, email: str, password: str): | |
url = f'{self.server_address}/auth/login' | |
json_obj = {"email": email, "password": password, "version": "V018"} | |
resp = requests.post(url, json=json_obj) | |
self.token = resp.json()["token"] | |
def get(self, url, **kwargs): | |
headers = { | |
'Authorization': f'Bearer {self.token}' | |
} | |
response = requests.get(url, headers=headers, **kwargs) | |
if not response.ok: | |
raise Exception(f"Error! status: {response.status_code}") | |
return response | |
def put(self, url, data): | |
headers = { | |
'Content-Type': 'application/json', | |
'Authorization': f'Bearer {self.token}' | |
} | |
if not isinstance(data, str): | |
data = json.dumps(data) | |
response = requests.put(url, headers=headers, data=data) | |
if not response.ok and response.status_code != 206: | |
raise Exception(f"Error! status: {response.status_code}") | |
return response | |
def put_job(self, image: str, command: List[str], ports: List[tuple[int, int]], input_files, output_files): | |
url = f"{self.server_address}/jobs" | |
data = { | |
'version': 'V018', | |
'container_desc': { | |
'version': 'V018', | |
'image': image, | |
'runtime': 'nvidia', | |
'ports': [[{"port": host_port, "ip": 'null'}, {"port": container_port, "protocol": "Tcp"}] for (host_port, container_port) in ports], | |
'command': command, | |
}, | |
'input_files': input_files, | |
'output_files': output_files, | |
'target_node': TARGET_NODE, | |
} | |
response = self.put(url=url, data=data).json() | |
return response['id'], response['pid'] | |
def get_job_info(self, job_id): | |
url = f"{self.server_address}/jobs/{job_id}/stat" | |
response = self.get(url=url).json() | |
return response | |
def get_cluster_summary(self): | |
url = f"{self.server_address}/nodes/summary" | |
response = self.get(url=url) | |
return response.json() | |
def put_job_stream_data(self, job_id, name, data): | |
url = f"{self.server_address}/jobs/{job_id}/data/streams/{name}" | |
response = self.put(url=url, data=data) | |
return response.text | |
def put_job_stream_eof(self, job_id, name): | |
url = f"{self.server_address}/jobs/{job_id}/data/streams/{name}/eof" | |
response = self.put(url=url, data=None) | |
return response.text | |
def wait_for_file(self, job_id, path, attempts=10) -> BinaryIO: | |
headers = { | |
'Authorization': f'Bearer {self.token}' | |
} | |
for i in range(attempts): | |
url = f"{self.server_address}/jobs/{job_id}/data/files/{path}" | |
print(f"Waiting for file {path}...") | |
try: | |
with requests.get(url=url, headers=headers, stream=True) as r: | |
r.raise_for_status() | |
f = tempfile.TemporaryFile() | |
for chunk in r.iter_content(chunk_size=8192): | |
f.write(chunk) | |
print(f"File {path} ready") | |
f.seek(0, 0) | |
return f | |
except Exception as e: | |
print(e) | |
time.sleep(0.5) | |
print(f"Failed to receive {path}") | |
class EndpointClient: | |
def infer(self, prompt): | |
client = OctoAiClient() | |
inputs = {"prompt": {"text": prompt}} | |
response = client.infer(endpoint_url=f"{ENDPOINT_ADDRESS}/infer", inputs=inputs) | |
image_b64 = response["output"]["image_b64"] | |
image_data = base64.b64decode(image_b64) | |
image_data = BytesIO(image_data) | |
image = Image.open(image_data) | |
return image | |
class ServerNotReadyException(Exception): | |
pass | |
def wait_for_server(retries, timeout, delay=1.0): | |
for i in range(retries): | |
try: | |
r = requests.get(ENDPOINT_ADDRESS, timeout=timeout) | |
r.raise_for_status() | |
return | |
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e: | |
if i == retries - 1: | |
raise ServerNotReadyException("Failed to start the server") from e | |
else: | |
logger.info("Server is not ready yet") | |
time.sleep(delay) | |
def start_server(): | |
# default credentials will work only for local server built in debug mode | |
email = os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr") | |
password = os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd") | |
client = FairApiClient(SERVER_ADDRESS) | |
client.authenticate(email=email, password=password) | |
job_id, job_pid = client.put_job( | |
image=DOCKER_IMAGE, | |
command=[], | |
ports=[(5000, 8080)], | |
input_files=[], | |
output_files=[]) | |
logger.info(f"Job id: {job_id}, pid: {job_pid}") | |
info = client.get_job_info(job_id=job_id) | |
logger.info(info) | |
while info["status"] != "Processing": | |
info = client.get_job_info(job_id=job_id) | |
logger.info(info) | |
time.sleep(0.5) | |
# wait until the server is ready | |
wait_for_server(retries=10, timeout=1.0) | |
def text_to_image(text): | |
# try: | |
# wait_for_server(retries=1, timeout=1.0, delay=0.0) | |
# except ServerNotReadyException: | |
# start_server() | |
client = EndpointClient() | |
return client.infer(text) | |
if __name__ == "__main__": | |
image = text_to_image(text="Robot dinosaur\n") | |
image.save("result.png") | |