lopho commited on
Commit
5b09d17
·
1 Parent(s): a1b3de8

remove trigger check, as it messes up the time estimate

Browse files
Files changed (2) hide show
  1. app.py +14 -23
  2. requirements.txt +1 -1
app.py CHANGED
@@ -89,6 +89,8 @@ def generate(
89
  def check_if_compiled(image, inference_steps, height, width, num_frames, message):
90
  height = int(height)
91
  width = int(width)
 
 
92
  hint_image = image
93
  if (hint_image is None, inference_steps, height, width, num_frames) in _seen_compilations:
94
  return ''
@@ -100,7 +102,7 @@ if _preheat:
100
  generate(
101
  prompt = 'preheating the oven',
102
  neg_prompt = '',
103
- image = { 'image': None, 'mask': None },
104
  inference_steps = 20,
105
  cfg = 12.0,
106
  seed = 0
@@ -109,7 +111,7 @@ if _preheat:
109
  dada = generate(
110
  prompt = 'Entertaining the guests with sailor songs played on an old harmonium.',
111
  neg_prompt = '',
112
- image = { 'image': Image.new('RGB', size = (512, 512), color = (0, 0, 0)), 'mask': None },
113
  inference_steps = 20,
114
  cfg = 12.0,
115
  seed = 0
@@ -227,29 +229,22 @@ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled =
227
  )
228
  with gr.Column(variant = variant):
229
  #no_gpu = gr.Markdown('**Until a GPU is assigned expect extremely long runtimes up to 1h+**')
230
- will_trigger = gr.Markdown('')
231
- patience = gr.Markdown('')
232
  image_output = gr.Image(
233
  label = 'Output',
234
  value = 'example.webp',
235
  interactive = False
236
  )
237
- trigger_inputs = [ image_input, inference_steps_input, height_input, width_input, num_frames_input ]
238
- trigger_check_fun = partial(check_if_compiled, message = 'Current parameters will trigger compilation.')
239
- height_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
240
- width_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
241
- num_frames_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
242
- image_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
243
- inference_steps_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
244
- will_trigger.value = trigger_check_fun(image_input.value, inference_steps_input.value, height_input.value, width_input.value, num_frames_input.value)
245
  ev = submit_button.click(
246
- fn = partial(
247
- check_if_compiled,
248
- message = 'Please be patient. The model has to be compiled with current parameters.'
249
- ),
250
- inputs = trigger_inputs,
251
- outputs = patience
252
- ).then(
253
  fn = generate,
254
  inputs = [
255
  prompt_input,
@@ -265,10 +260,6 @@ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled =
265
  ],
266
  outputs = image_output,
267
  postprocess = False
268
- ).then(
269
- fn = trigger_check_fun,
270
- inputs = trigger_inputs,
271
- outputs = will_trigger
272
  )
273
  #cancel_button.click(fn = lambda: None, cancels = ev)
274
 
 
89
  def check_if_compiled(image, inference_steps, height, width, num_frames, message):
90
  height = int(height)
91
  width = int(width)
92
+ height = (height // 64) * 64
93
+ width = (width // 64) * 64
94
  hint_image = image
95
  if (hint_image is None, inference_steps, height, width, num_frames) in _seen_compilations:
96
  return ''
 
102
  generate(
103
  prompt = 'preheating the oven',
104
  neg_prompt = '',
105
+ image = None,
106
  inference_steps = 20,
107
  cfg = 12.0,
108
  seed = 0
 
111
  dada = generate(
112
  prompt = 'Entertaining the guests with sailor songs played on an old harmonium.',
113
  neg_prompt = '',
114
+ image = Image.new('RGB', size = (512, 512), color = (0, 0, 0)),
115
  inference_steps = 20,
116
  cfg = 12.0,
117
  seed = 0
 
229
  )
230
  with gr.Column(variant = variant):
231
  #no_gpu = gr.Markdown('**Until a GPU is assigned expect extremely long runtimes up to 1h+**')
232
+ #will_trigger = gr.Markdown('')
233
+ patience = gr.Markdown('**Please be patient. The model might have to compile with current parameters.**')
234
  image_output = gr.Image(
235
  label = 'Output',
236
  value = 'example.webp',
237
  interactive = False
238
  )
239
+ #trigger_inputs = [ image_input, inference_steps_input, height_input, width_input, num_frames_input ]
240
+ #trigger_check_fun = partial(check_if_compiled, message = 'Current parameters will trigger compilation.')
241
+ #height_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
242
+ #width_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
243
+ #num_frames_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
244
+ #image_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
245
+ #inference_steps_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
246
+ #will_trigger.value = trigger_check_fun(image_input.value, inference_steps_input.value, height_input.value, width_input.value, num_frames_input.value)
247
  ev = submit_button.click(
 
 
 
 
 
 
 
248
  fn = generate,
249
  inputs = [
250
  prompt_input,
 
260
  ],
261
  outputs = image_output,
262
  postprocess = False
 
 
 
 
263
  )
264
  #cancel_button.click(fn = lambda: None, cancels = ev)
265
 
requirements.txt CHANGED
@@ -6,5 +6,5 @@ einops
6
  -f https://download.pytorch.org/whl/cpu/torch
7
  torch[cpu]
8
  -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
9
- jax[cuda11_cudnn82] #jax[cuda11_cudnn86] #jax[cuda11_cudnn805]
10
  flax
 
6
  -f https://download.pytorch.org/whl/cpu/torch
7
  torch[cpu]
8
  -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
9
+ jax[cuda11_pip] #jax[cuda11_cudnn82] #jax[cuda11_cudnn86] #jax[cuda11_cudnn805]
10
  flax