Spaces:
Runtime error
Runtime error
import torch | |
from torch import Tensor | |
import torch.nn as nn | |
from torch.nn import Conv2d | |
from torch.nn import functional as F | |
from torch.nn.modules.utils import _pair | |
from typing import Optional | |
from diffusers import StableDiffusionPipeline, DDPMScheduler | |
import diffusers | |
from PIL import Image | |
import gradio as gr | |
import spaces | |
import gc | |
def asymmetricConv2DConvForward_circular(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): | |
self.paddingX = ( | |
self._reversed_padding_repeated_twice[0], | |
self._reversed_padding_repeated_twice[1], | |
0, | |
0 | |
) | |
self.paddingY = ( | |
0, | |
0, | |
self._reversed_padding_repeated_twice[2], | |
self._reversed_padding_repeated_twice[3] | |
) | |
working = F.pad(input, self.paddingX, mode="circular") | |
working = F.pad(working, self.paddingY, mode="circular") | |
return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups) | |
def make_seamless(model): | |
for module in model.modules(): | |
if isinstance(module, torch.nn.Conv2d): | |
if isinstance(module, diffusers.models.lora.LoRACompatibleConv) and module.lora_layer is None: | |
module.lora_layer = lambda *x: 0 | |
module._conv_forward = asymmetricConv2DConvForward_circular.__get__(module, Conv2d) | |
def disable_seamless(model): | |
for module in model.modules(): | |
if isinstance(module, torch.nn.Conv2d): | |
if isinstance(module, diffusers.models.lora.LoRACompatibleConv) and module.lora_layer is None: | |
module.lora_layer = lambda *x: 0 | |
module._conv_forward = nn.Conv2d._conv_forward.__get__(module, Conv2d) | |
def diffusion_callback(pipe, step_index, timestep, callback_kwargs): | |
if step_index == int(pipe.num_timesteps * 0.8): | |
make_seamless(pipe.unet) | |
make_seamless(pipe.vae) | |
if step_index < int(pipe.num_timesteps * 0.8): | |
callback_kwargs["latents"] = torch.roll(callback_kwargs["latents"], shifts=(64, 64), dims=(2, 3)) | |
return callback_kwargs | |
print("Loading Pattern Diffusion model...") | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"Arrexel/pattern-diffusion", | |
torch_dtype=torch.float16, | |
safety_checker=None, | |
requires_safety_checker=False | |
) | |
pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config) | |
if torch.cuda.is_available(): | |
pipe = pipe.to("cuda") | |
pipe.enable_attention_slicing() | |
pipe.enable_model_cpu_offload() | |
print("Model loaded successfully on GPU with optimizations!") | |
else: | |
print("GPU not available, using CPU") | |
def generate_pattern(prompt, width=1024, height=1024, num_inference_steps=50, guidance_scale=7.5, seed=None): | |
try: | |
if torch.cuda.is_available(): | |
pipe.to("cuda") | |
if seed is not None and seed != "": | |
generator = torch.Generator(device=pipe.device).manual_seed(int(seed)) | |
else: | |
generator = None | |
disable_seamless(pipe.unet) | |
disable_seamless(pipe.vae) | |
with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"): | |
output = pipe( | |
prompt=prompt, | |
width=int(width), | |
height=int(height), | |
num_inference_steps=int(num_inference_steps), | |
guidance_scale=guidance_scale, | |
generator=generator, | |
callback_on_step_end=diffusion_callback | |
).images[0] | |
return output | |
except Exception as e: | |
print(f"Error during generation: {str(e)}") | |
return None | |
finally: | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
def create_interface(): | |
with gr.Blocks(title="Pattern Diffusion - Seamless Pattern Generator") as demo: | |
gr.Markdown(""" | |
# π¨ Pattern Diffusion - Seamless Pattern Generator | |
**Model:** [Arrexel/pattern-diffusion](https://huggingface.co/Arrexel/pattern-diffusion) | |
This model specializes in generating patterns that can be repeated without visible seams, | |
ideal for prints, wallpapers, textiles, and surfaces. | |
**Strengths:** | |
- Excellent for floral and abstract patterns | |
- Understands foreground and background colors well | |
- Fast and efficient on VRAM | |
**Limitations:** | |
- Does not generate coherent text | |
- Difficulty with anatomy of living creatures | |
- Inconsistent geometry in simple geometric patterns | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="Vibrant watercolor floral pattern with pink, purple, and blue flowers against a white background.", | |
lines=3, | |
value="Vibrant watercolor floral pattern with pink, purple, and blue flowers against a white background." | |
) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=1024, | |
step=256, | |
value=1024 | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=1024, | |
step=256, | |
value=1024 | |
) | |
with gr.Row(): | |
steps = gr.Slider( | |
label="Inference Steps", | |
minimum=20, | |
maximum=100, | |
step=5, | |
value=50 | |
) | |
guidance_scale = gr.Slider( | |
label="Guidance Scale", | |
minimum=1.0, | |
maximum=20.0, | |
step=0.5, | |
value=7.5 | |
) | |
seed = gr.Number( | |
label="Seed (optional, leave empty for random)", | |
precision=0 | |
) | |
generate_btn = gr.Button("π¨ Generate Pattern", variant="primary", size="lg") | |
with gr.Column(): | |
output_image = gr.Image( | |
label="Generated Pattern", | |
type="pil", | |
height=400 | |
) | |
gr.Markdown("## π Example Prompts") | |
examples = [ | |
["Vibrant watercolor floral pattern with pink, purple, and blue flowers against a white background."], | |
["Abstract geometric pattern with gold and navy blue triangles on cream background"], | |
["Delicate cherry blossom pattern with soft pink petals on light gray background"], | |
["Art deco pattern with emerald green and gold lines on black background"], | |
["Tropical leaves pattern with various shades of green on white background"], | |
["Vintage damask pattern in burgundy and cream colors"], | |
["Modern minimalist dots pattern in pastel colors"], | |
["Mandala-inspired pattern with intricate details in blue and white"] | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=[prompt], | |
label="Click an example to use" | |
) | |
generate_btn.click( | |
fn=generate_pattern, | |
inputs=[prompt, width, height, steps, guidance_scale, seed], | |
outputs=[output_image] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.queue(max_size=20).launch() |