Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,307 +1,78 @@
|
|
| 1 |
-
import
|
| 2 |
-
import
|
| 3 |
-
from
|
| 4 |
|
| 5 |
-
# setup Grouded-Segment-Anything
|
| 6 |
-
# building GroundingDINO requires torch but imports it before installing,
|
| 7 |
-
# so directly installing in requirements.txt causes dependency error.
|
| 8 |
-
# 1. build with "-e" option to keep the bin file in ./GroundingDINO/groundingdino/, rather than in site-package dir.
|
| 9 |
-
os.system("pip install -e ./GroundingDINO/")
|
| 10 |
-
# 2. for unknown reason, "import groundingdino" will fill due to unable to find the module, even after installing.
|
| 11 |
-
# add ./GroundingDINO/ to PATH, so package "groundingdino" can be imported.
|
| 12 |
-
sys.path.append(str(Path(__file__).parent / "GroundingDINO"))
|
| 13 |
-
|
| 14 |
-
import random # noqa: E402
|
| 15 |
-
|
| 16 |
-
import cv2 # noqa: E402
|
| 17 |
-
import groundingdino.datasets.transforms as T # noqa: E402
|
| 18 |
-
import numpy as np # noqa: E402
|
| 19 |
-
import torch # noqa: E402
|
| 20 |
-
import torchvision # noqa: E402
|
| 21 |
-
import torchvision.transforms as TS # noqa: E402
|
| 22 |
-
from groundingdino.models import build_model # noqa: E402
|
| 23 |
-
from groundingdino.util.slconfig import SLConfig # noqa: E402
|
| 24 |
-
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap # noqa: E402
|
| 25 |
-
from PIL import Image, ImageDraw, ImageFont # noqa: E402
|
| 26 |
-
from ram import inference_ram # noqa: E402
|
| 27 |
-
from ram import inference_tag2text # noqa: E402
|
| 28 |
-
from ram.models import ram # noqa: E402
|
| 29 |
-
from ram.models import tag2text_caption # noqa: E402
|
| 30 |
-
from segment_anything import SamPredictor, build_sam # noqa: E402
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
# args
|
| 34 |
-
config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
| 35 |
ram_checkpoint = "./ram_swin_large_14m.pth"
|
| 36 |
tag2text_checkpoint = "./tag2text_swin_14m.pth"
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
box_threshold = 0.25
|
| 40 |
-
text_threshold = 0.2
|
| 41 |
-
iou_threshold = 0.5
|
| 42 |
-
device = "cpu"
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def load_model(model_config_path, model_checkpoint_path, device):
|
| 46 |
-
args = SLConfig.fromfile(model_config_path)
|
| 47 |
-
args.device = device
|
| 48 |
-
model = build_model(args)
|
| 49 |
-
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
| 50 |
-
load_res = model.load_state_dict(
|
| 51 |
-
clean_state_dict(checkpoint["model"]), strict=False)
|
| 52 |
-
print(load_res)
|
| 53 |
-
_ = model.eval()
|
| 54 |
-
return model
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def get_grounding_output(model, image, caption, box_threshold, text_threshold, device="cpu"):
|
| 58 |
-
caption = caption.lower()
|
| 59 |
-
caption = caption.strip()
|
| 60 |
-
if not caption.endswith("."):
|
| 61 |
-
caption = caption + "."
|
| 62 |
-
model = model.to(device)
|
| 63 |
-
image = image.to(device)
|
| 64 |
-
with torch.no_grad():
|
| 65 |
-
outputs = model(image[None], captions=[caption])
|
| 66 |
-
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
|
| 67 |
-
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
|
| 68 |
-
logits.shape[0]
|
| 69 |
-
|
| 70 |
-
# filter output
|
| 71 |
-
logits_filt = logits.clone()
|
| 72 |
-
boxes_filt = boxes.clone()
|
| 73 |
-
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
| 74 |
-
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
| 75 |
-
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
| 76 |
-
logits_filt.shape[0]
|
| 77 |
-
|
| 78 |
-
# get phrase
|
| 79 |
-
tokenlizer = model.tokenizer
|
| 80 |
-
tokenized = tokenlizer(caption)
|
| 81 |
-
# build pred
|
| 82 |
-
pred_phrases = []
|
| 83 |
-
scores = []
|
| 84 |
-
for logit, box in zip(logits_filt, boxes_filt):
|
| 85 |
-
pred_phrase = get_phrases_from_posmap(
|
| 86 |
-
logit > text_threshold, tokenized, tokenlizer)
|
| 87 |
-
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
| 88 |
-
scores.append(logit.max().item())
|
| 89 |
-
|
| 90 |
-
return boxes_filt, torch.Tensor(scores), pred_phrases
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
def draw_mask(mask, draw, random_color=False):
|
| 94 |
-
if random_color:
|
| 95 |
-
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 153)
|
| 96 |
-
else:
|
| 97 |
-
color = (30, 144, 255, 153)
|
| 98 |
-
|
| 99 |
-
nonzero_coords = np.transpose(np.nonzero(mask))
|
| 100 |
-
|
| 101 |
-
for coord in nonzero_coords:
|
| 102 |
-
draw.point(coord[::-1], fill=color)
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
def draw_box(box, draw, label):
|
| 106 |
-
# random color
|
| 107 |
-
color = tuple(np.random.randint(0, 255, size=3).tolist())
|
| 108 |
-
line_width = int(max(4, min(20, 0.006*max(draw.im.size))))
|
| 109 |
-
draw.rectangle(((box[0], box[1]), (box[2], box[3])), outline=color, width=line_width)
|
| 110 |
-
|
| 111 |
-
if label:
|
| 112 |
-
font_path = os.path.join(
|
| 113 |
-
cv2.__path__[0], 'qt', 'fonts', 'DejaVuSans.ttf')
|
| 114 |
-
font_size = int(max(12, min(60, 0.02*max(draw.im.size))))
|
| 115 |
-
font = ImageFont.truetype(font_path, size=font_size)
|
| 116 |
-
if hasattr(font, "getbbox"):
|
| 117 |
-
bbox = draw.textbbox((box[0], box[1]), str(label), font)
|
| 118 |
-
else:
|
| 119 |
-
w, h = draw.textsize(str(label), font)
|
| 120 |
-
bbox = (box[0], box[1], w + box[0], box[1] + h)
|
| 121 |
-
draw.rectangle(bbox, fill=color)
|
| 122 |
-
draw.text((box[0], box[1]), str(label), fill="white", font=font)
|
| 123 |
-
|
| 124 |
-
draw.text((box[0], box[1]), label, font=font)
|
| 125 |
|
| 126 |
|
| 127 |
@torch.no_grad()
|
| 128 |
-
def inference(
|
| 129 |
-
raw_image, specified_tags, do_det_seg,
|
| 130 |
-
tagging_model_type, tagging_model, grounding_dino_model, sam_model
|
| 131 |
-
):
|
| 132 |
print(f"Start processing, image size {raw_image.size}")
|
| 133 |
-
raw_image = raw_image.convert("RGB")
|
| 134 |
-
|
| 135 |
-
# run tagging model
|
| 136 |
-
normalize = TS.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 137 |
-
transform = TS.Compose([
|
| 138 |
-
TS.Resize((384, 384)),
|
| 139 |
-
TS.ToTensor(),
|
| 140 |
-
normalize
|
| 141 |
-
])
|
| 142 |
|
| 143 |
-
image = raw_image.
|
| 144 |
-
image = transform(image).unsqueeze(0).to(device)
|
| 145 |
|
| 146 |
-
# Currently ", " is better for detecting single tags
|
| 147 |
-
# while ". " is a little worse in some case
|
| 148 |
if tagging_model_type == "RAM":
|
| 149 |
res = inference_ram(image, tagging_model)
|
| 150 |
-
tags = res[0].strip(' ').replace(' ', ' ')
|
| 151 |
-
tags_chinese = res[1].strip(' ').replace(' ', ' ')
|
| 152 |
print("Tags: ", tags)
|
| 153 |
-
print("
|
|
|
|
| 154 |
else:
|
| 155 |
res = inference_tag2text(image, tagging_model, specified_tags)
|
| 156 |
-
tags = res[0].strip(' ').replace(' ', ' ')
|
| 157 |
caption = res[2]
|
| 158 |
print(f"Tags: {tags}")
|
| 159 |
print(f"Caption: {caption}")
|
|
|
|
| 160 |
|
| 161 |
-
# return
|
| 162 |
-
if not do_det_seg:
|
| 163 |
-
if tagging_model_type == "RAM":
|
| 164 |
-
return tags.replace(", ", " | "), tags_chinese.replace(", ", " | "), None
|
| 165 |
-
else:
|
| 166 |
-
return tags.replace(", ", " | "), caption, None
|
| 167 |
-
|
| 168 |
-
# run groundingDINO
|
| 169 |
-
transform = T.Compose([
|
| 170 |
-
T.RandomResize([800], max_size=1333),
|
| 171 |
-
T.ToTensor(),
|
| 172 |
-
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 173 |
-
])
|
| 174 |
-
|
| 175 |
-
image, _ = transform(raw_image, None) # 3, h, w
|
| 176 |
-
|
| 177 |
-
boxes_filt, scores, pred_phrases = get_grounding_output(
|
| 178 |
-
grounding_dino_model, image, tags, box_threshold, text_threshold, device=device
|
| 179 |
-
)
|
| 180 |
-
print("GroundingDINO finished")
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
sam_model.set_image(image)
|
| 185 |
|
| 186 |
-
size = raw_image.size
|
| 187 |
-
H, W = size[1], size[0]
|
| 188 |
-
for i in range(boxes_filt.size(0)):
|
| 189 |
-
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
|
| 190 |
-
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
| 191 |
-
boxes_filt[i][2:] += boxes_filt[i][:2]
|
| 192 |
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
print(f"Before NMS: {boxes_filt.shape[0]} boxes")
|
| 196 |
-
nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
|
| 197 |
-
boxes_filt = boxes_filt[nms_idx]
|
| 198 |
-
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
|
| 199 |
-
print(f"After NMS: {boxes_filt.shape[0]} boxes")
|
| 200 |
-
|
| 201 |
-
transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
|
| 202 |
-
|
| 203 |
-
masks, _, _ = sam_model.predict_torch(
|
| 204 |
-
point_coords=None,
|
| 205 |
-
point_labels=None,
|
| 206 |
-
boxes=transformed_boxes.to(device),
|
| 207 |
-
multimask_output=False,
|
| 208 |
-
)
|
| 209 |
-
print("SAM finished")
|
| 210 |
-
|
| 211 |
-
# draw output image
|
| 212 |
-
mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
|
| 213 |
-
|
| 214 |
-
mask_draw = ImageDraw.Draw(mask_image)
|
| 215 |
-
for mask in masks:
|
| 216 |
-
draw_mask(mask[0].cpu().numpy(), mask_draw, random_color=True)
|
| 217 |
-
|
| 218 |
-
image_draw = ImageDraw.Draw(raw_image)
|
| 219 |
-
|
| 220 |
-
for box, label in zip(boxes_filt, pred_phrases):
|
| 221 |
-
draw_box(box, image_draw, label)
|
| 222 |
-
|
| 223 |
-
out_image = raw_image.convert('RGBA')
|
| 224 |
-
out_image.alpha_composite(mask_image)
|
| 225 |
-
|
| 226 |
-
# return
|
| 227 |
-
if tagging_model_type == "RAM":
|
| 228 |
-
return tags.replace(", ", " | "), tags_chinese.replace(", ", " | "), out_image
|
| 229 |
-
else:
|
| 230 |
-
return tags.replace(", ", " | "), caption, out_image
|
| 231 |
|
| 232 |
|
| 233 |
if __name__ == "__main__":
|
| 234 |
import gradio as gr
|
| 235 |
|
| 236 |
-
# load
|
| 237 |
-
|
| 238 |
-
ram_model.eval()
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
# load Tag2Text
|
| 242 |
-
delete_tag_index = [] # filter out attributes and action categories which are difficult to grounding
|
| 243 |
-
for i in range(3012, 3429):
|
| 244 |
-
delete_tag_index.append(i)
|
| 245 |
-
|
| 246 |
-
tag2text_model = tag2text_caption(pretrained=tag2text_checkpoint,
|
| 247 |
-
image_size=384,
|
| 248 |
-
vit='swin_b',
|
| 249 |
-
delete_tag_index=delete_tag_index)
|
| 250 |
-
tag2text_model.threshold = 0.64 # we reduce the threshold to obtain more tags
|
| 251 |
-
tag2text_model.eval()
|
| 252 |
-
tag2text_model = tag2text_model.to(device)
|
| 253 |
-
|
| 254 |
-
# load groundingDINO
|
| 255 |
-
grounding_dino_model = load_model(config_file, grounded_checkpoint, device=device)
|
| 256 |
-
|
| 257 |
-
# load SAM
|
| 258 |
-
sam_model = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
|
| 259 |
|
| 260 |
# build GUI
|
| 261 |
def build_gui():
|
| 262 |
|
| 263 |
description = """
|
| 264 |
-
<center><strong><font size='10'>Recognize Anything Model
|
| 265 |
<br>
|
| 266 |
-
Welcome to the
|
| 267 |
<li>
|
| 268 |
<b>Recognize Anything Model:</b> Upload your image to get the <b>English and Chinese tags</b>!
|
| 269 |
</li>
|
| 270 |
<li>
|
| 271 |
-
<b>Tag2Text Model:</b> Upload your image to get the <b>tags and caption</b>!
|
| 272 |
-
(Optional: Specify tags to get the corresponding caption.)
|
| 273 |
-
</li>
|
| 274 |
-
<li>
|
| 275 |
-
<b>Grounded-SAM:</b> Tick the checkbox to get <b>boxes</b> and <b>masks</b> of tags!
|
| 276 |
</li>
|
| 277 |
-
<
|
| 278 |
-
Great thanks to <a href='https://huggingface.co/majinyu' target='_blank'>Ma Jinyu</a>, the major contributor of this demo
|
| 279 |
""" # noqa
|
| 280 |
|
| 281 |
article = """
|
| 282 |
<p style='text-align: center'>
|
| 283 |
RAM and Tag2Text are trained on open-source datasets, and we are persisting in refining and iterating upon it.<br/>
|
| 284 |
-
Grounded-SAM is a combination of Grounding DINO and SAM aming to detect and segment anything with text inputs.<br/>
|
| 285 |
<a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything: A Strong Image Tagging Model</a>
|
| 286 |
|
|
| 287 |
<a href='https://https://tag2text.github.io/' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a>
|
| 288 |
-
|
|
| 289 |
-
<a href='https://github.com/IDEA-Research/Grounded-Segment-Anything' target='_blank'>Grounded-Segment-Anything</a>
|
| 290 |
</p>
|
| 291 |
""" # noqa
|
| 292 |
|
| 293 |
-
def inference_with_ram(img, do_det_seg):
|
| 294 |
-
return inference(
|
| 295 |
-
img, None, do_det_seg,
|
| 296 |
-
"RAM", ram_model, grounding_dino_model, sam_model
|
| 297 |
-
)
|
| 298 |
-
|
| 299 |
-
def inference_with_t2t(img, input_tags, do_det_seg):
|
| 300 |
-
return inference(
|
| 301 |
-
img, input_tags, do_det_seg,
|
| 302 |
-
"Tag2Text", tag2text_model, grounding_dino_model, sam_model
|
| 303 |
-
)
|
| 304 |
-
|
| 305 |
with gr.Blocks(title="Recognize Anything Model") as demo:
|
| 306 |
###############
|
| 307 |
# components
|
|
@@ -312,23 +83,24 @@ if __name__ == "__main__":
|
|
| 312 |
with gr.Row():
|
| 313 |
with gr.Column():
|
| 314 |
ram_in_img = gr.Image(type="pil")
|
| 315 |
-
ram_opt_det_seg = gr.Checkbox(label="Get Boxes and Masks with Grounded-SAM", value=True)
|
| 316 |
with gr.Row():
|
| 317 |
ram_btn_run = gr.Button(value="Run")
|
| 318 |
-
|
|
|
|
|
|
|
|
|
|
| 319 |
with gr.Column():
|
| 320 |
-
ram_out_img = gr.Image(type="pil")
|
| 321 |
ram_out_tag = gr.Textbox(label="Tags")
|
| 322 |
ram_out_biaoqian = gr.Textbox(label="标签")
|
| 323 |
gr.Examples(
|
| 324 |
examples=[
|
| 325 |
-
["images/demo1.jpg"
|
| 326 |
-
["images/demo2.jpg"
|
| 327 |
-
["images/demo4.jpg"
|
| 328 |
],
|
| 329 |
fn=inference_with_ram,
|
| 330 |
-
inputs=[ram_in_img
|
| 331 |
-
outputs=[ram_out_tag, ram_out_biaoqian
|
| 332 |
cache_examples=True
|
| 333 |
)
|
| 334 |
|
|
@@ -337,23 +109,24 @@ if __name__ == "__main__":
|
|
| 337 |
with gr.Column():
|
| 338 |
t2t_in_img = gr.Image(type="pil")
|
| 339 |
t2t_in_tag = gr.Textbox(label="User Specified Tags (Optional, separated by comma)")
|
| 340 |
-
t2t_opt_det_seg = gr.Checkbox(label="Get Boxes and Masks with Grounded-SAM", value=True)
|
| 341 |
with gr.Row():
|
| 342 |
t2t_btn_run = gr.Button(value="Run")
|
| 343 |
-
|
|
|
|
|
|
|
|
|
|
| 344 |
with gr.Column():
|
| 345 |
-
t2t_out_img = gr.Image(type="pil")
|
| 346 |
t2t_out_tag = gr.Textbox(label="Tags")
|
| 347 |
t2t_out_cap = gr.Textbox(label="Caption")
|
| 348 |
gr.Examples(
|
| 349 |
examples=[
|
| 350 |
-
["images/demo4.jpg", ""
|
| 351 |
-
["images/demo4.jpg", "power line"
|
| 352 |
-
["images/demo4.jpg", "track, train"
|
| 353 |
],
|
| 354 |
fn=inference_with_t2t,
|
| 355 |
-
inputs=[t2t_in_img, t2t_in_tag
|
| 356 |
-
outputs=[t2t_out_tag, t2t_out_cap
|
| 357 |
cache_examples=True
|
| 358 |
)
|
| 359 |
|
|
@@ -365,22 +138,20 @@ if __name__ == "__main__":
|
|
| 365 |
# run inference
|
| 366 |
ram_btn_run.click(
|
| 367 |
fn=inference_with_ram,
|
| 368 |
-
inputs=[ram_in_img
|
| 369 |
-
outputs=[ram_out_tag, ram_out_biaoqian
|
| 370 |
)
|
| 371 |
t2t_btn_run.click(
|
| 372 |
fn=inference_with_t2t,
|
| 373 |
-
inputs=[t2t_in_img, t2t_in_tag
|
| 374 |
-
outputs=[t2t_out_tag, t2t_out_cap
|
| 375 |
)
|
| 376 |
|
| 377 |
-
# hide or show image output
|
| 378 |
-
ram_opt_det_seg.change(fn=lambda b: gr.update(visible=b), inputs=[ram_opt_det_seg], outputs=[ram_out_img])
|
| 379 |
-
t2t_opt_det_seg.change(fn=lambda b: gr.update(visible=b), inputs=[t2t_opt_det_seg], outputs=[t2t_out_img])
|
| 380 |
-
|
| 381 |
# clear
|
| 382 |
-
ram_btn_clear
|
| 383 |
-
|
|
|
|
|
|
|
| 384 |
|
| 385 |
return demo
|
| 386 |
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from ram import get_transform, inference_ram, inference_tag2text
|
| 3 |
+
from ram.models import ram, tag2text_caption
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
ram_checkpoint = "./ram_swin_large_14m.pth"
|
| 6 |
tag2text_checkpoint = "./tag2text_swin_14m.pth"
|
| 7 |
+
image_size = 384
|
| 8 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
@torch.no_grad()
|
| 12 |
+
def inference(raw_image, specified_tags, tagging_model_type, tagging_model, transform):
|
|
|
|
|
|
|
|
|
|
| 13 |
print(f"Start processing, image size {raw_image.size}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
image = transform(raw_image).unsqueeze(0).to(device)
|
|
|
|
| 16 |
|
|
|
|
|
|
|
| 17 |
if tagging_model_type == "RAM":
|
| 18 |
res = inference_ram(image, tagging_model)
|
| 19 |
+
tags = res[0].strip(' ').replace(' ', ' ')
|
| 20 |
+
tags_chinese = res[1].strip(' ').replace(' ', ' ')
|
| 21 |
print("Tags: ", tags)
|
| 22 |
+
print("标签: ", tags_chinese)
|
| 23 |
+
return tags, tags_chinese
|
| 24 |
else:
|
| 25 |
res = inference_tag2text(image, tagging_model, specified_tags)
|
| 26 |
+
tags = res[0].strip(' ').replace(' ', ' ')
|
| 27 |
caption = res[2]
|
| 28 |
print(f"Tags: {tags}")
|
| 29 |
print(f"Caption: {caption}")
|
| 30 |
+
return tags, caption
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
def inference_with_ram(img):
|
| 34 |
+
return inference(img, None, "RAM", ram_model, transform)
|
|
|
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
def inference_with_t2t(img, input_tags):
|
| 38 |
+
return inference(img, input_tags, "Tag2Text", tag2text_model, transform)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
if __name__ == "__main__":
|
| 42 |
import gradio as gr
|
| 43 |
|
| 44 |
+
# get transform and load models
|
| 45 |
+
transform = get_transform(image_size=image_size)
|
| 46 |
+
ram_model = ram(pretrained=ram_checkpoint, image_size=image_size, vit='swin_l').eval().to(device)
|
| 47 |
+
tag2text_model = tag2text_caption(
|
| 48 |
+
pretrained=tag2text_checkpoint, image_size=image_size, vit='swin_b').eval().to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
# build GUI
|
| 51 |
def build_gui():
|
| 52 |
|
| 53 |
description = """
|
| 54 |
+
<center><strong><font size='10'>Recognize Anything Model</font></strong></center>
|
| 55 |
<br>
|
| 56 |
+
<p>Welcome to the <a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything Model</a> / <a href='https://tag2text.github.io/Tag2Text' target='_blank'>Tag2Text Model</a> demo!</p>
|
| 57 |
<li>
|
| 58 |
<b>Recognize Anything Model:</b> Upload your image to get the <b>English and Chinese tags</b>!
|
| 59 |
</li>
|
| 60 |
<li>
|
| 61 |
+
<b>Tag2Text Model:</b> Upload your image to get the <b>tags and caption</b>! (Optional: Specify tags to get the corresponding caption.)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
</li>
|
| 63 |
+
<p><b>More over:</b> Combine with <a href='https://github.com/IDEA-Research/Grounded-Segment-Anything' target='_blank'>Grounded-SAM</a>, you can get <b>boxes and masks</b>! Please run <a href='https://github.com/xinyu1205/recognize-anything/blob/main/gui_demo.ipynb' target='_blank'>this notebook</a> to try out!</p>
|
| 64 |
+
<p>Great thanks to <a href='https://huggingface.co/majinyu' target='_blank'>Ma Jinyu</a>, the major contributor of this demo!</p>
|
| 65 |
""" # noqa
|
| 66 |
|
| 67 |
article = """
|
| 68 |
<p style='text-align: center'>
|
| 69 |
RAM and Tag2Text are trained on open-source datasets, and we are persisting in refining and iterating upon it.<br/>
|
|
|
|
| 70 |
<a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything: A Strong Image Tagging Model</a>
|
| 71 |
|
|
| 72 |
<a href='https://https://tag2text.github.io/' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a>
|
|
|
|
|
|
|
| 73 |
</p>
|
| 74 |
""" # noqa
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
with gr.Blocks(title="Recognize Anything Model") as demo:
|
| 77 |
###############
|
| 78 |
# components
|
|
|
|
| 83 |
with gr.Row():
|
| 84 |
with gr.Column():
|
| 85 |
ram_in_img = gr.Image(type="pil")
|
|
|
|
| 86 |
with gr.Row():
|
| 87 |
ram_btn_run = gr.Button(value="Run")
|
| 88 |
+
try:
|
| 89 |
+
ram_btn_clear = gr.ClearButton()
|
| 90 |
+
except AttributeError: # old gradio does not have ClearButton, not big problem
|
| 91 |
+
ram_btn_clear = None
|
| 92 |
with gr.Column():
|
|
|
|
| 93 |
ram_out_tag = gr.Textbox(label="Tags")
|
| 94 |
ram_out_biaoqian = gr.Textbox(label="标签")
|
| 95 |
gr.Examples(
|
| 96 |
examples=[
|
| 97 |
+
["images/demo1.jpg"],
|
| 98 |
+
["images/demo2.jpg"],
|
| 99 |
+
["images/demo4.jpg"],
|
| 100 |
],
|
| 101 |
fn=inference_with_ram,
|
| 102 |
+
inputs=[ram_in_img],
|
| 103 |
+
outputs=[ram_out_tag, ram_out_biaoqian],
|
| 104 |
cache_examples=True
|
| 105 |
)
|
| 106 |
|
|
|
|
| 109 |
with gr.Column():
|
| 110 |
t2t_in_img = gr.Image(type="pil")
|
| 111 |
t2t_in_tag = gr.Textbox(label="User Specified Tags (Optional, separated by comma)")
|
|
|
|
| 112 |
with gr.Row():
|
| 113 |
t2t_btn_run = gr.Button(value="Run")
|
| 114 |
+
try:
|
| 115 |
+
t2t_btn_clear = gr.ClearButton()
|
| 116 |
+
except AttributeError: # old gradio does not have ClearButton, not big problem
|
| 117 |
+
t2t_btn_clear = None
|
| 118 |
with gr.Column():
|
|
|
|
| 119 |
t2t_out_tag = gr.Textbox(label="Tags")
|
| 120 |
t2t_out_cap = gr.Textbox(label="Caption")
|
| 121 |
gr.Examples(
|
| 122 |
examples=[
|
| 123 |
+
["images/demo4.jpg", ""],
|
| 124 |
+
["images/demo4.jpg", "power line"],
|
| 125 |
+
["images/demo4.jpg", "track, train"],
|
| 126 |
],
|
| 127 |
fn=inference_with_t2t,
|
| 128 |
+
inputs=[t2t_in_img, t2t_in_tag],
|
| 129 |
+
outputs=[t2t_out_tag, t2t_out_cap],
|
| 130 |
cache_examples=True
|
| 131 |
)
|
| 132 |
|
|
|
|
| 138 |
# run inference
|
| 139 |
ram_btn_run.click(
|
| 140 |
fn=inference_with_ram,
|
| 141 |
+
inputs=[ram_in_img],
|
| 142 |
+
outputs=[ram_out_tag, ram_out_biaoqian]
|
| 143 |
)
|
| 144 |
t2t_btn_run.click(
|
| 145 |
fn=inference_with_t2t,
|
| 146 |
+
inputs=[t2t_in_img, t2t_in_tag],
|
| 147 |
+
outputs=[t2t_out_tag, t2t_out_cap]
|
| 148 |
)
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
# clear
|
| 151 |
+
if ram_btn_clear is not None:
|
| 152 |
+
ram_btn_clear.add([ram_in_img, ram_out_tag, ram_out_biaoqian])
|
| 153 |
+
if t2t_btn_clear is not None:
|
| 154 |
+
t2t_btn_clear.add([t2t_in_img, t2t_in_tag, t2t_out_tag, t2t_out_cap])
|
| 155 |
|
| 156 |
return demo
|
| 157 |
|