devjas1 commited on
Commit
50c7ef1
·
1 Parent(s): 8dd961f

Adds Transparent AI Reasoning Engine for explainable predictions

Browse files

Introduces a new module that implements a Transparent AI Reasoning Engine, enabling explainable predictions with uncertainty quantification and hypothesis generation.

This addition enhances the model's interpretability by providing comprehensive explanations, confidence assessments, and analytical features for scientific hypotheses related to spectral data.

Updates the requirements to include necessary dependencies for SHAP and other libraries.

Files changed (2) hide show
  1. modules/transparent_ai.py +493 -0
  2. requirements.txt +5 -0
modules/transparent_ai.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transparent AI Reasoning Engine for POLYMEROS
3
+ Provides explainable predictions with uncertainty quantification and hypothesis generation
4
+ """
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from typing import Dict, List, Any, Tuple, Optional
10
+ from dataclasses import dataclass
11
+ import warnings
12
+
13
+ try:
14
+ import shap
15
+
16
+ SHAP_AVAILABLE = True
17
+ except ImportError:
18
+ SHAP_AVAILABLE = False
19
+ warnings.warn("SHAP not available. Install with: pip install shap")
20
+
21
+
22
+ @dataclass
23
+ class PredictionExplanation:
24
+ """Comprehensive explanation for a model prediction"""
25
+
26
+ prediction: int
27
+ confidence: float
28
+ confidence_level: str
29
+ probabilities: np.ndarray
30
+ feature_importance: Dict[str, float]
31
+ reasoning_chain: List[str]
32
+ uncertainty_sources: List[str]
33
+ similar_cases: List[Dict[str, Any]]
34
+ confidence_intervals: Dict[str, Tuple[float, float]]
35
+
36
+
37
+ @dataclass
38
+ class Hypothesis:
39
+ """AI-generated scientific hypothesis"""
40
+
41
+ statement: str
42
+ confidence: float
43
+ supporting_evidence: List[str]
44
+ testable_predictions: List[str]
45
+ suggested_experiments: List[str]
46
+ related_literature: List[str]
47
+
48
+
49
+ class UncertaintyEstimator:
50
+ """Bayesian uncertainty estimation for model predictions"""
51
+
52
+ def __init__(self, model, n_samples: int = 100):
53
+ self.model = model
54
+ self.n_samples = n_samples
55
+ self.epistemic_uncertainty = None
56
+ self.aleatoric_uncertainty = None
57
+
58
+ def estimate_uncertainty(self, x: torch.Tensor) -> Dict[str, float]:
59
+ """Estimate prediction uncertainty using Monte Carlo dropout"""
60
+ self.model.train() # Enable dropout
61
+
62
+ predictions = []
63
+ with torch.no_grad():
64
+ for _ in range(self.n_samples):
65
+ pred = F.softmax(self.model(x), dim=1)
66
+ predictions.append(pred.cpu().numpy())
67
+
68
+ predictions = np.array(predictions)
69
+
70
+ # Calculate uncertainties
71
+ mean_pred = np.mean(predictions, axis=0)
72
+ epistemic = np.var(predictions, axis=0) # Model uncertainty
73
+ aleatoric = np.mean(predictions * (1 - predictions), axis=0) # Data uncertainty
74
+
75
+ total_uncertainty = epistemic + aleatoric
76
+
77
+ return {
78
+ "epistemic": float(np.mean(epistemic)),
79
+ "aleatoric": float(np.mean(aleatoric)),
80
+ "total": float(np.mean(total_uncertainty)),
81
+ "prediction_variance": float(np.var(mean_pred)),
82
+ }
83
+
84
+ def confidence_intervals(
85
+ self, x: torch.Tensor, confidence_level: float = 0.95
86
+ ) -> Dict[str, Tuple[float, float]]:
87
+ """Calculate confidence intervals for predictions"""
88
+ self.model.train()
89
+
90
+ predictions = []
91
+ with torch.no_grad():
92
+ for _ in range(self.n_samples):
93
+ pred = F.softmax(self.model(x), dim=1)
94
+ predictions.append(pred.cpu().numpy().flatten())
95
+
96
+ predictions = np.array(predictions)
97
+
98
+ alpha = 1 - confidence_level
99
+ lower_percentile = (alpha / 2) * 100
100
+ upper_percentile = (1 - alpha / 2) * 100
101
+
102
+ intervals = {}
103
+ for i in range(predictions.shape[1]):
104
+ lower = np.percentile(predictions[:, i], lower_percentile)
105
+ upper = np.percentile(predictions[:, i], upper_percentile)
106
+ intervals[f"class_{i}"] = (lower, upper)
107
+
108
+ return intervals
109
+
110
+
111
+ class FeatureImportanceAnalyzer:
112
+ """Advanced feature importance analysis for spectral data"""
113
+
114
+ def __init__(self, model):
115
+ self.model = model
116
+ self.shap_explainer = None
117
+
118
+ if SHAP_AVAILABLE:
119
+ try:
120
+ # Initialize SHAP explainer for the model
121
+ if SHAP_AVAILABLE:
122
+ if SHAP_AVAILABLE:
123
+ self.shap_explainer = shap.DeepExplainer(
124
+ model, torch.zeros(1, 500)
125
+ )
126
+ else:
127
+ self.shap_explainer = None
128
+ else:
129
+ self.shap_explainer = None
130
+ except (ValueError, RuntimeError) as e:
131
+ warnings.warn(f"Could not initialize SHAP explainer: {e}")
132
+
133
+ def analyze_feature_importance(
134
+ self, x: torch.Tensor, wavenumbers: Optional[np.ndarray] = None
135
+ ) -> Dict[str, Any]:
136
+ """Comprehensive feature importance analysis"""
137
+ importance_data = {}
138
+
139
+ # SHAP analysis (if available)
140
+ if self.shap_explainer is not None:
141
+ try:
142
+ shap_values = self.shap_explainer.shap_values(x)
143
+ importance_data["shap_values"] = shap_values
144
+ importance_data["shap_available"] = True
145
+ except (ValueError, RuntimeError) as e:
146
+ warnings.warn(f"SHAP analysis failed: {e}")
147
+ importance_data["shap_available"] = False
148
+ else:
149
+ importance_data["shap_available"] = False
150
+
151
+ # Gradient-based importance
152
+ x.requires_grad_(True)
153
+ self.model.eval()
154
+
155
+ output = self.model(x)
156
+ predicted_class = torch.argmax(output, dim=1)
157
+
158
+ # Calculate gradients
159
+ self.model.zero_grad()
160
+ output[0, predicted_class].backward()
161
+
162
+ if x.grad is not None:
163
+ gradients = x.grad.detach().abs().cpu().numpy().flatten()
164
+ else:
165
+ raise RuntimeError(
166
+ "Gradients were not computed. Ensure x.requires_grad_(True) is set correctly."
167
+ )
168
+
169
+ importance_data["gradient_importance"] = gradients
170
+
171
+ # Integrated gradients approximation
172
+ integrated_grads = self._integrated_gradients(x, predicted_class)
173
+ importance_data["integrated_gradients"] = integrated_grads
174
+
175
+ # Spectral region importance
176
+ if wavenumbers is not None:
177
+ region_importance = self._analyze_spectral_regions(gradients, wavenumbers)
178
+ importance_data["spectral_regions"] = region_importance
179
+
180
+ return importance_data
181
+
182
+ def _integrated_gradients(
183
+ self, x: torch.Tensor, target_class: torch.Tensor, steps: int = 50
184
+ ) -> np.ndarray:
185
+ """Calculate integrated gradients for feature importance"""
186
+ baseline = torch.zeros_like(x)
187
+
188
+ integrated_grads = np.zeros(x.shape[1])
189
+
190
+ for i in range(steps):
191
+ alpha = i / steps
192
+ interpolated = baseline + alpha * (x - baseline)
193
+ interpolated.requires_grad_(True)
194
+
195
+ output = self.model(interpolated)
196
+ self.model.zero_grad()
197
+ output[0, target_class].backward(retain_graph=True)
198
+
199
+ if interpolated.grad is not None:
200
+ grads = interpolated.grad.cpu().numpy().flatten()
201
+ integrated_grads += grads
202
+
203
+ integrated_grads = (
204
+ integrated_grads * (x - baseline).detach().cpu().numpy().flatten() / steps
205
+ )
206
+ return integrated_grads
207
+
208
+ def _analyze_spectral_regions(
209
+ self, importance: np.ndarray, wavenumbers: np.ndarray
210
+ ) -> Dict[str, float]:
211
+ """Analyze importance by common spectral regions"""
212
+ regions = {
213
+ "fingerprint": (400, 1500),
214
+ "ch_stretch": (2800, 3100),
215
+ "oh_stretch": (3200, 3700),
216
+ "carbonyl": (1600, 1800),
217
+ "aromatic": (1450, 1650),
218
+ }
219
+
220
+ region_importance = {}
221
+
222
+ for region_name, (low, high) in regions.items():
223
+ mask = (wavenumbers >= low) & (wavenumbers <= high)
224
+ if np.any(mask):
225
+ region_importance[region_name] = float(np.mean(importance[mask]))
226
+ else:
227
+ region_importance[region_name] = 0.0
228
+
229
+ return region_importance
230
+
231
+
232
+ class HypothesisGenerator:
233
+ """AI-driven scientific hypothesis generation"""
234
+
235
+ def __init__(self):
236
+ self.hypothesis_templates = [
237
+ "The spectral differences in the {region} region suggest {mechanism} as a primary degradation pathway",
238
+ "Enhanced intensity at {wavenumber} cm⁻¹ indicates {chemical_change} in weathered samples",
239
+ "The correlation between {feature1} and {feature2} suggests {relationship}",
240
+ "Baseline shifts in {region} region may indicate {structural_change}",
241
+ ]
242
+
243
+ def generate_hypotheses(
244
+ self, explanation: PredictionExplanation, spectral_data: Dict[str, Any]
245
+ ) -> List[Hypothesis]:
246
+ """Generate testable hypotheses based on model predictions and explanations"""
247
+ hypotheses = []
248
+
249
+ # Analyze feature importance for hypothesis generation
250
+ important_features = self._identify_key_features(explanation.feature_importance)
251
+
252
+ for feature_info in important_features:
253
+ hypothesis = self._generate_single_hypothesis(feature_info, explanation)
254
+ if hypothesis:
255
+ hypotheses.append(hypothesis)
256
+
257
+ return hypotheses
258
+
259
+ def _identify_key_features(
260
+ self, feature_importance: Dict[str, float]
261
+ ) -> List[Dict[str, Any]]:
262
+ """Identify key features for hypothesis generation"""
263
+ # Sort features by importance
264
+ sorted_features = sorted(
265
+ feature_importance.items(), key=lambda x: abs(x[1]), reverse=True
266
+ )
267
+
268
+ key_features = []
269
+ for feature_name, importance in sorted_features[:5]: # Top 5 features
270
+ feature_info = {
271
+ "name": feature_name,
272
+ "importance": importance,
273
+ "type": self._classify_feature_type(feature_name),
274
+ "chemical_significance": self._get_chemical_significance(feature_name),
275
+ }
276
+ key_features.append(feature_info)
277
+
278
+ return key_features
279
+
280
+ def _classify_feature_type(self, feature_name: str) -> str:
281
+ """Classify spectral feature type"""
282
+ if "fingerprint" in feature_name.lower():
283
+ return "fingerprint"
284
+ elif "stretch" in feature_name.lower():
285
+ return "vibrational"
286
+ elif "carbonyl" in feature_name.lower():
287
+ return "functional_group"
288
+ else:
289
+ return "general"
290
+
291
+ def _get_chemical_significance(self, feature_name: str) -> str:
292
+ """Get chemical significance of spectral feature"""
293
+ significance_map = {
294
+ "fingerprint": "molecular backbone structure",
295
+ "ch_stretch": "aliphatic chain integrity",
296
+ "oh_stretch": "hydrogen bonding and hydration",
297
+ "carbonyl": "oxidative degradation products",
298
+ "aromatic": "aromatic ring preservation",
299
+ }
300
+
301
+ for key, significance in significance_map.items():
302
+ if key in feature_name.lower():
303
+ return significance
304
+
305
+ return "structural changes"
306
+
307
+ def _generate_single_hypothesis(
308
+ self, feature_info: Dict[str, Any], explanation: PredictionExplanation
309
+ ) -> Optional[Hypothesis]:
310
+ """Generate a single hypothesis from feature information"""
311
+ if feature_info["importance"] < 0.1: # Skip low-importance features
312
+ return None
313
+
314
+ # Create hypothesis statement
315
+ statement = f"Changes in {feature_info['name']} region indicate {feature_info['chemical_significance']} during polymer weathering"
316
+
317
+ # Generate supporting evidence
318
+ evidence = [
319
+ f"Feature importance score: {feature_info['importance']:.3f}",
320
+ f"Classification confidence: {explanation.confidence:.3f}",
321
+ f"Chemical significance: {feature_info['chemical_significance']}",
322
+ ]
323
+
324
+ # Generate testable predictions
325
+ predictions = [
326
+ f"Controlled weathering experiments should show progressive changes in {feature_info['name']} region",
327
+ f"Different polymer types should exhibit varying {feature_info['name']} responses to weathering",
328
+ ]
329
+
330
+ # Suggest experiments
331
+ experiments = [
332
+ f"Time-series weathering study monitoring {feature_info['name']} region",
333
+ f"Comparative analysis across polymer types focusing on {feature_info['chemical_significance']}",
334
+ "Cross-validation with other analytical techniques (DSC, GPC, etc.)",
335
+ ]
336
+
337
+ return Hypothesis(
338
+ statement=statement,
339
+ confidence=min(0.9, feature_info["importance"] * explanation.confidence),
340
+ supporting_evidence=evidence,
341
+ testable_predictions=predictions,
342
+ suggested_experiments=experiments,
343
+ related_literature=[], # Could be populated with literature search
344
+ )
345
+
346
+
347
+ class TransparentAIEngine:
348
+ """Main transparent AI engine combining all reasoning components"""
349
+
350
+ def __init__(self, model):
351
+ self.model = model
352
+ self.uncertainty_estimator = UncertaintyEstimator(model)
353
+ self.feature_analyzer = FeatureImportanceAnalyzer(model)
354
+ self.hypothesis_generator = HypothesisGenerator()
355
+
356
+ def predict_with_explanation(
357
+ self, x: torch.Tensor, wavenumbers: Optional[np.ndarray] = None
358
+ ) -> PredictionExplanation:
359
+ """Generate comprehensive prediction with full explanation"""
360
+ self.model.eval()
361
+
362
+ # Get basic prediction
363
+ with torch.no_grad():
364
+ logits = self.model(x)
365
+ probabilities = F.softmax(logits, dim=1).cpu().numpy().flatten()
366
+ prediction = int(torch.argmax(logits, dim=1).item())
367
+ confidence = float(np.max(probabilities))
368
+
369
+ # Determine confidence level
370
+ if confidence >= 0.80:
371
+ confidence_level = "HIGH"
372
+ elif confidence >= 0.60:
373
+ confidence_level = "MEDIUM"
374
+ else:
375
+ confidence_level = "LOW"
376
+
377
+ # Get uncertainty estimation
378
+ uncertainties = self.uncertainty_estimator.estimate_uncertainty(x)
379
+ confidence_intervals = self.uncertainty_estimator.confidence_intervals(x)
380
+
381
+ # Analyze feature importance
382
+ importance_data = self.feature_analyzer.analyze_feature_importance(
383
+ x, wavenumbers
384
+ )
385
+
386
+ # Create feature importance dictionary
387
+ if wavenumbers is not None and "spectral_regions" in importance_data:
388
+ feature_importance = importance_data["spectral_regions"]
389
+ else:
390
+ # Use gradient importance
391
+ gradients = importance_data.get("gradient_importance", [])
392
+ feature_importance = {
393
+ f"feature_{i}": float(val) for i, val in enumerate(gradients[:10])
394
+ }
395
+
396
+ # Generate reasoning chain
397
+ reasoning_chain = self._generate_reasoning_chain(
398
+ prediction, confidence, feature_importance, uncertainties
399
+ )
400
+
401
+ # Identify uncertainty sources
402
+ uncertainty_sources = self._identify_uncertainty_sources(uncertainties)
403
+
404
+ # Create explanation object
405
+ explanation = PredictionExplanation(
406
+ prediction=prediction,
407
+ confidence=confidence,
408
+ confidence_level=confidence_level,
409
+ probabilities=probabilities,
410
+ feature_importance=feature_importance,
411
+ reasoning_chain=reasoning_chain,
412
+ uncertainty_sources=uncertainty_sources,
413
+ similar_cases=[], # Could be populated with case-based reasoning
414
+ confidence_intervals=confidence_intervals,
415
+ )
416
+
417
+ return explanation
418
+
419
+ def generate_hypotheses(
420
+ self, explanation: PredictionExplanation
421
+ ) -> List[Hypothesis]:
422
+ """Generate scientific hypotheses based on prediction explanation"""
423
+ return self.hypothesis_generator.generate_hypotheses(explanation, {})
424
+
425
+ def _generate_reasoning_chain(
426
+ self,
427
+ prediction: int,
428
+ confidence: float,
429
+ feature_importance: Dict[str, float],
430
+ uncertainties: Dict[str, float],
431
+ ) -> List[str]:
432
+ """Generate human-readable reasoning chain"""
433
+ reasoning = []
434
+
435
+ # Start with prediction
436
+ class_names = ["Stable", "Weathered"]
437
+ reasoning.append(
438
+ f"Model predicts: {class_names[prediction]} (confidence: {confidence:.3f})"
439
+ )
440
+
441
+ # Add feature analysis
442
+ top_features = sorted(
443
+ feature_importance.items(), key=lambda x: abs(x[1]), reverse=True
444
+ )[:3]
445
+
446
+ for feature, importance in top_features:
447
+ reasoning.append(
448
+ f"Key evidence: {feature} region shows importance score {importance:.3f}"
449
+ )
450
+
451
+ # Add uncertainty analysis
452
+ total_uncertainty = uncertainties.get("total", 0)
453
+ if total_uncertainty > 0.1:
454
+ reasoning.append(
455
+ f"High uncertainty detected ({total_uncertainty:.3f}) - suggests ambiguous case"
456
+ )
457
+
458
+ # Add confidence assessment
459
+ if confidence > 0.8:
460
+ reasoning.append(
461
+ "High confidence: Strong spectral signature for classification"
462
+ )
463
+ elif confidence > 0.6:
464
+ reasoning.append("Medium confidence: Some ambiguity in spectral features")
465
+ else:
466
+ reasoning.append("Low confidence: Weak or conflicting spectral evidence")
467
+
468
+ return reasoning
469
+
470
+ def _identify_uncertainty_sources(
471
+ self, uncertainties: Dict[str, float]
472
+ ) -> List[str]:
473
+ """Identify sources of prediction uncertainty"""
474
+ sources = []
475
+
476
+ epistemic = uncertainties.get("epistemic", 0)
477
+ aleatoric = uncertainties.get("aleatoric", 0)
478
+
479
+ if epistemic > 0.05:
480
+ sources.append(
481
+ "Model uncertainty: Limited training data for this type of spectrum"
482
+ )
483
+
484
+ if aleatoric > 0.05:
485
+ sources.append("Data uncertainty: Noisy or degraded spectral quality")
486
+
487
+ if uncertainties.get("prediction_variance", 0) > 0.1:
488
+ sources.append("Prediction instability: Multiple possible interpretations")
489
+
490
+ if not sources:
491
+ sources.append("Low uncertainty: Clear and unambiguous classification")
492
+
493
+ return sources
requirements.txt CHANGED
@@ -7,8 +7,13 @@ pydantic
7
  scikit-learn
8
  seaborn
9
  scipy
 
10
  streamlit
11
  torch
12
  torchvision
13
  uvicorn
14
  matplotlib
 
 
 
 
 
7
  scikit-learn
8
  seaborn
9
  scipy
10
+ shap
11
  streamlit
12
  torch
13
  torchvision
14
  uvicorn
15
  matplotlib
16
+ xgboost
17
+ requests
18
+ Pillow
19
+ plotly