retwpay commited on
Commit
bb71e61
·
verified ·
1 Parent(s): 5908e1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -21
app.py CHANGED
@@ -4,44 +4,99 @@ import numpy as np
4
  import PIL.Image
5
  from PIL import Image
6
  import random
7
- from diffusers import ControlNetModel, StableDiffusionXLPipeline, AutoencoderKL
8
- from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler
9
- import cv2
10
  import torch
 
11
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
 
14
  pipe = StableDiffusionXLPipeline.from_pretrained(
15
  "votepurchase/noobreal_v21",
16
  torch_dtype=torch.float16,
 
 
17
  )
18
 
19
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
20
  pipe.to(device)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
  MAX_IMAGE_SIZE = 1216
24
 
 
 
 
 
 
 
 
 
 
25
 
26
  @spaces.GPU
27
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
28
-
 
 
29
  if randomize_seed:
30
  seed = random.randint(0, MAX_SEED)
31
 
32
- generator = torch.Generator().manual_seed(seed)
33
-
34
- output_image = pipe(
35
- prompt=prompt,
36
- negative_prompt=negative_prompt,
37
- guidance_scale=guidance_scale,
38
- num_inference_steps=num_inference_steps,
39
- width=width,
40
- height=height,
41
- generator=generator
42
- ).images[0]
43
-
44
- return output_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
  css = """
@@ -60,7 +115,7 @@ with gr.Blocks(css=css) as demo:
60
  label="Prompt",
61
  show_label=False,
62
  max_lines=1,
63
- placeholder="Enter your prompt",
64
  container=False,
65
  )
66
 
@@ -93,7 +148,7 @@ with gr.Blocks(css=css) as demo:
93
  minimum=256,
94
  maximum=MAX_IMAGE_SIZE,
95
  step=32,
96
- value=1024,#832,
97
  )
98
 
99
  height = gr.Slider(
@@ -101,7 +156,7 @@ with gr.Blocks(css=css) as demo:
101
  minimum=256,
102
  maximum=MAX_IMAGE_SIZE,
103
  step=32,
104
- value=1024,#1216,
105
  )
106
 
107
  with gr.Row():
@@ -121,7 +176,7 @@ with gr.Blocks(css=css) as demo:
121
  value=28,
122
  )
123
 
124
- run_button.click(#lambda x: None, inputs=None, outputs=result).then(
125
  fn=infer,
126
  inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
127
  outputs=[result]
 
4
  import PIL.Image
5
  from PIL import Image
6
  import random
7
+ from diffusers import StableDiffusionXLPipeline
8
+ from diffusers import EulerAncestralDiscreteScheduler
 
9
  import torch
10
+ from compel import Compel, ReturnedEmbeddingsType
11
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
+ # Make sure to use torch.float16 consistently throughout the pipeline
15
  pipe = StableDiffusionXLPipeline.from_pretrained(
16
  "votepurchase/noobreal_v21",
17
  torch_dtype=torch.float16,
18
+ variant="fp16", # Explicitly use fp16 variant
19
+ use_safetensors=True # Use safetensors if available
20
  )
21
 
22
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
23
  pipe.to(device)
24
 
25
+ # Force all components to use the same dtype
26
+ pipe.text_encoder.to(torch.float16)
27
+ pipe.text_encoder_2.to(torch.float16)
28
+ pipe.vae.to(torch.float16)
29
+ pipe.unet.to(torch.float16)
30
+
31
+ # 追加: Initialize Compel for long prompt processing
32
+ compel = Compel(
33
+ tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
34
+ text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
35
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
36
+ requires_pooled=[False, True],
37
+ truncate_long_prompts=False
38
+ )
39
+
40
  MAX_SEED = np.iinfo(np.int32).max
41
  MAX_IMAGE_SIZE = 1216
42
 
43
+ # 追加: Simple long prompt processing function
44
+ def process_long_prompt(prompt, negative_prompt=""):
45
+ """Simple long prompt processing using Compel"""
46
+ try:
47
+ conditioning, pooled = compel([prompt, negative_prompt])
48
+ return conditioning, pooled
49
+ except Exception as e:
50
+ print(f"Long prompt processing failed: {e}, falling back to standard processing")
51
+ return None, None
52
 
53
  @spaces.GPU
54
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
55
+ # 変更: Remove the 60-word limit warning and add long prompt check
56
+ use_long_prompt = len(prompt.split()) > 60 or len(prompt) > 300
57
+
58
  if randomize_seed:
59
  seed = random.randint(0, MAX_SEED)
60
 
61
+ generator = torch.Generator(device=device).manual_seed(seed)
62
+
63
+ try:
64
+ # 追加: Try long prompt processing first if prompt is long
65
+ if use_long_prompt:
66
+ print("Using long prompt processing...")
67
+ conditioning, pooled = process_long_prompt(prompt, negative_prompt)
68
+
69
+ if conditioning is not None:
70
+ output_image = pipe(
71
+ prompt_embeds=conditioning[0:1],
72
+ pooled_prompt_embeds=pooled[0:1],
73
+ negative_prompt_embeds=conditioning[1:2],
74
+ negative_pooled_prompt_embeds=pooled[1:2],
75
+ guidance_scale=guidance_scale,
76
+ num_inference_steps=num_inference_steps,
77
+ width=width,
78
+ height=height,
79
+ generator=generator
80
+ ).images[0]
81
+ return output_image
82
+
83
+ # Fall back to standard processing
84
+ output_image = pipe(
85
+ prompt=prompt,
86
+ negative_prompt=negative_prompt,
87
+ guidance_scale=guidance_scale,
88
+ num_inference_steps=num_inference_steps,
89
+ width=width,
90
+ height=height,
91
+ generator=generator
92
+ ).images[0]
93
+
94
+ return output_image
95
+ except RuntimeError as e:
96
+ print(f"Error during generation: {e}")
97
+ # Return a blank image with error message
98
+ error_img = Image.new('RGB', (width, height), color=(0, 0, 0))
99
+ return error_img
100
 
101
 
102
  css = """
 
115
  label="Prompt",
116
  show_label=False,
117
  max_lines=1,
118
+ placeholder="Enter your prompt (long prompts are automatically supported)", # 変更: Updated placeholder
119
  container=False,
120
  )
121
 
 
148
  minimum=256,
149
  maximum=MAX_IMAGE_SIZE,
150
  step=32,
151
+ value=1024,
152
  )
153
 
154
  height = gr.Slider(
 
156
  minimum=256,
157
  maximum=MAX_IMAGE_SIZE,
158
  step=32,
159
+ value=1024,
160
  )
161
 
162
  with gr.Row():
 
176
  value=28,
177
  )
178
 
179
+ run_button.click(
180
  fn=infer,
181
  inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
182
  outputs=[result]