burhansyam commited on
Commit
26cbc6f
·
verified ·
1 Parent(s): de6733b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -102
app.py CHANGED
@@ -1,136 +1,75 @@
1
  import os
2
  import gradio as gr
3
- import base64
4
  from random import randint
5
  from all_models import models
6
- from io import BytesIO
7
- from PIL import Image
8
- from fastapi import FastAPI, Request
9
 
10
  css_code = os.getenv("DazDinGo_CSS")
11
 
12
- # Load models
13
- models_load = {}
14
- for model in models:
15
- try:
16
- models_load[model] = gr.load(f'models/{model}')
17
- except Exception as error:
18
- models_load[model] = gr.Interface(lambda txt: None, ['text'], ['image'])
 
 
 
19
 
20
- app = FastAPI()
21
 
22
- def gen_image(model_str, prompt):
 
 
 
 
 
 
 
 
 
 
23
  if model_str == 'NA':
24
  return None
25
  noise = str(randint(0, 4294967296))
26
  klir = '| ultra detail, ultra elaboration, ultra quality, perfect'
27
  return models_load[model_str](f'{prompt} {klir} {noise}')
28
 
29
- def image_to_base64(image):
30
- buffered = BytesIO()
31
- if isinstance(image, str): # if it's a file path
32
- img = Image.open(image)
33
- img.save(buffered, format="JPEG")
34
- else: # if it's a PIL Image
35
- image.save(buffered, format="JPEG")
36
- return base64.b64encode(buffered.getvalue()).decode()
37
-
38
- # API endpoint
39
- @app.post("/generate")
40
- async def api_generate(request: Request):
41
- data = await request.json()
42
- model = data.get('model', models[0])
43
- prompt = data.get('prompt', '')
44
-
45
- if model not in models:
46
- return {"error": "Model not found"}
47
-
48
- image = gen_image(model, prompt)
49
- if image is None:
50
- return {"error": "Image generation failed"}
51
-
52
- base64_str = image_to_base64(image)
53
-
54
- return {
55
- "status": "success",
56
- "model": model,
57
- "prompt": prompt,
58
- "image_base64": base64_str,
59
- "image_format": "jpeg"
60
- }
61
-
62
- # Gradio Interface
63
  def make_me():
64
  with gr.Row():
65
  with gr.Column(scale=4):
66
- txt_input = gr.Textbox(
67
- label='Your prompt:',
68
- lines=4,
69
- container=False,
70
- elem_id="custom_textbox",
71
- placeholder="Prompt"
72
- )
73
 
74
  with gr.Column(scale=1):
75
- gen_button = gr.Button('Generate image', elem_id="custom_gen_button")
76
- stop_button = gr.Button('Stop', variant='secondary', interactive=False,
77
- elem_id="custom_stop_button")
78
 
79
  def on_generate_click():
80
- return gr.Button('Generate image', elem_id="custom_gen_button"), gr.Button('Stop', variant='secondary', interactive=True, elem_id="custom_stop_button")
81
 
82
  def on_stop_click():
83
- return gr.Button('Generate image', elem_id="custom_gen_button"), gr.Button('Stop', variant='secondary', interactive=False, elem_id="custom_stop_button")
84
 
85
  gen_button.click(on_generate_click, inputs=None, outputs=[gen_button, stop_button])
86
  stop_button.click(on_stop_click, inputs=None, outputs=[gen_button, stop_button])
87
 
88
  with gr.Row():
89
- with gr.Column():
90
- model_dropdown = gr.Dropdown(models, label="Select Model",
91
- value=models[0] if models else None)
92
- output_image = gr.Image(
93
- label="Generated Image",
94
- width=512,
95
- height=768,
96
- elem_id="custom_image",
97
- show_label=True,
98
- interactive=False
99
- )
100
-
101
- # JSON output
102
- json_output = gr.JSON(label="API Response")
103
-
104
- def generate_wrapper(model_str, prompt):
105
- image = gen_image(model_str, prompt)
106
- if image is None:
107
- return None, {"error": "Generation failed"}
108
-
109
- base64_str = image_to_base64(image)
110
- response = {
111
- "status": "success",
112
- "model": model_str,
113
- "prompt": prompt,
114
- "image_base64": base64_str,
115
- "image_format": "jpeg"
116
- }
117
- return image, response
118
-
119
- gen_event = gen_button.click(generate_wrapper, [model_dropdown, txt_input],
120
- [output_image, json_output])
121
- stop_button.click(on_stop_click, inputs=None,
122
- outputs=[gen_button, stop_button], cancels=[gen_event])
123
 
124
- # Create Gradio app
125
  with gr.Blocks(css=css_code) as demo:
126
  make_me()
127
 
128
- # Enable queue before mounting
129
  demo.queue(concurrency_count=50)
130
-
131
- # Mount Gradio app to FastAPI
132
- app = gr.mount_gradio_app(app, demo, path="/")
133
-
134
- if __name__ == "__main__":
135
- import uvicorn
136
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
  import gradio as gr
 
3
  from random import randint
4
  from all_models import models
 
 
 
5
 
6
  css_code = os.getenv("DazDinGo_CSS")
7
 
8
+ def load_fn(models):
9
+ global models_load
10
+ models_load = {}
11
+ for model in models:
12
+ if model not in models_load.keys():
13
+ try:
14
+ m = gr.load(f'models/{model}')
15
+ except Exception as error:
16
+ m = gr.Interface(lambda txt: None, ['text'], ['image'])
17
+ models_load.update({model: m})
18
 
19
+ load_fn(models)
20
 
21
+ num_models = len(models)
22
+ default_models = models[:num_models]
23
+
24
+ def extend_choices(choices):
25
+ return choices + (num_models - len(choices)) * ['NA']
26
+
27
+ def update_imgbox(choices):
28
+ choices_plus = extend_choices(choices)
29
+ return [gr.Image(None, label=m, visible=(m != 'NA'), elem_id="custom_image") for m in choices_plus]
30
+
31
+ def gen_fn(model_str, prompt):
32
  if model_str == 'NA':
33
  return None
34
  noise = str(randint(0, 4294967296))
35
  klir = '| ultra detail, ultra elaboration, ultra quality, perfect'
36
  return models_load[model_str](f'{prompt} {klir} {noise}')
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def make_me():
39
  with gr.Row():
40
  with gr.Column(scale=4):
41
+ txt_input = gr.Textbox(label='Your prompt:', lines=4, container=False, elem_id="custom_textbox", placeholder="Prompt", height=250)
 
 
 
 
 
 
42
 
43
  with gr.Column(scale=1):
44
+ gen_button = gr.Button('Generate images', elem_id="custom_gen_button")
45
+ stop_button = gr.Button('Stop', variant='secondary', interactive=False, elem_id="custom_stop_button")
 
46
 
47
  def on_generate_click():
48
+ return gr.Button('Generate images', elem_id="custom_gen_button"), gr.Button('Stop', variant='secondary', interactive=True, elem_id="custom_stop_button")
49
 
50
  def on_stop_click():
51
+ return gr.Button('Generate images', elem_id="custom_gen_button"), gr.Button('Stop', variant='secondary', interactive=False, elem_id="custom_stop_button")
52
 
53
  gen_button.click(on_generate_click, inputs=None, outputs=[gen_button, stop_button])
54
  stop_button.click(on_stop_click, inputs=None, outputs=[gen_button, stop_button])
55
 
56
  with gr.Row():
57
+ output = [gr.Image(label=m, width=512, max_height=768, elem_id="custom_image", show_label=True, interactive=False, show_share_button=False) for m in default_models]
58
+ current_models = [gr.Textbox(m, visible=False) for m in default_models]
59
+ for m, o in zip(current_models, output):
60
+ gen_event = gen_button.click(gen_fn, [m, txt_input], o)
61
+ stop_button.click(on_stop_click, inputs=None, outputs=[gen_button, stop_button], cancels=[gen_event])
62
+
63
+ with gr.Accordion('Model selection', elem_id="custom_accordion"):
64
+ model_choice = gr.CheckboxGroup(models, label=f'{num_models} different models selected', value=default_models, interactive=True, elem_id="custom_checkbox_group")
65
+ model_choice.change(update_imgbox, model_choice, output)
66
+ model_choice.change(extend_choices, model_choice, current_models)
67
+
68
+ with gr.Row():
69
+ gr.HTML("")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
 
71
  with gr.Blocks(css=css_code) as demo:
72
  make_me()
73
 
 
74
  demo.queue(concurrency_count=50)
75
+ demo.launch()