Ghibliy / app.py
arrafaqat's picture
Update app.py
f7fb3ee verified
raw
history blame
7.94 kB
import spaces
import os
import json
import time
import torch
from PIL import Image
from tqdm import tqdm
import gradio as gr
import base64
from flask import Flask, send_file, request, Response, jsonify
from safetensors.torch import save_file
from src.pipeline import FluxPipeline
from src.transformer_flux import FluxTransformer2DModel
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
# Initialize the image processor
base_path = "black-forest-labs/FLUX.1-dev"
lora_base_path = "./models"
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe.transformer = transformer
pipe.to("cuda")
def clear_cache(transformer):
for name, attn_processor in transformer.attn_processors.items():
attn_processor.bank_kv.clear()
# Define the Gradio interface
@spaces.GPU()
def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type):
# Set the control type
if control_type == "Ghibli":
lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
# Process the image
spatial_imgs = [spatial_img] if spatial_img else []
image = pipe(
prompt,
height=int(height),
width=int(width),
guidance_scale=3.5,
num_inference_steps=25,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(seed),
subject_images=[],
spatial_images=spatial_imgs,
cond_size=512,
).images[0]
clear_cache(pipe.transformer)
return image
# Define the Gradio interface components
control_types = ["Ghibli"]
# Example data
single_examples = [
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 680, 1024, 5, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 560, 1024, 42, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/03.png"), 568, 1024, 1, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/04.png"), 768, 672, 1, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/06.png"), 896, 1024, 1, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/07.png"), 528, 800, 1, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/08.png"), 696, 1024, 1, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/09.png"), 896, 1024, 1, "Ghibli"],
]
# Create the Gradio Blocks interface
with gr.Blocks() as demo:
gr.Markdown("# Ghibli Studio Control Image Generation with EasyControl")
gr.Markdown("The model is trained on **only 100 real Asian faces** paired with **GPT-4o-generated Ghibli-style counterparts**, and it preserves facial features while applying the iconic anime aesthetic.")
gr.Markdown("Generate images using EasyControl with Ghibli control LoRAs.(Due to hardware constraints, only low-resolution images can be generated. For high-resolution (1024+), please set up your own environment.)")
gr.Markdown("**[Attention!!]**:The recommended prompts for using Ghibli Control LoRA should include the trigger words: `Ghibli Studio style, Charming hand-drawn anime-style illustration`")
gr.Markdown("😊😊If you like this demo, please give us a star (github: [EasyControl](https://github.com/Xiaojiu-z/EasyControl))")
with gr.Tab("Ghibli Condition Generation"):
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, Charming hand-drawn anime-style illustration")
spatial_img = gr.Image(label="Ghibli Image", type="pil") # 上传图像文件
height = gr.Slider(minimum=256, maximum=1024, step=64, label="Height", value=768)
width = gr.Slider(minimum=256, maximum=1024, step=64, label="Width", value=768)
seed = gr.Number(label="Seed", value=42)
control_type = gr.Dropdown(choices=control_types, label="Control Type")
single_generate_btn = gr.Button("Generate Image")
with gr.Column():
single_output_image = gr.Image(label="Generated Image")
# Add examples for Single Condition Generation
gr.Examples(
examples=single_examples,
inputs=[prompt, spatial_img, height, width, seed, control_type],
outputs=single_output_image,
fn=single_condition_generate_image,
cache_examples=False, # 缓存示例结果以加快加载速度
label="Single Condition Examples"
)
# Link the buttons to the functions
single_generate_btn.click(
single_condition_generate_image,
inputs=[prompt, spatial_img, height, width, seed, control_type],
outputs=single_output_image
)
# Add proxy endpoints to access temporary files
with gr.Tab("API Endpoints"):
gr.Markdown("## Image Proxy Endpoints")
gr.Markdown("The following endpoints allow external access to generated images:")
gr.Markdown("- `/proxy-image?path=PATH` - Access a generated image directly")
gr.Markdown("- `/proxy-image-base64?path=PATH` - Get a generated image as base64")
gr.Markdown("These endpoints should only be used programmatically.")
# Create Flask app for proxy endpoints
# This allows Gradio to handle its routes while we add custom Flask routes
app = gr.flask_app
@app.route('/proxy-image')
def proxy_image():
"""
Proxies an image from a local file path
Example usage: /proxy-image?path=/tmp/gradio/4f461b7833afb048edcda4fd8968db08aeb352e871a5445721d63b11f9f489c7/image.webp
"""
file_path = request.args.get('path')
if not file_path:
return "No path specified", 400
# Security check to prevent directory traversal attacks
if '..' in file_path:
return "Invalid path", 400
# Ensure path starts with a slash
if not file_path.startswith('/'):
file_path = '/' + file_path
try:
# Read the file
with open(file_path, 'rb') as f:
file_data = f.read()
# Return the file
return Response(file_data, mimetype='image/webp')
except Exception as e:
return f"Error accessing file: {str(e)}", 404
@app.route('/proxy-image-base64')
def proxy_image_base64():
"""
Proxies an image from a local file path and returns it as base64
Example usage: /proxy-image-base64?path=/tmp/gradio/4f461b7833afb048edcda4fd8968db08aeb352e871a5445721d63b11f9f489c7/image.webp
"""
file_path = request.args.get('path')
if not file_path:
return jsonify({"error": "No path specified"}), 400
# Security check to prevent directory traversal attacks
if '..' in file_path:
return jsonify({"error": "Invalid path"}), 400
# Ensure path starts with a slash
if not file_path.startswith('/'):
file_path = '/' + file_path
try:
# Read the file
with open(file_path, 'rb') as f:
file_data = f.read()
# Convert to base64
base64_data = base64.b64encode(file_data).decode('utf-8')
# Return the base64 data
return jsonify({"base64": base64_data, "mime_type": "image/webp"})
except Exception as e:
return jsonify({"error": str(e)}), 404
# Launch the Gradio app
demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True)