import os os.system("pip install fairseq2 --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/pt2.6.0/cu124 -q") from huggingface_hub import hf_hub_download import gradio as gr import torch import requests from PIL import Image from transformers import SiglipImageProcessor, SiglipVisionModel from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline import torch.nn as nn import torch.nn.functional as F from io import BytesIO from transformers.image_utils import load_image cos = nn.CosineSimilarity() model_path = hf_hub_download( repo_id="Sibgat-Ul/SONAR-Image_enc", filename="best_sonar.pth", repo_type="model" ) language_mapping = { "English": "eng_Latn", "Bengali": "ben_Beng", "French": "fra_Latn" } # -------- Load Image Encoder -------- class SonarImageEnc(nn.Module): def __init__(self, path="google/siglip2-base-patch16-384", initial_temperature=0.07): super().__init__() self.model = SiglipVisionModel.from_pretrained(path, torch_dtype=torch.float32) for param in self.model.parameters(): param.requires_grad = False self.projection = nn.Sequential( nn.Linear(self.model.config.hidden_size, 2048), nn.GELU(), nn.Dropout(0.1), nn.Linear(2048, 1024), nn.LayerNorm(1024, eps=1e-5), ) self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1.0) / initial_temperature)) def forward(self, pixel_values): with torch.no_grad(): vision_outputs = self.model(pixel_values=pixel_values) pooled_output = vision_outputs.pooler_output embeddings = self.projection(pooled_output) self.logit_scale.data.clamp_( min=torch.log(torch.tensor(1.0) / torch.tensor(0.001)), max=torch.log(torch.tensor(1.0) / torch.tensor(100.0)) ) return embeddings, torch.exp(self.logit_scale) # Load processor and models device = torch.device("cuda" if torch.cuda.is_available() else "cpu") processor = SiglipImageProcessor.from_pretrained("google/siglip2-base-patch16-384") t2t_model_emb = TextToEmbeddingModelPipeline( encoder="text_sonar_basic_encoder", tokenizer="text_sonar_basic_encoder", device=device, dtype=torch.float16, ) img_encoder = SonarImageEnc().to(device).eval() img_encoder.load_state_dict(torch.load(model_path, map_location=device)) # -------- Similarity Scoring -------- def compute_similarity( image, image_url, option_a, option_b, option_c, option_d, lang_opt_a, lang_opt_b, lang_opt_c, lang_opt_d ): if not image: try: headers = { "User-Agent": "Mozilla/5.0" } response = requests.get(image_url, headers=headers) response.raise_for_status() image = Image.open(BytesIO(response.content)).convert("RGB") except Exception as e: return None, {"Error": f"Image could not be loaded: {str(e)}"} # Preprocess image inputs = processor(image, return_tensors="pt").to(device) with torch.no_grad(): image_emb, _ = img_encoder(inputs.pixel_values) image_emb = image_emb.to(device, torch.float16) # Map languages lang_codes = [ language_mapping[lang_opt_a], language_mapping[lang_opt_b], language_mapping[lang_opt_c], language_mapping[lang_opt_d], ] texts = [option_a, option_b, option_c, option_d] # Get embeddings per option with corresponding language text_embeddings = [] for text, lang in zip(texts, lang_codes): emb = t2t_model_emb.predict([text], source_lang=lang) text_embeddings.append(emb) text_embeddings = torch.cat(text_embeddings, dim=0).to(device) scores = cos(image_emb, text_embeddings) results = { f"Option {chr(65+i)}": round(score.item(), 3) for i, score in enumerate(scores) } results = { k: f"{round(v * 100, 2)}%" for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True) } return image, results # -------- Gradio UI -------- with gr.Blocks() as demo: gr.Markdown("## 🔍 SONAR: Image-Text Similarity Scorer") gr.Markdown("#### Upload an Image or provide an URL.") with gr.Row(): with gr.Column(): image_url = gr.Textbox(label="Image URL", value="http://images.cocodataset.org/val2017/000000039769.jpg") with gr.Row(): option_a = gr.Textbox(label="Option A", value="A cat with two remotes.") lang_opt_a = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language") option_b = gr.Textbox(label="Option B", value="Two cat with two remotes.") lang_opt_b = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language") option_c = gr.Textbox(label="Option C", value="Two remotes.") lang_opt_c = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language") option_d = gr.Textbox(label="Option D", value="Two cats.") lang_opt_d = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language") language = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Select Language") with gr.Column(): image_input = gr.Image(label="Upload an image", type="pil") btn = gr.Button("Done") with gr.Row(): img_output = gr.Image(label="Input Image", type="pil", width=300, height=300) result_output = gr.JSON(label="Similarity Scores") btn.click( fn=compute_similarity, inputs=[ image_input, image_url, option_a, option_b, option_c, option_d, lang_opt_a, lang_opt_b, lang_opt_c, lang_opt_d ], outputs=[img_output, result_output] ) demo.launch()