File size: 2,988 Bytes
abff508
 
 
 
 
 
 
 
64daa8f
 
abff508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6017551
abff508
6017551
 
 
 
 
abff508
 
6017551
abff508
 
 
 
 
 
 
 
 
 
 
 
a1944dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import random
import gradio as gr
import torch
import clip
import numpy as np
import pandas as pd



device = "mps" if torch.backends.mps.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
print('Using ' + device)

features_path = 'features/'



photo_features = np.load(features_path + "features.npy")
photo_ids = pd.read_csv(features_path+ "updated_file.csv")
descriptions = list(photo_ids['description'])
photo_filenames = list(photo_ids['photo_id'])



def clip_search(search_string):
        
    with torch.no_grad():
        # Encode and normalize the description using CLIP
        text_encoded = model.encode_text(clip.tokenize(search_string).to(device))
        text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
        # Retrieve the description vector and the photo vectors
    text_features = text_encoded.cpu().numpy()

    # Compute the similarity between the descrption and each photo using the Cosine similarity
    similarities = list((text_features @ photo_features.T).squeeze(0))

    # Sort the photos by their similarity score
    candidates = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)
    
    images = []
    for i in range(30):
        # Retrieve the photo ID
        idx = candidates[i][1]
        photo_id = photo_filenames[idx]
        caption = descriptions[idx]

        images.append([('images/' + str(photo_id)),  caption])
    return images

css = "footer {display: none !important;} .gradio-container {min-height: 0px !important;}"
with gr.Blocks(css = css) as demo:
    with gr.Column(variant="panel"):
        with gr.Row(variant="compact"):
            search_string = gr.Textbox(
                label="Evocative Search",
                show_label=True,
                max_lines=1,
                placeholder="Type something abstruse, or click a suggested search below.",
            )
            btn = gr.Button("Retrieve Images", variant="primary")
        with gr.Row(variant="compact"):
            suggest1 = gr.Button("rococo", variant="secondary")    
            suggest2 = gr.Button("brutalism", variant="secondary")    
            suggest3 = gr.Button("classical", variant="secondary")       
            suggest4 = gr.Button("gothic", variant="secondary")    
            suggest5 = gr.Button("foliate", variant="secondary")  
        gallery = gr.Gallery(
            label=False, show_label=False, elem_id="gallery"
        )

    suggest1.click(clip_search, inputs=suggest1, outputs=gallery)
    suggest2.click(clip_search, inputs=suggest2, outputs=gallery)
    suggest3.click(clip_search, inputs=suggest3, outputs=gallery)
    suggest4.click(clip_search, inputs=suggest4, outputs=gallery)
    suggest5.click(clip_search, inputs=suggest5, outputs=gallery)
    btn.click(clip_search, inputs=search_string, outputs=gallery)
    search_string.submit(clip_search, search_string, gallery)



if __name__ == "__main__":
     demo.launch()