Spaces:
Runtime error
Runtime error
| import io | |
| import os | |
| from typing import List | |
| import PIL.Image | |
| import requests | |
| import torch | |
| from diffusers import AutoencoderTiny, StableDiffusionPipeline | |
| from streamdiffusion import StreamDiffusion | |
| from streamdiffusion.acceleration.sfast import accelerate_with_stable_fast | |
| from streamdiffusion.image_utils import postprocess_image | |
| def download_image(url: str): | |
| response = requests.get(url) | |
| image = PIL.Image.open(io.BytesIO(response.content)) | |
| return image | |
| class StreamDiffusionWrapper: | |
| def __init__( | |
| self, | |
| model_id: str, | |
| lcm_lora_id: str, | |
| vae_id: str, | |
| device: str, | |
| dtype: str, | |
| t_index_list: List[int], | |
| warmup: int, | |
| ): | |
| self.device = device | |
| self.dtype = dtype | |
| self.prompt = "" | |
| self.stream = self._load_model( | |
| model_id=model_id, | |
| lcm_lora_id=lcm_lora_id, | |
| vae_id=vae_id, | |
| t_index_list=t_index_list, | |
| warmup=warmup, | |
| ) | |
| def _load_model( | |
| self, | |
| model_id: str, | |
| lcm_lora_id: str, | |
| vae_id: str, | |
| t_index_list: List[int], | |
| warmup: int, | |
| ): | |
| if os.path.exists(model_id): | |
| pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file(model_id).to( | |
| device=self.device, dtype=self.dtype | |
| ) | |
| else: | |
| pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(model_id).to( | |
| device=self.device, dtype=self.dtype | |
| ) | |
| stream = StreamDiffusion( | |
| pipe=pipe, | |
| t_index_list=t_index_list, | |
| torch_dtype=self.dtype, | |
| is_drawing=True, | |
| ) | |
| stream.load_lcm_lora(lcm_lora_id) | |
| stream.fuse_lora() | |
| stream.vae = AutoencoderTiny.from_pretrained(vae_id).to(device=pipe.device, dtype=pipe.dtype) | |
| stream = accelerate_with_stable_fast(stream) | |
| stream.prepare( | |
| "", | |
| num_inference_steps=50, | |
| generator=torch.manual_seed(2), | |
| ) | |
| # warmup | |
| for _ in range(warmup): | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| stream.txt2img() | |
| end.record() | |
| torch.cuda.synchronize() | |
| return stream | |
| def __call__(self, prompt: str) -> List[PIL.Image.Image]: | |
| self.stream.prepare("") | |
| images = [] | |
| for i in range(9 + 3): | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| if self.prompt != prompt: | |
| self.stream.update_prompt(prompt) | |
| self.prompt = prompt | |
| x_output = self.stream.txt2img() | |
| if i >= 3: | |
| images.append(postprocess_image(x_output, output_type="pil")[0]) | |
| end.record() | |
| torch.cuda.synchronize() | |
| return images | |
| if __name__ == "__main__": | |
| wrapper = StreamDiffusionWrapper(10, 10) | |
| wrapper() | |