Spaces:
Sleeping
Sleeping
| import base64 | |
| import json | |
| import logging | |
| import os | |
| import time | |
| from io import BytesIO | |
| from typing import List, BinaryIO | |
| from PIL import Image | |
| from octoai.client import Client as OctoAiClient | |
| logger = logging.getLogger() | |
| import requests | |
| import tempfile | |
| SERVER_ADDRESS = "https://faircompute.com:8000/api/v0" | |
| ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000" | |
| TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c" | |
| # SERVER_ADDRESS = "http://localhost:8000/api/v0" | |
| # ENDPOINT_ADDRESS = "http://localhost:5000" | |
| # TARGET_NODE = None | |
| DOCKER_IMAGE = "faircompute/diffusion-octo:v1" | |
| class FairApiClient: | |
| def __init__(self, server_address: str): | |
| self.server_address = server_address | |
| self.token = None | |
| def authenticate(self, email: str, password: str): | |
| url = f'{self.server_address}/auth/login' | |
| json_obj = {"email": email, "password": password, "version": "V018"} | |
| resp = requests.post(url, json=json_obj) | |
| self.token = resp.json()["token"] | |
| def get(self, url, **kwargs): | |
| headers = { | |
| 'Authorization': f'Bearer {self.token}' | |
| } | |
| response = requests.get(url, headers=headers, **kwargs) | |
| if not response.ok: | |
| raise Exception(f"Error! status: {response.status_code}") | |
| return response | |
| def put(self, url, data): | |
| headers = { | |
| 'Content-Type': 'application/json', | |
| 'Authorization': f'Bearer {self.token}' | |
| } | |
| if not isinstance(data, str): | |
| data = json.dumps(data) | |
| response = requests.put(url, headers=headers, data=data) | |
| if not response.ok and response.status_code != 206: | |
| raise Exception(f"Error! status: {response.status_code}") | |
| return response | |
| def put_job(self, image: str, command: List[str], ports: List[tuple[int, int]], input_files, output_files): | |
| url = f"{self.server_address}/jobs" | |
| data = { | |
| 'version': 'V018', | |
| 'container_desc': { | |
| 'version': 'V018', | |
| 'image': image, | |
| 'runtime': 'nvidia', | |
| 'ports': [[{"port": host_port, "ip": 'null'}, {"port": container_port, "protocol": "Tcp"}] for (host_port, container_port) in ports], | |
| 'command': command, | |
| }, | |
| 'input_files': input_files, | |
| 'output_files': output_files, | |
| 'target_node': TARGET_NODE, | |
| } | |
| response = self.put(url=url, data=data).json() | |
| return response['id'], response['pid'] | |
| def get_job_info(self, job_id): | |
| url = f"{self.server_address}/jobs/{job_id}/stat" | |
| response = self.get(url=url).json() | |
| return response | |
| def get_cluster_summary(self): | |
| url = f"{self.server_address}/nodes/summary" | |
| response = self.get(url=url) | |
| return response.json() | |
| def put_job_stream_data(self, job_id, name, data): | |
| url = f"{self.server_address}/jobs/{job_id}/data/streams/{name}" | |
| response = self.put(url=url, data=data) | |
| return response.text | |
| def put_job_stream_eof(self, job_id, name): | |
| url = f"{self.server_address}/jobs/{job_id}/data/streams/{name}/eof" | |
| response = self.put(url=url, data=None) | |
| return response.text | |
| def wait_for_file(self, job_id, path, attempts=10) -> BinaryIO: | |
| headers = { | |
| 'Authorization': f'Bearer {self.token}' | |
| } | |
| for i in range(attempts): | |
| url = f"{self.server_address}/jobs/{job_id}/data/files/{path}" | |
| print(f"Waiting for file {path}...") | |
| try: | |
| with requests.get(url=url, headers=headers, stream=True) as r: | |
| r.raise_for_status() | |
| f = tempfile.TemporaryFile() | |
| for chunk in r.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| print(f"File {path} ready") | |
| f.seek(0, 0) | |
| return f | |
| except Exception as e: | |
| print(e) | |
| time.sleep(0.5) | |
| print(f"Failed to receive {path}") | |
| class EndpointClient: | |
| def infer(self, prompt): | |
| client = OctoAiClient() | |
| inputs = {"prompt": {"text": prompt}} | |
| response = client.infer(endpoint_url=f"{ENDPOINT_ADDRESS}/infer", inputs=inputs) | |
| image_b64 = response["output"]["image_b64"] | |
| image_data = base64.b64decode(image_b64) | |
| image_data = BytesIO(image_data) | |
| image = Image.open(image_data) | |
| return image | |
| class ServerNotReadyException(Exception): | |
| pass | |
| def wait_for_server(retries, timeout, delay=1.0): | |
| for i in range(retries): | |
| try: | |
| r = requests.get(ENDPOINT_ADDRESS, timeout=timeout) | |
| r.raise_for_status() | |
| return | |
| except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e: | |
| if i == retries - 1: | |
| raise ServerNotReadyException("Failed to start the server") from e | |
| else: | |
| logger.info("Server is not ready yet") | |
| time.sleep(delay) | |
| def start_server(): | |
| # default credentials will work only for local server built in debug mode | |
| email = os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr") | |
| password = os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd") | |
| client = FairApiClient(SERVER_ADDRESS) | |
| client.authenticate(email=email, password=password) | |
| job_id, job_pid = client.put_job( | |
| image=DOCKER_IMAGE, | |
| command=[], | |
| ports=[(5000, 8080)], | |
| input_files=[], | |
| output_files=[]) | |
| logger.info(f"Job id: {job_id}, pid: {job_pid}") | |
| info = client.get_job_info(job_id=job_id) | |
| logger.info(info) | |
| while info["status"] != "Processing": | |
| info = client.get_job_info(job_id=job_id) | |
| logger.info(info) | |
| time.sleep(0.5) | |
| # wait until the server is ready | |
| wait_for_server(retries=10, timeout=1.0) | |
| def text_to_image(text): | |
| # try: | |
| # wait_for_server(retries=1, timeout=1.0, delay=0.0) | |
| # except ServerNotReadyException: | |
| # start_server() | |
| client = EndpointClient() | |
| return client.infer(text) | |
| if __name__ == "__main__": | |
| image = text_to_image(text="Robot dinosaur\n") | |
| image.save("result.png") | |