Spaces:
Sleeping
Sleeping
"""App to display images in a gallery""" | |
import json | |
import pandas as pd | |
import numpy as np | |
import streamlit as st | |
from image_preprocessing import get_image_caption, get_images, resize_image | |
import plotly.express as px | |
from park_statistics import get_plot_from_most_common_elements, aggregate_fauna_elements | |
from data_models.sql_connection import get_db_connection | |
from data_models.park_manager import ParkManager | |
from data_models.image_manager import ImageManager | |
from data_models.openai_manager import OpenAIManager | |
from cloud_storage import GoogleCloudStorage | |
engine, session = get_db_connection() | |
park_manager = ParkManager() | |
image_manager = ImageManager() | |
openai_manager = OpenAIManager() | |
gcs = GoogleCloudStorage() | |
def get_image_predictions(images: list[dict]) -> dict: | |
"""Get predictions for a list of images | |
Args: | |
images (list): list of images to get predictions for | |
Returns: | |
list: list of predictions for the images | |
""" | |
predictions = {} | |
for image in images: | |
predictions[image["id"]] = openai_manager.get_predictions(image["id"]) | |
return predictions | |
def image_gallery(images: list[dict], predictions: dict) -> None: | |
"""Display a gallery of images in a streamlit app | |
Args: | |
images (list): list of images to display | |
""" | |
st.title("Welcome") | |
columns = st.columns(3) | |
for index, image in enumerate(images): | |
with columns[index % 3]: | |
image_blob = gcs.download_blob("suad_park", image["name"]) | |
st.image( | |
image_blob, width=200, caption=f"Image id #{image['id']}" | |
) | |
if not image["id"] in predictions or predictions[image["id"]] is None: | |
st.write("No predictions available") | |
continue | |
st.json(predictions[image["id"]]) | |
def get_park_list() -> list: | |
return park_manager.get_parks() | |
def get_park_images(park: str) -> list: | |
return image_manager.get_images_by_park(park) | |
def sidebar() -> dict | None: | |
"""Create a sidebar to select the park | |
Returns: | |
selected_park: selected park in the sidebar | |
""" | |
park_list = get_park_list() | |
def _park_list_formatter(park): | |
return park["name"] | |
return st.sidebar.selectbox( | |
label="Park List", | |
options=park_list, | |
index=None, | |
format_func=_park_list_formatter, | |
) | |
def display_stats() -> None: | |
"""Display statistics about the images""" | |
st.title("Statistics") | |
st.write("Number of images: ", image_manager.get_image_count()) | |
st.write("Number of parks: ", park_manager.get_park_count()) | |
predictions = openai_manager.get_all_predictions() | |
df = pd.DataFrame(predictions) | |
st.markdown("## Most common elements") | |
st.plotly_chart(get_plot_from_most_common_elements(df, "built_elements", "elements")) | |
st.markdown("## Fauna identification") | |
fauna_elements = aggregate_fauna_elements(df) | |
fauna_elements = pd.DataFrame(fauna_elements.items(), columns=["fauna", "count"]) | |
st.plotly_chart(px.pie(fauna_elements, names="fauna", values="count", labels={"count": "# Animals", "fauna": "Fauna"})) | |
st.markdown("## Vegetation detection") | |
st.plotly_chart(get_plot_from_most_common_elements(df, "vegetation_detection", "vegetation")) | |
def main() -> None: | |
"""Main function to run the app""" | |
park = sidebar() | |
if not park: | |
display_stats() | |
st.stop() | |
images = get_park_images(park["id"]) | |
predictions = get_image_predictions(images) | |
image_gallery(images, predictions) | |
if __name__ == "__main__": | |
main() | |