File size: 3,601 Bytes
98815d6
 
332cb73
 
1288cce
98815d6
 
332cb73
ecbbafe
150c5ae
ecbbafe
1288cce
ecbbafe
 
4029b73
ecbbafe
 
 
 
 
687f595
ecbbafe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98815d6
4029b73
98815d6
 
 
 
 
 
 
 
 
 
4029b73
98815d6
4029b73
98815d6
ecbbafe
 
 
 
98815d6
 
ecbbafe
 
1288cce
 
ecbbafe
 
1288cce
 
ecbbafe
98815d6
 
 
 
 
ecbbafe
 
 
 
 
 
 
 
 
 
 
98815d6
 
b151c5c
 
 
 
 
 
332cb73
 
 
 
150c5ae
 
 
 
 
 
332cb73
 
 
 
5479120
98815d6
 
1288cce
b151c5c
 
ecbbafe
 
 
 
98815d6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""App to display images in a gallery"""

import json
import pandas as pd
import numpy as np
import streamlit as st
from image_preprocessing import get_image_caption, get_images, resize_image
import plotly.express as px

from park_statistics import get_plot_from_most_common_elements, aggregate_fauna_elements
from data_models.sql_connection import get_db_connection
from data_models.park_manager import ParkManager
from data_models.image_manager import ImageManager
from data_models.openai_manager import OpenAIManager
from cloud_storage import GoogleCloudStorage

engine, session = get_db_connection()
park_manager = ParkManager()
image_manager = ImageManager()
openai_manager = OpenAIManager()
gcs = GoogleCloudStorage()


def get_image_predictions(images: list[dict]) -> dict:
    """Get predictions for a list of images

    Args:
        images (list): list of images to get predictions for

    Returns:
        list: list of predictions for the images
    """
    predictions = {}
    for image in images:
        predictions[image["id"]] = openai_manager.get_predictions(image["id"])
    return predictions


def image_gallery(images: list[dict], predictions: dict) -> None:
    """Display a gallery of images in a streamlit app

    Args:
        images (list): list of images to display
    """
    st.title("Welcome")

    columns = st.columns(3)
    for index, image in enumerate(images):
        with columns[index % 3]:
            image_blob = gcs.download_blob("suad_park", image["name"])
            st.image(
                image_blob, width=200, caption=f"Image id #{image['id']}"
            )
            if not image["id"] in predictions or predictions[image["id"]] is None:
                st.write("No predictions available")
                continue
            st.json(predictions[image["id"]])


def get_park_list() -> list:
    return park_manager.get_parks()


def get_park_images(park: str) -> list:
    return image_manager.get_images_by_park(park)


def sidebar() -> dict | None:
    """Create a sidebar to select the park

    Returns:
        selected_park: selected park in the sidebar
    """
    park_list = get_park_list()

    def _park_list_formatter(park):
        return park["name"]

    return st.sidebar.selectbox(
        label="Park List",
        options=park_list,
        index=None,
        format_func=_park_list_formatter,
    )


def display_stats() -> None:
    """Display statistics about the images"""
    st.title("Statistics")
    st.write("Number of images: ", image_manager.get_image_count())
    st.write("Number of parks: ", park_manager.get_park_count())

    predictions = openai_manager.get_all_predictions()
    df = pd.DataFrame(predictions)
    st.markdown("## Most common elements")
    st.plotly_chart(get_plot_from_most_common_elements(df, "built_elements", "elements"))

    st.markdown("## Fauna identification")
    fauna_elements = aggregate_fauna_elements(df)
    fauna_elements = pd.DataFrame(fauna_elements.items(), columns=["fauna", "count"])
    st.plotly_chart(px.pie(fauna_elements, names="fauna", values="count", labels={"count": "# Animals", "fauna": "Fauna"}))

    st.markdown("## Vegetation detection")
    st.plotly_chart(get_plot_from_most_common_elements(df, "vegetation_detection", "vegetation"))


def main() -> None:
    """Main function to run the app"""
    park = sidebar()
    if not park:
        display_stats()
        st.stop()

    images = get_park_images(park["id"])
    predictions = get_image_predictions(images)
    image_gallery(images, predictions)


if __name__ == "__main__":
    main()