Spaces:
Configuration error
Configuration error
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,7 +15,7 @@ from transformers import (
|
|
| 15 |
CLIPTextModelWithProjection,
|
| 16 |
CLIPVisionModelWithProjection,
|
| 17 |
CLIPImageProcessor,
|
| 18 |
-
CLIPTokenizer
|
| 19 |
)
|
| 20 |
|
| 21 |
from transformers import CLIPTokenizer
|
|
@@ -33,10 +33,11 @@ if torch.cuda.is_available():
|
|
| 33 |
__device__ = "cuda"
|
| 34 |
__dtype__ = torch.float16
|
| 35 |
|
|
|
|
| 36 |
class Model:
|
| 37 |
def __init__(self):
|
| 38 |
self.device = __device__
|
| 39 |
-
|
| 40 |
self.text_encoder = (
|
| 41 |
CLIPTextModelWithProjection.from_pretrained(
|
| 42 |
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
|
@@ -65,102 +66,48 @@ class Model:
|
|
| 65 |
self.pipe = DiffusionPipeline.from_pretrained(
|
| 66 |
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=__dtype__
|
| 67 |
).to(self.device)
|
| 68 |
-
|
| 69 |
-
def inference(self, raw_data):
|
|
|
|
| 70 |
image_emb, negative_image_emb = self.pipe_prior(
|
| 71 |
raw_data=raw_data,
|
|
|
|
| 72 |
).to_tuple()
|
| 73 |
image = self.pipe(
|
| 74 |
image_embeds=image_emb,
|
| 75 |
negative_image_embeds=negative_image_emb,
|
| 76 |
num_inference_steps=50,
|
| 77 |
-
guidance_scale=
|
|
|
|
| 78 |
).images[0]
|
| 79 |
return image
|
| 80 |
-
|
| 81 |
-
def
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
data: dict[str, Any] = {}
|
| 91 |
-
data['text'] = text
|
| 92 |
-
|
| 93 |
-
txt = self.tokenizer(
|
| 94 |
-
text,
|
| 95 |
-
padding='max_length',
|
| 96 |
-
truncation=True,
|
| 97 |
-
return_tensors='pt',
|
| 98 |
-
)
|
| 99 |
-
txt_items = {k: v.to(device) for k, v in txt.items()}
|
| 100 |
-
new_feats = self.text_encoder(**txt_items)
|
| 101 |
-
new_last_hidden_states = new_feats.last_hidden_state[0].cpu().numpy()
|
| 102 |
-
|
| 103 |
-
plt.imshow(image)
|
| 104 |
-
plt.title('image')
|
| 105 |
-
plt.savefig('image_testt2.png')
|
| 106 |
-
plt.show()
|
| 107 |
-
|
| 108 |
-
mask_img = self.image_processor(image, return_tensors="pt").to(__device__)
|
| 109 |
-
vision_feats = self.vision_encoder(
|
| 110 |
-
**mask_img
|
| 111 |
-
).image_embeds
|
| 112 |
-
|
| 113 |
-
entity_tokens = self.tokenizer(keyword)["input_ids"][1:-1]
|
| 114 |
-
for tid in entity_tokens:
|
| 115 |
-
indices = np.where(txt_items["input_ids"][0].cpu().numpy() == tid)[0]
|
| 116 |
-
new_last_hidden_states[indices] = vision_feats[0].cpu().numpy()
|
| 117 |
-
print(indices)
|
| 118 |
-
|
| 119 |
-
if image2 is not None:
|
| 120 |
-
mask_img2 = self.image_processor(image2, return_tensors="pt").to(__device__)
|
| 121 |
-
vision_feats2 = self.vision_encoder(
|
| 122 |
-
**mask_img2
|
| 123 |
-
).image_embeds
|
| 124 |
-
if keyword2 is not None:
|
| 125 |
-
entity_tokens = self.tokenizer(keyword2)["input_ids"][1:-1]
|
| 126 |
-
for tid in entity_tokens:
|
| 127 |
-
indices = np.where(txt_items["input_ids"][0].cpu().numpy() == tid)[0]
|
| 128 |
-
new_last_hidden_states[indices] = vision_feats2[0].cpu().numpy()
|
| 129 |
-
print(indices)
|
| 130 |
-
|
| 131 |
-
text_feats = {
|
| 132 |
-
"prompt_embeds": new_feats.text_embeds.to(__device__),
|
| 133 |
-
"text_encoder_hidden_states": torch.tensor(new_last_hidden_states).unsqueeze(0).to(__device__),
|
| 134 |
-
"text_mask": txt_items["attention_mask"].to(__device__),
|
| 135 |
-
}
|
| 136 |
-
return text_feats
|
| 137 |
-
|
| 138 |
-
def run(self,
|
| 139 |
-
image: dict[str, PIL.Image.Image],
|
| 140 |
-
keyword: str,
|
| 141 |
-
image2: dict[str, PIL.Image.Image],
|
| 142 |
-
keyword2: str,
|
| 143 |
-
text: str,
|
| 144 |
-
):
|
| 145 |
-
|
| 146 |
-
# aug_feats = self.process_data(image["composite"], keyword, image2["composite"], keyword2, text)
|
| 147 |
sub_imgs = [image["composite"]]
|
| 148 |
-
if image2:
|
| 149 |
-
sub_imgs.append(image2["composite"])
|
| 150 |
sun_keywords = [keyword]
|
| 151 |
-
if keyword2:
|
| 152 |
sun_keywords.append(keyword2)
|
|
|
|
|
|
|
| 153 |
raw_data = {
|
| 154 |
"prompt": text,
|
| 155 |
"subject_images": sub_imgs,
|
| 156 |
-
"subject_keywords": sun_keywords
|
| 157 |
}
|
| 158 |
-
image = self.inference(raw_data)
|
| 159 |
return image
|
| 160 |
|
| 161 |
-
def create_demo():
|
| 162 |
|
| 163 |
-
|
|
|
|
| 164 |
1. Upload your image.
|
| 165 |
2. <span style='color: red;'>**Upload a masked subject image with white blankspace or whiten out manually using brush tool.**
|
| 166 |
3. Input a Keyword i.e. 'Dog'
|
|
@@ -169,7 +116,7 @@ def create_demo():
|
|
| 169 |
4-2. Input the Keyword i.e. 'Sunglasses'
|
| 170 |
3. Input proper text prompts, such as "A photo of Dog" or "A Dog wearing sunglasses", Please use the same keyword in the prompt.
|
| 171 |
4. Click the Run button.
|
| 172 |
-
|
| 173 |
|
| 174 |
model = Model()
|
| 175 |
|
|
@@ -180,6 +127,8 @@ def create_demo():
|
|
| 180 |
|
| 181 |
<p style="text-align: center; color: red;">This demo is currently hosted on either a small GPU or CPU. We will soon provide high-end GPU support.</p>
|
| 182 |
<p style="text-align: center; color: red;">Please follow the instructions from here to run it locally: <a href="https://github.com/eclipse-t2i/lambda-eclipse-inference">GitHub Inference Code</a></p>
|
|
|
|
|
|
|
| 183 |
"""
|
| 184 |
)
|
| 185 |
gr.Markdown(USAGE)
|
|
@@ -187,28 +136,41 @@ def create_demo():
|
|
| 187 |
with gr.Column():
|
| 188 |
with gr.Group():
|
| 189 |
gr.Markdown(
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
keyword = gr.Text(
|
| 193 |
-
label=
|
| 194 |
placeholder='e.g. "Dog", "Goofie"',
|
| 195 |
-
info=
|
|
|
|
| 196 |
gr.Markdown(
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
label=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
placeholder='e.g. "Sunglasses", "Grand Canyon"',
|
| 202 |
-
info=
|
|
|
|
| 203 |
prompt = gr.Text(
|
| 204 |
-
label=
|
| 205 |
placeholder='e.g. "A photo of dog", "A dog wearing sunglasses"',
|
| 206 |
-
info=
|
|
|
|
| 207 |
|
| 208 |
-
run_button = gr.Button(
|
| 209 |
|
| 210 |
with gr.Column():
|
| 211 |
-
result = gr.Image(label=
|
| 212 |
|
| 213 |
inputs = [
|
| 214 |
image,
|
|
@@ -217,18 +179,77 @@ def create_demo():
|
|
| 217 |
keyword2,
|
| 218 |
prompt,
|
| 219 |
]
|
| 220 |
-
|
| 221 |
gr.Examples(
|
| 222 |
-
examples=[
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
fn=model.run,
|
| 225 |
outputs=result,
|
| 226 |
)
|
| 227 |
-
|
| 228 |
run_button.click(fn=model.run, inputs=inputs, outputs=result)
|
| 229 |
return demo
|
| 230 |
|
| 231 |
|
| 232 |
-
if __name__ ==
|
| 233 |
demo = create_demo()
|
| 234 |
-
demo.queue(max_size=20).launch()
|
|
|
|
| 15 |
CLIPTextModelWithProjection,
|
| 16 |
CLIPVisionModelWithProjection,
|
| 17 |
CLIPImageProcessor,
|
| 18 |
+
CLIPTokenizer,
|
| 19 |
)
|
| 20 |
|
| 21 |
from transformers import CLIPTokenizer
|
|
|
|
| 33 |
__device__ = "cuda"
|
| 34 |
__dtype__ = torch.float16
|
| 35 |
|
| 36 |
+
|
| 37 |
class Model:
|
| 38 |
def __init__(self):
|
| 39 |
self.device = __device__
|
| 40 |
+
|
| 41 |
self.text_encoder = (
|
| 42 |
CLIPTextModelWithProjection.from_pretrained(
|
| 43 |
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
|
|
|
| 66 |
self.pipe = DiffusionPipeline.from_pretrained(
|
| 67 |
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=__dtype__
|
| 68 |
).to(self.device)
|
| 69 |
+
|
| 70 |
+
def inference(self, raw_data, seed):
|
| 71 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
| 72 |
image_emb, negative_image_emb = self.pipe_prior(
|
| 73 |
raw_data=raw_data,
|
| 74 |
+
generator=generator,
|
| 75 |
).to_tuple()
|
| 76 |
image = self.pipe(
|
| 77 |
image_embeds=image_emb,
|
| 78 |
negative_image_embeds=negative_image_emb,
|
| 79 |
num_inference_steps=50,
|
| 80 |
+
guidance_scale=7.5,
|
| 81 |
+
generator=generator,
|
| 82 |
).images[0]
|
| 83 |
return image
|
| 84 |
+
|
| 85 |
+
def run(
|
| 86 |
+
self,
|
| 87 |
+
image: dict[str, PIL.Image.Image],
|
| 88 |
+
keyword: str,
|
| 89 |
+
image2: dict[str, PIL.Image.Image],
|
| 90 |
+
keyword2: str,
|
| 91 |
+
text: str,
|
| 92 |
+
seed: int,
|
| 93 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
sub_imgs = [image["composite"]]
|
|
|
|
|
|
|
| 95 |
sun_keywords = [keyword]
|
| 96 |
+
if keyword2 and keyword2 != "no subject":
|
| 97 |
sun_keywords.append(keyword2)
|
| 98 |
+
if image2:
|
| 99 |
+
sub_imgs.append(image2["composite"])
|
| 100 |
raw_data = {
|
| 101 |
"prompt": text,
|
| 102 |
"subject_images": sub_imgs,
|
| 103 |
+
"subject_keywords": sun_keywords,
|
| 104 |
}
|
| 105 |
+
image = self.inference(raw_data, seed)
|
| 106 |
return image
|
| 107 |
|
|
|
|
| 108 |
|
| 109 |
+
def create_demo():
|
| 110 |
+
USAGE = """## To run the demo, you should:
|
| 111 |
1. Upload your image.
|
| 112 |
2. <span style='color: red;'>**Upload a masked subject image with white blankspace or whiten out manually using brush tool.**
|
| 113 |
3. Input a Keyword i.e. 'Dog'
|
|
|
|
| 116 |
4-2. Input the Keyword i.e. 'Sunglasses'
|
| 117 |
3. Input proper text prompts, such as "A photo of Dog" or "A Dog wearing sunglasses", Please use the same keyword in the prompt.
|
| 118 |
4. Click the Run button.
|
| 119 |
+
"""
|
| 120 |
|
| 121 |
model = Model()
|
| 122 |
|
|
|
|
| 127 |
|
| 128 |
<p style="text-align: center; color: red;">This demo is currently hosted on either a small GPU or CPU. We will soon provide high-end GPU support.</p>
|
| 129 |
<p style="text-align: center; color: red;">Please follow the instructions from here to run it locally: <a href="https://github.com/eclipse-t2i/lambda-eclipse-inference">GitHub Inference Code</a></p>
|
| 130 |
+
|
| 131 |
+
<a href="https://colab.research.google.com/drive/1VcqzXZmilntec3AsIyzCqlstEhX4Pa1o?usp=sharing" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
|
| 132 |
"""
|
| 133 |
)
|
| 134 |
gr.Markdown(USAGE)
|
|
|
|
| 136 |
with gr.Column():
|
| 137 |
with gr.Group():
|
| 138 |
gr.Markdown(
|
| 139 |
+
"Upload your first masked subject image or mask out marginal space"
|
| 140 |
+
)
|
| 141 |
+
image = gr.ImageEditor(
|
| 142 |
+
label="Input",
|
| 143 |
+
type="pil",
|
| 144 |
+
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"),
|
| 145 |
+
)
|
| 146 |
keyword = gr.Text(
|
| 147 |
+
label="Keyword",
|
| 148 |
placeholder='e.g. "Dog", "Goofie"',
|
| 149 |
+
info="Keyword for first subject",
|
| 150 |
+
)
|
| 151 |
gr.Markdown(
|
| 152 |
+
"For Multi-Subject generation : Upload your second masked subject image or mask out marginal space"
|
| 153 |
+
)
|
| 154 |
+
image2 = gr.ImageEditor(
|
| 155 |
+
label="Input",
|
| 156 |
+
type="pil",
|
| 157 |
+
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"),
|
| 158 |
+
)
|
| 159 |
+
keyword2 = gr.Text(
|
| 160 |
+
label="Keyword",
|
| 161 |
placeholder='e.g. "Sunglasses", "Grand Canyon"',
|
| 162 |
+
info="Keyword for second subject",
|
| 163 |
+
)
|
| 164 |
prompt = gr.Text(
|
| 165 |
+
label="Prompt",
|
| 166 |
placeholder='e.g. "A photo of dog", "A dog wearing sunglasses"',
|
| 167 |
+
info="Keep the keywords used previously in the prompt",
|
| 168 |
+
)
|
| 169 |
|
| 170 |
+
run_button = gr.Button("Run")
|
| 171 |
|
| 172 |
with gr.Column():
|
| 173 |
+
result = gr.Image(label="Result")
|
| 174 |
|
| 175 |
inputs = [
|
| 176 |
image,
|
|
|
|
| 179 |
keyword2,
|
| 180 |
prompt,
|
| 181 |
]
|
| 182 |
+
|
| 183 |
gr.Examples(
|
| 184 |
+
examples=[
|
| 185 |
+
[
|
| 186 |
+
os.path.join(os.path.dirname(__file__), "./assets/luffy.jpg"),
|
| 187 |
+
"luffy",
|
| 188 |
+
os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
|
| 189 |
+
"no subject",
|
| 190 |
+
"luffy holding a sword",
|
| 191 |
+
],
|
| 192 |
+
[
|
| 193 |
+
os.path.join(os.path.dirname(__file__), "./assets/luffy.jpg"),
|
| 194 |
+
"luffy",
|
| 195 |
+
os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
|
| 196 |
+
"no subject",
|
| 197 |
+
"luffy in the living room",
|
| 198 |
+
],
|
| 199 |
+
[
|
| 200 |
+
os.path.join(os.path.dirname(__file__), "./assets/teapot.jpg"),
|
| 201 |
+
"teapot",
|
| 202 |
+
os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
|
| 203 |
+
"no subject",
|
| 204 |
+
"teapot on a cobblestone street",
|
| 205 |
+
],
|
| 206 |
+
[
|
| 207 |
+
os.path.join(os.path.dirname(__file__), "./assets/trex.jpg"),
|
| 208 |
+
"trex",
|
| 209 |
+
os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
|
| 210 |
+
"no subject",
|
| 211 |
+
"trex near a river",
|
| 212 |
+
],
|
| 213 |
+
[
|
| 214 |
+
os.path.join(os.path.dirname(__file__), "./assets/cat.png"),
|
| 215 |
+
"cat",
|
| 216 |
+
os.path.join(
|
| 217 |
+
os.path.dirname(__file__), "./assets/blue_sunglasses.png"
|
| 218 |
+
),
|
| 219 |
+
"glasses",
|
| 220 |
+
"A cat wearing glasses on a snowy field",
|
| 221 |
+
],
|
| 222 |
+
[
|
| 223 |
+
os.path.join(os.path.dirname(__file__), "./assets/statue.jpg"),
|
| 224 |
+
"statue",
|
| 225 |
+
os.path.join(os.path.dirname(__file__), "./assets/toilet.jpg"),
|
| 226 |
+
"toilet",
|
| 227 |
+
"statue sitting on a toilet",
|
| 228 |
+
],
|
| 229 |
+
[
|
| 230 |
+
os.path.join(os.path.dirname(__file__), "./assets/teddy.jpg"),
|
| 231 |
+
"teddy",
|
| 232 |
+
os.path.join(os.path.dirname(__file__), "./assets/luffy_hat.jpg"),
|
| 233 |
+
"hat",
|
| 234 |
+
"a teddy wearing the hat at a beach",
|
| 235 |
+
],
|
| 236 |
+
[
|
| 237 |
+
os.path.join(os.path.dirname(__file__), "./assets/chair.jpg"),
|
| 238 |
+
"chair",
|
| 239 |
+
os.path.join(os.path.dirname(__file__), "./assets/table.jpg"),
|
| 240 |
+
"table",
|
| 241 |
+
"a chair and table in living room",
|
| 242 |
+
],
|
| 243 |
+
],
|
| 244 |
+
inputs=inputs,
|
| 245 |
fn=model.run,
|
| 246 |
outputs=result,
|
| 247 |
)
|
| 248 |
+
|
| 249 |
run_button.click(fn=model.run, inputs=inputs, outputs=result)
|
| 250 |
return demo
|
| 251 |
|
| 252 |
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
demo = create_demo()
|
| 255 |
+
demo.queue(max_size=20).launch()
|