|
import base64 |
|
from dataclasses import dataclass |
|
from io import BytesIO |
|
from pathlib import Path |
|
from typing import Literal, cast |
|
|
|
import gradio as gr |
|
import jinja2 |
|
from openai import OpenAI |
|
from PIL import Image |
|
from pydantic import BaseModel |
|
|
|
client = OpenAI() |
|
|
|
TEMPLATES_DIR = Path(__file__).resolve().parent / "templates" |
|
jinja_env = jinja2.Environment(loader=jinja2.FileSystemLoader(str(TEMPLATES_DIR))) |
|
|
|
SYSTEM_PROMPT = "You are expert prompt engineer" |
|
|
|
StyleName = Literal[ |
|
"General", |
|
"Fashion", |
|
"Emotional Lifestyle", |
|
"Extreme Sports", |
|
"Captivating", |
|
"Image Replication", |
|
"Red Bar Lighting", |
|
"Teal Noir", |
|
] |
|
|
|
|
|
@dataclass(frozen=True) |
|
class StyleDefinition: |
|
name: StyleName |
|
template_filename: str |
|
info: str |
|
|
|
|
|
STYLE_DEFINITIONS: dict[StyleName, StyleDefinition] = { |
|
"General": StyleDefinition( |
|
name="General", |
|
template_filename="general_prompt.jinja", |
|
info="Versatile, balanced storytelling with cinematic detail for most scenarios.", |
|
), |
|
"Fashion": StyleDefinition( |
|
name="Fashion", |
|
template_filename="fashion_prompt.jinja", |
|
info="Editorial fashion aesthetic highlighting garments, styling, and runway polish.", |
|
), |
|
"Emotional Lifestyle": StyleDefinition( |
|
name="Emotional Lifestyle", |
|
template_filename="emotional_lifestyle_prompt.jinja", |
|
info="Warm, candid lifestyle imagery that focuses on mood, relationships, and feelings.", |
|
), |
|
"Extreme Sports": StyleDefinition( |
|
name="Extreme Sports", |
|
template_filename="extreme_sports_prompt.jinja", |
|
info="High-adrenaline action shots that emphasize energy, motion, and athletic feats.", |
|
), |
|
"Captivating": StyleDefinition( |
|
name="Captivating", |
|
template_filename="captivating_prompt.jinja", |
|
info="Visually striking compositions with dramatic flair and memorable storytelling.", |
|
), |
|
"Image Replication": StyleDefinition( |
|
name="Image Replication", |
|
template_filename="image_replication_prompt.jinja", |
|
info=( |
|
"Mimic the reference image's composition, lighting, and styling exactly while" |
|
" inserting the user or their face in place of the original subject. Eg. If the reference image is a music album cover, the user's face will be embedded in the album cover." |
|
), |
|
), |
|
"Red Bar Lighting": StyleDefinition( |
|
name="Red Bar Lighting", |
|
template_filename="red_bar_lighting_prompt.jinja", |
|
info="Red bar lighting style for image generation.", |
|
), |
|
"Teal Noir": StyleDefinition( |
|
name="Teal Noir", |
|
template_filename="teal_noir_prompt.jinja", |
|
info="Teal noir style for image generation.", |
|
) |
|
} |
|
|
|
PROMPT_TEMPLATES = { |
|
style: jinja_env.get_template(config.template_filename) |
|
for style, config in STYLE_DEFINITIONS.items() |
|
} |
|
|
|
DEFAULT_STYLE: StyleName = "General" |
|
STYLE_CHOICES: tuple[StyleName, ...] = tuple(STYLE_DEFINITIONS.keys()) |
|
|
|
STYLE_INFORMATION_BLOCK = "\n".join( |
|
f"- {style}: {config.info}" for style, config in STYLE_DEFINITIONS.items() |
|
) |
|
|
|
|
|
class StyleSelectionResponse(BaseModel): |
|
style: StyleName |
|
|
|
|
|
def process_prompt(user_image, reference_image, target_label: str, user_prompt: str, style: StyleName) -> str: |
|
user_image_url = None |
|
reference_image_url = None |
|
|
|
if user_image is not None: |
|
buffer = BytesIO() |
|
user_image.convert("RGB").save(buffer, format="JPEG", quality=90) |
|
b64_image = base64.b64encode(buffer.getvalue()).decode("utf-8") |
|
user_image_url = f"data:image/jpeg;base64,{b64_image}" |
|
|
|
if reference_image is not None: |
|
buffer = BytesIO() |
|
reference_image.convert("RGB").save(buffer, format="JPEG", quality=90) |
|
b64_reference_image = base64.b64encode(buffer.getvalue()).decode("utf-8") |
|
reference_image_url = f"data:image/jpeg;base64,{b64_reference_image}" |
|
|
|
try: |
|
template = PROMPT_TEMPLATES[style] |
|
except KeyError as error: |
|
raise ValueError(f"Unsupported style: {style}") from error |
|
|
|
user_content = template.render(user_prompt=user_prompt) |
|
|
|
content = [{"type": "input_text", "text": user_content}] |
|
|
|
if user_image_url is not None: |
|
content.append({"type": "input_image", "image_url": user_image_url}) |
|
if reference_image_url is not None: |
|
content.append({"type": "input_image", "image_url": reference_image_url}) |
|
|
|
response = client.responses.create( |
|
model="gpt-5", |
|
reasoning={"effort": "minimal"}, |
|
input=[ |
|
{ |
|
"role": "system", |
|
"content": SYSTEM_PROMPT, |
|
}, |
|
{ |
|
"role": "user", |
|
"content": content, |
|
} |
|
], |
|
) |
|
return f"{response.output_text} {target_label.strip()}" |
|
|
|
|
|
def recommend_style(user_prompt: str, reference_image: Image.Image | None) -> StyleSelectionResponse: |
|
if reference_image is not None: |
|
buffer = BytesIO() |
|
reference_image.convert("RGB").save(buffer, format="JPEG", quality=90) |
|
b64_reference_image = base64.b64encode(buffer.getvalue()).decode("utf-8") |
|
reference_image_url = f"data:image/jpeg;base64,{b64_reference_image}" |
|
else: |
|
reference_image_url = None |
|
|
|
|
|
user_prompt = f"""You are an art director who must pick the most fitting style name for a user's prompt. |
|
Consider the available styles and choose the single best option. User has provided the reference image. |
|
|
|
Style Guide: |
|
{STYLE_INFORMATION_BLOCK} |
|
|
|
User Prompt: |
|
{user_prompt} |
|
""" |
|
content = [{"type": "input_text", "text": user_prompt}] |
|
if reference_image_url is not None: |
|
content.append({ |
|
"type": "input_image", "image_url": reference_image_url |
|
}) |
|
completion = client.responses.parse( |
|
model="gpt-5-mini", |
|
reasoning={"effort": "low"}, |
|
input=[{ |
|
"role": "user", |
|
"content": content, |
|
}], |
|
text_format=StyleSelectionResponse, |
|
) |
|
return completion.output_parsed.style |
|
|
|
|
|
def handle_auto_style_toggle(auto_enabled: bool) -> dict[str, object]: |
|
return gr.update(interactive=not auto_enabled) |
|
|
|
|
|
def generate_prompt_handler( |
|
user_image, |
|
reference_image, |
|
target_label: str, |
|
user_prompt: str, |
|
current_style: str | None, |
|
auto_style_enabled: bool, |
|
): |
|
|
|
if auto_style_enabled: |
|
current_style = recommend_style(user_prompt, reference_image) |
|
|
|
prompt_text = process_prompt( |
|
user_image=user_image, |
|
reference_image=reference_image, |
|
target_label=target_label, |
|
user_prompt=user_prompt, |
|
style=current_style, |
|
) |
|
display_text = f"Selected style: {current_style}\n\n{prompt_text}" |
|
return display_text, gr.update(value=current_style, interactive=False) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
user_image = gr.Image( |
|
label="Upload user photo", |
|
type="pil" |
|
) |
|
reference_image = gr.Image( |
|
label="Optional: Upload reference image (Eg. movie poster, music album cover, etc.)", |
|
type="pil", |
|
) |
|
target_label = gr.Textbox( |
|
label="Enter target label", |
|
placeholder="SMRA", |
|
) |
|
user_prompt = gr.Textbox( |
|
label="Enter your prompt", |
|
placeholder="picture of me while sitting in a chair in the ocean", |
|
lines=4, |
|
) |
|
style_dropdown = gr.Dropdown( |
|
choices=list(STYLE_CHOICES), |
|
value=DEFAULT_STYLE, |
|
label="Style Selection", |
|
info="Choose the visual style for your enhanced prompt", |
|
interactive=False, |
|
) |
|
auto_style_checkbox = gr.Checkbox( |
|
label="Auto-select best style", |
|
value=True, |
|
) |
|
generate_button = gr.Button("Generate Prompt") |
|
with gr.Column(): |
|
prompt_output = gr.Textbox( |
|
label="Style Prompt", |
|
lines=20, |
|
) |
|
|
|
generate_button.click( |
|
generate_prompt_handler, |
|
inputs=[ |
|
user_image, |
|
reference_image, |
|
target_label, |
|
user_prompt, |
|
style_dropdown, |
|
auto_style_checkbox, |
|
], |
|
outputs=[prompt_output, style_dropdown], |
|
) |
|
auto_style_checkbox.change( |
|
handle_auto_style_toggle, |
|
inputs=[auto_style_checkbox], |
|
outputs=[style_dropdown], |
|
) |
|
|
|
demo.launch() |
|
|