sumityadav329 commited on
Commit
114c45d
·
verified ·
1 Parent(s): ada9c6c

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -108
app.py CHANGED
@@ -1,122 +1,111 @@
 
1
  import gradio as gr
2
- import numpy as np
3
- import random
4
- import spaces
5
- import torch
6
- from diffusers import DiffusionPipeline
7
 
8
- dtype = torch.bfloat16
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
-
11
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
12
-
13
- MAX_SEED = np.iinfo(np.int32).max
14
- MAX_IMAGE_SIZE = 2048
15
-
16
- @spaces.GPU()
17
- def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
18
- if randomize_seed:
19
- seed = random.randint(0, MAX_SEED)
20
- generator = torch.Generator().manual_seed(seed)
21
- image = pipe(
22
- prompt = prompt,
23
- width = width,
24
- height = height,
25
- num_inference_steps = num_inference_steps,
26
- generator = generator,
27
- guidance_scale=0.0
28
- ).images[0]
29
- return image, seed
30
-
31
- examples = [
32
- "a tiny astronaut hatching from an egg on the moon",
33
- "a cat holding a sign that says hello world",
34
- "an anime illustration of a wiener schnitzel",
35
- ]
36
-
37
- css="""
38
- #col-container {
39
- margin: 0 auto;
40
- max-width: 520px;
41
- }
42
- """
43
-
44
- with gr.Blocks(css=css) as demo:
45
 
46
- with gr.Column(elem_id="col-container"):
47
- gr.Markdown(f"""# FLUX.1 [schnell]
48
- 12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation
49
- [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-schnell)]
50
- """)
 
51
 
52
- with gr.Row():
53
-
54
- prompt = gr.Text(
55
- label="Prompt",
56
- show_label=False,
57
- max_lines=1,
58
- placeholder="Enter your prompt",
59
- container=False,
60
- )
61
-
62
- run_button = gr.Button("Run", scale=0)
63
 
64
- result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- with gr.Accordion("Advanced Settings", open=False):
67
-
68
- seed = gr.Slider(
69
- label="Seed",
70
- minimum=0,
71
- maximum=MAX_SEED,
72
- step=1,
73
- value=0,
74
- )
75
-
76
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
77
-
78
- with gr.Row():
79
-
80
- width = gr.Slider(
81
- label="Width",
82
- minimum=256,
83
- maximum=MAX_IMAGE_SIZE,
84
- step=32,
85
- value=1024,
86
  )
87
 
88
- height = gr.Slider(
89
- label="Height",
90
- minimum=256,
91
- maximum=MAX_IMAGE_SIZE,
92
- step=32,
93
- value=1024,
94
- )
95
-
96
- with gr.Row():
 
 
 
 
 
 
 
97
 
98
-
99
- num_inference_steps = gr.Slider(
100
- label="Number of inference steps",
101
- minimum=1,
102
- maximum=50,
103
- step=1,
104
- value=4,
 
 
105
  )
106
 
107
- gr.Examples(
108
- examples = examples,
109
- fn = infer,
110
- inputs = [prompt],
111
- outputs = [result, seed],
112
- cache_examples="lazy"
 
 
 
113
  )
 
 
114
 
115
- gr.on(
116
- triggers=[run_button.click, prompt.submit],
117
- fn = infer,
118
- inputs = [prompt, seed, randomize_seed, width, height, num_inference_steps],
119
- outputs = [result, seed]
120
- )
 
 
 
 
 
 
 
121
 
122
- demo.launch()
 
 
1
+ import os
2
  import gradio as gr
3
+ from PIL import Image
4
+ import io
5
+ from utils import query_hf_api
 
 
6
 
7
+ def generate_image(prompt: str) -> Image.Image:
8
+ """
9
+ Generate an image from a text prompt.
10
+
11
+ Args:
12
+ prompt (str): Text description for image generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ Returns:
15
+ Image.Image: Generated PIL Image
16
+ """
17
+ try:
18
+ # Generate image bytes
19
+ image_bytes = query_hf_api(prompt)
20
 
21
+ # Convert to PIL Image
22
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
 
 
 
 
 
 
 
 
23
 
24
+ return image
25
+
26
+ except Exception as e:
27
+ print(f"Image generation error: {e}")
28
+ return None
29
+
30
+ def create_gradio_interface():
31
+ """
32
+ Create and configure the Gradio interface.
33
+
34
+ Returns:
35
+ gr.Blocks: Configured Gradio interface
36
+ """
37
+ with gr.Blocks(
38
+ theme=gr.themes.Soft(),
39
+ title="🎨 AI Image Generator"
40
+ ) as demo:
41
+ # Title and Description
42
+ gr.Markdown("# 🎨 AI Image Generator")
43
+ gr.Markdown("Generate stunning images from your text prompts using AI!")
44
 
45
+ # Input and Output Components
46
+ with gr.Row():
47
+ with gr.Column(scale=3):
48
+ # Prompt Input
49
+ text_input = gr.Textbox(
50
+ label="Enter your image prompt",
51
+ placeholder="e.g., 'Astronaut riding a bike on Mars at sunset'",
52
+ lines=3
 
 
 
 
 
 
 
 
 
 
 
 
53
  )
54
 
55
+ # Advanced Options
56
+ with gr.Accordion("Advanced Options", open=False):
57
+ steps_slider = gr.Slider(
58
+ minimum=10,
59
+ maximum=100,
60
+ value=50,
61
+ step=1,
62
+ label="Inference Steps"
63
+ )
64
+ guidance_slider = gr.Slider(
65
+ minimum=1,
66
+ maximum=20,
67
+ value=7.5,
68
+ step=0.5,
69
+ label="Guidance Scale"
70
+ )
71
 
72
+ # Generate Button
73
+ generate_button = gr.Button("✨ Generate Image", variant="primary")
74
+
75
+ # Output Image Display
76
+ with gr.Column(scale=4):
77
+ output_image = gr.Image(
78
+ label="Generated Image",
79
+ type="pil",
80
+ interactive=False
81
  )
82
 
83
+ # Error Handling Output
84
+ error_output = gr.Textbox(label="Status", visible=False)
85
+
86
+ # Event Handlers
87
+ generate_button.click(
88
+ fn=generate_image,
89
+ inputs=[text_input],
90
+ outputs=[output_image, error_output],
91
+ api_name="generate"
92
  )
93
+
94
+ return demo
95
 
96
+ def main():
97
+ """
98
+ Main entry point for the Gradio application.
99
+ """
100
+ try:
101
+ demo = create_gradio_interface()
102
+ demo.launch(
103
+ server_name="0.0.0.0", # Listen on all network interfaces
104
+ server_port=7860, # Default Gradio port
105
+ share=True # Set to True if you want a public link
106
+ )
107
+ except Exception as e:
108
+ print(f"Error launching Gradio app: {e}")
109
 
110
+ if __name__ == "__main__":
111
+ main()