stable-diffusion / fair.py
trifonova's picture
Clean up commented out code
2e91066
raw
history blame
5.38 kB
import json
import os
import time
from typing import List, BinaryIO
import logging
logger = logging.getLogger()
import requests
import tempfile
SERVER_ADRESS="https://faircompute.com:8000/api/v1"
#SERVER_ADRESS="http://localhost:8000/api/v1"
DOCKER_IMAGE="faircompute/stable-diffusion:pytorch-1.13.1-cu116"
class FairApiClient:
def __init__(self, server_address: str):
self.server_address = server_address
self.token = None
def authenticate(self, email: str, password: str):
url = f'{self.server_address}/auth/login'
json_obj = {"email": email, "password": password}
resp = requests.post(url, json=json_obj)
self.token = resp.json()["token"]
def get(self, url, **kwargs):
headers = {
'Authorization': f'Bearer {self.token}'
}
response = requests.get(url, headers=headers, **kwargs)
if not response.ok:
raise Exception(f"Error! status: {response.status_code}")
return response
def put(self, url, data):
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.token}'
}
if not isinstance(data, str):
data = json.dumps(data)
response = requests.put(url, headers=headers, data=data)
if not response.ok and response.status_code != 206:
raise Exception(f"Error! status: {response.status_code}")
return response
def put_program(self, launcher: str, image: str, runtime: str, command: List[str]):
url = f"{self.server_address}/programs"
data = {
launcher: {
"image": image,
"command": command,
"runtime": runtime
}
}
response = self.put(url=url, data=data)
return int(response.text)
def put_job(self, program_id, input_files, output_files):
url = f"{self.server_address}/jobs?program={program_id}"
data = {
'input_files': input_files,
'output_files': output_files
}
response = self.put(url=url, data=data)
return int(response.text)
def get_job_status(self, job_id):
url = f"{self.server_address}/jobs/{job_id}/status"
response = self.get(url=url)
return response.text
def get_cluster_summary(self):
url = f"{self.server_address}/nodes/summary"
response = self.get(url=url)
return response.json()
def put_job_stream_data(self, job_id, name, data):
url = f"{self.server_address}/jobs/{job_id}/data/streams/{name}"
response = self.put(url=url, data=data)
return response.text
def put_job_stream_eof(self, job_id, name):
url = f"{self.server_address}/jobs/{job_id}/data/streams/{name}/eof"
response = self.put(url=url, data=None)
return response.text
def wait_for_file(self, job_id, path, attempts=10) -> BinaryIO:
headers = {
'Authorization': f'Bearer {self.token}'
}
for i in range(attempts):
url = f"{self.server_address}/jobs/{job_id}/data/files/{path}"
print(f"Waiting for file {path}...")
try:
with requests.get(url=url, headers=headers, stream=True) as r:
r.raise_for_status()
f = tempfile.TemporaryFile()
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
print(f"File {path} ready")
f.seek(0, 0)
return f
except Exception as e:
print(e)
time.sleep(0.5)
print(f"Failed to receive {path}")
def text_to_image(text):
email = os.getenv('FAIRCOMPUTE_EMAIL')
password = os.environ.get('FAIRCOMPUTE_PASSWORD')
client = FairApiClient(SERVER_ADRESS)
client.authenticate(email=email, password=password)
program_id = client.put_program(
launcher="Docker",
image=DOCKER_IMAGE,
runtime="nvidia",
command=[])
logger.info(program_id)
job_id = client.put_job(
program_id=program_id,
input_files=[],
output_files=["/workspace/result.png"])
logger.info(job_id)
status = client.get_job_status(job_id=job_id)
logger.info(status)
while status != "Processing" and status != "Completed":
status = client.get_job_status(job_id=job_id)
logger.info(status)
time.sleep(0.5)
res = client.put_job_stream_data(
job_id=job_id,
name="stdin",
data=text + "\n")
logger.info(res)
res = client.put_job_stream_eof(
job_id=job_id,
name="stdin")
logger.info(res)
status = client.get_job_status(job_id=job_id)
logger.info(status)
while status == "Processing":
status = client.get_job_status(
job_id=job_id)
logger.info(status)
time.sleep(0.5)
if status == "Completed":
logger.info("Done!")
else:
logger.info("Job Failed")
file = client.wait_for_file(
job_id=job_id,
path="%2Fworkspace%2Fresult.png")
return file
if __name__=="__main__":
from PIL import Image
file = text_to_image(text="Robot dinozaur\n")
image = Image.open(file)
image.save("result.png")