File size: 3,967 Bytes
cfefca9
 
6fd2572
d06af8a
 
874da76
1fb31cc
5c7e9d3
cfefca9
79cd377
cfefca9
5c7e9d3
79cd377
a1c96c0
 
 
 
79cd377
a1c96c0
79cd377
12af388
79cd377
 
 
6f43826
79cd377
 
a1c96c0
54eb6cc
 
 
cfefca9
 
a1c96c0
cfefca9
 
54eb6cc
 
 
 
874da76
 
 
 
 
 
54eb6cc
6fd2572
 
 
 
 
 
 
 
 
cfefca9
1fb31cc
 
e765d2d
874da76
a1c96c0
12af388
a1c96c0
874da76
 
a1c96c0
fe37c2f
874da76
 
 
 
 
 
 
 
 
 
 
79cd377
 
 
a1c96c0
 
79cd377
 
12af388
79cd377
 
 
 
 
874da76
047cb5b
e765d2d
cfefca9
a1c96c0
 
 
 
 
 
874da76
 
79cd377
54eb6cc
874da76
cfefca9
ff04f46
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import fal_client
from fal_client.client import FalClientError
import requests
from PIL import Image
from io import BytesIO
import traceback
import os

def generate_image(api_key, prompt, image_size, seed, sync_mode, num_images, enable_safety_checker, safety_tolerance):
    try:
        os.environ['FAL_KEY'] = api_key

        arguments = {
            "prompt": prompt,
            "image_size": image_size,
            "num_images": num_images,
            "enable_safety_checker": enable_safety_checker,
        }

        arguments["safety_tolerance"] = safety_tolerance

        if seed is not None and seed != "":
            arguments["seed"] = int(seed)
        
        if sync_mode is not None:
            arguments["sync_mode"] = sync_mode

        # Log the actual request body
        print(f"Request Body: {arguments}")

        handler = fal_client.submit(
            "fal-ai/flux-pro/v1.1",
            arguments=arguments,
        )
        result = handler.get()

        # Display and log the response
        print(f"Response: {result}")

        images = []
        for img_info in result['images']:
            img_url = img_info['url']
            response = requests.get(img_url)
            img = Image.open(BytesIO(response.content))
            images.append(img)
        return [gr.update(value=images, visible=True), gr.update(value=str(result), visible=True)]

    except FalClientError as e:
        error_messages = []
        for error_obj in e.args[0]:
            error_messages.append(error_obj['msg'])
        
        error_msg = "Errors:\n" + "\n".join(error_messages)
        print(error_msg)
        return [gr.update(value=[]), gr.update(value=error_msg)]
    except Exception as e:
        error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
        print(error_msg)
        return [gr.update(value=[]), gr.update(value=error_msg)]

def update_safety_tolerance_visibility(enable_safety):
    return gr.update(visible=enable_safety, value="6")

with gr.Blocks() as demo:
    gr.Markdown("# FLUX1.1 [pro] Text-to-Image Generator")
    gr.Markdown("Get your API key at https://fal.ai/dashboard/keys")

    with gr.Row():
        api_key = gr.Textbox(label="API Key", type="password", placeholder="Enter your API key here")
    with gr.Row():
        prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Enter your prompt here")
    with gr.Row():
        image_size = gr.Dropdown(
            label="Image Size",
            choices=["square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"],
            value="landscape_4_3"
        )
        num_images = gr.Slider(label="Number of Images", minimum=1, maximum=4, step=1, value=1)
    with gr.Row():
        seed = gr.Textbox(label="Seed (optional)", placeholder="Enter a number for reproducible results")
        sync_mode = gr.Checkbox(label="Sync Mode", value=False)
    with gr.Row():
        enable_safety_checker = gr.Checkbox(label="Enable Safety Checker", value=True)
        safety_tolerance = gr.Dropdown(
            label="Safety Tolerance",
            choices=["1", "2", "3", "4", "5"],
            value="2",
            visible=True
        )
    gr.Markdown("**Note:** Safety Tolerance: 1 is the most strict, 6 is the most permissive. Default is 2.")
    
    generate_btn = gr.Button("Generate Image")
    output_gallery = gr.Gallery(label="Generated Images", columns=2, rows=2)
    response_output = gr.Textbox(label="Response", visible=True)

    enable_safety_checker.change(
        fn=update_safety_tolerance_visibility,
        inputs=[enable_safety_checker],
        outputs=[safety_tolerance]
    )

    generate_btn.click(
        fn=generate_image,
        inputs=[api_key, prompt, image_size, seed, sync_mode, num_images, enable_safety_checker, safety_tolerance],
        outputs=[output_gallery, response_output]
    )

if __name__ == "__main__":
    demo.launch()