Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,7 +7,12 @@ import torch.nn as nn
|
|
| 7 |
import copy
|
| 8 |
import pydeck as pdk
|
| 9 |
import numpy as np
|
|
|
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
keep_layer_count=6
|
| 12 |
|
| 13 |
|
|
@@ -75,7 +80,7 @@ class ByT5ForTextGeotagging(PreTrainedModel):
|
|
| 75 |
return logits
|
| 76 |
|
| 77 |
|
| 78 |
-
@st.
|
| 79 |
def load_model_and_tokenizer():
|
| 80 |
byt5_tokenizer = AutoTokenizer.from_pretrained("yachay/byt5-geotagging-es", token=st.secrets['token'])
|
| 81 |
model = ByT5ForTextGeotagging.from_pretrained("yachay/byt5-geotagging-es", token=st.secrets['token'])
|
|
@@ -97,12 +102,21 @@ def geolocate_text_byt5_multiclass(text):
|
|
| 97 |
prob = probas[0][class_idx]
|
| 98 |
cumulative_prob += prob
|
| 99 |
if cumulative_prob > 0.5:
|
| 100 |
-
coordinates = model.config.class_to_location.get(str(class_idx))
|
| 101 |
-
if coordinates:
|
| 102 |
-
results.append((class_idx, prob, coordinates))
|
| 103 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
return results
|
|
|
|
| 106 |
|
| 107 |
|
| 108 |
def geolocate_text_byt5(text):
|
|
@@ -147,37 +161,64 @@ if st.button('Submit'):
|
|
| 147 |
|
| 148 |
if st.session_state.text_input:
|
| 149 |
results = geolocate_text_byt5_multiclass(st.session_state.text_input)
|
|
|
|
| 150 |
_, confidence, (lat, lon) = results[0]
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
# Render map with pydeck
|
| 154 |
map_data = pd.DataFrame(
|
| 155 |
[[lat, lon]],
|
| 156 |
columns=["lat", "lon"]
|
| 157 |
)
|
| 158 |
-
|
| 159 |
-
layers = [
|
| 160 |
-
# Your existing layer for the geolocated text
|
| 161 |
-
pdk.Layer(
|
| 162 |
-
'ScatterplotLayer',
|
| 163 |
-
data=map_data,
|
| 164 |
-
get_position='[lon, lat]',
|
| 165 |
-
get_color='[200, 30, 0, 160]',
|
| 166 |
-
get_radius=200,
|
| 167 |
-
),
|
| 168 |
-
# Additional layers for other markers
|
| 169 |
-
]
|
| 170 |
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
)
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
st.pydeck_chart(pdk.Deck(
|
| 183 |
map_style='mapbox://styles/mapbox/light-v9',
|
|
|
|
| 7 |
import copy
|
| 8 |
import pydeck as pdk
|
| 9 |
import numpy as np
|
| 10 |
+
import base64
|
| 11 |
|
| 12 |
+
def get_base64_encoded_image(image_path):
|
| 13 |
+
with open(image_path, "rb") as img_file:
|
| 14 |
+
return base64.b64encode(img_file.read()).decode('utf-8')
|
| 15 |
+
|
| 16 |
keep_layer_count=6
|
| 17 |
|
| 18 |
|
|
|
|
| 80 |
return logits
|
| 81 |
|
| 82 |
|
| 83 |
+
@st.cache_resource(ttl=None)
|
| 84 |
def load_model_and_tokenizer():
|
| 85 |
byt5_tokenizer = AutoTokenizer.from_pretrained("yachay/byt5-geotagging-es", token=st.secrets['token'])
|
| 86 |
model = ByT5ForTextGeotagging.from_pretrained("yachay/byt5-geotagging-es", token=st.secrets['token'])
|
|
|
|
| 102 |
prob = probas[0][class_idx]
|
| 103 |
cumulative_prob += prob
|
| 104 |
if cumulative_prob > 0.5:
|
|
|
|
|
|
|
|
|
|
| 105 |
break
|
| 106 |
+
coordinates = model.config.class_to_location.get(str(class_idx))
|
| 107 |
+
if coordinates:
|
| 108 |
+
results.append((class_idx, prob, coordinates))
|
| 109 |
+
|
| 110 |
+
# Check if at least one result is added; if not, add the highest probability class
|
| 111 |
+
if not results:
|
| 112 |
+
class_idx = sorted_indices[0]
|
| 113 |
+
prob = probas[0][class_idx]
|
| 114 |
+
coordinates = model.config.class_to_location.get(str(class_idx))
|
| 115 |
+
if coordinates:
|
| 116 |
+
results.append((class_idx, prob, coordinates))
|
| 117 |
|
| 118 |
return results
|
| 119 |
+
|
| 120 |
|
| 121 |
|
| 122 |
def geolocate_text_byt5(text):
|
|
|
|
| 161 |
|
| 162 |
if st.session_state.text_input:
|
| 163 |
results = geolocate_text_byt5_multiclass(st.session_state.text_input)
|
| 164 |
+
#st.write(results)
|
| 165 |
_, confidence, (lat, lon) = results[0]
|
| 166 |
+
if len(results) == 1:
|
| 167 |
+
confidence_def = 'High'
|
| 168 |
+
elif len(results) < 50:
|
| 169 |
+
confidence_def = 'Low'
|
| 170 |
+
else:
|
| 171 |
+
confidence_def = 'Very low'
|
| 172 |
+
st.write('Predicted Location: (', lat, lon, '). Confidence: ', confidence_def)
|
| 173 |
+
if confidence_def == 'Low':
|
| 174 |
+
st.write('Multiple possible locations were identified as confidence is low')
|
| 175 |
+
elif confidence_def == 'Very low':
|
| 176 |
+
st.write('There are too many possible locations as confidence is very low')
|
| 177 |
|
| 178 |
# Render map with pydeck
|
| 179 |
map_data = pd.DataFrame(
|
| 180 |
[[lat, lon]],
|
| 181 |
columns=["lat", "lon"]
|
| 182 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
+
encoded_image = get_base64_encoded_image("icons8-map-pin-48.png")
|
| 185 |
+
icon_url = f"data:image/png;base64,{encoded_image}"
|
| 186 |
+
|
| 187 |
+
# Example icon data
|
| 188 |
+
icon_data = {
|
| 189 |
+
"url": icon_url, # URL of the icon image
|
| 190 |
+
"width": 128, # Width of the icon in pixels
|
| 191 |
+
"height": 128, # Height of the icon in pixels
|
| 192 |
+
"anchorY": 128 # Anchor point of the icon in pixels (bottom center)
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
# Example location data
|
| 196 |
+
locations = pd.DataFrame({
|
| 197 |
+
'lat': [lat], # Latitude values
|
| 198 |
+
'lon': [lon], # Longitude values
|
| 199 |
+
'icon_data': [icon_data] # Repeating the icon data for each location
|
| 200 |
+
})
|
| 201 |
+
|
| 202 |
+
layers = []
|
| 203 |
+
|
| 204 |
+
if confidence_def != 'Very low':
|
| 205 |
+
# Add layers for each additional marker
|
| 206 |
+
for item in results:
|
| 207 |
+
_, confidence, (lat, lon) = item
|
| 208 |
+
layer = pdk.Layer(
|
| 209 |
+
type='IconLayer',
|
| 210 |
+
data=pd.DataFrame({
|
| 211 |
+
'lat': [lat], # Latitude values
|
| 212 |
+
'lon': [lon], # Longitude values
|
| 213 |
+
'icon_data': [icon_data] # Repeating the icon data for each location
|
| 214 |
+
}),
|
| 215 |
+
get_icon='icon_data',
|
| 216 |
+
get_size=4,
|
| 217 |
+
size_scale=15,
|
| 218 |
+
get_position=['lon', 'lat'],
|
| 219 |
+
pickable=True
|
| 220 |
+
)
|
| 221 |
+
layers.append(layer)
|
| 222 |
|
| 223 |
st.pydeck_chart(pdk.Deck(
|
| 224 |
map_style='mapbox://styles/mapbox/light-v9',
|