votepurchase commited on
Commit
126c369
·
verified ·
1 Parent(s): bddef80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -5
app.py CHANGED
@@ -7,6 +7,7 @@ import random
7
  from diffusers import StableDiffusionXLPipeline
8
  from diffusers import EulerAncestralDiscreteScheduler
9
  import torch
 
10
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
@@ -27,21 +28,59 @@ pipe.text_encoder_2.to(torch.float16)
27
  pipe.vae.to(torch.float16)
28
  pipe.unet.to(torch.float16)
29
 
 
 
 
 
 
 
 
 
 
30
  MAX_SEED = np.iinfo(np.int32).max
31
  MAX_IMAGE_SIZE = 1216
 
 
 
 
 
 
 
 
 
 
32
 
33
  @spaces.GPU
34
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
35
- # Check and truncate prompt if too long (CLIP can only handle 77 tokens)
36
- if len(prompt.split()) > 60: # Rough estimate to avoid exceeding token limit
37
- print("Warning: Prompt may be too long and will be truncated by the model")
38
-
39
  if randomize_seed:
40
  seed = random.randint(0, MAX_SEED)
41
 
42
  generator = torch.Generator(device=device).manual_seed(seed)
43
 
44
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  output_image = pipe(
46
  prompt=prompt,
47
  negative_prompt=negative_prompt,
@@ -76,7 +115,7 @@ with gr.Blocks(css=css) as demo:
76
  label="Prompt",
77
  show_label=False,
78
  max_lines=1,
79
- placeholder="Enter your prompt (keep it under 60 words for best results)",
80
  container=False,
81
  )
82
 
 
7
  from diffusers import StableDiffusionXLPipeline
8
  from diffusers import EulerAncestralDiscreteScheduler
9
  import torch
10
+ from compel import Compel, ReturnedEmbeddingsType
11
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
 
28
  pipe.vae.to(torch.float16)
29
  pipe.unet.to(torch.float16)
30
 
31
+ # 追加: Initialize Compel for long prompt processing
32
+ compel = Compel(
33
+ tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
34
+ text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
35
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
36
+ requires_pooled=[False, True],
37
+ truncate_long_prompts=False
38
+ )
39
+
40
  MAX_SEED = np.iinfo(np.int32).max
41
  MAX_IMAGE_SIZE = 1216
42
+
43
+ # 追加: Simple long prompt processing function
44
+ def process_long_prompt(prompt, negative_prompt=""):
45
+ """Simple long prompt processing using Compel"""
46
+ try:
47
+ conditioning, pooled = compel([prompt, negative_prompt])
48
+ return conditioning, pooled
49
+ except Exception as e:
50
+ print(f"Long prompt processing failed: {e}, falling back to standard processing")
51
+ return None, None
52
 
53
  @spaces.GPU
54
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
55
+ # 変更: Remove the 60-word limit warning and add long prompt check
56
+ use_long_prompt = len(prompt.split()) > 60 or len(prompt) > 300
57
+
 
58
  if randomize_seed:
59
  seed = random.randint(0, MAX_SEED)
60
 
61
  generator = torch.Generator(device=device).manual_seed(seed)
62
 
63
  try:
64
+ # 追加: Try long prompt processing first if prompt is long
65
+ if use_long_prompt:
66
+ print("Using long prompt processing...")
67
+ conditioning, pooled = process_long_prompt(prompt, negative_prompt)
68
+
69
+ if conditioning is not None:
70
+ output_image = pipe(
71
+ prompt_embeds=conditioning[0:1],
72
+ pooled_prompt_embeds=pooled[0:1],
73
+ negative_prompt_embeds=conditioning[1:2],
74
+ negative_pooled_prompt_embeds=pooled[1:2],
75
+ guidance_scale=guidance_scale,
76
+ num_inference_steps=num_inference_steps,
77
+ width=width,
78
+ height=height,
79
+ generator=generator
80
+ ).images[0]
81
+ return output_image
82
+
83
+ # Fall back to standard processing
84
  output_image = pipe(
85
  prompt=prompt,
86
  negative_prompt=negative_prompt,
 
115
  label="Prompt",
116
  show_label=False,
117
  max_lines=1,
118
+ placeholder="Enter your prompt (long prompts are automatically supported)", # 変更: Updated placeholder
119
  container=False,
120
  )
121