Ghibliy / app.py
arrafaqat's picture
Update app.py
a457b7a verified
import spaces
import os
import json
import time
import torch
from PIL import Image
from tqdm import tqdm
import gradio as gr
from safetensors.torch import save_file
from src.pipeline import FluxPipeline
from src.transformer_flux import FluxTransformer2DModel
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
# Initialize the image processor
base_path = "black-forest-labs/FLUX.1-dev"
lora_base_path = "./models"
# Environment variable for API token (set this in your Hugging Face space settings)
API_TOKEN = os.environ.get("HF_TOKEN")
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe.transformer = transformer
pipe.to("cuda")
def clear_cache(transformer):
for name, attn_processor in transformer.attn_processors.items():
attn_processor.bank_kv.clear()
# Define the Gradio interface with token verification
@spaces.GPU()
def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type, api_token=""):
# Check if API token is required and valid
if API_TOKEN and api_token != API_TOKEN:
return "ERROR: Invalid API token. Please provide a valid token to generate images."
try:
# Ensure height and width are divisible by 8
height = int(height)
width = int(width)
if height % 8 != 0 or width % 8 != 0:
# Adjust to nearest multiple of 8
height = (height // 8) * 8
width = (width // 8) * 8
print(f"Dimensions adjusted to be divisible by 8: {height}x{width}")
# Set the control type
if control_type == "Ghibli":
lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
# Process the image
spatial_imgs = [spatial_img] if spatial_img else []
image = pipe(
prompt,
height=height,
width=width,
guidance_scale=3.5,
num_inference_steps=25,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(seed),
subject_images=[],
spatial_images=spatial_imgs,
cond_size=512,
).images[0]
clear_cache(pipe.transformer)
return image
except Exception as e:
error_message = f"Error during generation: {str(e)}"
print(error_message)
return f"ERROR: {error_message}"
# Define the Gradio interface components
control_types = ["Ghibli"]
# Create the Gradio Blocks interface
with gr.Blocks() as demo:
gr.Markdown("# Ghibli Studio Control Image Generation with EasyControl")
# Only show token field if API token is required
if API_TOKEN:
gr.Markdown("⚠️ **AUTHENTICATION REQUIRED**: Please enter your API token to use this service.")
api_token = gr.Textbox(label="API Token", type="password", value="")
else:
api_token = gr.Textbox(visible=False, value="") # Hidden field with empty value
gr.Markdown("The model is trained on **only 100 real Asian faces** paired with **GPT-4o-generated Ghibli-style counterparts**, and it preserves facial features while applying the iconic anime aesthetic.")
gr.Markdown("Generate images using EasyControl with Ghibli control LoRAs.(Due to hardware constraints, only low-resolution images can be generated. For high-resolution (1024+), please set up your own environment.)")
gr.Markdown("**[Attention!!]**:The recommended prompts for using Ghibli Control LoRA should include the trigger words: `Ghibli Studio style, Charming hand-drawn anime-style illustration`")
gr.Markdown("😊😊If you like this demo, please give us a star (github: [EasyControl](https://github.com/Xiaojiu-z/EasyControl))")
gr.Markdown("**NOTE**: Both height and width must be divisible by 8. Values will be automatically adjusted if needed.")
with gr.Tab("Ghibli Condition Generation"):
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, Charming hand-drawn anime-style illustration")
spatial_img = gr.Image(label="Ghibli Image", type="pil")
height = gr.Slider(minimum=256, maximum=1024, step=8, label="Height", value=768)
width = gr.Slider(minimum=256, maximum=1024, step=8, label="Width", value=768)
seed = gr.Number(label="Seed", value=42)
control_type = gr.Dropdown(choices=control_types, label="Control Type", value="Ghibli")
single_generate_btn = gr.Button("Generate Image")
with gr.Column():
single_output_image = gr.Image(label="Generated Image")
# Set up examples (with token automatically added if present)
example_inputs = [prompt, spatial_img, height, width, seed, control_type]
if API_TOKEN:
# Add token to examples for convenience
example_data = [
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 680, 1024, 5, "Ghibli", API_TOKEN],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 560, 1024, 42, "Ghibli", API_TOKEN],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/03.png"), 568, 1024, 1, "Ghibli", API_TOKEN],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/04.png"), 768, 672, 1, "Ghibli", API_TOKEN],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/06.png"), 896, 1024, 1, "Ghibli", API_TOKEN],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/07.png"), 528, 800, 1, "Ghibli", API_TOKEN],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/08.png"), 696, 1024, 1, "Ghibli", API_TOKEN],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/09.png"), 896, 1024, 1, "Ghibli", API_TOKEN],
]
example_inputs.append(api_token)
else:
# Use examples without token
example_data = [
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 680, 1024, 5, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 560, 1024, 42, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/03.png"), 568, 1024, 1, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/04.png"), 768, 672, 1, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/06.png"), 896, 1024, 1, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/07.png"), 528, 800, 1, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/08.png"), 696, 1024, 1, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/09.png"), 896, 1024, 1, "Ghibli"],
]
gr.Examples(
examples=example_data,
inputs=example_inputs,
outputs=single_output_image,
fn=single_condition_generate_image,
cache_examples=False,
label="Single Condition Examples"
)
# Link the buttons to the functions with API token included
inputs = [prompt, spatial_img, height, width, seed, control_type]
if API_TOKEN:
inputs.append(api_token)
single_generate_btn.click(
single_condition_generate_image,
inputs=inputs,
outputs=single_output_image
)
# Launch the Gradio app
demo.queue().launch()