Spaces:
Running
Running
devjas1
commited on
Commit
·
50c7ef1
1
Parent(s):
8dd961f
Adds Transparent AI Reasoning Engine for explainable predictions
Browse filesIntroduces a new module that implements a Transparent AI Reasoning Engine, enabling explainable predictions with uncertainty quantification and hypothesis generation.
This addition enhances the model's interpretability by providing comprehensive explanations, confidence assessments, and analytical features for scientific hypotheses related to spectral data.
Updates the requirements to include necessary dependencies for SHAP and other libraries.
- modules/transparent_ai.py +493 -0
- requirements.txt +5 -0
modules/transparent_ai.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Transparent AI Reasoning Engine for POLYMEROS
|
3 |
+
Provides explainable predictions with uncertainty quantification and hypothesis generation
|
4 |
+
"""
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from typing import Dict, List, Any, Tuple, Optional
|
10 |
+
from dataclasses import dataclass
|
11 |
+
import warnings
|
12 |
+
|
13 |
+
try:
|
14 |
+
import shap
|
15 |
+
|
16 |
+
SHAP_AVAILABLE = True
|
17 |
+
except ImportError:
|
18 |
+
SHAP_AVAILABLE = False
|
19 |
+
warnings.warn("SHAP not available. Install with: pip install shap")
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class PredictionExplanation:
|
24 |
+
"""Comprehensive explanation for a model prediction"""
|
25 |
+
|
26 |
+
prediction: int
|
27 |
+
confidence: float
|
28 |
+
confidence_level: str
|
29 |
+
probabilities: np.ndarray
|
30 |
+
feature_importance: Dict[str, float]
|
31 |
+
reasoning_chain: List[str]
|
32 |
+
uncertainty_sources: List[str]
|
33 |
+
similar_cases: List[Dict[str, Any]]
|
34 |
+
confidence_intervals: Dict[str, Tuple[float, float]]
|
35 |
+
|
36 |
+
|
37 |
+
@dataclass
|
38 |
+
class Hypothesis:
|
39 |
+
"""AI-generated scientific hypothesis"""
|
40 |
+
|
41 |
+
statement: str
|
42 |
+
confidence: float
|
43 |
+
supporting_evidence: List[str]
|
44 |
+
testable_predictions: List[str]
|
45 |
+
suggested_experiments: List[str]
|
46 |
+
related_literature: List[str]
|
47 |
+
|
48 |
+
|
49 |
+
class UncertaintyEstimator:
|
50 |
+
"""Bayesian uncertainty estimation for model predictions"""
|
51 |
+
|
52 |
+
def __init__(self, model, n_samples: int = 100):
|
53 |
+
self.model = model
|
54 |
+
self.n_samples = n_samples
|
55 |
+
self.epistemic_uncertainty = None
|
56 |
+
self.aleatoric_uncertainty = None
|
57 |
+
|
58 |
+
def estimate_uncertainty(self, x: torch.Tensor) -> Dict[str, float]:
|
59 |
+
"""Estimate prediction uncertainty using Monte Carlo dropout"""
|
60 |
+
self.model.train() # Enable dropout
|
61 |
+
|
62 |
+
predictions = []
|
63 |
+
with torch.no_grad():
|
64 |
+
for _ in range(self.n_samples):
|
65 |
+
pred = F.softmax(self.model(x), dim=1)
|
66 |
+
predictions.append(pred.cpu().numpy())
|
67 |
+
|
68 |
+
predictions = np.array(predictions)
|
69 |
+
|
70 |
+
# Calculate uncertainties
|
71 |
+
mean_pred = np.mean(predictions, axis=0)
|
72 |
+
epistemic = np.var(predictions, axis=0) # Model uncertainty
|
73 |
+
aleatoric = np.mean(predictions * (1 - predictions), axis=0) # Data uncertainty
|
74 |
+
|
75 |
+
total_uncertainty = epistemic + aleatoric
|
76 |
+
|
77 |
+
return {
|
78 |
+
"epistemic": float(np.mean(epistemic)),
|
79 |
+
"aleatoric": float(np.mean(aleatoric)),
|
80 |
+
"total": float(np.mean(total_uncertainty)),
|
81 |
+
"prediction_variance": float(np.var(mean_pred)),
|
82 |
+
}
|
83 |
+
|
84 |
+
def confidence_intervals(
|
85 |
+
self, x: torch.Tensor, confidence_level: float = 0.95
|
86 |
+
) -> Dict[str, Tuple[float, float]]:
|
87 |
+
"""Calculate confidence intervals for predictions"""
|
88 |
+
self.model.train()
|
89 |
+
|
90 |
+
predictions = []
|
91 |
+
with torch.no_grad():
|
92 |
+
for _ in range(self.n_samples):
|
93 |
+
pred = F.softmax(self.model(x), dim=1)
|
94 |
+
predictions.append(pred.cpu().numpy().flatten())
|
95 |
+
|
96 |
+
predictions = np.array(predictions)
|
97 |
+
|
98 |
+
alpha = 1 - confidence_level
|
99 |
+
lower_percentile = (alpha / 2) * 100
|
100 |
+
upper_percentile = (1 - alpha / 2) * 100
|
101 |
+
|
102 |
+
intervals = {}
|
103 |
+
for i in range(predictions.shape[1]):
|
104 |
+
lower = np.percentile(predictions[:, i], lower_percentile)
|
105 |
+
upper = np.percentile(predictions[:, i], upper_percentile)
|
106 |
+
intervals[f"class_{i}"] = (lower, upper)
|
107 |
+
|
108 |
+
return intervals
|
109 |
+
|
110 |
+
|
111 |
+
class FeatureImportanceAnalyzer:
|
112 |
+
"""Advanced feature importance analysis for spectral data"""
|
113 |
+
|
114 |
+
def __init__(self, model):
|
115 |
+
self.model = model
|
116 |
+
self.shap_explainer = None
|
117 |
+
|
118 |
+
if SHAP_AVAILABLE:
|
119 |
+
try:
|
120 |
+
# Initialize SHAP explainer for the model
|
121 |
+
if SHAP_AVAILABLE:
|
122 |
+
if SHAP_AVAILABLE:
|
123 |
+
self.shap_explainer = shap.DeepExplainer(
|
124 |
+
model, torch.zeros(1, 500)
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
self.shap_explainer = None
|
128 |
+
else:
|
129 |
+
self.shap_explainer = None
|
130 |
+
except (ValueError, RuntimeError) as e:
|
131 |
+
warnings.warn(f"Could not initialize SHAP explainer: {e}")
|
132 |
+
|
133 |
+
def analyze_feature_importance(
|
134 |
+
self, x: torch.Tensor, wavenumbers: Optional[np.ndarray] = None
|
135 |
+
) -> Dict[str, Any]:
|
136 |
+
"""Comprehensive feature importance analysis"""
|
137 |
+
importance_data = {}
|
138 |
+
|
139 |
+
# SHAP analysis (if available)
|
140 |
+
if self.shap_explainer is not None:
|
141 |
+
try:
|
142 |
+
shap_values = self.shap_explainer.shap_values(x)
|
143 |
+
importance_data["shap_values"] = shap_values
|
144 |
+
importance_data["shap_available"] = True
|
145 |
+
except (ValueError, RuntimeError) as e:
|
146 |
+
warnings.warn(f"SHAP analysis failed: {e}")
|
147 |
+
importance_data["shap_available"] = False
|
148 |
+
else:
|
149 |
+
importance_data["shap_available"] = False
|
150 |
+
|
151 |
+
# Gradient-based importance
|
152 |
+
x.requires_grad_(True)
|
153 |
+
self.model.eval()
|
154 |
+
|
155 |
+
output = self.model(x)
|
156 |
+
predicted_class = torch.argmax(output, dim=1)
|
157 |
+
|
158 |
+
# Calculate gradients
|
159 |
+
self.model.zero_grad()
|
160 |
+
output[0, predicted_class].backward()
|
161 |
+
|
162 |
+
if x.grad is not None:
|
163 |
+
gradients = x.grad.detach().abs().cpu().numpy().flatten()
|
164 |
+
else:
|
165 |
+
raise RuntimeError(
|
166 |
+
"Gradients were not computed. Ensure x.requires_grad_(True) is set correctly."
|
167 |
+
)
|
168 |
+
|
169 |
+
importance_data["gradient_importance"] = gradients
|
170 |
+
|
171 |
+
# Integrated gradients approximation
|
172 |
+
integrated_grads = self._integrated_gradients(x, predicted_class)
|
173 |
+
importance_data["integrated_gradients"] = integrated_grads
|
174 |
+
|
175 |
+
# Spectral region importance
|
176 |
+
if wavenumbers is not None:
|
177 |
+
region_importance = self._analyze_spectral_regions(gradients, wavenumbers)
|
178 |
+
importance_data["spectral_regions"] = region_importance
|
179 |
+
|
180 |
+
return importance_data
|
181 |
+
|
182 |
+
def _integrated_gradients(
|
183 |
+
self, x: torch.Tensor, target_class: torch.Tensor, steps: int = 50
|
184 |
+
) -> np.ndarray:
|
185 |
+
"""Calculate integrated gradients for feature importance"""
|
186 |
+
baseline = torch.zeros_like(x)
|
187 |
+
|
188 |
+
integrated_grads = np.zeros(x.shape[1])
|
189 |
+
|
190 |
+
for i in range(steps):
|
191 |
+
alpha = i / steps
|
192 |
+
interpolated = baseline + alpha * (x - baseline)
|
193 |
+
interpolated.requires_grad_(True)
|
194 |
+
|
195 |
+
output = self.model(interpolated)
|
196 |
+
self.model.zero_grad()
|
197 |
+
output[0, target_class].backward(retain_graph=True)
|
198 |
+
|
199 |
+
if interpolated.grad is not None:
|
200 |
+
grads = interpolated.grad.cpu().numpy().flatten()
|
201 |
+
integrated_grads += grads
|
202 |
+
|
203 |
+
integrated_grads = (
|
204 |
+
integrated_grads * (x - baseline).detach().cpu().numpy().flatten() / steps
|
205 |
+
)
|
206 |
+
return integrated_grads
|
207 |
+
|
208 |
+
def _analyze_spectral_regions(
|
209 |
+
self, importance: np.ndarray, wavenumbers: np.ndarray
|
210 |
+
) -> Dict[str, float]:
|
211 |
+
"""Analyze importance by common spectral regions"""
|
212 |
+
regions = {
|
213 |
+
"fingerprint": (400, 1500),
|
214 |
+
"ch_stretch": (2800, 3100),
|
215 |
+
"oh_stretch": (3200, 3700),
|
216 |
+
"carbonyl": (1600, 1800),
|
217 |
+
"aromatic": (1450, 1650),
|
218 |
+
}
|
219 |
+
|
220 |
+
region_importance = {}
|
221 |
+
|
222 |
+
for region_name, (low, high) in regions.items():
|
223 |
+
mask = (wavenumbers >= low) & (wavenumbers <= high)
|
224 |
+
if np.any(mask):
|
225 |
+
region_importance[region_name] = float(np.mean(importance[mask]))
|
226 |
+
else:
|
227 |
+
region_importance[region_name] = 0.0
|
228 |
+
|
229 |
+
return region_importance
|
230 |
+
|
231 |
+
|
232 |
+
class HypothesisGenerator:
|
233 |
+
"""AI-driven scientific hypothesis generation"""
|
234 |
+
|
235 |
+
def __init__(self):
|
236 |
+
self.hypothesis_templates = [
|
237 |
+
"The spectral differences in the {region} region suggest {mechanism} as a primary degradation pathway",
|
238 |
+
"Enhanced intensity at {wavenumber} cm⁻¹ indicates {chemical_change} in weathered samples",
|
239 |
+
"The correlation between {feature1} and {feature2} suggests {relationship}",
|
240 |
+
"Baseline shifts in {region} region may indicate {structural_change}",
|
241 |
+
]
|
242 |
+
|
243 |
+
def generate_hypotheses(
|
244 |
+
self, explanation: PredictionExplanation, spectral_data: Dict[str, Any]
|
245 |
+
) -> List[Hypothesis]:
|
246 |
+
"""Generate testable hypotheses based on model predictions and explanations"""
|
247 |
+
hypotheses = []
|
248 |
+
|
249 |
+
# Analyze feature importance for hypothesis generation
|
250 |
+
important_features = self._identify_key_features(explanation.feature_importance)
|
251 |
+
|
252 |
+
for feature_info in important_features:
|
253 |
+
hypothesis = self._generate_single_hypothesis(feature_info, explanation)
|
254 |
+
if hypothesis:
|
255 |
+
hypotheses.append(hypothesis)
|
256 |
+
|
257 |
+
return hypotheses
|
258 |
+
|
259 |
+
def _identify_key_features(
|
260 |
+
self, feature_importance: Dict[str, float]
|
261 |
+
) -> List[Dict[str, Any]]:
|
262 |
+
"""Identify key features for hypothesis generation"""
|
263 |
+
# Sort features by importance
|
264 |
+
sorted_features = sorted(
|
265 |
+
feature_importance.items(), key=lambda x: abs(x[1]), reverse=True
|
266 |
+
)
|
267 |
+
|
268 |
+
key_features = []
|
269 |
+
for feature_name, importance in sorted_features[:5]: # Top 5 features
|
270 |
+
feature_info = {
|
271 |
+
"name": feature_name,
|
272 |
+
"importance": importance,
|
273 |
+
"type": self._classify_feature_type(feature_name),
|
274 |
+
"chemical_significance": self._get_chemical_significance(feature_name),
|
275 |
+
}
|
276 |
+
key_features.append(feature_info)
|
277 |
+
|
278 |
+
return key_features
|
279 |
+
|
280 |
+
def _classify_feature_type(self, feature_name: str) -> str:
|
281 |
+
"""Classify spectral feature type"""
|
282 |
+
if "fingerprint" in feature_name.lower():
|
283 |
+
return "fingerprint"
|
284 |
+
elif "stretch" in feature_name.lower():
|
285 |
+
return "vibrational"
|
286 |
+
elif "carbonyl" in feature_name.lower():
|
287 |
+
return "functional_group"
|
288 |
+
else:
|
289 |
+
return "general"
|
290 |
+
|
291 |
+
def _get_chemical_significance(self, feature_name: str) -> str:
|
292 |
+
"""Get chemical significance of spectral feature"""
|
293 |
+
significance_map = {
|
294 |
+
"fingerprint": "molecular backbone structure",
|
295 |
+
"ch_stretch": "aliphatic chain integrity",
|
296 |
+
"oh_stretch": "hydrogen bonding and hydration",
|
297 |
+
"carbonyl": "oxidative degradation products",
|
298 |
+
"aromatic": "aromatic ring preservation",
|
299 |
+
}
|
300 |
+
|
301 |
+
for key, significance in significance_map.items():
|
302 |
+
if key in feature_name.lower():
|
303 |
+
return significance
|
304 |
+
|
305 |
+
return "structural changes"
|
306 |
+
|
307 |
+
def _generate_single_hypothesis(
|
308 |
+
self, feature_info: Dict[str, Any], explanation: PredictionExplanation
|
309 |
+
) -> Optional[Hypothesis]:
|
310 |
+
"""Generate a single hypothesis from feature information"""
|
311 |
+
if feature_info["importance"] < 0.1: # Skip low-importance features
|
312 |
+
return None
|
313 |
+
|
314 |
+
# Create hypothesis statement
|
315 |
+
statement = f"Changes in {feature_info['name']} region indicate {feature_info['chemical_significance']} during polymer weathering"
|
316 |
+
|
317 |
+
# Generate supporting evidence
|
318 |
+
evidence = [
|
319 |
+
f"Feature importance score: {feature_info['importance']:.3f}",
|
320 |
+
f"Classification confidence: {explanation.confidence:.3f}",
|
321 |
+
f"Chemical significance: {feature_info['chemical_significance']}",
|
322 |
+
]
|
323 |
+
|
324 |
+
# Generate testable predictions
|
325 |
+
predictions = [
|
326 |
+
f"Controlled weathering experiments should show progressive changes in {feature_info['name']} region",
|
327 |
+
f"Different polymer types should exhibit varying {feature_info['name']} responses to weathering",
|
328 |
+
]
|
329 |
+
|
330 |
+
# Suggest experiments
|
331 |
+
experiments = [
|
332 |
+
f"Time-series weathering study monitoring {feature_info['name']} region",
|
333 |
+
f"Comparative analysis across polymer types focusing on {feature_info['chemical_significance']}",
|
334 |
+
"Cross-validation with other analytical techniques (DSC, GPC, etc.)",
|
335 |
+
]
|
336 |
+
|
337 |
+
return Hypothesis(
|
338 |
+
statement=statement,
|
339 |
+
confidence=min(0.9, feature_info["importance"] * explanation.confidence),
|
340 |
+
supporting_evidence=evidence,
|
341 |
+
testable_predictions=predictions,
|
342 |
+
suggested_experiments=experiments,
|
343 |
+
related_literature=[], # Could be populated with literature search
|
344 |
+
)
|
345 |
+
|
346 |
+
|
347 |
+
class TransparentAIEngine:
|
348 |
+
"""Main transparent AI engine combining all reasoning components"""
|
349 |
+
|
350 |
+
def __init__(self, model):
|
351 |
+
self.model = model
|
352 |
+
self.uncertainty_estimator = UncertaintyEstimator(model)
|
353 |
+
self.feature_analyzer = FeatureImportanceAnalyzer(model)
|
354 |
+
self.hypothesis_generator = HypothesisGenerator()
|
355 |
+
|
356 |
+
def predict_with_explanation(
|
357 |
+
self, x: torch.Tensor, wavenumbers: Optional[np.ndarray] = None
|
358 |
+
) -> PredictionExplanation:
|
359 |
+
"""Generate comprehensive prediction with full explanation"""
|
360 |
+
self.model.eval()
|
361 |
+
|
362 |
+
# Get basic prediction
|
363 |
+
with torch.no_grad():
|
364 |
+
logits = self.model(x)
|
365 |
+
probabilities = F.softmax(logits, dim=1).cpu().numpy().flatten()
|
366 |
+
prediction = int(torch.argmax(logits, dim=1).item())
|
367 |
+
confidence = float(np.max(probabilities))
|
368 |
+
|
369 |
+
# Determine confidence level
|
370 |
+
if confidence >= 0.80:
|
371 |
+
confidence_level = "HIGH"
|
372 |
+
elif confidence >= 0.60:
|
373 |
+
confidence_level = "MEDIUM"
|
374 |
+
else:
|
375 |
+
confidence_level = "LOW"
|
376 |
+
|
377 |
+
# Get uncertainty estimation
|
378 |
+
uncertainties = self.uncertainty_estimator.estimate_uncertainty(x)
|
379 |
+
confidence_intervals = self.uncertainty_estimator.confidence_intervals(x)
|
380 |
+
|
381 |
+
# Analyze feature importance
|
382 |
+
importance_data = self.feature_analyzer.analyze_feature_importance(
|
383 |
+
x, wavenumbers
|
384 |
+
)
|
385 |
+
|
386 |
+
# Create feature importance dictionary
|
387 |
+
if wavenumbers is not None and "spectral_regions" in importance_data:
|
388 |
+
feature_importance = importance_data["spectral_regions"]
|
389 |
+
else:
|
390 |
+
# Use gradient importance
|
391 |
+
gradients = importance_data.get("gradient_importance", [])
|
392 |
+
feature_importance = {
|
393 |
+
f"feature_{i}": float(val) for i, val in enumerate(gradients[:10])
|
394 |
+
}
|
395 |
+
|
396 |
+
# Generate reasoning chain
|
397 |
+
reasoning_chain = self._generate_reasoning_chain(
|
398 |
+
prediction, confidence, feature_importance, uncertainties
|
399 |
+
)
|
400 |
+
|
401 |
+
# Identify uncertainty sources
|
402 |
+
uncertainty_sources = self._identify_uncertainty_sources(uncertainties)
|
403 |
+
|
404 |
+
# Create explanation object
|
405 |
+
explanation = PredictionExplanation(
|
406 |
+
prediction=prediction,
|
407 |
+
confidence=confidence,
|
408 |
+
confidence_level=confidence_level,
|
409 |
+
probabilities=probabilities,
|
410 |
+
feature_importance=feature_importance,
|
411 |
+
reasoning_chain=reasoning_chain,
|
412 |
+
uncertainty_sources=uncertainty_sources,
|
413 |
+
similar_cases=[], # Could be populated with case-based reasoning
|
414 |
+
confidence_intervals=confidence_intervals,
|
415 |
+
)
|
416 |
+
|
417 |
+
return explanation
|
418 |
+
|
419 |
+
def generate_hypotheses(
|
420 |
+
self, explanation: PredictionExplanation
|
421 |
+
) -> List[Hypothesis]:
|
422 |
+
"""Generate scientific hypotheses based on prediction explanation"""
|
423 |
+
return self.hypothesis_generator.generate_hypotheses(explanation, {})
|
424 |
+
|
425 |
+
def _generate_reasoning_chain(
|
426 |
+
self,
|
427 |
+
prediction: int,
|
428 |
+
confidence: float,
|
429 |
+
feature_importance: Dict[str, float],
|
430 |
+
uncertainties: Dict[str, float],
|
431 |
+
) -> List[str]:
|
432 |
+
"""Generate human-readable reasoning chain"""
|
433 |
+
reasoning = []
|
434 |
+
|
435 |
+
# Start with prediction
|
436 |
+
class_names = ["Stable", "Weathered"]
|
437 |
+
reasoning.append(
|
438 |
+
f"Model predicts: {class_names[prediction]} (confidence: {confidence:.3f})"
|
439 |
+
)
|
440 |
+
|
441 |
+
# Add feature analysis
|
442 |
+
top_features = sorted(
|
443 |
+
feature_importance.items(), key=lambda x: abs(x[1]), reverse=True
|
444 |
+
)[:3]
|
445 |
+
|
446 |
+
for feature, importance in top_features:
|
447 |
+
reasoning.append(
|
448 |
+
f"Key evidence: {feature} region shows importance score {importance:.3f}"
|
449 |
+
)
|
450 |
+
|
451 |
+
# Add uncertainty analysis
|
452 |
+
total_uncertainty = uncertainties.get("total", 0)
|
453 |
+
if total_uncertainty > 0.1:
|
454 |
+
reasoning.append(
|
455 |
+
f"High uncertainty detected ({total_uncertainty:.3f}) - suggests ambiguous case"
|
456 |
+
)
|
457 |
+
|
458 |
+
# Add confidence assessment
|
459 |
+
if confidence > 0.8:
|
460 |
+
reasoning.append(
|
461 |
+
"High confidence: Strong spectral signature for classification"
|
462 |
+
)
|
463 |
+
elif confidence > 0.6:
|
464 |
+
reasoning.append("Medium confidence: Some ambiguity in spectral features")
|
465 |
+
else:
|
466 |
+
reasoning.append("Low confidence: Weak or conflicting spectral evidence")
|
467 |
+
|
468 |
+
return reasoning
|
469 |
+
|
470 |
+
def _identify_uncertainty_sources(
|
471 |
+
self, uncertainties: Dict[str, float]
|
472 |
+
) -> List[str]:
|
473 |
+
"""Identify sources of prediction uncertainty"""
|
474 |
+
sources = []
|
475 |
+
|
476 |
+
epistemic = uncertainties.get("epistemic", 0)
|
477 |
+
aleatoric = uncertainties.get("aleatoric", 0)
|
478 |
+
|
479 |
+
if epistemic > 0.05:
|
480 |
+
sources.append(
|
481 |
+
"Model uncertainty: Limited training data for this type of spectrum"
|
482 |
+
)
|
483 |
+
|
484 |
+
if aleatoric > 0.05:
|
485 |
+
sources.append("Data uncertainty: Noisy or degraded spectral quality")
|
486 |
+
|
487 |
+
if uncertainties.get("prediction_variance", 0) > 0.1:
|
488 |
+
sources.append("Prediction instability: Multiple possible interpretations")
|
489 |
+
|
490 |
+
if not sources:
|
491 |
+
sources.append("Low uncertainty: Clear and unambiguous classification")
|
492 |
+
|
493 |
+
return sources
|
requirements.txt
CHANGED
@@ -7,8 +7,13 @@ pydantic
|
|
7 |
scikit-learn
|
8 |
seaborn
|
9 |
scipy
|
|
|
10 |
streamlit
|
11 |
torch
|
12 |
torchvision
|
13 |
uvicorn
|
14 |
matplotlib
|
|
|
|
|
|
|
|
|
|
7 |
scikit-learn
|
8 |
seaborn
|
9 |
scipy
|
10 |
+
shap
|
11 |
streamlit
|
12 |
torch
|
13 |
torchvision
|
14 |
uvicorn
|
15 |
matplotlib
|
16 |
+
xgboost
|
17 |
+
requests
|
18 |
+
Pillow
|
19 |
+
plotly
|