Spaces:
Runtime error
Runtime error
import torch | |
from geoclip import GeoCLIP | |
from PIL import Image | |
import tempfile | |
from pathlib import Path | |
import gradio as gr | |
import spaces | |
from geopy.geocoders import Nominatim | |
from transformers import CLIPProcessor, CLIPModel | |
from torchvision import transforms | |
import reverse_geocoder as rg | |
from models.huggingface import Geolocalizer | |
import folium | |
import json | |
from geopy.exc import GeocoderTimedOut | |
if torch.cuda.is_available(): | |
geoclip_model = GeoCLIP().to("cuda") | |
else: | |
geoclip_model = GeoCLIP() | |
geolocator = Nominatim(user_agent="predictGeolocforImage") | |
streetclip_model = CLIPModel.from_pretrained("geolocal/StreetCLIP") | |
streetclip_processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP") | |
labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', | |
'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', | |
'Czech Republic', 'Denmark', 'Dominican Republic', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', | |
'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia','Iran', 'Ireland', | |
'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', | |
'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', | |
'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', | |
'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', | |
'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', | |
'Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', | |
'United States', 'Uruguay'] | |
IMAGE_SIZE = (224, 224) | |
GEOLOC_MODEL_NAME = "osv5m/baseline" | |
geoloc_model = Geolocalizer.from_pretrained(GEOLOC_MODEL_NAME) | |
geoloc_model.eval() | |
def transform_image(image): | |
transform = transforms.Compose([ | |
transforms.Resize(IMAGE_SIZE), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
return transform(image).unsqueeze(0) | |
def create_map(lat, lon): | |
m = folium.Map(location=[lat, lon], zoom_start=4) | |
folium.Marker([lat, lon]).add_to(m) | |
map_html = m._repr_html_() | |
return map_html | |
def get_country_coordinates(country_name): | |
try: | |
location = geolocator.geocode(country_name, timeout=10) | |
if location: | |
return location.latitude, location.longitude | |
except GeocoderTimedOut: | |
return None | |
return None | |
def predict_geoclip(image): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
tmppath = Path(tmp_dir) / "tmp.jpg" | |
image.save(str(tmppath)) | |
top_pred_gps, top_pred_prob = geoclip_model.predict(str(tmppath), top_k=50) | |
predictions = [] | |
for i in range(1): | |
lat, lon = top_pred_gps[i] | |
probpercent = top_pred_prob[i] * 100 | |
location = geolocator.reverse((lat, lon), exactly_one=True) | |
address = location.raw['address'] | |
city = address.get('city', '') | |
country = address.get('country', '') | |
prediction = f"Latitude: {lat:.6f}, Longitude: {lon:.6f} - Country: {country}" | |
predictions.append(prediction) | |
map_html = create_map(lat, lon) | |
return "\n".join(predictions), map_html | |
def classify_streetclip(image): | |
inputs = streetclip_processor(text=labels, images=image, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
outputs = streetclip_model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
prediction = logits_per_image.softmax(dim=1) | |
confidences = {labels[i]: float(prediction[0][i].item()) for i in range(len(labels))} | |
sorted_confidences = sorted(confidences.items(), key=lambda item: item[1], reverse=True) | |
top_label, top_confidence = sorted_confidences[0] | |
coords = get_country_coordinates(top_label) | |
map_html = create_map(*coords) if coords else "Map not available" | |
return f"Country: {top_label}", map_html | |
def infer(image): | |
try: | |
img_tensor = transform_image(image) | |
gps_radians = geoloc_model(img_tensor) | |
gps_degrees = torch.rad2deg(gps_radians).squeeze(0).cpu().tolist() | |
lat, lon = gps_degrees[0], gps_degrees[1] | |
location_query = rg.search((lat, lon))[0] | |
location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}" | |
map_html = create_map(lat, lon) | |
return f"Latitude: {lat:.6f}, Longitude: {lon:.6f} - Country: {location_query['admin1']} - {location_query['cc']}", map_html | |
except Exception as e: | |
return f"Failed to predict the location: {e}", None | |
geoclip_interface = gr.Interface( | |
fn=predict_geoclip, | |
inputs=gr.Image(type="pil", label="Upload Image", elem_id="geoclip_image_input"), | |
outputs=[gr.Textbox(label="Prediction", elem_id="geoclip_output"), | |
gr.HTML(label="Map", elem_id="geoclip_map_output")], | |
title="GeoCLIP" | |
) | |
streetclip_interface = gr.Interface( | |
fn=classify_streetclip, | |
inputs=gr.Image(type="pil", label="Upload Image", elem_id="streetclip_image_input"), | |
outputs=[gr.Textbox(label="Prediction", elem_id="streetclip_output"), | |
gr.HTML(label="Map", elem_id="streetclip_map_output")], | |
title="StreetCLIP" | |
) | |
osv5m_interface = gr.Interface( | |
fn=infer, | |
inputs=gr.Image(label="Upload Image", type="pil", elem_id="osv5m_image_input"), | |
outputs=[gr.Textbox(label="Prediction", elem_id="result_text"), gr.HTML(label="Map", elem_id="map_output")], | |
title="OSV-5M Baseline" | |
) | |
demo = gr.TabbedInterface([geoclip_interface, streetclip_interface, osv5m_interface], | |
tab_names=["GeoCLIP", "StreetCLIP", "OSV-5M Baseline"]) | |
demo.launch() | |