Spaces:
Sleeping
Sleeping
import collections | |
import heapq | |
import json | |
import os | |
import logging | |
import gradio as gr | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from open_clip import create_model, get_tokenizer | |
from torchvision import transforms | |
from templates import openai_imagenet_template | |
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" | |
logging.basicConfig(level=logging.INFO, format=log_format) | |
logger = logging.getLogger() | |
model_str = "hf-hub:imageomics/bioclip" | |
tokenizer_str = "ViT-B-16" | |
txt_emb_npy = r"txt_emb_species.npy" | |
txt_names_json = r"txt_emb_species.json" | |
min_prob = 1e-9 | |
k = 5 | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
preprocess_img = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Resize((224, 224), antialias=True), | |
transforms.Normalize( | |
mean=(0.48145466, 0.4578275, 0.40821073), | |
std=(0.26862954, 0.26130258, 0.27577711), | |
), | |
] | |
) | |
ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species") | |
open_domain_examples = [ | |
['example1_Pararge_aegeria.jpg', "Species"] | |
] | |
zero_shot_examples = [ | |
['example1_Pararge_aegeria.jpg', "Pararge aegeria \nPieris brassicae \nSatyrium w-album \nDanaus chrysippus"] | |
] | |
def indexed(lst, indices): | |
return [lst[i] for i in indices] | |
def get_txt_features(classnames, templates): | |
all_features = [] | |
for classname in classnames: | |
txts = [template(classname) for template in templates] | |
txts = tokenizer(txts).to(device) | |
txt_features = model.encode_text(txts) | |
txt_features = F.normalize(txt_features, dim=-1).mean(dim=0) | |
txt_features /= txt_features.norm() | |
all_features.append(txt_features) | |
all_features = torch.stack(all_features, dim=1) | |
return all_features | |
def zero_shot_classification(img, cls_str: str) -> dict[str, float]: | |
classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()] | |
txt_features = get_txt_features(classes, openai_imagenet_template) | |
img = preprocess_img(img).to(device) | |
img_features = model.encode_image(img.unsqueeze(0)) | |
img_features = F.normalize(img_features, dim=-1) | |
logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze() | |
probs = F.softmax(logits, dim=0).to("cpu").tolist() | |
return {cls: prob for cls, prob in zip(classes, probs)} | |
def format_name(taxon, common): | |
taxon = " ".join(taxon) | |
if not common: | |
return taxon | |
return f"{taxon} ({common})" | |
def open_domain_classification(img, rank: int) -> dict[str, float]: | |
""" | |
Predicts from the entire tree of life. | |
If targeting a higher rank than species, then this function predicts among all | |
species, then sums up species-level probabilities for the given rank. | |
""" | |
img = preprocess_img(img).to(device) | |
img_features = model.encode_image(img.unsqueeze(0)) | |
img_features = F.normalize(img_features, dim=-1) | |
logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze() | |
probs = F.softmax(logits, dim=0) | |
# If predicting species, no need to sum probabilities. | |
if rank + 1 == len(ranks): | |
topk = probs.topk(k) | |
return { | |
format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values) | |
} | |
# Sum up by the rank | |
output = collections.defaultdict(float) | |
for i in torch.nonzero(probs > min_prob).squeeze(): | |
output[" ".join(txt_names[i][0][: rank + 1])] += probs[i] | |
topk_names = heapq.nlargest(k, output, key=output.get) | |
return {name: output[name] for name in topk_names} | |
def change_output(choice): | |
return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None) | |
js = """ | |
function createGradioAnimation() { | |
var container = document.createElement('div'); | |
container.id = 'gradio-animation'; | |
container.style.fontSize = '2em'; | |
container.style.fontWeight = 'bold'; | |
container.style.textAlign = 'center'; | |
container.style.marginBottom = '20px'; | |
var text = 'Global Species Identifier: Powered by Artificial Intelligence'; | |
for (var i = 0; i < text.length; i++) { | |
(function(i){ | |
setTimeout(function(){ | |
var letter = document.createElement('span'); | |
letter.style.opacity = '0'; | |
letter.style.transition = 'opacity 0.5s'; | |
letter.innerText = text[i]; | |
container.appendChild(letter); | |
setTimeout(function() { | |
letter.style.opacity = '1'; | |
}, 50); | |
}, i * 50); | |
})(i); | |
} | |
var gradioContainer = document.querySelector('.gradio-container'); | |
gradioContainer.insertBefore(container, gradioContainer.firstChild); | |
return 'Animation created'; | |
} | |
""" | |
if __name__ == "__main__": | |
logger.info("Starting.") | |
model = create_model(model_str, output_dict=True, require_pretrained=True) | |
model = model.to(device) | |
logger.info("Created model.") | |
# model = torch.compile(model) | |
logger.info("Compiled model.") | |
tokenizer = get_tokenizer(tokenizer_str) | |
txt_emb = torch.from_numpy(np.load(txt_emb_npy, mmap_mode="r")).to(device) | |
with open(txt_names_json) as fd: | |
txt_names = json.load(fd) | |
done = txt_emb.any(axis=0).sum().item() | |
total = txt_emb.shape[1] | |
status_msg = "" | |
if done != total: | |
status_msg = f"{done}/{total} ({done / total * 100:.1f}%) indexed" | |
with gr.Blocks(title='Global Species Identifier: Powered by Artificial Intelligence', css="footer {visibility: hidden}", js=js) as app: | |
gr.Markdown( | |
""" | |
Upload an image of any plant, animal, or other organism, and our Artificial Intelligence-powered tool will identify the species. Our database covers species from around the world, aiming to support biodiversity awareness and conservation efforts. | |
Features include: | |
- **Instant identification** of plants, animals, and other organisms. | |
- **Detailed information** on species, including habitat, distribution, and conservation status. | |
- An **interactive, user-friendly interface** designed for both experts and enthusiasts. | |
- **Continuous learning and improvement** of AI models to expand the app's knowledge base and accuracy. | |
Join us in exploring the diversity of life on Earth, powered by the intelligence of technology. Start your journey of discovery today! | |
""") | |
img_input = gr.Image() | |
with gr.Tab("Open-Ended"): | |
with gr.Row(): | |
with gr.Column(): | |
rank_dropdown = gr.Dropdown( | |
label="Taxonomic Rank", | |
info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.", | |
choices=ranks, | |
value="Species", | |
type="index", | |
) | |
open_domain_btn = gr.Button("Submit", variant="primary") | |
with gr.Column(): | |
open_domain_output = gr.Label( | |
num_top_classes=k, | |
label="Prediction", | |
show_label=True, | |
value=None, | |
) | |
with gr.Row(): | |
gr.Examples( | |
examples=open_domain_examples, | |
inputs=[img_input, rank_dropdown], | |
cache_examples=True, | |
fn=open_domain_classification, | |
outputs=[open_domain_output], | |
) | |
with gr.Tab("Zero-Shot"): | |
with gr.Row(): | |
with gr.Column(): | |
classes_txt = gr.Textbox( | |
placeholder= "Pararge aegeria \nPieris brassicae \nSatyrium w-album \nDanaus chrysippus\n...", | |
lines=3, | |
label="Classes", | |
show_label=True, | |
info="Use taxonomic names where possible; include common names if possible.", | |
) | |
zero_shot_btn = gr.Button("Submit", variant="primary") | |
with gr.Column(): | |
zero_shot_output = gr.Label( | |
num_top_classes=k, label="Prediction", show_label=True | |
) | |
with gr.Row(): | |
gr.Examples( | |
examples=zero_shot_examples, | |
inputs=[img_input, classes_txt], | |
cache_examples=True, | |
fn=zero_shot_classification, | |
outputs=[zero_shot_output], | |
) | |
rank_dropdown.change( | |
fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output] | |
) | |
open_domain_btn.click( | |
fn=open_domain_classification, | |
inputs=[img_input, rank_dropdown], | |
outputs=[open_domain_output], | |
) | |
zero_shot_btn.click( | |
fn=zero_shot_classification, | |
inputs=[img_input, classes_txt], | |
outputs=zero_shot_output, | |
) | |
app.queue(max_size=20) | |
app.launch(show_api=False) |