File size: 2,016 Bytes
9272f72 b4fd529 9272f72 35abe6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
---
license: mit
datasets:
- likaixin/IconStack-Captions-48M
- likaixin/IconStack-48M-Pre-Rendered
- starvector/svg-stack
language:
- en
metrics:
- accuracy
base_model:
- laion/CLIP-ViT-B-32-laion2B-s34B-b79K
tags:
- art
- icon
model-index:
- name: IconClip-ViT-B-32
results:
- task:
type: zero-shot-classification
dataset:
name: ui-icon-dataset
type: ui-icon-dataset
metrics:
- name: acc@1
type: accuracy
value: 78.815
- name: acc@5
type: accuracy
value: 93.966
---
# Model Description
A CLIP ViT-B/32 model trained with the [IconStack dataset](https://huggingface.co/datasets/likaixin/IconStack-Captions-48M) using [OpenCLIP](https://github.com/mlfoundations/open_clip).
It scores 78.82% on zero-shot classification on [icon-dataset](https://huggingface.co/datasets/likaixin/ui-icon-dataset).
# Usage
## Installation
You need to install `open_clip` to use this model:
```bash
pip install open_clip_torch
```
## Icon-to-Text Zero-Shot Classification
```python
import torch
from PIL import Image
import open_clip
CLIP_TEXT_TEMPLATE = "an icon of {}"
ICON_CLASSES = ["add", "close", "play", ...] # Modify your class names here
model_checkpoint = "<path_to_your_local_model>"
model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained=model_checkpoint)
model.eval()
tokenizer = open_clip.get_tokenizer('ViT-L-14')
image = preprocess(Image.open("icon.png")).unsqueeze(0)
text = tokenizer([CLIP_TEXT_TEMPLATE.format(cls) for cls in ICON_CLASSES])
with torch.no_grad(), torch.autocast("cuda"):
image_features = model.encode_image(image)
text_features = model.encode_text(text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
print("Label probs:", text_probs) # prints something like: [[1., 0., 0., ...]]
``` |