| | import numpy as np |
| | import streamlit as st |
| | from PIL import Image |
| | import torch |
| | import clip |
| | from torchray.attribution.grad_cam import grad_cam |
| | from miniclip.imageWrangle import heatmap, min_max_norm, torch_to_rgba |
| |
|
| | st.set_page_config(layout="wide") |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| |
|
| | @st.cache(show_spinner=True, allow_output_mutation=True) |
| | def get_model(): |
| | return clip.load("RN50", device=device, jit=False) |
| |
|
| |
|
| | |
| |
|
| | st.sidebar.header('Options') |
| | alpha = st.sidebar.radio("select alpha", [0.5, 0.7, 0.8], index=1) |
| | layer = st.sidebar.selectbox("select saliency layer", ['layer4.2.relu'], index=0) |
| |
|
| | st.header("Saliency Map demo for CLIP") |
| | st.write( |
| | "a quick experiment by [Hendrik Strobelt](http://hendrik.strobelt.com) ([MIT-IBM Watson AI Lab](https://mitibmwatsonailab.mit.edu/)) ") |
| | with st.beta_expander('1. Upload Image', expanded=True): |
| | imageFile = st.file_uploader("Select a file:", type=[".jpg", ".png", ".jpeg"]) |
| |
|
| | |
| | with st.beta_expander('2. Write Descriptions', expanded=True): |
| | textarea = st.text_area("Descriptions seperated by semicolon", "a car; a dog; a cat") |
| | prefix = st.text_input("(optional) Prefix all descriptions with: ", "an image of") |
| |
|
| | if imageFile: |
| | st.markdown("<hr style='border:black solid;'>", unsafe_allow_html=True) |
| | image_raw = Image.open(imageFile) |
| | model, preprocess = get_model() |
| |
|
| | |
| | image = preprocess(image_raw).unsqueeze(0).to(device) |
| |
|
| | |
| | prefix = prefix.strip() |
| | if len(prefix) > 0: |
| | categories = [f"{prefix} {x.strip()}" for x in textarea.split(';')] |
| | else: |
| | categories = [x.strip() for x in textarea.split(';')] |
| | text = clip.tokenize(categories).to(device) |
| | |
| | |
| | with torch.no_grad(): |
| | image_features = model.encode_image(image) |
| | text_features = model.encode_text(text) |
| | image_features_norm = image_features.norm(dim=-1, keepdim=True) |
| | image_features_new = image_features / image_features_norm |
| | text_features_norm = text_features.norm(dim=-1, keepdim=True) |
| | text_features_new = text_features / text_features_norm |
| | logit_scale = model.logit_scale.exp() |
| | logits_per_image = logit_scale * image_features_new @ text_features_new.t() |
| | probs = logits_per_image.softmax(dim=-1).cpu().numpy().tolist() |
| |
|
| | saliency = grad_cam(model.visual, image.type(model.dtype), image_features, saliency_layer=layer) |
| | hm = heatmap(image[0], saliency[0][0,].detach().type(torch.float32).cpu(), alpha=alpha) |
| |
|
| | collect_images = [] |
| | for i in range(len(categories)): |
| | |
| | text_prediction = (text_features_new[[i]] * image_features_norm) |
| | saliency = grad_cam(model.visual, image.type(model.dtype), text_prediction, saliency_layer=layer) |
| | hm = heatmap(image[0], saliency[0][0,].detach().type(torch.float32).cpu(), alpha=alpha) |
| | collect_images.append(hm) |
| | logits = logits_per_image.cpu().numpy().tolist()[0] |
| | st.write("### Grad Cam for text embeddings") |
| | st.image(collect_images, |
| | width=256, |
| | caption=[f"{x} - {str(round(y, 3))}/{str(round(l, 2))}" for (x, y, l) in |
| | zip(categories, probs[0], logits)]) |
| |
|
| | st.write("### Original Image and Grad Cam for image embedding") |
| | st.image([Image.fromarray((torch_to_rgba(image[0]).numpy() * 255.).astype(np.uint8)), hm], |
| | caption=["original", "image gradcam"]) |
| |
|
| | |
| |
|
| |
|
| | |
| | def get_readme(): |
| | with open('README.md') as f: |
| | return "\n".join([x.strip() for x in f.readlines()]) |
| |
|
| |
|
| | st.markdown("<hr style='border:black solid;'>", unsafe_allow_html=True) |
| | with st.beta_expander('Description', expanded=True): |
| | st.markdown(get_readme(), unsafe_allow_html=True) |
| |
|
| | hide_streamlit_style = """ |
| | <style> |
| | #MainMenu {visibility: hidden;} |
| | footer {visibility: hidden;} |
| | </style> |
| | |
| | """ |
| | st.markdown(hide_streamlit_style, unsafe_allow_html=True) |
| |
|