burhansyam commited on
Commit
223039f
·
verified ·
1 Parent(s): 058cd31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -32
app.py CHANGED
@@ -5,41 +5,62 @@ from random import randint
5
  from all_models import models
6
  from io import BytesIO
7
  from PIL import Image
 
 
8
 
9
  css_code = os.getenv("DazDinGo_CSS")
10
 
11
- def load_fn(models):
12
- global models_load
13
- models_load = {}
14
- for model in models:
15
- if model not in models_load.keys():
16
- try:
17
- m = gr.load(f'models/{model}')
18
- except Exception as error:
19
- m = gr.Interface(lambda txt: None, ['text'], ['image'])
20
- models_load.update({model: m})
21
 
22
- load_fn(models)
23
 
24
- def gen_fn(model_str, prompt):
25
  if model_str == 'NA':
26
- return None, ""
27
  noise = str(randint(0, 4294967296))
28
  klir = '| ultra detail, ultra elaboration, ultra quality, perfect'
29
- image = models_load[model_str](f'{prompt} {klir} {noise}')
30
-
31
- # Convert to base64
32
  buffered = BytesIO()
33
  if isinstance(image, str): # if it's a file path
34
  img = Image.open(image)
35
  img.save(buffered, format="JPEG")
36
  else: # if it's a PIL Image
37
  image.save(buffered, format="JPEG")
38
- img_str = base64.b64encode(buffered.getvalue()).decode()
39
- base64_code = f"data:image/jpeg;base64,{img_str}"
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- return image, base64_code
 
 
 
 
 
 
 
 
42
 
 
43
  def make_me():
44
  with gr.Row():
45
  with gr.Column(scale=4):
@@ -66,24 +87,38 @@ def make_me():
66
  value=models[0] if models else None)
67
  output_image = gr.Image(label="Generated Image", width=512,
68
  max_height=768, elem_id="custom_image",
69
- show_label=True, interactive=False,
70
- show_share_button=False)
 
 
71
 
72
- # Base64 output
73
- base64_output = gr.Textbox(label="Base64 Code", interactive=True,
74
- placeholder="Base64 code will appear here after generation",
75
- lines=4)
 
 
 
 
 
 
 
 
 
 
76
 
77
- gen_event = gen_button.click(gen_fn, [model_dropdown, txt_input],
78
- [output_image, base64_output])
79
  stop_button.click(on_stop_click, inputs=None,
80
  outputs=[gen_button, stop_button], cancels=[gen_event])
81
-
82
- with gr.Row():
83
- gr.HTML("")
84
 
 
85
  with gr.Blocks(css=css_code) as demo:
86
  make_me()
87
 
88
- demo.queue(concurrency_count=50)
89
- demo.launch()
 
 
 
 
 
5
  from all_models import models
6
  from io import BytesIO
7
  from PIL import Image
8
+ from fastapi import FastAPI, Request
9
+ import json
10
 
11
  css_code = os.getenv("DazDinGo_CSS")
12
 
13
+ # Load models
14
+ models_load = {}
15
+ for model in models:
16
+ try:
17
+ models_load[model] = gr.load(f'models/{model}')
18
+ except Exception as error:
19
+ models_load[model] = gr.Interface(lambda txt: None, ['text'], ['image'])
 
 
 
20
 
21
+ app = FastAPI()
22
 
23
+ def gen_image(model_str, prompt):
24
  if model_str == 'NA':
25
+ return None
26
  noise = str(randint(0, 4294967296))
27
  klir = '| ultra detail, ultra elaboration, ultra quality, perfect'
28
+ return models_load[model_str](f'{prompt} {klir} {noise}')
29
+
30
+ def image_to_base64(image):
31
  buffered = BytesIO()
32
  if isinstance(image, str): # if it's a file path
33
  img = Image.open(image)
34
  img.save(buffered, format="JPEG")
35
  else: # if it's a PIL Image
36
  image.save(buffered, format="JPEG")
37
+ return base64.b64encode(buffered.getvalue()).decode()
38
+
39
+ # API endpoint
40
+ @app.post("/generate")
41
+ async def api_generate(request: Request):
42
+ data = await request.json()
43
+ model = data.get('model', models[0])
44
+ prompt = data.get('prompt', '')
45
+
46
+ if model not in models:
47
+ return {"error": "Model not found"}
48
+
49
+ image = gen_image(model, prompt)
50
+ if image is None:
51
+ return {"error": "Image generation failed"}
52
 
53
+ base64_str = image_to_base64(image)
54
+
55
+ return {
56
+ "status": "success",
57
+ "model": model,
58
+ "prompt": prompt,
59
+ "image_base64": base64_str,
60
+ "image_format": "jpeg"
61
+ }
62
 
63
+ # Gradio Interface
64
  def make_me():
65
  with gr.Row():
66
  with gr.Column(scale=4):
 
87
  value=models[0] if models else None)
88
  output_image = gr.Image(label="Generated Image", width=512,
89
  max_height=768, elem_id="custom_image",
90
+ show_label=True, interactive=False)
91
+
92
+ # JSON output
93
+ json_output = gr.JSON(label="API Response")
94
 
95
+ def generate_wrapper(model_str, prompt):
96
+ image = gen_image(model_str, prompt)
97
+ if image is None:
98
+ return None, {"error": "Generation failed"}
99
+
100
+ base64_str = image_to_base64(image)
101
+ response = {
102
+ "status": "success",
103
+ "model": model_str,
104
+ "prompt": prompt,
105
+ "image_base64": base64_str,
106
+ "image_format": "jpeg"
107
+ }
108
+ return image, response
109
 
110
+ gen_event = gen_button.click(generate_wrapper, [model_dropdown, txt_input],
111
+ [output_image, json_output])
112
  stop_button.click(on_stop_click, inputs=None,
113
  outputs=[gen_button, stop_button], cancels=[gen_event])
 
 
 
114
 
115
+ # Create Gradio app
116
  with gr.Blocks(css=css_code) as demo:
117
  make_me()
118
 
119
+ # Mount Gradio app to FastAPI
120
+ app = gr.mount_gradio_app(app, demo, path="/")
121
+
122
+ if __name__ == "__main__":
123
+ import uvicorn
124
+ uvicorn.run(app, host="0.0.0.0", port=7860)