Spaces:
Runtime error
Runtime error
Commit
·
876dc56
1
Parent(s):
fc81a43
update
Browse files- .gitignore +5 -0
- app.py +3 -2
- app_canny.py +12 -12
- autoregressive/models/gpt_t2i.py +0 -2
- checkpoints/flan-t5-xl/flan-t5-xl/spiece.model +3 -0
- language/t5.py +6 -4
- model.py +11 -5
.gitignore
CHANGED
|
@@ -154,6 +154,11 @@ dmypy.json
|
|
| 154 |
# Cython debug symbols
|
| 155 |
cython_debug/
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
# PyCharm
|
| 158 |
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 159 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
|
|
|
| 154 |
# Cython debug symbols
|
| 155 |
cython_debug/
|
| 156 |
|
| 157 |
+
*.safetensors
|
| 158 |
+
*.lock
|
| 159 |
+
*.bin
|
| 160 |
+
*.pt
|
| 161 |
+
*.json
|
| 162 |
# PyCharm
|
| 163 |
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 164 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
app.py
CHANGED
|
@@ -18,6 +18,7 @@ DESCRIPTION = "# [ControlAR: Controllable Image Generation with Autoregressive M
|
|
| 18 |
SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
|
| 19 |
model = Model()
|
| 20 |
device = "cuda"
|
|
|
|
| 21 |
with gr.Blocks(css="style.css") as demo:
|
| 22 |
gr.Markdown(DESCRIPTION)
|
| 23 |
gr.DuplicateButton(
|
|
@@ -26,8 +27,8 @@ with gr.Blocks(css="style.css") as demo:
|
|
| 26 |
visible=SHOW_DUPLICATE_BUTTON,
|
| 27 |
)
|
| 28 |
with gr.Tabs():
|
| 29 |
-
with gr.TabItem("Depth"):
|
| 30 |
-
|
| 31 |
with gr.TabItem("Canny"):
|
| 32 |
create_demo_canny(model.process_canny)
|
| 33 |
|
|
|
|
| 18 |
SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
|
| 19 |
model = Model()
|
| 20 |
device = "cuda"
|
| 21 |
+
model.to(device)
|
| 22 |
with gr.Blocks(css="style.css") as demo:
|
| 23 |
gr.Markdown(DESCRIPTION)
|
| 24 |
gr.DuplicateButton(
|
|
|
|
| 27 |
visible=SHOW_DUPLICATE_BUTTON,
|
| 28 |
)
|
| 29 |
with gr.Tabs():
|
| 30 |
+
# with gr.TabItem("Depth"):
|
| 31 |
+
# create_demo_depth(model.process_depth)
|
| 32 |
with gr.TabItem("Canny"):
|
| 33 |
create_demo_canny(model.process_canny)
|
| 34 |
|
app_canny.py
CHANGED
|
@@ -104,18 +104,18 @@ def create_demo(process):
|
|
| 104 |
canny_low_threshold,
|
| 105 |
canny_high_threshold,
|
| 106 |
]
|
| 107 |
-
prompt.submit(
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
).then(
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
)
|
| 119 |
run_button.click(
|
| 120 |
fn=randomize_seed_fn,
|
| 121 |
inputs=[seed, randomize_seed],
|
|
|
|
| 104 |
canny_low_threshold,
|
| 105 |
canny_high_threshold,
|
| 106 |
]
|
| 107 |
+
# prompt.submit(
|
| 108 |
+
# fn=randomize_seed_fn,
|
| 109 |
+
# inputs=[seed, randomize_seed],
|
| 110 |
+
# outputs=seed,
|
| 111 |
+
# queue=False,
|
| 112 |
+
# api_name=False,
|
| 113 |
+
# ).then(
|
| 114 |
+
# fn=process,
|
| 115 |
+
# inputs=inputs,
|
| 116 |
+
# outputs=result,
|
| 117 |
+
# api_name=False,
|
| 118 |
+
# )
|
| 119 |
run_button.click(
|
| 120 |
fn=randomize_seed_fn,
|
| 121 |
inputs=[seed, randomize_seed],
|
autoregressive/models/gpt_t2i.py
CHANGED
|
@@ -375,8 +375,6 @@ class Transformer(nn.Module):
|
|
| 375 |
# Zero-out output layers:
|
| 376 |
nn.init.constant_(self.output.weight, 0)
|
| 377 |
|
| 378 |
-
|
| 379 |
-
|
| 380 |
def _init_weights(self, module):
|
| 381 |
std = self.config.initializer_range
|
| 382 |
if isinstance(module, nn.Linear):
|
|
|
|
| 375 |
# Zero-out output layers:
|
| 376 |
nn.init.constant_(self.output.weight, 0)
|
| 377 |
|
|
|
|
|
|
|
| 378 |
def _init_weights(self, module):
|
| 379 |
std = self.config.initializer_range
|
| 380 |
if isinstance(module, nn.Linear):
|
checkpoints/flan-t5-xl/flan-t5-xl/spiece.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
|
| 3 |
+
size 791656
|
language/t5.py
CHANGED
|
@@ -18,7 +18,7 @@ class T5Embedder:
|
|
| 18 |
|
| 19 |
def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
|
| 20 |
t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120):
|
| 21 |
-
self.device = torch.device(
|
| 22 |
self.torch_dtype = torch_dtype or torch.bfloat16
|
| 23 |
if t5_model_kwargs is None:
|
| 24 |
t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
|
|
@@ -53,6 +53,7 @@ class T5Embedder:
|
|
| 53 |
print(tokenizer_path)
|
| 54 |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
| 55 |
self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
|
|
|
|
| 56 |
self.model_max_length = model_max_length
|
| 57 |
|
| 58 |
def get_text_embeddings(self, texts):
|
|
@@ -72,11 +73,12 @@ class T5Embedder:
|
|
| 72 |
text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']
|
| 73 |
|
| 74 |
with torch.no_grad():
|
|
|
|
| 75 |
text_encoder_embs = self.model(
|
| 76 |
-
input_ids=text_tokens_and_mask['input_ids'].to(self.device),
|
| 77 |
-
attention_mask=text_tokens_and_mask['attention_mask'].to(self.device),
|
| 78 |
)['last_hidden_state'].detach()
|
| 79 |
-
return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device)
|
| 80 |
|
| 81 |
def text_preprocessing(self, text):
|
| 82 |
if self.use_text_preprocessing:
|
|
|
|
| 18 |
|
| 19 |
def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
|
| 20 |
t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120):
|
| 21 |
+
self.device = torch.device('cuda:0')
|
| 22 |
self.torch_dtype = torch_dtype or torch.bfloat16
|
| 23 |
if t5_model_kwargs is None:
|
| 24 |
t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
|
|
|
|
| 53 |
print(tokenizer_path)
|
| 54 |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
| 55 |
self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
|
| 56 |
+
self.model.to('cuda')
|
| 57 |
self.model_max_length = model_max_length
|
| 58 |
|
| 59 |
def get_text_embeddings(self, texts):
|
|
|
|
| 73 |
text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']
|
| 74 |
|
| 75 |
with torch.no_grad():
|
| 76 |
+
print("t5:", self.model.device)
|
| 77 |
text_encoder_embs = self.model(
|
| 78 |
+
input_ids=text_tokens_and_mask['input_ids'].to(self.model.device),
|
| 79 |
+
attention_mask=text_tokens_and_mask['attention_mask'].to(self.model.device),
|
| 80 |
)['last_hidden_state'].detach()
|
| 81 |
+
return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.model.device)
|
| 82 |
|
| 83 |
def text_preprocessing(self, text):
|
| 84 |
if self.use_text_preprocessing:
|
model.py
CHANGED
|
@@ -40,7 +40,7 @@ class Model:
|
|
| 40 |
|
| 41 |
def __init__(self):
|
| 42 |
self.device = torch.device(
|
| 43 |
-
"cuda:0"
|
| 44 |
self.base_model_id = ""
|
| 45 |
self.task_name = ""
|
| 46 |
self.vq_model = self.load_vq()
|
|
@@ -48,12 +48,17 @@ class Model:
|
|
| 48 |
self.gpt_model_canny = self.load_gpt(condition_type='canny')
|
| 49 |
self.gpt_model_depth = self.load_gpt(condition_type='depth')
|
| 50 |
self.get_control_canny = CannyDetector()
|
| 51 |
-
self.get_control_depth = MidasDetector(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
def load_vq(self):
|
| 54 |
vq_model = VQ_models["VQ-16"](codebook_size=16384,
|
| 55 |
codebook_embed_dim=8)
|
| 56 |
-
vq_model.to(
|
| 57 |
vq_model.eval()
|
| 58 |
checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt",
|
| 59 |
map_location="cpu")
|
|
@@ -71,7 +76,7 @@ class Model:
|
|
| 71 |
cls_token_num=120,
|
| 72 |
model_type='t2i',
|
| 73 |
condition_type=condition_type,
|
| 74 |
-
).to(device=
|
| 75 |
|
| 76 |
model_weight = load_file(gpt_ckpt)
|
| 77 |
gpt_model.load_state_dict(model_weight, strict=False)
|
|
@@ -82,7 +87,7 @@ class Model:
|
|
| 82 |
def load_t5(self):
|
| 83 |
precision = torch.bfloat16
|
| 84 |
t5_model = T5Embedder(
|
| 85 |
-
device=
|
| 86 |
local_cache=False,
|
| 87 |
cache_dir='checkpoints/flan-t5-xl',
|
| 88 |
dir_or_name='flan-t5-xl',
|
|
@@ -134,6 +139,7 @@ class Model:
|
|
| 134 |
c_emb_masks = new_emb_masks
|
| 135 |
qzshape = [len(c_indices), 8, H // 16, W // 16]
|
| 136 |
t1 = time.time()
|
|
|
|
| 137 |
index_sample = generate(
|
| 138 |
self.gpt_model_canny,
|
| 139 |
c_indices,
|
|
|
|
| 40 |
|
| 41 |
def __init__(self):
|
| 42 |
self.device = torch.device(
|
| 43 |
+
"cuda:0")
|
| 44 |
self.base_model_id = ""
|
| 45 |
self.task_name = ""
|
| 46 |
self.vq_model = self.load_vq()
|
|
|
|
| 48 |
self.gpt_model_canny = self.load_gpt(condition_type='canny')
|
| 49 |
self.gpt_model_depth = self.load_gpt(condition_type='depth')
|
| 50 |
self.get_control_canny = CannyDetector()
|
| 51 |
+
self.get_control_depth = MidasDetector('cuda')
|
| 52 |
+
|
| 53 |
+
def to(self, device):
|
| 54 |
+
self.gpt_model_canny.to('cuda')
|
| 55 |
+
print(next(self.gpt_model_canny.adapter.parameters()).device)
|
| 56 |
+
# print(self.gpt_model_canny.device)
|
| 57 |
|
| 58 |
def load_vq(self):
|
| 59 |
vq_model = VQ_models["VQ-16"](codebook_size=16384,
|
| 60 |
codebook_embed_dim=8)
|
| 61 |
+
vq_model.to('cuda')
|
| 62 |
vq_model.eval()
|
| 63 |
checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt",
|
| 64 |
map_location="cpu")
|
|
|
|
| 76 |
cls_token_num=120,
|
| 77 |
model_type='t2i',
|
| 78 |
condition_type=condition_type,
|
| 79 |
+
).to(device='cuda', dtype=precision)
|
| 80 |
|
| 81 |
model_weight = load_file(gpt_ckpt)
|
| 82 |
gpt_model.load_state_dict(model_weight, strict=False)
|
|
|
|
| 87 |
def load_t5(self):
|
| 88 |
precision = torch.bfloat16
|
| 89 |
t5_model = T5Embedder(
|
| 90 |
+
device="cuda",
|
| 91 |
local_cache=False,
|
| 92 |
cache_dir='checkpoints/flan-t5-xl',
|
| 93 |
dir_or_name='flan-t5-xl',
|
|
|
|
| 139 |
c_emb_masks = new_emb_masks
|
| 140 |
qzshape = [len(c_indices), 8, H // 16, W // 16]
|
| 141 |
t1 = time.time()
|
| 142 |
+
print(caption_embs.device)
|
| 143 |
index_sample = generate(
|
| 144 |
self.gpt_model_canny,
|
| 145 |
c_indices,
|