devjas1 commited on
Commit
9e50ae2
Β·
1 Parent(s): 27f8f90

(FEAT): implement confidence calculation and visualization utilities

Browse files
Files changed (1) hide show
  1. utils/confidence.py +163 -0
utils/confidence.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Confidence calculation and visualization utilities.
2
+ Provides normalized softmax confidence and color-coded badges"""
3
+ from typing import Tuple, List
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def calculate_softmax_confidence(logits: torch.Tensor) -> Tuple[np.ndarray, float, str, str]:
10
+ """Calculate normalized confidence using softmax
11
+ Args:
12
+ logits: Raw model logits tensor
13
+ Returns:
14
+ Tuple of (probabilities, max_confidence, confidence_level, confidence_emoji)
15
+ """
16
+ # ===Apply softmax to get probabilities===
17
+ probs_np = F.softmax(logits, dim=1).cpu().numpy().flatten()
18
+
19
+ # ===Get maximum probability as confidence===
20
+ max_confidence = float(np.max(probs_np))
21
+
22
+ # ===Determine confidence level and emoji===
23
+ if max_confidence >= 0.80:
24
+ confidence_level = "HIGH"
25
+ confidence_emoji = "🟒"
26
+ elif max_confidence >= 0.60:
27
+ confidence_level = "MEDIUM"
28
+ confidence_emoji = "🟑"
29
+ else:
30
+ confidence_level = "LOW"
31
+ confidence_emoji = "πŸ”΄"
32
+
33
+ return probs_np, max_confidence, confidence_level, confidence_emoji
34
+
35
+
36
+ def get_confidence_badge(confidence: float) -> Tuple[str, str]:
37
+ """Get confidence badge emoji and level description
38
+ Args:
39
+ confidence: Confidence value (0-1)
40
+ Returns:
41
+ Tuple of (emoji, level)
42
+ """
43
+ if confidence >= 0.80:
44
+ return "🟒", "HIGH"
45
+ elif confidence >= 0.60:
46
+ return "🟑", "MEDIUM"
47
+ else:
48
+ return "πŸ”΄", "LOW"
49
+
50
+
51
+ def format_confidence_display(confidence: float, level: str, emoji: str) -> str:
52
+ """
53
+ Format confidence for display in UI
54
+
55
+ Args:
56
+ confidence: Confidence value (0-1)
57
+ level: Confidence level (HIGH/MEDIUM/LOW)
58
+ emoji: Confidence emoji
59
+
60
+ Returns:
61
+ Formatted confidence string
62
+ """
63
+ return f"{emoji} **{level}** ({confidence:.1%})"
64
+
65
+
66
+ def create_confidence_progress_html(
67
+ probabilities: np.ndarray,
68
+ labels: List[str],
69
+ highlight_idx: int
70
+ ) -> str:
71
+ """
72
+ Create HTML for confidence progress bars
73
+
74
+ Args:
75
+ probabilities: Array of class probabilities
76
+ labels: List of class labels
77
+ highlight_idx: Index of predicted class to highlight
78
+
79
+ Returns:
80
+ HTML string for progress bars
81
+ """
82
+ if len(probabilities) == 0 or len(labels) == 0:
83
+ return "<p>No confidence data available</p>"
84
+
85
+ html_parts = []
86
+
87
+ for i, (prob, label) in enumerate(zip(probabilities, labels)):
88
+ # ===Color based on whether this is the predicted class===
89
+ if i == highlight_idx:
90
+ if prob >= 0.80:
91
+ color = "#22c55e" # green-500
92
+ text_color = "#ffffff"
93
+ elif prob >= 0.60:
94
+ color = "#eab308" # yellow-500
95
+ text_color = "#000000"
96
+ else:
97
+ color = "#ef4444" # red-500
98
+ text_color = "#ffffff"
99
+ else:
100
+ color = "#e5e7eb" # gray-200
101
+ text_color = "#6b7280" # gray-500
102
+
103
+ percentage = prob * 100
104
+
105
+ html_parts.append(f"""
106
+ <div style="margin-bottom: 8px;">
107
+ <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 4px;">
108
+ <span style="font-size: 0.875rem; font-weight: 500; color: #374151;">{label}</span>
109
+ <span style="font-size: 0.875rem; color: #6b7280;">{percentage:.1f}%</span>
110
+ </div>
111
+ <div style="width: 100%; background-color: #f3f4f6; border-radius: 0.375rem; height: 20px; overflow: hidden;">
112
+ <div style="
113
+ width: {percentage}%;
114
+ height: 100%;
115
+ background-color: {color};
116
+ display: flex;
117
+ align-items: center;
118
+ justify-content: center;
119
+ transition: width 0.3s ease;
120
+ ">
121
+ {f'<span style="color: {text_color}; font-size: 0.75rem; font-weight: 600;">{percentage:.1f}%</span>' if percentage > 20 else ''}
122
+ </div>
123
+ </div>
124
+ </div>
125
+ """)
126
+
127
+ return f"""
128
+ <div style="padding: 16px; background-color: #f9fafb; border-radius: 0.5rem; border: 1px solid #e5e7eb;">
129
+ <h4 style="margin: 0 0 12px 0; font-size: 1rem; color: #374151;">Confidence Breakdown</h4>
130
+ {''.join(html_parts)}
131
+ </div>
132
+ """
133
+
134
+
135
+ def calculate_legacy_confidence(logits_list: List[float]) -> Tuple[float, str, str]:
136
+ """
137
+ Calculate confidence using legacy logit margin method for backward compatibility
138
+
139
+ Args:
140
+ logits_list: List of raw logits
141
+
142
+ Returns:
143
+ Tuple of (margin, confidence_level, confidence_emoji)
144
+ """
145
+ if len(logits_list) < 2:
146
+ return 0.0, "LOW", "πŸ”΄"
147
+
148
+ logits_array = np.array(logits_list)
149
+ sorted_logits = np.sort(logits_array)[::-1] # Descending order
150
+ margin = sorted_logits[0] - sorted_logits[1]
151
+
152
+ # ===Define thresholds for margin-based confidence===
153
+ if margin >= 2.0:
154
+ confidence_level = "HIGH"
155
+ confidence_emoji = "🟒"
156
+ elif margin >= 1.0:
157
+ confidence_level = "MEDIUM"
158
+ confidence_emoji = "🟑"
159
+ else:
160
+ confidence_level = "LOW"
161
+ confidence_emoji = "πŸ”΄"
162
+
163
+ return margin, confidence_level, confidence_emoji