Update run_civitai_sdxl.py
Browse files- run_civitai_sdxl.py +60 -17
run_civitai_sdxl.py
CHANGED
@@ -1,9 +1,13 @@
|
|
|
|
|
|
1 |
import torch
|
2 |
-
from diffusers import StableDiffusionXLPipeline
|
3 |
-
from
|
|
|
4 |
|
|
|
|
|
5 |
ckpt_path = "noobaiXLNAIXL_vPred10Version.safetensors" #### https://civitai.com/models/833294
|
6 |
-
ckpt_path = "nyaflow-xl-alpha.safetensors" #### https://huggingface.co/nyanko7/nyaflow-xl-alpha
|
7 |
pipe = StableDiffusionXLPipeline.from_single_file(
|
8 |
ckpt_path,
|
9 |
use_safetensors=True,
|
@@ -14,17 +18,56 @@ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, **sch
|
|
14 |
pipe.enable_xformers_memory_efficient_attention()
|
15 |
pipe = pipe.to("cuda")
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
prompt
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
guidance_scale=
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#### inference code drived from https://civitai.com/models/833294 and https://huggingface.co/spaces/nyanko7/toaru-xl-model
|
2 |
+
|
3 |
import torch
|
4 |
+
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
|
5 |
+
from PIL import Image
|
6 |
+
import random
|
7 |
|
8 |
+
# 初始化 pipe,只执行一次
|
9 |
+
ckpt_path = "nyaflow-xl-alpha.safetensors" # https://huggingface.co/nyanko7/nyaflow-xl-alpha
|
10 |
ckpt_path = "noobaiXLNAIXL_vPred10Version.safetensors" #### https://civitai.com/models/833294
|
|
|
11 |
pipe = StableDiffusionXLPipeline.from_single_file(
|
12 |
ckpt_path,
|
13 |
use_safetensors=True,
|
|
|
18 |
pipe.enable_xformers_memory_efficient_attention()
|
19 |
pipe = pipe.to("cuda")
|
20 |
|
21 |
+
# 定义默认参数
|
22 |
+
PRESET_Q = "year_2022, best quality, high quality, very aesthetic"
|
23 |
+
NEGATIVE_PROMPT = "lowres, worst quality, displeasing, bad anatomy, text, error, extra digit, cropped, error, fewer, extra, missing, worst quality, jpeg artifacts, censored, ai-generated worst quality displeasing, bad quality"
|
24 |
+
|
25 |
+
def generate_image(
|
26 |
+
prompt: str,
|
27 |
+
preset: str = PRESET_Q,
|
28 |
+
height: int = 1216,
|
29 |
+
width: int = 832,
|
30 |
+
negative_prompt: str = NEGATIVE_PROMPT,
|
31 |
+
guidance_scale: float = 4.0,
|
32 |
+
randomize_seed: bool = True,
|
33 |
+
seed: int = 42,
|
34 |
+
inference_steps: int = 25,
|
35 |
+
) -> Image:
|
36 |
+
# 合并 prompt 和 preset
|
37 |
+
prompt = prompt.strip() + ", " + preset.strip()
|
38 |
+
negative_prompt = negative_prompt.strip() if negative_prompt and negative_prompt.strip() else None
|
39 |
+
|
40 |
+
# 随机化种子
|
41 |
+
if randomize_seed:
|
42 |
+
seed = random.randint(0, 9007199254740991)
|
43 |
+
|
44 |
+
# 设置生成器
|
45 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
46 |
+
|
47 |
+
# 限制推理步数
|
48 |
+
if inference_steps > 50:
|
49 |
+
inference_steps = 50
|
50 |
+
|
51 |
+
# 生成图像
|
52 |
+
image = pipe(
|
53 |
+
prompt,
|
54 |
+
height=height,
|
55 |
+
width=width,
|
56 |
+
negative_prompt=negative_prompt,
|
57 |
+
guidance_scale=guidance_scale,
|
58 |
+
generator=generator,
|
59 |
+
num_inference_steps=inference_steps
|
60 |
+
).images[0]
|
61 |
+
|
62 |
+
return image
|
63 |
+
|
64 |
+
|
65 |
+
# 示例调用
|
66 |
+
if __name__ == "__main__":
|
67 |
+
prompt = "zhongli"
|
68 |
+
image = generate_image(prompt)
|
69 |
+
image
|
70 |
+
|
71 |
+
prompt = "Neuvillette"
|
72 |
+
image = generate_image(prompt)
|
73 |
+
image
|