Spaces:
Runtime error
Runtime error
Commit
·
0e7e92c
1
Parent(s):
b62a9c0
update
Browse files- app.py +2 -2
- autoregressive/models/generate.py +1 -1
- model.py +9 -6
app.py
CHANGED
|
@@ -54,8 +54,8 @@ hf_hub_download(repo_id="facebook/dinov2-small", filename="pytorch_model.bin", l
|
|
| 54 |
DESCRIPTION = "# [ControlAR: Controllable Image Generation with Autoregressive Models](https://arxiv.org/abs/2410.02705) \n ### The first row in outputs is the input image and condition. The second row is the images generated by ControlAR. \n ### You can run locally by following the instruction on our [Github Repo](https://github.com/hustvl/ControlAR)."
|
| 55 |
SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
|
| 56 |
model = Model()
|
| 57 |
-
device = "cuda"
|
| 58 |
-
model.to(device)
|
| 59 |
with gr.Blocks(css="style.css") as demo:
|
| 60 |
gr.Markdown(DESCRIPTION)
|
| 61 |
gr.DuplicateButton(
|
|
|
|
| 54 |
DESCRIPTION = "# [ControlAR: Controllable Image Generation with Autoregressive Models](https://arxiv.org/abs/2410.02705) \n ### The first row in outputs is the input image and condition. The second row is the images generated by ControlAR. \n ### You can run locally by following the instruction on our [Github Repo](https://github.com/hustvl/ControlAR)."
|
| 55 |
SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
|
| 56 |
model = Model()
|
| 57 |
+
# device = "cuda"
|
| 58 |
+
# model.to(device)
|
| 59 |
with gr.Blocks(css="style.css") as demo:
|
| 60 |
gr.Markdown(DESCRIPTION)
|
| 61 |
gr.DuplicateButton(
|
autoregressive/models/generate.py
CHANGED
|
@@ -145,7 +145,7 @@ def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_int
|
|
| 145 |
print(condition)
|
| 146 |
condition = torch.ones_like(condition)
|
| 147 |
condition = model.adapter_mlp(condition)
|
| 148 |
-
print(condition)
|
| 149 |
if model.model_type == 'c2i':
|
| 150 |
if cfg_scale > 1.0:
|
| 151 |
cond_null = torch.ones_like(cond) * model.num_classes
|
|
|
|
| 145 |
print(condition)
|
| 146 |
condition = torch.ones_like(condition)
|
| 147 |
condition = model.adapter_mlp(condition)
|
| 148 |
+
#print(condition)
|
| 149 |
if model.model_type == 'c2i':
|
| 150 |
if cfg_scale > 1.0:
|
| 151 |
cond_null = torch.ones_like(cond) * model.num_classes
|
model.py
CHANGED
|
@@ -44,7 +44,7 @@ class Model:
|
|
| 44 |
|
| 45 |
def __init__(self):
|
| 46 |
self.device = torch.device(
|
| 47 |
-
"cuda
|
| 48 |
self.base_model_id = ""
|
| 49 |
self.task_name = ""
|
| 50 |
self.vq_model = self.load_vq()
|
|
@@ -63,7 +63,7 @@ class Model:
|
|
| 63 |
def load_vq(self):
|
| 64 |
vq_model = VQ_models["VQ-16"](codebook_size=16384,
|
| 65 |
codebook_embed_dim=8)
|
| 66 |
-
vq_model.to('cuda')
|
| 67 |
vq_model.eval()
|
| 68 |
checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt",
|
| 69 |
map_location="cpu")
|
|
@@ -82,11 +82,13 @@ class Model:
|
|
| 82 |
cls_token_num=120,
|
| 83 |
model_type='t2i',
|
| 84 |
condition_type=condition_type,
|
| 85 |
-
).to(device='
|
| 86 |
|
| 87 |
model_weight = load_file(gpt_ckpt)
|
| 88 |
-
#
|
|
|
|
| 89 |
gpt_model.eval()
|
|
|
|
| 90 |
print("gpt model is loaded")
|
| 91 |
return gpt_model
|
| 92 |
|
|
@@ -121,8 +123,9 @@ class Model:
|
|
| 121 |
image = resize_image_to_16_multiple(image, 'canny')
|
| 122 |
W, H = image.size
|
| 123 |
print(W, H)
|
| 124 |
-
self.
|
| 125 |
-
self.
|
|
|
|
| 126 |
|
| 127 |
condition_img = self.get_control_canny(np.array(image), low_threshold,
|
| 128 |
high_threshold)
|
|
|
|
| 44 |
|
| 45 |
def __init__(self):
|
| 46 |
self.device = torch.device(
|
| 47 |
+
"cuda")
|
| 48 |
self.base_model_id = ""
|
| 49 |
self.task_name = ""
|
| 50 |
self.vq_model = self.load_vq()
|
|
|
|
| 63 |
def load_vq(self):
|
| 64 |
vq_model = VQ_models["VQ-16"](codebook_size=16384,
|
| 65 |
codebook_embed_dim=8)
|
| 66 |
+
# vq_model.to('cuda')
|
| 67 |
vq_model.eval()
|
| 68 |
checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt",
|
| 69 |
map_location="cpu")
|
|
|
|
| 82 |
cls_token_num=120,
|
| 83 |
model_type='t2i',
|
| 84 |
condition_type=condition_type,
|
| 85 |
+
).to(device='cpu', dtype=precision)
|
| 86 |
|
| 87 |
model_weight = load_file(gpt_ckpt)
|
| 88 |
+
# print("prev:", model_weight['adapter.model.embeddings.patch_embeddings.projection.weight'])
|
| 89 |
+
gpt_model.load_state_dict(model_weight, strict=True)
|
| 90 |
gpt_model.eval()
|
| 91 |
+
print("loaded:", gpt_model.adapter.model.embeddings.patch_embeddings.projection.weight)
|
| 92 |
print("gpt model is loaded")
|
| 93 |
return gpt_model
|
| 94 |
|
|
|
|
| 123 |
image = resize_image_to_16_multiple(image, 'canny')
|
| 124 |
W, H = image.size
|
| 125 |
print(W, H)
|
| 126 |
+
print("before cuda", self.gpt_model_canny.adapter.model.embeddings.patch_embeddings.projection.weight)
|
| 127 |
+
self.t5_model.model.to('cuda')
|
| 128 |
+
self.gpt_model_canny.to('cuda')
|
| 129 |
|
| 130 |
condition_img = self.get_control_canny(np.array(image), low_threshold,
|
| 131 |
high_threshold)
|