leo-bourrel commited on
Commit
332cb73
·
1 Parent(s): b151c5c

feat: add statistics

Browse files
src/app.py CHANGED
@@ -1,9 +1,13 @@
1
  """App to display images in a gallery"""
2
 
 
 
3
  import numpy as np
4
  import streamlit as st
5
  from image_preprocessing import get_image_caption, get_images, resize_image
 
6
 
 
7
  from data_models.sql_connection import get_db_connection
8
  from data_models.park_manager import ParkManager
9
  from data_models.image_manager import ImageManager
@@ -86,6 +90,14 @@ def display_stats() -> None:
86
  st.write("Number of images: ", image_manager.get_image_count())
87
  st.write("Number of parks: ", park_manager.get_park_count())
88
 
 
 
 
 
 
 
 
 
89
  def main() -> None:
90
  """Main function to run the app"""
91
  park = sidebar()
 
1
  """App to display images in a gallery"""
2
 
3
+ import json
4
+ import pandas as pd
5
  import numpy as np
6
  import streamlit as st
7
  from image_preprocessing import get_image_caption, get_images, resize_image
8
+ import plotly.express as px
9
 
10
+ from statistics import get_plot_from_most_common_elements
11
  from data_models.sql_connection import get_db_connection
12
  from data_models.park_manager import ParkManager
13
  from data_models.image_manager import ImageManager
 
90
  st.write("Number of images: ", image_manager.get_image_count())
91
  st.write("Number of parks: ", park_manager.get_park_count())
92
 
93
+ predictions = openai_manager.get_all_predictions()
94
+ df = pd.DataFrame(predictions)
95
+ st.markdown("## Most common elements")
96
+ st.plotly_chart(get_plot_from_most_common_elements(df, "built_elements", "elements"))
97
+ st.markdown("## Vegetation detection")
98
+ st.plotly_chart(get_plot_from_most_common_elements(df, "vegetation_detection", "vegetation"))
99
+
100
+
101
  def main() -> None:
102
  """Main function to run the app"""
103
  park = sidebar()
src/data_models/openai_manager.py CHANGED
@@ -77,6 +77,20 @@ class OpenAIManager:
77
  except Exception as e:
78
  raise Exception(f"An error occurred while getting the predictions: {e}")
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def close_connection(self):
81
  """Close the connection."""
82
  self.session.close()
 
77
  except Exception as e:
78
  raise Exception(f"An error occurred while getting the predictions: {e}")
79
 
80
+ def get_all_predictions(self):
81
+ """
82
+ Get all predictions from the `openai_predictions` table.
83
+
84
+ Returns:
85
+ list: List of predictions.
86
+ """
87
+ query = text("SELECT * FROM openai_predictions")
88
+ try:
89
+ result = self.session.execute(query).fetchall()
90
+ return [row._asdict() for row in result]
91
+ except Exception as e:
92
+ raise Exception(f"An error occurred while getting the predictions: {e}")
93
+
94
  def close_connection(self):
95
  """Close the connection."""
96
  self.session.close()
src/statistics.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import itertools
3
+ from collections import Counter
4
+ import pandas as pd
5
+ import plotly.express as px
6
+ from plotly.graph_objs import Figure
7
+ import streamlit as st
8
+ from typing import Any
9
+
10
+
11
+ def get_elements(x: str, key: str = "elements") -> Any:
12
+ try:
13
+ result = ast.literal_eval(x)
14
+ if isinstance(result, dict):
15
+ return result[key]
16
+ elif isinstance(result, str):
17
+ return [result]
18
+ elif isinstance(result, list):
19
+ return result
20
+ except:
21
+ return []
22
+
23
+
24
+ def get_most_common_elements(df: pd.DataFrame, column: str, key: str = "elements") -> list:
25
+ built_elements = df[column].apply(get_elements, args=(key,))
26
+ most_common_elements = Counter(itertools.chain.from_iterable(built_elements)).most_common()
27
+ return most_common_elements
28
+
29
+
30
+ def get_plot_from_most_common_elements(df: pd.DataFrame, column: str, key: str = "elements") -> Figure:
31
+ most_common_elements = get_most_common_elements(df, column, key)
32
+ most_common_elements = pd.DataFrame(most_common_elements, columns=[column, "count"])
33
+ return px.bar(
34
+ most_common_elements,
35
+ x=column,
36
+ y="count",
37
+ labels={"count": "# Objects", column: "Built Elements"},
38
+ )