Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
from transformers import AutoFeatureExtractor, AutoModel | |
from datasets import load_dataset | |
from PIL import Image, ImageDraw | |
import os | |
# Load model for computing embeddings of the candidate images | |
print('Load model for computing embeddings of the candidate images') | |
model_ckpt = "google/vit-base-patch16-224" | |
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt) | |
model = AutoModel.from_pretrained(model_ckpt) | |
hidden_dim = model.config.hidden_size | |
# Load dataset | |
dataset_with_embeddings = load_dataset("LucyintheSky/24-1-30-ds-embeddings", split="train", token=os.environ.get('TOKEN')) | |
dataset_with_embeddings.add_faiss_index(column='embeddings') | |
def get_neighbors(query_image, top_k=8): | |
qi_embedding = model(**extractor(query_image, return_tensors="pt")) | |
qi_embedding = qi_embedding.last_hidden_state[:, 0].detach().numpy().squeeze() | |
scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples('embeddings', qi_embedding, k=top_k) | |
return scores, retrieved_examples | |
def search(image_dict): | |
# Open query image | |
query_image = Image.open(image_dict['composite']).convert(mode='RGB') | |
# Get similar image | |
scores, retrieved_examples = get_neighbors(query_image) | |
final_md = "" | |
# Create result diction for gr.Gallery | |
result = [] | |
for i in range(len(retrieved_examples["image"])): | |
name = retrieved_examples["name"][i] | |
result.append((retrieved_examples["image_link"][i], name)) | |
#final_md += """\n""" | |
final_md += """<a href='"""+retrieved_examples["link"][i] +"""'> <img src='"""+retrieved_examples["image_link"][i] +"""' width='200'/> </a> """ | |
return result, final_md | |
iface = gr.Interface(fn=search, | |
description=""" | |
<center><a href="https://www.lucyinthesky.com/"><img width="500" src="https://cdn.discordapp.com/attachments/1120417968032063538/1201666647157657640/LucyITS-2022-blk.png?ex=65caa646&is=65b83146&hm=09ad6fe279edc3a32981306d563e63af815d760fc0d8d0a3fbef4e4553c0a83a&"> </a> </center> | |
<br> | |
<center> Sketch to find your favorite Lucy in the Sky dress! </center> | |
<br> | |
""", | |
inputs=gr.ImageEditor(label='Sketchpad' ,type='filepath', value={'background': './template.JPG', 'layers': None, 'composite': None}, sources=['upload'], transforms=[]), | |
outputs=[gr.Gallery(label='Similar', object_fit='contain', height=1200), gr.Markdown()], | |
theme = gr.themes.Base(primary_hue="teal",secondary_hue="teal",neutral_hue="slate"),) | |
iface.launch() |