import base64 import json import logging import os import time from io import BytesIO from typing import List, BinaryIO from PIL import Image from octoai.client import Client as OctoAiClient logger = logging.getLogger() import requests import tempfile SERVER_ADDRESS = "https://faircompute.com:8000/api/v0" ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000" TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c" # SERVER_ADDRESS = "http://localhost:8000/api/v0" # ENDPOINT_ADDRESS = "http://localhost:5000" # TARGET_NODE = None DOCKER_IMAGE = "faircompute/diffusion-octo:v1" class FairApiClient: def __init__(self, server_address: str): self.server_address = server_address self.token = None def authenticate(self, email: str, password: str): url = f'{self.server_address}/auth/login' json_obj = {"email": email, "password": password, "version": "V018"} resp = requests.post(url, json=json_obj) self.token = resp.json()["token"] def get(self, url, **kwargs): headers = { 'Authorization': f'Bearer {self.token}' } response = requests.get(url, headers=headers, **kwargs) if not response.ok: raise Exception(f"Error! status: {response.status_code}") return response def put(self, url, data): headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {self.token}' } if not isinstance(data, str): data = json.dumps(data) response = requests.put(url, headers=headers, data=data) if not response.ok and response.status_code != 206: raise Exception(f"Error! status: {response.status_code}") return response def put_job(self, image: str, command: List[str], ports: List[tuple[int, int]], input_files, output_files): url = f"{self.server_address}/jobs" data = { 'version': 'V018', 'container_desc': { 'version': 'V018', 'image': image, 'runtime': 'nvidia', 'ports': [[{"port": host_port, "ip": 'null'}, {"port": container_port, "protocol": "Tcp"}] for (host_port, container_port) in ports], 'command': command, }, 'input_files': input_files, 'output_files': output_files, 'target_node': TARGET_NODE, } response = self.put(url=url, data=data).json() return response['id'], response['pid'] def get_job_info(self, job_id): url = f"{self.server_address}/jobs/{job_id}/stat" response = self.get(url=url).json() return response def get_cluster_summary(self): url = f"{self.server_address}/nodes/summary" response = self.get(url=url) return response.json() def put_job_stream_data(self, job_id, name, data): url = f"{self.server_address}/jobs/{job_id}/data/streams/{name}" response = self.put(url=url, data=data) return response.text def put_job_stream_eof(self, job_id, name): url = f"{self.server_address}/jobs/{job_id}/data/streams/{name}/eof" response = self.put(url=url, data=None) return response.text def wait_for_file(self, job_id, path, attempts=10) -> BinaryIO: headers = { 'Authorization': f'Bearer {self.token}' } for i in range(attempts): url = f"{self.server_address}/jobs/{job_id}/data/files/{path}" print(f"Waiting for file {path}...") try: with requests.get(url=url, headers=headers, stream=True) as r: r.raise_for_status() f = tempfile.TemporaryFile() for chunk in r.iter_content(chunk_size=8192): f.write(chunk) print(f"File {path} ready") f.seek(0, 0) return f except Exception as e: print(e) time.sleep(0.5) print(f"Failed to receive {path}") class EndpointClient: def infer(self, prompt): client = OctoAiClient() inputs = {"prompt": {"text": prompt}} response = client.infer(endpoint_url=f"{ENDPOINT_ADDRESS}/infer", inputs=inputs) image_b64 = response["output"]["image_b64"] image_data = base64.b64decode(image_b64) image_data = BytesIO(image_data) image = Image.open(image_data) return image class ServerNotReadyException(Exception): pass def wait_for_server(retries, timeout, delay=1.0): for i in range(retries): try: r = requests.get(ENDPOINT_ADDRESS, timeout=timeout) r.raise_for_status() return except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e: if i == retries - 1: raise ServerNotReadyException("Failed to start the server") from e else: logger.info("Server is not ready yet") time.sleep(delay) def start_server(): # default credentials will work only for local server built in debug mode email = os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr") password = os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd") client = FairApiClient(SERVER_ADDRESS) client.authenticate(email=email, password=password) job_id, job_pid = client.put_job( image=DOCKER_IMAGE, command=[], ports=[(5000, 8080)], input_files=[], output_files=[]) logger.info(f"Job id: {job_id}, pid: {job_pid}") info = client.get_job_info(job_id=job_id) logger.info(info) while info["status"] != "Processing": info = client.get_job_info(job_id=job_id) logger.info(info) time.sleep(0.5) # wait until the server is ready wait_for_server(retries=10, timeout=1.0) def text_to_image(text): # try: # wait_for_server(retries=1, timeout=1.0, delay=0.0) # except ServerNotReadyException: # start_server() client = EndpointClient() return client.infer(text) if __name__ == "__main__": image = text_to_image(text="Robot dinosaur\n") image.save("result.png")