Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -3,11 +3,13 @@ import torch
|
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
import matplotlib.pyplot as plt
|
|
|
|
|
6 |
|
7 |
-
depth_anything = pipeline(task = "depth-estimation", model="nielsr/depth-anything-small", device=
|
8 |
checkpoint = "BAAI/seggpt-vit-large"
|
9 |
image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
|
10 |
-
model = SegGptForImageSegmentation.from_pretrained(checkpoint)
|
11 |
|
12 |
def infer_seggpt(image_input, image_prompt, mask_prompt):
|
13 |
num_labels = 100
|
@@ -17,7 +19,7 @@ def infer_seggpt(image_input, image_prompt, mask_prompt):
|
|
17 |
prompt_masks=mask_prompt,
|
18 |
return_tensors="pt",
|
19 |
num_labels=num_labels
|
20 |
-
)
|
21 |
with torch.no_grad():
|
22 |
outputs = model(**inputs)
|
23 |
|
@@ -38,6 +40,7 @@ def infer_seggpt(image_input, image_prompt, mask_prompt):
|
|
38 |
plt.savefig("masks.png", bbox_inches='tight', pad_inches=0)
|
39 |
return "masks.png"
|
40 |
|
|
|
41 |
def infer(image_input, image_prompt, mask_prompt):
|
42 |
sg_masks = []
|
43 |
mask_prompt = depth_anything(image_prompt)["depth"].convert("RGB")
|
|
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
import matplotlib.pyplot as plt
|
6 |
+
import spaces
|
7 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
8 |
|
9 |
+
depth_anything = pipeline(task = "depth-estimation", model="nielsr/depth-anything-small", device=device)
|
10 |
checkpoint = "BAAI/seggpt-vit-large"
|
11 |
image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
|
12 |
+
model = SegGptForImageSegmentation.from_pretrained(checkpoint).to(device)
|
13 |
|
14 |
def infer_seggpt(image_input, image_prompt, mask_prompt):
|
15 |
num_labels = 100
|
|
|
19 |
prompt_masks=mask_prompt,
|
20 |
return_tensors="pt",
|
21 |
num_labels=num_labels
|
22 |
+
).to(device)
|
23 |
with torch.no_grad():
|
24 |
outputs = model(**inputs)
|
25 |
|
|
|
40 |
plt.savefig("masks.png", bbox_inches='tight', pad_inches=0)
|
41 |
return "masks.png"
|
42 |
|
43 |
+
@spaces.GPU
|
44 |
def infer(image_input, image_prompt, mask_prompt):
|
45 |
sg_masks = []
|
46 |
mask_prompt = depth_anything(image_prompt)["depth"].convert("RGB")
|