Spaces:
Sleeping
Sleeping
import gradio as gr | |
import json | |
import os | |
import mimetypes | |
import google.generativeai as genai | |
from google.generativeai import types | |
from PIL import Image | |
import time | |
# --- Helper Function to Save Generated Image --- | |
def save_binary_file(directory, file_name, data): | |
"""Saves binary data to a file, creating the directory if needed.""" | |
if not os.path.exists(directory): | |
os.makedirs(directory) | |
file_path = os.path.join(directory, file_name) | |
with open(file_path, "wb") as f: | |
f.write(data) | |
print(f"File saved to: {file_path}") | |
return file_path | |
# --- Main Function to Generate Image --- | |
def generate_image( | |
api_key, | |
reference_image, | |
scene, | |
subject_type, | |
age_range, | |
hair, | |
makeup, | |
jewellery, | |
top, | |
bottom, | |
footwear, | |
wardrobe_notes, | |
pose_angle, | |
body_pose, | |
hands_pose, | |
framing, | |
camera_device, | |
flash, | |
orientation, | |
aspect_ratio, | |
distance, | |
focus, | |
texture, | |
sharpness, | |
color, | |
effects, | |
background_environment, | |
background_props, | |
style_genre, | |
authenticity, | |
use_original_structure, | |
face_description, | |
ban_mirror, | |
ban_phone, | |
ban_selfie, | |
ban_grainy, | |
ban_harsh_flash, | |
ban_logos, | |
ban_nsfw, | |
ban_cropped_feet, | |
output_count, | |
output_size, | |
safety, | |
variant_name, | |
variant_angle, | |
): | |
# --- Input Validation --- | |
if not api_key: | |
raise gr.Error("API Key is missing. Please enter your Gemini API key.") | |
if reference_image is None: | |
raise gr.Error("Reference image is missing. Please upload an image.") | |
# --- Build Banned List --- | |
banned_items = [] | |
if ban_mirror: banned_items.append("mirror") | |
if ban_phone: banned_items.append("phone") | |
if ban_selfie: banned_items.append("selfie look") | |
if ban_grainy: banned_items.append("grainy noise") | |
if ban_harsh_flash: banned_items.append("harsh LED flash") | |
if ban_logos: banned_items.append("logos/brand text") | |
if ban_nsfw: banned_items.append("nsfw") | |
if ban_cropped_feet: banned_items.append("cropped feet") | |
# --- Construct JSON Payload --- | |
output_json = { | |
"scene": scene, | |
"subject": {"type": subject_type, "age_range": age_range, "hair": hair, "makeup": makeup, "jewellery": jewellery}, | |
"wardrobe": {"top": top, "bottom": bottom, "footwear": footwear, "notes": wardrobe_notes}, | |
"pose": {"angle": pose_angle, "body": body_pose, "hands": hands_pose, "framing": framing}, | |
"camera": {"device": camera_device, "flash": flash, "orientation": orientation, "aspect_ratio": aspect_ratio, "distance": distance, "focus": focus}, | |
"look": {"texture": texture, "sharpness": sharpness, "color": color, "effects": effects}, | |
"background": {"environment": background_environment, "props": background_props}, | |
"style": {"genre": style_genre, "authenticity": authenticity}, | |
"reference_face": {"use_original_structure": use_original_structure, "description": face_description}, | |
"ban": banned_items, | |
"output": {"count": int(output_count), "size": output_size, "safety": safety}, | |
"variants": [{"name": variant_name, "angle": variant_angle}], | |
} | |
final_json_string = json.dumps(output_json, indent=4) | |
# --- Call Gemini API --- | |
try: | |
# Configure the client | |
client = genai.Client(api_key=api_key) | |
# Prepare the prompt parts (JSON instructions + reference image) | |
prompt_text_part = types.Part.from_text(text=final_json_string) | |
with open(reference_image, 'rb') as f: | |
image_data = f.read() | |
image_mime_type = mimetypes.guess_type(reference_image)[0] | |
image_part = types.Part.from_data(data=image_data, mime_type=image_mime_type) | |
# Define the model and generation config | |
model = "gemini-1.5-flash-latest" # Using a standard available model name | |
contents = [types.Content(role="user", parts=[prompt_text_part, image_part])] | |
generate_content_config = types.GenerateContentConfig( | |
response_modalities=["IMAGE", "TEXT"], | |
) | |
# --- Process Streaming Response --- | |
output_files = [] | |
output_directory = "generated_images" | |
timestamp = int(time.time()) | |
file_index = 0 | |
# Make the streaming API call | |
response_stream = client.models.generate_content_stream( | |
model=model, | |
contents=contents, | |
config=generate_content_config, | |
) | |
for chunk in response_stream: | |
if chunk.candidates and chunk.candidates[0].content and chunk.candidates[0].content.parts: | |
part = chunk.candidates[0].content.parts[0] | |
if part.inline_data and part.inline_data.data: | |
inline_data = part.inline_data | |
file_extension = mimetypes.guess_extension(inline_data.mime_type) | |
file_name = f"output_{timestamp}_{file_index}{file_extension}" | |
# Save the file and get its path | |
saved_file_path = save_binary_file(output_directory, file_name, inline_data.data) | |
output_files.append(saved_file_path) | |
file_index += 1 | |
elif part.text: | |
print(f"Received text chunk: {part.text}") | |
if not output_files: | |
return None, final_json_string, "No image was generated. Please check the model's response or your prompt." | |
# Return file paths for the Gallery and the JSON for inspection | |
return output_files, final_json_string, "Image generation complete." | |
except Exception as e: | |
# Handle potential errors gracefully | |
error_message = f"An error occurred: {str(e)}" | |
print(error_message) | |
raise gr.Error(error_message) | |
# --- Gradio Interface Definition --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# Gemini Image Generation Studio") | |
gr.Markdown("Use the tabs below to define your image, then click 'Generate Image' to call the API.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# --- Left Column for Inputs --- | |
with gr.Tabs(): | |
with gr.TabItem("π API & Image"): | |
api_key_input = gr.Textbox(label="Gemini API Key", type="password", info="Your API key is required to generate images.") | |
reference_image_input = gr.Image(label="Reference Image", type="filepath", info="Upload the base image for generation or editing.") | |
with gr.TabItem("π¨ Scene & Subject"): | |
scene_input = gr.Textbox(label="Scene", value="cinematic outdoor portrait; professional photography") | |
subject_type_input = gr.Textbox(label="Subject Type", value="adult woman (idol vibe)") | |
age_range_input = gr.Textbox(label="Age Range", value="20s") | |
hair_input = gr.Textbox(label="Hair", value="straight or styled natural open hair with natural shine") | |
makeup_input = gr.Textbox(label="Makeup", value="glossy lips, soft eyeliner, luminous skin") | |
jewellery_input = gr.Textbox(label="Jewellery", value="small hoops, thin chain, subtle bracelets") | |
with gr.TabItem("π Wardrobe"): | |
top_input = gr.Textbox(label="Top", value="basic tee or camisole") | |
bottom_input = gr.Textbox(label="Bottom", value="denim shorts or mini skirt") | |
footwear_input = gr.Textbox(label="Footwear", value="sneakers or ankle boots") | |
wardrobe_notes_input = gr.Textbox(label="Wardrobe Notes", value="casual modern look, styled for natural setting") | |
with gr.TabItem("π§ Pose & Framing"): | |
pose_angle_input = gr.Dropdown(label="Pose Angle", choices=["three-quarter", "full body"], value="three-quarter") | |
body_pose_input = gr.Textbox(label="Body Pose", value="standing or walking casually, relaxed natural posture") | |
hands_pose_input = gr.Textbox(label="Hands Pose", value="one resting by side or touching hair, the other relaxed") | |
framing_input = gr.Dropdown(label="Framing", choices=["head-to-toe", "waist-up", "cinematic composition"], value="waist-up") | |
with gr.TabItem("π· Camera & Look"): | |
camera_device_input = gr.Textbox(label="Camera Device", value="professional cinema camera / DSLR with prime lens") | |
flash_input = gr.Textbox(label="Flash", value="none; natural golden hour light or soft reflectors") | |
orientation_input = gr.Dropdown(label="Orientation", choices=["vertical", "horizontal"], value="vertical") | |
aspect_ratio_input = gr.Dropdown(label="Aspect Ratio", choices=["16:9", "3:2", "4:3", "1:1"], value="16:9") | |
distance_input = gr.Textbox(label="Distance", value="cinematic portrait distance with shallow depth") | |
focus_input = gr.Textbox(label="Focus", value="sharp on subject; soft bokeh background") | |
texture_input = gr.Textbox(label="Texture", value="smooth high-resolution detail") | |
sharpness_input = gr.Textbox(label="Sharpness", value="very high; crisp cinematic clarity") | |
color_input = gr.Textbox(label="Color", value="warm cinematic grading; golden tones and soft contrast") | |
effects_input = gr.Textbox(label="Effects", value="subtle film grain; natural light flares, depth of field") | |
with gr.TabItem("π³ Background & Style"): | |
background_environment_input = gr.Textbox(label="Background Environment", value="nature setting β forest, park, or meadow with soft light") | |
background_props_input = gr.Textbox(label="Background Props", value="none; focus on subject against natural backdrop") | |
style_genre_input = gr.Textbox(label="Style Genre", value="cinematic portrait photography") | |
authenticity_input = gr.Textbox(label="Authenticity", value="natural, elegant, polished") | |
with gr.TabItem("π€ Face & Bans"): | |
use_original_structure_input = gr.Checkbox(label="Use Original Face Structure", value=True) | |
face_description_input = gr.Textbox(label="Face Description", value="maintain the same face shape, features, and proportions as in the provided reference image") | |
gr.Markdown("#### Banned Items") | |
with gr.Row(): | |
ban_mirror_input = gr.Checkbox(label="Mirror") | |
ban_phone_input = gr.Checkbox(label="Phone") | |
ban_selfie_input = gr.Checkbox(label="Selfie Look") | |
ban_grainy_input = gr.Checkbox(label="Grainy Noise") | |
with gr.Row(): | |
ban_harsh_flash_input = gr.Checkbox(label="Harsh Flash") | |
ban_logos_input = gr.Checkbox(label="Logos") | |
ban_nsfw_input = gr.Checkbox(label="NSFW") | |
ban_cropped_feet_input = gr.Checkbox(label="Cropped Feet") | |
with gr.TabItem("βοΈ Output & Variants"): | |
output_count_input = gr.Slider(label="Output Count", minimum=1, maximum=4, step=1, value=1) | |
output_size_input = gr.Textbox(label="Output Size", value="1024x1024") | |
safety_input = gr.Dropdown(label="Safety", choices=["strict", "moderate", "none"], value="strict") | |
variant_name_input = gr.Textbox(label="Variant Name", value="cinematic_nature_fullbody") | |
variant_angle_input = gr.Textbox(label="Variant Angle", value="full-body shot in meadow or forest path, subject centered with depth of field") | |
with gr.Column(scale=1): | |
# --- Right Column for Outputs --- | |
generate_button = gr.Button("Generate Image", variant="primary") | |
status_text = gr.Textbox(label="Status", interactive=False) | |
image_gallery = gr.Gallery(label="Generated Image(s)", show_label=True, elem_id="gallery", columns=[2], rows=[2], object_fit="contain", height="auto") | |
json_output = gr.JSON(label="Generated JSON Input") | |
# --- Button Click Action --- | |
all_inputs = [ | |
api_key_input, reference_image_input, scene_input, subject_type_input, | |
age_range_input, hair_input, makeup_input, jewellery_input, top_input, | |
bottom_input, footwear_input, wardrobe_notes_input, pose_angle_input, | |
body_pose_input, hands_pose_input, framing_input, camera_device_input, | |
flash_input, orientation_input, aspect_ratio_input, distance_input, | |
focus_input, texture_input, sharpness_input, color_input, effects_input, | |
background_environment_input, background_props_input, style_genre_input, | |
authenticity_input, use_original_structure_input, face_description_input, | |
ban_mirror_input, ban_phone_input, ban_selfie_input, ban_grainy_input, | |
ban_harsh_flash_input, ban_logos_input, ban_nsfw_input, | |
ban_cropped_feet_input, output_count_input, output_size_input, | |
safety_input, variant_name_input, variant_angle_input | |
] | |
generate_button.click( | |
fn=generate_image, | |
inputs=all_inputs, | |
outputs=[image_gallery, json_output, status_text], | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |