1inkusFace's picture
Update app.py
a86abeb verified
import spaces
import os
import subprocess
#subprocess.run(['sh', './jax.sh'])
'''
subprocess.run(['sh', './conda.sh'])
import sys
conda_prefix = os.path.expanduser("~/miniconda3")
conda_bin = os.path.join(conda_prefix, "bin")
# Add Conda's bin directory to your PATH
os.environ["PATH"] = conda_bin + os.pathsep + os.environ["PATH"]
# Activate the base environment (adjust if needed)
os.system(f'{conda_bin}/conda init --all')
os.system(f'{conda_bin}/conda activate base')
os.system(f'{conda_bin}/conda install nvidia/label/cudnn-9.3.0::cudnn')
'''
import gradio as gr
import numpy as np
import paramiko
from image_gen_aux import UpscaleWithModel
import cyper
from PIL import Image
os.environ['JAX_PLATFORMS'] = 'cpu'
import random
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
import keras
import keras_hub
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
#torch.backends.cuda.preferred_blas_library="cublas"
#torch.backends.cuda.preferred_linalg_library="cusolver"
torch.set_float32_matmul_precision("highest")
upscaler_2 = None # UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(device)
text_to_image = None
@spaces.GPU(duration=60)
def load_model():
global upscaler_2
upscaler_2 = UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(device)
global text_to_image
text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset(
"stable_diffusion_3_medium", width=768, height=768, dtype="bfloat16"
)
return text_to_image
code = r'''
import paramiko
import os
FTP_HOST = '1ink.us'
FTP_USER = 'ford442'
FTP_PASS = os.getenv("FTP_PASS")
FTP_DIR = '1ink.us/stable_diff/'
def upload_to_ftp(filename):
try:
transport = paramiko.Transport((FTP_HOST, 22))
destination_path=FTP_DIR+filename
transport.connect(username = FTP_USER, password = FTP_PASS)
sftp = paramiko.SFTPClient.from_transport(transport)
sftp.put(filename, destination_path)
sftp.close()
transport.close()
print(f"Uploaded {filename} to FTP server")
except Exception as e:
print(f"FTP upload error: {e}")
'''
pyx = cyper.inline(code, fast_indexing=True, directives=dict(boundscheck=False, wraparound=False, language_level=3))
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 4096
@spaces.GPU(duration=40)
def infer_30(
prompt,
negative_prompt,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
global text_to_image
if text_to_image is None:
text_to_image = load_model()
os.environ['JAX_PLATFORMS'] = 'gpu'
os.environ['KERAS_BACKEND'] = 'jax'
seed = random.randint(0, MAX_SEED)
sd_image = text_to_image.generate(
prompt,
num_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=seed
)
print('-- got image --')
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
sd35_path = f"sd3keras_{timestamp}.png"
sd_image.save(sd35_path,optimize=False,compress_level=0)
pyx.upload_to_ftp(sd35_path)
with torch.no_grad():
upscale2 = upscaler_2(sd_image, tiling=True, tile_width=256, tile_height=256)
print('-- got upscaled image --')
downscale2 = upscale2.resize((upscale2.width // 4, upscale2.height // 4),Image.LANCZOS)
upscale_path = f"sd3keras_upscale_{timestamp}.png"
downscale2.save(upscale_path,optimize=False,compress_level=0)
pyx.upload_to_ftp(upscale_path)
return sd_image, prompt
@spaces.GPU(duration=70)
def infer_60(
prompt,
negative_prompt,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
global text_to_image
if text_to_image is None:
text_to_image = load_model()
os.environ['JAX_PLATFORMS'] = 'gpu'
os.environ['KERAS_BACKEND'] = 'jax'
seed = random.randint(0, MAX_SEED)
sd_image = text_to_image.generate(
prompt,
num_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=seed
)
print('-- got image --')
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
sd35_path = f"sd3keras_{timestamp}.png"
sd_image.save(sd35_path,optimize=False,compress_level=0)
pyx.upload_to_ftp(sd35_path)
with torch.no_grad():
upscale2 = upscaler_2(sd_image, tiling=True, tile_width=256, tile_height=256)
print('-- got upscaled image --')
downscale2 = upscale2.resize((upscale2.width // 4, upscale2.height // 4),Image.LANCZOS)
upscale_path = f"sd3keras_upscale_{timestamp}.png"
downscale2.save(upscale_path,optimize=False,compress_level=0)
pyx.upload_to_ftp(upscale_path)
return sd_image, prompt
@spaces.GPU(duration=100)
def infer_90(
prompt,
negative_prompt,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
global text_to_image
if text_to_image is None:
text_to_image = load_model()
os.environ['JAX_PLATFORMS'] = 'gpu'
os.environ['KERAS_BACKEND'] = 'jax'
seed = random.randint(0, MAX_SEED)
sd_image = text_to_image.generate(
prompt,
num_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=seed
)
print('-- got image --')
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
sd35_path = f"sd3keras_{timestamp}.png"
sd_image.save(sd35_path,optimize=False,compress_level=0)
pyx.upload_to_ftp(sd35_path)
with torch.no_grad():
upscale2 = upscaler_2(sd_image, tiling=True, tile_width=256, tile_height=256)
print('-- got upscaled image --')
downscale2 = upscale2.resize((upscale2.width // 4, upscale2.height // 4),Image.LANCZOS)
upscale_path = f"sd3keras_upscale_{timestamp}.png"
downscale2.save(upscale_path,optimize=False,compress_level=0)
pyx.upload_to_ftp(upscale_path)
return sd_image, prompt
css = """
#col-container {margin: 0 auto;max-width: 640px;}
body{background-color: blue;}
"""
with gr.Blocks(theme=gr.themes.Origin(),css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # StableDiffusion 3 Medium from Keras-hub")
expanded_prompt_output = gr.Textbox(label="Prompt", lines=1) # Add this line
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
load_button = gr.Button("Load model", scale=0, variant="primary")
run_button_30 = gr.Button("Run 30", scale=0, variant="primary")
run_button_60 = gr.Button("Run 60", scale=0, variant="primary")
run_button_90 = gr.Button("Run 90", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=True):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=True,
value="bad anatomy, poorly drawn hands, distorted face, blurry, out of frame, low resolution, grainy, pixelated, disfigured, mutated, extra limbs, bad composition"
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=30.0,
step=0.1,
value=4.2,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=500,
step=1,
value=50,
)
gr.on(
triggers=[load_button.click],
fn=load_model,
inputs=[],
outputs=[],
)
gr.on(
triggers=[run_button_30.click, prompt.submit],
fn=infer_30,
inputs=[
prompt,
negative_prompt,
guidance_scale,
num_inference_steps,
],
outputs=[result, expanded_prompt_output],
)
gr.on(
triggers=[run_button_60.click, prompt.submit],
fn=infer_60,
inputs=[
prompt,
negative_prompt,
guidance_scale,
num_inference_steps,
],
outputs=[result, expanded_prompt_output],
)
gr.on(
triggers=[run_button_90.click, prompt.submit],
fn=infer_90,
inputs=[
prompt,
negative_prompt,
guidance_scale,
num_inference_steps,
],
outputs=[result, expanded_prompt_output],
)
if __name__ == "__main__":
demo.launch()