burhansyam commited on
Commit
cf0796f
·
verified ·
1 Parent(s): 26cbc6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -41
app.py CHANGED
@@ -1,75 +1,136 @@
1
  import os
2
  import gradio as gr
 
3
  from random import randint
4
  from all_models import models
 
 
 
5
 
6
  css_code = os.getenv("DazDinGo_CSS")
7
 
8
- def load_fn(models):
9
- global models_load
10
- models_load = {}
11
- for model in models:
12
- if model not in models_load.keys():
13
- try:
14
- m = gr.load(f'models/{model}')
15
- except Exception as error:
16
- m = gr.Interface(lambda txt: None, ['text'], ['image'])
17
- models_load.update({model: m})
18
 
19
- load_fn(models)
20
 
21
- num_models = len(models)
22
- default_models = models[:num_models]
23
-
24
- def extend_choices(choices):
25
- return choices + (num_models - len(choices)) * ['NA']
26
-
27
- def update_imgbox(choices):
28
- choices_plus = extend_choices(choices)
29
- return [gr.Image(None, label=m, visible=(m != 'NA'), elem_id="custom_image") for m in choices_plus]
30
-
31
- def gen_fn(model_str, prompt):
32
  if model_str == 'NA':
33
  return None
34
  noise = str(randint(0, 4294967296))
35
  klir = '| ultra detail, ultra elaboration, ultra quality, perfect'
36
  return models_load[model_str](f'{prompt} {klir} {noise}')
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def make_me():
39
  with gr.Row():
40
  with gr.Column(scale=4):
41
- txt_input = gr.Textbox(label='Your prompt:', lines=4, container=False, elem_id="custom_textbox", placeholder="Prompt", height=250)
 
 
 
 
 
 
42
 
43
  with gr.Column(scale=1):
44
- gen_button = gr.Button('Generate images', elem_id="custom_gen_button")
45
- stop_button = gr.Button('Stop', variant='secondary', interactive=False, elem_id="custom_stop_button")
 
46
 
47
  def on_generate_click():
48
- return gr.Button('Generate images', elem_id="custom_gen_button"), gr.Button('Stop', variant='secondary', interactive=True, elem_id="custom_stop_button")
49
 
50
  def on_stop_click():
51
- return gr.Button('Generate images', elem_id="custom_gen_button"), gr.Button('Stop', variant='secondary', interactive=False, elem_id="custom_stop_button")
52
 
53
  gen_button.click(on_generate_click, inputs=None, outputs=[gen_button, stop_button])
54
  stop_button.click(on_stop_click, inputs=None, outputs=[gen_button, stop_button])
55
 
56
  with gr.Row():
57
- output = [gr.Image(label=m, width=512, max_height=768, elem_id="custom_image", show_label=True, interactive=False, show_share_button=False) for m in default_models]
58
- current_models = [gr.Textbox(m, visible=False) for m in default_models]
59
- for m, o in zip(current_models, output):
60
- gen_event = gen_button.click(gen_fn, [m, txt_input], o)
61
- stop_button.click(on_stop_click, inputs=None, outputs=[gen_button, stop_button], cancels=[gen_event])
62
-
63
- with gr.Accordion('Model selection', elem_id="custom_accordion"):
64
- model_choice = gr.CheckboxGroup(models, label=f'{num_models} different models selected', value=default_models, interactive=True, elem_id="custom_checkbox_group")
65
- model_choice.change(update_imgbox, model_choice, output)
66
- model_choice.change(extend_choices, model_choice, current_models)
67
-
68
- with gr.Row():
69
- gr.HTML("")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
 
71
  with gr.Blocks(css=css_code) as demo:
72
  make_me()
73
 
 
74
  demo.queue(concurrency_count=50)
75
- demo.launch()
 
 
 
 
 
 
 
1
  import os
2
  import gradio as gr
3
+ import base64
4
  from random import randint
5
  from all_models import models
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ from fastapi import FastAPI, Request
9
 
10
  css_code = os.getenv("DazDinGo_CSS")
11
 
12
+ # Load models
13
+ models_load = {}
14
+ for model in models:
15
+ try:
16
+ models_load[model] = gr.load(f'models/{model}')
17
+ except Exception as error:
18
+ models_load[model] = gr.Interface(lambda txt: None, ['text'], ['image'])
 
 
 
19
 
20
+ app = FastAPI()
21
 
22
+ def gen_image(model_str, prompt):
 
 
 
 
 
 
 
 
 
 
23
  if model_str == 'NA':
24
  return None
25
  noise = str(randint(0, 4294967296))
26
  klir = '| ultra detail, ultra elaboration, ultra quality, perfect'
27
  return models_load[model_str](f'{prompt} {klir} {noise}')
28
 
29
+ def image_to_base64(image):
30
+ buffered = BytesIO()
31
+ if isinstance(image, str): # if it's a file path
32
+ img = Image.open(image)
33
+ img.save(buffered, format="JPEG")
34
+ else: # if it's a PIL Image
35
+ image.save(buffered, format="JPEG")
36
+ return base64.b64encode(buffered.getvalue()).decode()
37
+
38
+ # API endpoint
39
+ @app.post("/generate")
40
+ async def api_generate(request: Request):
41
+ data = await request.json()
42
+ model = data.get('model', models[0])
43
+ prompt = data.get('prompt', '')
44
+
45
+ if model not in models:
46
+ return {"error": "Model not found"}
47
+
48
+ image = gen_image(model, prompt)
49
+ if image is None:
50
+ return {"error": "Image generation failed"}
51
+
52
+ base64_str = image_to_base64(image)
53
+
54
+ return {
55
+ "status": "success",
56
+ "model": model,
57
+ "prompt": prompt,
58
+ "image_base64": base64_str,
59
+ "image_format": "jpeg"
60
+ }
61
+
62
+ # Gradio Interface
63
  def make_me():
64
  with gr.Row():
65
  with gr.Column(scale=4):
66
+ txt_input = gr.Textbox(
67
+ label='Your prompt:',
68
+ lines=4,
69
+ container=False,
70
+ elem_id="custom_textbox",
71
+ placeholder="Prompt"
72
+ )
73
 
74
  with gr.Column(scale=1):
75
+ gen_button = gr.Button('Generate image', elem_id="custom_gen_button")
76
+ stop_button = gr.Button('Stop', variant='secondary', interactive=False,
77
+ elem_id="custom_stop_button")
78
 
79
  def on_generate_click():
80
+ return gr.Button('Generate image', elem_id="custom_gen_button"), gr.Button('Stop', variant='secondary', interactive=True, elem_id="custom_stop_button")
81
 
82
  def on_stop_click():
83
+ return gr.Button('Generate image', elem_id="custom_gen_button"), gr.Button('Stop', variant='secondary', interactive=False, elem_id="custom_stop_button")
84
 
85
  gen_button.click(on_generate_click, inputs=None, outputs=[gen_button, stop_button])
86
  stop_button.click(on_stop_click, inputs=None, outputs=[gen_button, stop_button])
87
 
88
  with gr.Row():
89
+ with gr.Column():
90
+ model_dropdown = gr.Dropdown(models, label="Select Model",
91
+ value=models[0] if models else None)
92
+ output_image = gr.Image(
93
+ label="Generated Image",
94
+ width=512,
95
+ height=768,
96
+ elem_id="custom_image",
97
+ show_label=True,
98
+ interactive=False
99
+ )
100
+
101
+ # JSON output
102
+ json_output = gr.JSON(label="API Response")
103
+
104
+ def generate_wrapper(model_str, prompt):
105
+ image = gen_image(model_str, prompt)
106
+ if image is None:
107
+ return None, {"error": "Generation failed"}
108
+
109
+ base64_str = image_to_base64(image)
110
+ response = {
111
+ "status": "success",
112
+ "model": model_str,
113
+ "prompt": prompt,
114
+ "image_base64": base64_str,
115
+ "image_format": "jpeg"
116
+ }
117
+ return image, response
118
+
119
+ gen_event = gen_button.click(generate_wrapper, [model_dropdown, txt_input],
120
+ [output_image, json_output])
121
+ stop_button.click(on_stop_click, inputs=None,
122
+ outputs=[gen_button, stop_button], cancels=[gen_event])
123
 
124
+ # Create Gradio app
125
  with gr.Blocks(css=css_code) as demo:
126
  make_me()
127
 
128
+ # Enable queue before mounting
129
  demo.queue(concurrency_count=50)
130
+
131
+ # Mount Gradio app to FastAPI
132
+ app = gr.mount_gradio_app(app, demo, path="/")
133
+
134
+ if __name__ == "__main__":
135
+ import uvicorn
136
+ uvicorn.run(app, host="0.0.0.0", port=7860)