Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| MAX_COLORS = 12 | |
| def get_high_freq_colors(image): | |
| im = image.getcolors(maxcolors=1024*1024) | |
| sorted_colors = sorted(im, key=lambda x: x[0], reverse=True) | |
| freqs = [c[0] for c in sorted_colors] | |
| mean_freq = sum(freqs) / len(freqs) | |
| high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq/3)] # Ignore colors that occur very few times (less than 2) or less than half the average frequency | |
| return high_freq_colors | |
| def color_quantization(image, n_colors): | |
| # Get color histogram | |
| hist, _ = np.histogramdd(image.reshape(-1, 3), bins=(256, 256, 256), range=((0, 256), (0, 256), (0, 256))) | |
| # Get most frequent colors | |
| colors = np.argwhere(hist > 0) | |
| colors = colors[np.argsort(hist[colors[:, 0], colors[:, 1], colors[:, 2]])[::-1]] | |
| colors = colors[:n_colors] | |
| # Replace each pixel with the closest color | |
| dists = np.sum((image.reshape(-1, 1, 3) - colors.reshape(1, -1, 3))**2, axis=2) | |
| labels = np.argmin(dists, axis=1) | |
| return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8) | |
| def create_binary_matrix(img_arr, target_color): | |
| print(target_color) | |
| # Create mask of pixels with target color | |
| mask = np.all(img_arr == target_color, axis=-1) | |
| # Convert mask to binary matrix | |
| binary_matrix = mask.astype(int) | |
| return binary_matrix | |
| def process_sketch(image, binary_matrixes): | |
| high_freq_colors = get_high_freq_colors(image) | |
| how_many_colors = len(high_freq_colors) | |
| im2arr = np.array(image) # im2arr.shape: height x width x channel | |
| im2arr = color_quantization(im2arr, n_colors=how_many_colors) | |
| colors_fixed = [] | |
| for color in high_freq_colors[1:]: | |
| r = color[1][0] | |
| g = color[1][1] | |
| b = color[1][2] | |
| binary_matrix = create_binary_matrix(im2arr, (r,g,b)) | |
| binary_matrixes.append(binary_matrix) | |
| colors_fixed.append(gr.update(value=f'<div class="color-bg-item" style="background-color: rgb({r},{g},{b})"></div>')) | |
| visibilities = [] | |
| colors = [] | |
| for n in range(MAX_COLORS): | |
| visibilities.append(gr.update(visible=False)) | |
| colors.append(gr.update(value=f'<div class="color-bg-item" style="background-color: black"></div>')) | |
| for n in range(how_many_colors-1): | |
| visibilities[n] = gr.update(visible=True) | |
| colors[n] = colors_fixed[n] | |
| return [gr.update(visible=True), binary_matrixes, *visibilities, *colors] | |
| def process_generation(binary_matrixes, master_prompt, *prompts): | |
| clipped_prompts = prompts[:len(binary_matrixes)] | |
| #Now: master_prompt can be used as the main prompt, and binary_matrixes and clipped_prompts can be used as the masked inputs | |
| pass | |
| css = ''' | |
| #color-bg{display:flex;justify-content: center;align-items: center;} | |
| .color-bg-item{width: 100%; height: 32px} | |
| #main_button{width:100%} | |
| ''' | |
| def update_css(aspect): | |
| if(aspect=='Square'): | |
| width = 512 | |
| height = 512 | |
| elif(aspect == 'Horizontal'): | |
| width = 768 | |
| height = 512 | |
| elif(aspect=='Vertical'): | |
| width = 512 | |
| height = 768 | |
| return gr.update(value=f"<style>#main-image{{width: {width}px}} .fixed-height{{height: {height}px !important}}</style>") | |
| with gr.Blocks(css=css) as demo: | |
| binary_matrixes = gr.State([]) | |
| gr.Markdown('''## Control your Stable Diffusion generation with Sketches | |
| This Space demonstrates MultiDiffusion region-based generation using Stable Diffusion model. To get started, draw your masks and type your prompts. More details in the [project page](https://multidiffusion.github.io). | |
| ''') | |
| with gr.Row(): | |
| with gr.Box(elem_id="main-image"): | |
| with gr.Row(): | |
| image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil") | |
| with gr.Row(): | |
| aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio") | |
| button_run = gr.Button("I've finished my sketch",elem_id="main_button") | |
| prompts = [] | |
| colors = [] | |
| color_row = [None] * MAX_COLORS | |
| with gr.Column(visible=False) as post_sketch: | |
| general_prompt = gr.Textbox(label="General Prompt") | |
| for n in range(MAX_COLORS): | |
| with gr.Row(visible=False) as color_row[n]: | |
| with gr.Box(elem_id="color-bg"): | |
| colors.append(gr.HTML('<div class="color-bg-item" style="background-color: black"></div>')) | |
| prompts.append(gr.Textbox(label="Prompt for this color")) | |
| final_run_btn = gr.Button("Generate!") | |
| out_image = gr.Image(label="Result") | |
| gr.Markdown(''' | |
|  | |
| ''') | |
| css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>") | |
| aspect.change(update_css, inputs=aspect, outputs=css_height) | |
| button_run.click(process_sketch, inputs=[image, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors]) | |
| final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image) | |
| demo.launch() |