charliebaby2023 commited on
Commit
53d6862
Β·
verified Β·
1 Parent(s): fb80c3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -80
app.py CHANGED
@@ -1,126 +1,109 @@
1
  import gradio as gr
2
  from random import randint
3
- from all_models import models
4
  from datetime import datetime
5
 
6
- kii = "mohawk femboy racecar driver"
7
-
8
-
9
- model_group_selector = gr.Dropdown(
10
- choices=list(model_groups.keys()),
11
- label="Select Model Group",
12
- value="Group A", # Default group
13
- )
14
-
15
- # Update output based on the selected group
16
- def update_models(group_name):
17
- return model_groups[group_name] # Return the list of models for the selected group
18
-
19
- # Connect the dropdown to the update function
20
- model_group_selector.change(update_models, inputs=[model_group_selector], outputs=[...])
21
-
22
-
23
- def get_current_time():
24
- now = datetime.now()
25
- current_time = now.strftime("%Y-%m-%d %H:%M:%S")
26
- return f'{kii} {current_time}'
27
-
28
-
 
 
 
 
 
29
  def load_fn(models):
30
- models_load = {}
31
  for model in models:
32
- if model not in models_load:
33
- try:
34
- m = gr.load(f'models/{model}') # Adjust `gr.load` as needed
35
- except Exception as error:
36
- m = gr.Interface(lambda txt: None, ['text'], ['image'])
37
- models_load[model] = m
38
- return models_load
39
-
40
-
41
- models_load = load_fn(models)
42
-
43
- def extend_choices(choices, num_models):
44
- return choices + (num_models - len(choices)) * ['NA']
45
-
46
-
47
- def update_imgbox(choices, num_models):
48
- choices_plus = extend_choices(choices, num_models)
49
- return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus]
50
 
 
51
 
 
52
  def gen_fn(model_str, prompt, tallies):
53
  if model_str == 'NA':
54
  return None, tallies
55
 
56
  noise = str(randint(0, 9999))
57
  combined_prompt = f'{prompt} {model_str} {noise}'
58
- print(f"Generating with prompt: {combined_prompt}") # Debug line
59
 
60
  try:
61
- result = models_load.get(model_str, lambda txt: None)(combined_prompt)
62
- if result is not None:
63
- tallies[model_str] += 1
64
  return result, tallies
65
  except Exception as e:
66
  print(f"Error generating for {model_str}: {e}")
67
  return None, tallies
68
 
69
-
70
  def make_me():
71
  with gr.Row():
72
- # Input elements
73
  model_group_selector = gr.Dropdown(
74
- choices=list(model_groups.keys()), label="Select Model Group", value="Group A"
 
 
75
  )
76
- txt_input = gr.Textbox(label='Your prompt:', lines=2, value=kii)
77
  gen_button = gr.Button('Generate images', elem_id="generate-btn")
78
- stop_button = gr.Button('Stop', variant='secondary', interactive=False)
79
-
80
- gr.HTML("""
81
- <div style="text-align: center; max-width: 100%; margin: 0 auto;">
82
- <body></body>
83
- </div>
84
- """)
85
 
86
  with gr.Row():
87
- # Output elements
88
- output = gr.State([])
89
- tally_boxes = gr.State({})
90
- output_display = gr.Column()
91
- with output_display:
92
- result_images = []
93
- tally_counters = []
94
 
95
  def update_outputs(group_name):
96
  selected_models = model_groups[group_name]
97
- result_images.clear()
98
- tally_counters.clear()
99
- for model in selected_models:
100
- result_images.append(gr.Image(label=model, width=170, height=170))
101
- tally_counters.append(gr.Textbox(value="0", label=f"Tally for {model}", interactive=False))
102
- return result_images, tally_counters, {model: 0 for model in selected_models}
103
 
104
  model_group_selector.change(
105
- update_outputs, [model_group_selector], [output, tally_boxes]
 
 
106
  )
107
 
108
  def generate_images(prompt, outputs, tallies):
109
- for idx, model_element in enumerate(outputs):
110
  model_str = list(tallies.keys())[idx]
111
  result, tallies = gen_fn(model_str, prompt, tallies)
112
- model_element.update(value=result)
113
- for idx, tally_box in enumerate(tally_counters):
114
- tally_box.update(value=str(tallies[list(tallies.keys())[idx]]))
115
  return tallies
116
 
117
  gen_button.click(
118
  generate_images,
119
- inputs=[txt_input, output, tally_boxes],
120
- outputs=[tally_boxes],
121
  )
122
 
123
-
124
  js_code = """
125
  <script>
126
  const originalScroll = window.scrollTo;
@@ -136,7 +119,6 @@ js_code = """
136
  """
137
 
138
  with gr.Blocks() as demo:
139
- gr.Markdown("<div></div>")
140
  make_me()
141
  gr.Markdown(js_code)
142
 
 
1
  import gradio as gr
2
  from random import randint
 
3
  from datetime import datetime
4
 
5
+ # Define manual grouping for models
6
+ model_groups = {
7
+ "Group A": [
8
+ "Bakanayatsu/ponyDiffusion-V6-XL-Turbo-DPO",
9
+ "John6666/photo-realistic-pony-v5-sdxl",
10
+ "John6666/photo-realistic-pony-v5-sdxl",
11
+ ],
12
+ "Group B": [
13
+ "John6666/jib-mix-pony-realistic-v2-sdxl",
14
+ "John6666/3x3x3mixxl-v2-sdxl-spo",
15
+ "John6666/3x3x3mixxl-v2-sdxl",
16
+ ],
17
+ "Group C": [
18
+ "John6666/3x3x3mixxl-v2-sdxl-spo",
19
+ "John6666/3x3x3mixxl-v2-sdxl",
20
+ "John6666/titania-mix-realistic-pony-gbv10-sdxl",
21
+ ],
22
+ "Group D": [
23
+ "John6666/titania-mix-realistic-pony-gbv20-sdxl",
24
+ "John6666/titania-mix-realistic-pony-gbv30-sdxl",
25
+ "John6666/mala-anime-mix-nsfw-pony-xl-v3-sdxl",
26
+ ],
27
+ }
28
+
29
+ # Placeholder for models
30
+ models_load = {model: None for group in model_groups.values() for model in group}
31
+
32
+ # Function to simulate model loading
33
  def load_fn(models):
34
+ loaded_models = {}
35
  for model in models:
36
+ try:
37
+ # Simulate model loading
38
+ loaded_models[model] = lambda txt: f"Generated image for {txt}"
39
+ except Exception as error:
40
+ loaded_models[model] = lambda txt: None
41
+ return loaded_models
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ models_load = load_fn(models_load.keys())
44
 
45
+ # Function to simulate generation with tally tracking
46
  def gen_fn(model_str, prompt, tallies):
47
  if model_str == 'NA':
48
  return None, tallies
49
 
50
  noise = str(randint(0, 9999))
51
  combined_prompt = f'{prompt} {model_str} {noise}'
52
+ print(f"Generating with prompt: {combined_prompt}")
53
 
54
  try:
55
+ result = models_load[model_str](combined_prompt)
56
+ tallies[model_str] += 1
 
57
  return result, tallies
58
  except Exception as e:
59
  print(f"Error generating for {model_str}: {e}")
60
  return None, tallies
61
 
62
+ # Build the Gradio interface
63
  def make_me():
64
  with gr.Row():
 
65
  model_group_selector = gr.Dropdown(
66
+ choices=list(model_groups.keys()),
67
+ label="Select Model Group",
68
+ value="Group A", # Default group
69
  )
70
+ txt_input = gr.Textbox(label='Your prompt:', lines=2, value="mohawk femboy racecar driver")
71
  gen_button = gr.Button('Generate images', elem_id="generate-btn")
 
 
 
 
 
 
 
72
 
73
  with gr.Row():
74
+ output = gr.State([]) # Placeholder for output elements
75
+ tallies = gr.State({}) # Tally counter for each model
76
+ result_images = gr.Column()
77
+ tally_counters = gr.Column()
 
 
 
78
 
79
  def update_outputs(group_name):
80
  selected_models = model_groups[group_name]
81
+ outputs = [gr.Image(label=model, visible=True, width=170, height=170) for model in selected_models]
82
+ tallies_dict = {model: 0 for model in selected_models}
83
+ tally_boxes = [gr.Textbox(value="0", label=f"Tally for {model}", interactive=False) for model in selected_models]
84
+ return outputs, tally_boxes, tallies_dict
 
 
85
 
86
  model_group_selector.change(
87
+ update_outputs,
88
+ inputs=[model_group_selector],
89
+ outputs=[result_images, tally_counters, tallies],
90
  )
91
 
92
  def generate_images(prompt, outputs, tallies):
93
+ for idx, model_output in enumerate(outputs):
94
  model_str = list(tallies.keys())[idx]
95
  result, tallies = gen_fn(model_str, prompt, tallies)
96
+ model_output.update(value=result)
97
+ for idx, tally_box in enumerate(tallies.values()):
98
+ tally_counters[idx].update(value=str(tally_box))
99
  return tallies
100
 
101
  gen_button.click(
102
  generate_images,
103
+ inputs=[txt_input, output, tallies],
104
+ outputs=[tallies],
105
  )
106
 
 
107
  js_code = """
108
  <script>
109
  const originalScroll = window.scrollTo;
 
119
  """
120
 
121
  with gr.Blocks() as demo:
 
122
  make_me()
123
  gr.Markdown(js_code)
124