from html import escape
import re
import streamlit as st
import pandas as pd, numpy as np
import torch
from transformers import CLIPProcessor, CLIPModel
from st_clickable_images import clickable_images
MODEL_NAMES = [
# "base-patch32",
# "base-patch16",
# "large-patch14",
"large-patch14-336"
]
@st.cache(allow_output_mutation=True)
def load():
df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
models = {}
processors = {}
embeddings = {}
for name in MODEL_NAMES:
models[name] = CLIPModel.from_pretrained(f"openai/clip-vit-{name}").eval()
processors[name] = CLIPProcessor.from_pretrained(f"openai/clip-vit-{name}")
embeddings[name] = {
0: np.load(f"embeddings-vit-{name}.npy"),
1: np.load(f"embeddings2-vit-{name}.npy"),
}
for k in [0, 1]:
embeddings[name][k] = embeddings[name][k] / np.linalg.norm(
embeddings[name][k], axis=1, keepdims=True
)
return models, processors, df, embeddings
models, processors, df, embeddings = load()
source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
def compute_text_embeddings(list_of_strings, name):
inputs = processors[name](text=list_of_strings, return_tensors="pt", padding=True)
with torch.no_grad():
result = models[name].get_text_features(**inputs).detach().numpy()
return result / np.linalg.norm(result, axis=1, keepdims=True)
def image_search(query, corpus, name, n_results=24):
positive_embeddings = None
def concatenate_embeddings(e1, e2):
if e1 is None:
return e2
else:
return np.concatenate((e1, e2), axis=0)
splitted_query = query.split("EXCLUDING ")
dot_product = 0
k = 0 if corpus == "Unsplash" else 1
if len(splitted_query[0]) > 0:
positive_queries = splitted_query[0].split(";")
for positive_query in positive_queries:
match = re.match(r"\[(Movies|Unsplash):(\d{1,5})\](.*)", positive_query)
if match:
corpus2, idx, remainder = match.groups()
idx, remainder = int(idx), remainder.strip()
k2 = 0 if corpus2 == "Unsplash" else 1
positive_embeddings = concatenate_embeddings(
positive_embeddings, embeddings[name][k2][idx : idx + 1, :]
)
if len(remainder) > 0:
positive_embeddings = concatenate_embeddings(
positive_embeddings, compute_text_embeddings([remainder], name)
)
else:
positive_embeddings = concatenate_embeddings(
positive_embeddings, compute_text_embeddings([positive_query], name)
)
dot_product = embeddings[name][k] @ positive_embeddings.T
dot_product = dot_product - np.median(dot_product, axis=0)
dot_product = dot_product / np.max(dot_product, axis=0, keepdims=True)
dot_product = np.min(dot_product, axis=1)
if len(splitted_query) > 1:
negative_queries = (" ".join(splitted_query[1:])).split(";")
negative_embeddings = compute_text_embeddings(negative_queries, name)
dot_product2 = embeddings[name][k] @ negative_embeddings.T
dot_product2 = dot_product2 - np.median(dot_product2, axis=0)
dot_product2 = dot_product2 / np.max(dot_product2, axis=0, keepdims=True)
dot_product -= np.max(np.maximum(dot_product2, 0), axis=1)
results = np.argsort(dot_product)[-1 : -n_results - 1 : -1]
return [
(
df[k].iloc[i]["path"],
df[k].iloc[i]["tooltip"] + source[k],
i,
)
for i in results
]
description = """
# Semantic image search
**Enter your query and hit enter**
*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
*Inspired by [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) from Vladimir Haltakov and [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) from Travis Hoppe*
"""
howto = """
- Click on an image to use it as a query and find similar images
- Several queries, including one based on an image, can be combined (use "**;**" as a separator)
- If the input includes "**EXCLUDING**", the part right of it will be used as a negative query
"""
div_style = {
"display": "flex",
"justify-content": "center",
"flex-wrap": "wrap",
}
def main():
st.markdown(
"""
""",
unsafe_allow_html=True,
)
st.sidebar.markdown(description)
with st.sidebar.expander("Advanced use"):
st.markdown(howto)
# mode = st.sidebar.selectbox(
# "", ["Results for ViT-L/14@336px", "Comparison of 2 models"], index=0
# )
_, c, _ = st.columns((1, 3, 1))
if "query" in st.session_state:
query = c.text_input("", value=st.session_state["query"])
else:
query = c.text_input("", value="clouds at sunset")
corpus = st.radio("", ["Unsplash", "Movies"])
models_dict = {
"ViT-B/32 (quicker)": "base-patch32",
"ViT-B/16 (average)": "base-patch16",
# "ViT-L/14 (slow)": "large-patch14",
"ViT-L/14@336px (slower)": "large-patch14-336",
}
if False: # "Comparison" in mode:
c1, c2 = st.columns((1, 1))
selection1 = c1.selectbox("", models_dict.keys(), index=0)
selection2 = c2.selectbox("", models_dict.keys(), index=2)
name1 = models_dict[selection1]
name2 = models_dict[selection2]
else:
name1 = MODEL_NAMES[-1]
if len(query) > 0:
results1 = image_search(query, corpus, name1)
if False: # "Comparison" in mode:
with c1:
clicked1 = clickable_images(
[result[0] for result in results1],
titles=[result[1] for result in results1],
div_style=div_style,
img_style={"margin": "2px", "height": "150px"},
key=query + corpus + name1 + "1",
)
results2 = image_search(query, corpus, name2)
with c2:
clicked2 = clickable_images(
[result[0] for result in results2],
titles=[result[1] for result in results2],
div_style=div_style,
img_style={"margin": "2px", "height": "150px"},
key=query + corpus + name2 + "2",
)
else:
clicked1 = clickable_images(
[result[0] for result in results1],
titles=[result[1] for result in results1],
div_style=div_style,
img_style={"margin": "2px", "height": "200px"},
key=query + corpus + name1 + "1",
)
clicked2 = -1
if clicked2 >= 0 or clicked1 >= 0:
change_query = False
if "last_clicked" not in st.session_state:
change_query = True
else:
if max(clicked2, clicked1) != st.session_state["last_clicked"]:
change_query = True
if change_query:
if clicked1 >= 0:
st.session_state["query"] = f"[{corpus}:{results1[clicked1][2]}]"
# elif clicked2 >= 0:
# st.session_state["query"] = f"[{corpus}:{results2[clicked2][2]}]"
st.experimental_rerun()
if __name__ == "__main__":
main()