devjas1 commited on
Commit
a427341
·
1 Parent(s): 177dc98

(FEAT) Adds comprehensive 'ResultsManager' class for session-wide results management for multi-file inference.

Browse files

- Implemented 'ResultsManager' class to handle inference results in Streamlit session state
- Added methods to initialize, add, retrieve, and clear results (`init_results_table`, `add_resu
lts`, `get_results`, `clear_results')
- Introduced functionality to convert results into a pandas 'DataFrame' for display and export (`get_results_dataframe`)
- Added export capabilities for results in CSV and JSON formats (`export_to_csv`, `export_to_json`)
- Implemented summary statistics calculation, including: accuracy, average confidence, and processing time (`get_summary_stats`)
- Provided a method to remove results by filename ('remove_results_by_filename')
- Integrated a Streamlit UI for displaying results, summary metrics, and export/download options (`display_results_table`)
- Ensured robust handling of empty results and edge cases for better user experience"

Files changed (1) hide show
  1. utils/results_manager.py +213 -0
utils/results_manager.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Session results management for multi-file inference.
2
+ Handles in-memory results table and export functionality"""
3
+
4
+ import streamlit as st
5
+ import pandas as pd
6
+ import json
7
+ from datetime import datetime
8
+ from typing import Dict, List, Any, Optional
9
+ from pathlib import Path
10
+ import io
11
+
12
+ class ResultsManager:
13
+ """Manages session-wide results for multi-file inference"""
14
+
15
+ RESULTS_KEY = "inference_results"
16
+
17
+ @staticmethod
18
+ def init_results_table() -> None:
19
+ """Initialize the results table in session state"""
20
+ if ResultsManager.RESULTS_KEY not in st.session_state:
21
+ st.session_state[ResultsManager.RESULTS_KEY] = []
22
+
23
+ @staticmethod
24
+ def add_results(
25
+ filename: str,
26
+ model_name: str,
27
+ prediction: int,
28
+ predicted_class: str,
29
+ confidence: float,
30
+ logits: List[float],
31
+ ground_truth: Optional[int] = None,
32
+ processing_time: float = 0.0,
33
+ metadata: Optional[Dict[str, Any]] = None
34
+ ) -> None:
35
+ """Add a single inference result to the results table"""
36
+ ResultsManager.init_results_table()
37
+
38
+ result = {
39
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
40
+ "filename": filename,
41
+ "model": model_name,
42
+ "prediction": prediction,
43
+ "predicted_class": predicted_class,
44
+ "confidence": confidence,
45
+ "logits": logits,
46
+ "ground_truth": ground_truth,
47
+ "processing_time": processing_time,
48
+ "metadata": metadata or {}
49
+ }
50
+
51
+ st.session_state[ResultsManager.RESULTS_KEY].append(result)
52
+
53
+ @staticmethod
54
+ def get_results() -> List[Dict[str, Any]]:
55
+ """Get all inference results"""
56
+ ResultsManager.init_results_table()
57
+ return st.session_state[ResultsManager.RESULTS_KEY]
58
+
59
+ @staticmethod
60
+ def get_results_count() -> int:
61
+ """Get the number of stored results"""
62
+ return len(ResultsManager.get_results())
63
+
64
+ @staticmethod
65
+ def clear_results() -> None:
66
+ """Clear all stored results"""
67
+ st.session_state[ResultsManager.RESULTS_KEY] = []
68
+
69
+ @staticmethod
70
+ def get_results_dataframe() -> pd.DataFrame:
71
+ """Convert results to pandas DataFrame for display and export"""
72
+ results = ResultsManager.get_results()
73
+ if not results:
74
+ return pd.DataFrame()
75
+
76
+ #===Flatten the results for DataFrame===
77
+ df_data = []
78
+ for result in results:
79
+ row = {
80
+ "Timestamp": result["timestamp"],
81
+ "Filename": result["filename"],
82
+ "Model": result["model"],
83
+ "Prediction": result["prediction"],
84
+ "Predicted Class": result["predicted_class"],
85
+ "Confidence": f"{result['confidence']:.3f}",
86
+ "Stable Logit": f"{result['logits'][0]:.3f}" if len(result['logits']) > 0 else "N/A",
87
+ "Weathered Logit": f"{result['logits'][1]:.3f}" if len(result['logits']) > 1 else "N/A",
88
+ "Ground Truth": result["ground_truth"] if result["ground_truth"] is not None else "Unknown",
89
+ "Processing Time (s)": f"{result['processing_time']:.3f}",
90
+ }
91
+ df_data.append(row)
92
+
93
+ return pd.DataFrame(df_data)
94
+
95
+ @staticmethod
96
+ def export_to_csv() -> bytes:
97
+ """Export results to CSV format"""
98
+ df = ResultsManager.get_results_dataframe()
99
+ if df.empty:
100
+ return b""
101
+
102
+ #===Use StringIO to create CSV in memory===
103
+ csv_buffer = io.StringIO()
104
+ df.to_csv(csv_buffer, index=False)
105
+ return csv_buffer.getvalue().encode('utf-8')
106
+
107
+ @staticmethod
108
+ def export_to_json() -> str:
109
+ """Export results to JSON format"""
110
+ results = ResultsManager.get_results()
111
+ return json.dumps(results, indent=2, default=str)
112
+
113
+ @staticmethod
114
+ def get_summary_stats() -> Dict[str, Any]:
115
+ """Get summary statistics for the results"""
116
+ results = ResultsManager.get_results()
117
+ if not results:
118
+ return {}
119
+
120
+ df = ResultsManager.get_results_dataframe()
121
+
122
+ stats = {
123
+ "total_files": len(results),
124
+ "models_used": list(set(r["model"] for r in results)),
125
+ "stable_predictions": sum(1 for r in results if r["prediction"] == 0),
126
+ "weathered_predictions": sum(1 for r in results if r["prediction"] == 1),
127
+ "avg_confidence": sum(r["confidence"] for r in results) / len(results),
128
+ "avg_processing_time": sum(r["processing_time"] for r in results) / len(results),
129
+ "files_with_ground_truth": sum(1 for r in results if r["ground_truth"] is not None),
130
+ }
131
+ #===Calculate accuracy if ground truth is available===
132
+ correct_predictions = sum(
133
+ 1 for r in results
134
+ if r["ground_truth"] is not None and r["prediction"] == r["ground_truth"]
135
+ )
136
+ total_with_gt = stats["files_with_ground_truth"]
137
+ if total_with_gt > 0:
138
+ stats["accuracy"] = correct_predictions / total_with_gt
139
+ else:
140
+ stats["accuracy"] = None
141
+
142
+ return stats
143
+
144
+ @staticmethod
145
+ def remove_result_by_filename(filename: str) -> bool:
146
+ """Remove a result by filename. Returns True if removed, False if not found."""
147
+ results = ResultsManager.get_results()
148
+ original_length = len(results)
149
+
150
+ # Filter out results with matching filename
151
+ st.session_state[ResultsManager.RESULTS_KEY] = [
152
+ r for r in results if r["filename"] != filename
153
+ ]
154
+
155
+ return len(st.session_state[ResultsManager.RESULTS_KEY]) < original_length
156
+
157
+ @staticmethod
158
+ def display_results_table() -> None:
159
+ """Display the results table in Streamlit UI"""
160
+ df = ResultsManager.get_results_dataframe()
161
+
162
+ if df.empty:
163
+ st.info("No inference results yet. Upload files and run analysis to see results here.")
164
+ return
165
+
166
+ st.subheader(f"Inference Results ({len(df)} files)")
167
+
168
+ #==Summary stats==
169
+ stats = ResultsManager.get_summary_stats()
170
+ if stats:
171
+ col1, col2, col3, col4 = st.columns(4)
172
+ with col1:
173
+ st.metric("Total Files", stats["total_files"])
174
+ with col2:
175
+ st.metric("Avg Confidence", f"{stats['avg_confidence']:.3f}")
176
+ with col3:
177
+ st.metric("Stable/Weathered", f"{stats['stable_predictions']}/{stats['weathered_predictions']}")
178
+ with col4:
179
+ if stats["accuracy"] is not None:
180
+ st.metric("Accuracy", f"{stats['accuracy']:.3f}")
181
+ else:
182
+ st.metric("Accuracy", "N/A")
183
+
184
+ #==Results Table==
185
+ st.dataframe(df, use_container_width=True)
186
+
187
+ #==Export Button==
188
+ col1, col2, col3 = st.columns([1, 1, 2])
189
+
190
+ with col1:
191
+ csv_data = ResultsManager.export_to_csv()
192
+ if csv_data:
193
+ st.download_button(
194
+ label="Download CSV",
195
+ data=csv_data,
196
+ file_name=f"polymer_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
197
+ mime="text/csv"
198
+ )
199
+
200
+ with col2:
201
+ json_data = ResultsManager.export_to_json()
202
+ if json_data:
203
+ st.download_button(
204
+ label="📥 Download JSON",
205
+ data=json_data,
206
+ file_name=f"polymer_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
207
+ mime="application/json"
208
+ )
209
+
210
+ with col3:
211
+ if st.button("Clear All Results", help="Clear all stored results"):
212
+ ResultsManager.clear_results()
213
+ st.rerun()