|
import streamlit as st |
|
import pandas as pd, numpy as np |
|
from html import escape |
|
import os |
|
from transformers import CLIPProcessor, CLIPModel |
|
|
|
|
|
@st.cache( |
|
show_spinner=False, |
|
hash_funcs={ |
|
CLIPModel: lambda _: None, |
|
CLIPProcessor: lambda _: None, |
|
dict: lambda _: None, |
|
}, |
|
) |
|
def load(): |
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16") |
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") |
|
df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")} |
|
embeddings = {0: np.load("embeddings.npy"), 1: np.load("embeddings2.npy")} |
|
for k in [0, 1]: |
|
embeddings[k] = np.divide( |
|
embeddings[k], np.sqrt(np.sum(embeddings[k] ** 2, axis=1, keepdims=True)) |
|
) |
|
return model, processor, df, embeddings |
|
|
|
|
|
model, processor, df, embeddings = load() |
|
|
|
source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"} |
|
|
|
|
|
def get_html(url_list, height=200): |
|
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>" |
|
for url, title, link in url_list: |
|
html2 = f"<img title='{escape(title)}' style='height: {height}px; margin: 5px' src='{escape(url)}'>" |
|
if len(link) > 0: |
|
html2 = f"<a href='{escape(link)}' target='_blank'>" + html2 + "</a>" |
|
html = html + html2 |
|
html += "</div>" |
|
return html |
|
|
|
|
|
def compute_text_embeddings(list_of_strings): |
|
inputs = processor(text=list_of_strings, return_tensors="pt", padding=True) |
|
return model.get_text_features(**inputs) |
|
|
|
|
|
st.cache(show_spinner=False) |
|
|
|
|
|
def image_search(query, corpus, n_results=24): |
|
text_embeddings = compute_text_embeddings([query]).detach().numpy() |
|
k = 0 if corpus == "Unsplash" else 1 |
|
results = np.argsort((embeddings[k] @ text_embeddings.T)[:, 0])[ |
|
-1 : -n_results - 1 : -1 |
|
] |
|
return [ |
|
( |
|
df[k].iloc[i]["path"], |
|
df[k].iloc[i]["tooltip"] + source[k], |
|
df[k].iloc[i]["link"], |
|
) |
|
for i in results |
|
] |
|
|
|
|
|
description = """ |
|
# Semantic image search |
|
|
|
**Enter your query and hit enter** |
|
|
|
*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, π€ Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)* |
|
|
|
*Inspired by [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) from Vladimir Haltakov and [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) from Travis Hoppe* |
|
""" |
|
|
|
|
|
def main(): |
|
st.markdown( |
|
""" |
|
<style> |
|
.block-container{ |
|
max-width: 1200px; |
|
} |
|
div.row-widget.stRadio > div{ |
|
flex-direction:row; |
|
display: flex; |
|
justify-content: center; |
|
} |
|
div.row-widget.stRadio > div > label{ |
|
margin-left: 5px; |
|
margin-right: 5px; |
|
} |
|
section.main>div:first-child { |
|
padding-top: 0px; |
|
} |
|
section:not(.main)>div:first-child { |
|
padding-top: 30px; |
|
} |
|
div.reportview-container > section:first-child{ |
|
max-width: 320px; |
|
} |
|
#MainMenu { |
|
visibility: hidden; |
|
} |
|
footer { |
|
visibility: hidden; |
|
} |
|
</style>""", |
|
unsafe_allow_html=True, |
|
) |
|
st.sidebar.markdown(description) |
|
_, c, _ = st.columns((1, 3, 1)) |
|
query = c.text_input("", value="clouds at sunset") |
|
corpus = st.radio("", ["Unsplash", "Movies"]) |
|
if len(query) > 0: |
|
results = image_search(query, corpus) |
|
st.markdown(get_html(results), unsafe_allow_html=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|