Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -69,11 +69,16 @@ img_encoder = SonarImageEnc().to(device).eval()
|
|
69 |
img_encoder.load_state_dict(torch.load(model_path, map_location=device))
|
70 |
|
71 |
# -------- Similarity Scoring --------
|
72 |
-
def compute_similarity(
|
|
|
|
|
|
|
|
|
|
|
73 |
if not image:
|
74 |
try:
|
75 |
headers = {
|
76 |
-
"User-Agent": "Mozilla/5.0
|
77 |
}
|
78 |
response = requests.get(image_url, headers=headers)
|
79 |
response.raise_for_status()
|
@@ -88,16 +93,38 @@ def compute_similarity(image, image_url, option_a, option_b, option_c, option_d,
|
|
88 |
image_emb, _ = img_encoder(inputs.pixel_values)
|
89 |
image_emb = image_emb.to(device, torch.float16)
|
90 |
|
91 |
-
#
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
93 |
texts = [option_a, option_b, option_c, option_d]
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
scores = cos(image_emb, text_embeddings)
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
return image, results
|
100 |
|
|
|
101 |
# -------- Gradio UI --------
|
102 |
with gr.Blocks() as demo:
|
103 |
gr.Markdown("## π SONAR: Image-Text Similarity Scorer")
|
@@ -106,10 +133,20 @@ with gr.Blocks() as demo:
|
|
106 |
with gr.Row():
|
107 |
with gr.Column():
|
108 |
image_url = gr.Textbox(label="Image URL", value="http://images.cocodataset.org/val2017/000000039769.jpg")
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
language = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Select Language")
|
114 |
|
115 |
with gr.Column():
|
@@ -120,8 +157,15 @@ with gr.Blocks() as demo:
|
|
120 |
img_output = gr.Image(label="Input Image", type="pil", width=300, height=300)
|
121 |
result_output = gr.JSON(label="Similarity Scores")
|
122 |
|
123 |
-
btn.click(
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
demo.launch()
|
|
|
69 |
img_encoder.load_state_dict(torch.load(model_path, map_location=device))
|
70 |
|
71 |
# -------- Similarity Scoring --------
|
72 |
+
def compute_similarity(
|
73 |
+
image, image_url,
|
74 |
+
option_a, option_b, option_c, option_d,
|
75 |
+
lang_opt_a, lang_opt_b, lang_opt_c, lang_opt_d
|
76 |
+
):
|
77 |
+
|
78 |
if not image:
|
79 |
try:
|
80 |
headers = {
|
81 |
+
"User-Agent": "Mozilla/5.0"
|
82 |
}
|
83 |
response = requests.get(image_url, headers=headers)
|
84 |
response.raise_for_status()
|
|
|
93 |
image_emb, _ = img_encoder(inputs.pixel_values)
|
94 |
image_emb = image_emb.to(device, torch.float16)
|
95 |
|
96 |
+
# Map languages
|
97 |
+
lang_codes = [
|
98 |
+
language_mapping[lang_opt_a],
|
99 |
+
language_mapping[lang_opt_b],
|
100 |
+
language_mapping[lang_opt_c],
|
101 |
+
language_mapping[lang_opt_d],
|
102 |
+
]
|
103 |
texts = [option_a, option_b, option_c, option_d]
|
104 |
+
|
105 |
+
# Get embeddings per option with corresponding language
|
106 |
+
text_embeddings = []
|
107 |
+
for text, lang in zip(texts, lang_codes):
|
108 |
+
emb = t2t_model_emb.predict([text], source_lang=lang)
|
109 |
+
text_embeddings.append(emb)
|
110 |
+
|
111 |
+
text_embeddings = torch.cat(text_embeddings, dim=0).to(device)
|
112 |
|
113 |
scores = cos(image_emb, text_embeddings)
|
114 |
+
|
115 |
+
results = {
|
116 |
+
f"Option {chr(65+i)}": round(score.item(), 3)
|
117 |
+
for i, score in enumerate(scores)
|
118 |
+
}
|
119 |
+
|
120 |
+
results = {
|
121 |
+
k: f"{round(v * 100, 2)}%"
|
122 |
+
for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
|
123 |
+
}
|
124 |
|
125 |
return image, results
|
126 |
|
127 |
+
|
128 |
# -------- Gradio UI --------
|
129 |
with gr.Blocks() as demo:
|
130 |
gr.Markdown("## π SONAR: Image-Text Similarity Scorer")
|
|
|
133 |
with gr.Row():
|
134 |
with gr.Column():
|
135 |
image_url = gr.Textbox(label="Image URL", value="http://images.cocodataset.org/val2017/000000039769.jpg")
|
136 |
+
|
137 |
+
with gr.Row():
|
138 |
+
option_a = gr.Textbox(label="Option A", value="A cat with two remotes.")
|
139 |
+
lang_opt_a = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language")
|
140 |
+
|
141 |
+
option_b = gr.Textbox(label="Option B", value="Two cat with two remotes.")
|
142 |
+
lang_opt_b = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language")
|
143 |
+
|
144 |
+
option_c = gr.Textbox(label="Option C", value="Two remotes.")
|
145 |
+
lang_opt_c = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language")
|
146 |
+
|
147 |
+
option_d = gr.Textbox(label="Option D", value="Two cats.")
|
148 |
+
lang_opt_d = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language")
|
149 |
+
|
150 |
language = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Select Language")
|
151 |
|
152 |
with gr.Column():
|
|
|
157 |
img_output = gr.Image(label="Input Image", type="pil", width=300, height=300)
|
158 |
result_output = gr.JSON(label="Similarity Scores")
|
159 |
|
160 |
+
btn.click(
|
161 |
+
fn=compute_similarity,
|
162 |
+
inputs=[
|
163 |
+
image_input, image_url,
|
164 |
+
option_a, option_b, option_c, option_d,
|
165 |
+
lang_opt_a, lang_opt_b, lang_opt_c, lang_opt_d
|
166 |
+
],
|
167 |
+
outputs=[img_output, result_output]
|
168 |
+
)
|
169 |
+
|
170 |
|
171 |
demo.launch()
|