Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -421,7 +421,6 @@ def import_state(state, json_text):
|
|
| 421 |
|
| 422 |
### Main worker
|
| 423 |
|
| 424 |
-
|
| 425 |
def register(state, drawpad, model):
|
| 426 |
seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
|
| 427 |
print('Generate!')
|
|
@@ -436,13 +435,13 @@ def register(state, drawpad, model):
|
|
| 436 |
print('Inpainting mode: ', inpainting_mode)
|
| 437 |
|
| 438 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
| 439 |
-
foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
|
| 440 |
-
user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
|
| 441 |
|
| 442 |
palette = torch.tensor([
|
| 443 |
tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
|
| 444 |
for s in opt.colors[1:]
|
| 445 |
-
]) # (N, 3)
|
| 446 |
masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
|
| 447 |
# has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
|
| 448 |
has_masks = list(range(opt.max_palettes))
|
|
@@ -542,13 +541,13 @@ def draw(state, drawpad):
|
|
| 542 |
# conn = Client(opt.address, authkey=opt.authkey)
|
| 543 |
|
| 544 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
| 545 |
-
foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
|
| 546 |
-
user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
|
| 547 |
|
| 548 |
palette = torch.tensor([
|
| 549 |
tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
|
| 550 |
for s in opt.colors[1:]
|
| 551 |
-
]) # (N, 3)
|
| 552 |
masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
|
| 553 |
# has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
|
| 554 |
has_masks = list(range(opt.max_palettes))
|
|
|
|
| 421 |
|
| 422 |
### Main worker
|
| 423 |
|
|
|
|
| 424 |
def register(state, drawpad, model):
|
| 425 |
seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
|
| 426 |
print('Generate!')
|
|
|
|
| 435 |
print('Inpainting mode: ', inpainting_mode)
|
| 436 |
|
| 437 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
| 438 |
+
foreground_mask = torch.tensor(user_input[..., -1], device=model.device)[None, None] # (1, 1, H, W)
|
| 439 |
+
user_input = torch.tensor(user_input[..., :-1], device=model.device) # (H, W, 3)
|
| 440 |
|
| 441 |
palette = torch.tensor([
|
| 442 |
tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
|
| 443 |
for s in opt.colors[1:]
|
| 444 |
+
], device=model.device) # (N, 3)
|
| 445 |
masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
|
| 446 |
# has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
|
| 447 |
has_masks = list(range(opt.max_palettes))
|
|
|
|
| 541 |
# conn = Client(opt.address, authkey=opt.authkey)
|
| 542 |
|
| 543 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
| 544 |
+
foreground_mask = torch.tensor(user_input[..., -1], device=model.device)[None, None] # (1, 1, H, W)
|
| 545 |
+
user_input = torch.tensor(user_input[..., :-1], device=model.device) # (H, W, 3)
|
| 546 |
|
| 547 |
palette = torch.tensor([
|
| 548 |
tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
|
| 549 |
for s in opt.colors[1:]
|
| 550 |
+
], device=model.device) # (N, 3)
|
| 551 |
masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
|
| 552 |
# has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
|
| 553 |
has_masks = list(range(opt.max_palettes))
|