Commit
·
3a45ac7
1
Parent(s):
a9b3bf8
add nparams count
Browse files
app.py
CHANGED
|
@@ -50,20 +50,25 @@ if __name__ == "__main__":
|
|
| 50 |
with gr.Tab("Example Prompts"):
|
| 51 |
examples = gr.Examples(examples=example_list, inputs=[text])
|
| 52 |
|
| 53 |
-
with gr.Column(variant='panel',
|
| 54 |
# Define original model output components
|
| 55 |
gr.Markdown('<h2 align="center">Original Stable Diffusion 1.4</h2>')
|
| 56 |
original_model_output = gr.Image(label="Original Model")
|
| 57 |
with gr.Row().style(equal_height=True):
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
original_model_error = gr.Markdown()
|
|
|
|
| 60 |
|
| 61 |
-
with gr.Column(variant='panel',
|
| 62 |
# Define compressed model output components
|
| 63 |
gr.Markdown('<h2 align="center">Compressed Stable Diffusion (Ours)</h2>')
|
| 64 |
-
compressed_model_output = gr.Image(label="Compressed Model")
|
| 65 |
with gr.Row().style(equal_height=True):
|
| 66 |
-
|
|
|
|
|
|
|
| 67 |
compressed_model_error = gr.Markdown()
|
| 68 |
|
| 69 |
inputs = [text, negative, guidance_scale, steps, seed]
|
|
|
|
| 50 |
with gr.Tab("Example Prompts"):
|
| 51 |
examples = gr.Examples(examples=example_list, inputs=[text])
|
| 52 |
|
| 53 |
+
with gr.Column(variant='panel',scale=35):
|
| 54 |
# Define original model output components
|
| 55 |
gr.Markdown('<h2 align="center">Original Stable Diffusion 1.4</h2>')
|
| 56 |
original_model_output = gr.Image(label="Original Model")
|
| 57 |
with gr.Row().style(equal_height=True):
|
| 58 |
+
with gr.Column():
|
| 59 |
+
original_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
|
| 60 |
+
original_model_params = gr.Textbox(value=servicer.get_sdm_params(servicer.pipe_original), label="# Parameters")
|
| 61 |
original_model_error = gr.Markdown()
|
| 62 |
+
|
| 63 |
|
| 64 |
+
with gr.Column(variant='panel',scale=35):
|
| 65 |
# Define compressed model output components
|
| 66 |
gr.Markdown('<h2 align="center">Compressed Stable Diffusion (Ours)</h2>')
|
| 67 |
+
compressed_model_output = gr.Image(label="Compressed Model")
|
| 68 |
with gr.Row().style(equal_height=True):
|
| 69 |
+
with gr.Column():
|
| 70 |
+
compressed_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
|
| 71 |
+
compressed_model_params = gr.Textbox(value=servicer.get_sdm_params(servicer.pipe_compressed), label="# Parameters")
|
| 72 |
compressed_model_error = gr.Markdown()
|
| 73 |
|
| 74 |
inputs = [text, negative, guidance_scale, steps, seed]
|
demo.py
CHANGED
|
@@ -26,6 +26,17 @@ class SdmCompressionDemo:
|
|
| 26 |
self.pipe_compressed = self.pipe_compressed.to(self.device)
|
| 27 |
self.device_msg = 'Tested on GPU.' if 'cuda' in self.device else 'Tested on CPU.'
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def generate_image(self, pipe, text, negative, guidance_scale, steps, seed):
|
| 30 |
generator = torch.Generator(self.device).manual_seed(seed)
|
| 31 |
start = time.time()
|
|
|
|
| 26 |
self.pipe_compressed = self.pipe_compressed.to(self.device)
|
| 27 |
self.device_msg = 'Tested on GPU.' if 'cuda' in self.device else 'Tested on CPU.'
|
| 28 |
|
| 29 |
+
def _count_params(self, model):
|
| 30 |
+
return sum(p.numel() for p in model.parameters())
|
| 31 |
+
|
| 32 |
+
def get_sdm_params(self, pipe):
|
| 33 |
+
params_unet = self._count_params(pipe.unet)
|
| 34 |
+
params_text_enc = self._count_params(pipe.text_encoder)
|
| 35 |
+
params_image_dec = self._count_params(pipe.vae.decoder)
|
| 36 |
+
params_total = params_unet + params_text_enc + params_image_dec
|
| 37 |
+
return f"Total {(params_total/1e6):.1f}M (U-Net {(params_unet/1e6):.1f}M)"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
def generate_image(self, pipe, text, negative, guidance_scale, steps, seed):
|
| 41 |
generator = torch.Generator(self.device).manual_seed(seed)
|
| 42 |
start = time.time()
|