Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import os | |
import subprocess | |
#subprocess.run(['sh', './jax.sh']) | |
''' | |
subprocess.run(['sh', './conda.sh']) | |
import sys | |
conda_prefix = os.path.expanduser("~/miniconda3") | |
conda_bin = os.path.join(conda_prefix, "bin") | |
# Add Conda's bin directory to your PATH | |
os.environ["PATH"] = conda_bin + os.pathsep + os.environ["PATH"] | |
# Activate the base environment (adjust if needed) | |
os.system(f'{conda_bin}/conda init --all') | |
os.system(f'{conda_bin}/conda activate base') | |
os.system(f'{conda_bin}/conda install nvidia/label/cudnn-9.3.0::cudnn') | |
''' | |
import gradio as gr | |
import numpy as np | |
import paramiko | |
from image_gen_aux import UpscaleWithModel | |
import cyper | |
from PIL import Image | |
os.environ['JAX_PLATFORMS'] = 'cpu' | |
import random | |
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00" | |
import keras | |
import keras_hub | |
import torch | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
torch.backends.cuda.matmul.allow_tf32 = False | |
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False | |
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False | |
torch.backends.cudnn.allow_tf32 = False | |
torch.backends.cudnn.deterministic = False | |
torch.backends.cudnn.benchmark = False | |
#torch.backends.cuda.preferred_blas_library="cublas" | |
#torch.backends.cuda.preferred_linalg_library="cusolver" | |
torch.set_float32_matmul_precision("highest") | |
upscaler_2 = None # UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(device) | |
text_to_image = None | |
def load_model(): | |
global upscaler_2 | |
upscaler_2 = UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(device) | |
global text_to_image | |
text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset( | |
"stable_diffusion_3_medium", width=768, height=768, dtype="bfloat16" | |
) | |
return text_to_image | |
code = r''' | |
import paramiko | |
import os | |
FTP_HOST = '1ink.us' | |
FTP_USER = 'ford442' | |
FTP_PASS = os.getenv("FTP_PASS") | |
FTP_DIR = '1ink.us/stable_diff/' | |
def upload_to_ftp(filename): | |
try: | |
transport = paramiko.Transport((FTP_HOST, 22)) | |
destination_path=FTP_DIR+filename | |
transport.connect(username = FTP_USER, password = FTP_PASS) | |
sftp = paramiko.SFTPClient.from_transport(transport) | |
sftp.put(filename, destination_path) | |
sftp.close() | |
transport.close() | |
print(f"Uploaded {filename} to FTP server") | |
except Exception as e: | |
print(f"FTP upload error: {e}") | |
''' | |
pyx = cyper.inline(code, fast_indexing=True, directives=dict(boundscheck=False, wraparound=False, language_level=3)) | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 4096 | |
def infer_30( | |
prompt, | |
negative_prompt, | |
guidance_scale, | |
num_inference_steps, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
global text_to_image | |
if text_to_image is None: | |
text_to_image = load_model() | |
os.environ['JAX_PLATFORMS'] = 'gpu' | |
os.environ['KERAS_BACKEND'] = 'jax' | |
seed = random.randint(0, MAX_SEED) | |
sd_image = text_to_image.generate( | |
prompt, | |
num_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
seed=seed | |
) | |
print('-- got image --') | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
sd35_path = f"sd3keras_{timestamp}.png" | |
sd_image.save(sd35_path,optimize=False,compress_level=0) | |
pyx.upload_to_ftp(sd35_path) | |
with torch.no_grad(): | |
upscale2 = upscaler_2(sd_image, tiling=True, tile_width=256, tile_height=256) | |
print('-- got upscaled image --') | |
downscale2 = upscale2.resize((upscale2.width // 4, upscale2.height // 4),Image.LANCZOS) | |
upscale_path = f"sd3keras_upscale_{timestamp}.png" | |
downscale2.save(upscale_path,optimize=False,compress_level=0) | |
pyx.upload_to_ftp(upscale_path) | |
return sd_image, prompt | |
def infer_60( | |
prompt, | |
negative_prompt, | |
guidance_scale, | |
num_inference_steps, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
global text_to_image | |
if text_to_image is None: | |
text_to_image = load_model() | |
os.environ['JAX_PLATFORMS'] = 'gpu' | |
os.environ['KERAS_BACKEND'] = 'jax' | |
seed = random.randint(0, MAX_SEED) | |
sd_image = text_to_image.generate( | |
prompt, | |
num_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
seed=seed | |
) | |
print('-- got image --') | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
sd35_path = f"sd3keras_{timestamp}.png" | |
sd_image.save(sd35_path,optimize=False,compress_level=0) | |
pyx.upload_to_ftp(sd35_path) | |
with torch.no_grad(): | |
upscale2 = upscaler_2(sd_image, tiling=True, tile_width=256, tile_height=256) | |
print('-- got upscaled image --') | |
downscale2 = upscale2.resize((upscale2.width // 4, upscale2.height // 4),Image.LANCZOS) | |
upscale_path = f"sd3keras_upscale_{timestamp}.png" | |
downscale2.save(upscale_path,optimize=False,compress_level=0) | |
pyx.upload_to_ftp(upscale_path) | |
return sd_image, prompt | |
def infer_90( | |
prompt, | |
negative_prompt, | |
guidance_scale, | |
num_inference_steps, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
global text_to_image | |
if text_to_image is None: | |
text_to_image = load_model() | |
os.environ['JAX_PLATFORMS'] = 'gpu' | |
os.environ['KERAS_BACKEND'] = 'jax' | |
seed = random.randint(0, MAX_SEED) | |
sd_image = text_to_image.generate( | |
prompt, | |
num_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
seed=seed | |
) | |
print('-- got image --') | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
sd35_path = f"sd3keras_{timestamp}.png" | |
sd_image.save(sd35_path,optimize=False,compress_level=0) | |
pyx.upload_to_ftp(sd35_path) | |
with torch.no_grad(): | |
upscale2 = upscaler_2(sd_image, tiling=True, tile_width=256, tile_height=256) | |
print('-- got upscaled image --') | |
downscale2 = upscale2.resize((upscale2.width // 4, upscale2.height // 4),Image.LANCZOS) | |
upscale_path = f"sd3keras_upscale_{timestamp}.png" | |
downscale2.save(upscale_path,optimize=False,compress_level=0) | |
pyx.upload_to_ftp(upscale_path) | |
return sd_image, prompt | |
css = """ | |
#col-container {margin: 0 auto;max-width: 640px;} | |
body{background-color: blue;} | |
""" | |
with gr.Blocks(theme=gr.themes.Origin(),css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(" # StableDiffusion 3 Medium from Keras-hub") | |
expanded_prompt_output = gr.Textbox(label="Prompt", lines=1) # Add this line | |
with gr.Row(): | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
load_button = gr.Button("Load model", scale=0, variant="primary") | |
run_button_30 = gr.Button("Run 30", scale=0, variant="primary") | |
run_button_60 = gr.Button("Run 60", scale=0, variant="primary") | |
run_button_90 = gr.Button("Run 90", scale=0, variant="primary") | |
result = gr.Image(label="Result", show_label=False) | |
with gr.Accordion("Advanced Settings", open=True): | |
negative_prompt = gr.Text( | |
label="Negative prompt", | |
max_lines=1, | |
placeholder="Enter a negative prompt", | |
visible=True, | |
value="bad anatomy, poorly drawn hands, distorted face, blurry, out of frame, low resolution, grainy, pixelated, disfigured, mutated, extra limbs, bad composition" | |
) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=0.0, | |
maximum=30.0, | |
step=0.1, | |
value=4.2, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=500, | |
step=1, | |
value=50, | |
) | |
gr.on( | |
triggers=[load_button.click], | |
fn=load_model, | |
inputs=[], | |
outputs=[], | |
) | |
gr.on( | |
triggers=[run_button_30.click, prompt.submit], | |
fn=infer_30, | |
inputs=[ | |
prompt, | |
negative_prompt, | |
guidance_scale, | |
num_inference_steps, | |
], | |
outputs=[result, expanded_prompt_output], | |
) | |
gr.on( | |
triggers=[run_button_60.click, prompt.submit], | |
fn=infer_60, | |
inputs=[ | |
prompt, | |
negative_prompt, | |
guidance_scale, | |
num_inference_steps, | |
], | |
outputs=[result, expanded_prompt_output], | |
) | |
gr.on( | |
triggers=[run_button_90.click, prompt.submit], | |
fn=infer_90, | |
inputs=[ | |
prompt, | |
negative_prompt, | |
guidance_scale, | |
num_inference_steps, | |
], | |
outputs=[result, expanded_prompt_output], | |
) | |
if __name__ == "__main__": | |
demo.launch() |