Spaces:
Runtime error
Runtime error
liuyizhang
commited on
Commit
·
6e3e561
1
Parent(s):
1c6fec7
update app.py
Browse files
app.py
CHANGED
|
@@ -434,7 +434,9 @@ def concatenate_images_vertical(image1, image2):
|
|
| 434 |
|
| 435 |
return new_image
|
| 436 |
|
| 437 |
-
def relate_anything(
|
|
|
|
|
|
|
| 438 |
w, h = input_image.size
|
| 439 |
max_edge = 1500
|
| 440 |
if w > max_edge or h > max_edge:
|
|
@@ -442,12 +444,14 @@ def relate_anything(input_image, k):
|
|
| 442 |
new_size = (int(w / ratio), int(h / ratio))
|
| 443 |
input_image.thumbnail(new_size)
|
| 444 |
|
|
|
|
| 445 |
# load image
|
| 446 |
pil_image = input_image.convert('RGBA')
|
| 447 |
image = np.array(input_image)
|
| 448 |
sam_masks = sam_mask_generator.generate(image)
|
| 449 |
filtered_masks = sort_and_deduplicate(sam_masks)
|
| 450 |
|
|
|
|
| 451 |
feat_list = []
|
| 452 |
for fm in filtered_masks:
|
| 453 |
feat = torch.Tensor(fm['feat']).unsqueeze(0).unsqueeze(0).to(device)
|
|
@@ -455,6 +459,7 @@ def relate_anything(input_image, k):
|
|
| 455 |
feat = torch.cat(feat_list, dim=1).to(device)
|
| 456 |
matrix_output, rel_triplets = ram_model.predict(feat)
|
| 457 |
|
|
|
|
| 458 |
pil_image_list = []
|
| 459 |
for i, rel in enumerate(rel_triplets[:k]):
|
| 460 |
s,o,r = int(rel[0]),int(rel[1]),int(rel[2])
|
|
@@ -473,6 +478,7 @@ def relate_anything(input_image, k):
|
|
| 473 |
concate_pil_image = concatenate_images_vertical(current_pil_image, title_image)
|
| 474 |
pil_image_list.append(concate_pil_image)
|
| 475 |
|
|
|
|
| 476 |
yield pil_image_list
|
| 477 |
|
| 478 |
|
|
|
|
| 434 |
|
| 435 |
return new_image
|
| 436 |
|
| 437 |
+
def relate_anything(input_image_mask, k):
|
| 438 |
+
logger.info(f'relate_anything_1_')
|
| 439 |
+
input_image = input_image_mask['image']
|
| 440 |
w, h = input_image.size
|
| 441 |
max_edge = 1500
|
| 442 |
if w > max_edge or h > max_edge:
|
|
|
|
| 444 |
new_size = (int(w / ratio), int(h / ratio))
|
| 445 |
input_image.thumbnail(new_size)
|
| 446 |
|
| 447 |
+
logger.info(f'relate_anything_2_')
|
| 448 |
# load image
|
| 449 |
pil_image = input_image.convert('RGBA')
|
| 450 |
image = np.array(input_image)
|
| 451 |
sam_masks = sam_mask_generator.generate(image)
|
| 452 |
filtered_masks = sort_and_deduplicate(sam_masks)
|
| 453 |
|
| 454 |
+
logger.info(f'relate_anything_3_')
|
| 455 |
feat_list = []
|
| 456 |
for fm in filtered_masks:
|
| 457 |
feat = torch.Tensor(fm['feat']).unsqueeze(0).unsqueeze(0).to(device)
|
|
|
|
| 459 |
feat = torch.cat(feat_list, dim=1).to(device)
|
| 460 |
matrix_output, rel_triplets = ram_model.predict(feat)
|
| 461 |
|
| 462 |
+
logger.info(f'relate_anything_4_')
|
| 463 |
pil_image_list = []
|
| 464 |
for i, rel in enumerate(rel_triplets[:k]):
|
| 465 |
s,o,r = int(rel[0]),int(rel[1]),int(rel[2])
|
|
|
|
| 478 |
concate_pil_image = concatenate_images_vertical(current_pil_image, title_image)
|
| 479 |
pil_image_list.append(concate_pil_image)
|
| 480 |
|
| 481 |
+
logger.info(f'relate_anything_5_')
|
| 482 |
yield pil_image_list
|
| 483 |
|
| 484 |
|