File size: 3,282 Bytes
6e0708a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import gradio as gr
from dotenv import load_dotenv
from google import genai
from google.genai import types

load_dotenv()

def save_binary_file(file_name, data):
    with open(file_name, "wb") as f:
        f.write(data)
    return file_name

def generate_couple_photo(input_image):
    # Check if API key is set
    api_key = os.getenv("GEMINI_API_KEY")
    if not api_key:
        return None, "Error: GEMINI_API_KEY environment variable not set"

    client = genai.Client(api_key=api_key)

    try:
        # Upload the input image
        uploaded_file = client.files.upload(file=input_image)

        # Default prompt if none provided
        
        prompt_text = os.getenv("PROMPT")

        model = "gemini-2.0-flash-exp-image-generation"
        contents = [
            types.Content(
                role="user",
                parts=[
                    types.Part.from_uri(
                        file_uri=uploaded_file.uri,
                        mime_type=uploaded_file.mime_type,
                    ),
                    types.Part.from_text(text=prompt_text),
                ],
            ),
        ]

        generate_content_config = types.GenerateContentConfig(
            temperature=1,
            top_p=0.95,
            top_k=40,
            max_output_tokens=8192,
            response_modalities=["image", "text"],
            safety_settings=[
                types.SafetySetting(
                    category="HARM_CATEGORY_CIVIC_INTEGRITY",
                    threshold="OFF",
                ),
            ],
            response_mime_type="text/plain",
        )

        # Generate content
        response = client.models.generate_content(
            model=model,
            contents=contents,
            config=generate_content_config,
        )

        # Process response
        output_text = ""
        output_image_path = None

        for part in response.candidates[0].content.parts:
            if part.text is not None:
                output_text += part.text
            elif part.inline_data is not None:
                # Save the generated image
                output_image_path = save_binary_file(
                    "generated_couple_photo.jpg",
                    part.inline_data.data
                )

        return output_image_path, output_text

    except Exception as e:
        return None, f"Error: {str(e)}"

# Create Gradio interface
def create_interface():
    with gr.Blocks(title="Couple Photo Generator") as demo:
        gr.Markdown("# Couple Photo Generator")
        gr.Markdown("Upload an image and get a generated couple photo version of it.")

        with gr.Row():
            with gr.Column():
                input_image = gr.Image(type="filepath", label="Upload Image")
                submit_btn = gr.Button("Generate Couple Photo")

            with gr.Column():
                output_image = gr.Image(label="Generated Couple Photo")
                output_text = gr.Textbox(label="Generation Notes")

        submit_btn.click(
            fn=generate_couple_photo,
            inputs=[input_image],
            outputs=[output_image, output_text]
        )

    return demo

if __name__ == "__main__":
    # Launch the app
    demo = create_interface()
    demo.launch()