Spaces:
Sleeping
Sleeping
devjas1
FEAT(transparent_ai): refine hypothesis generation by removing unused spectral data parameter
68f2a01
""" | |
Transparent AI Reasoning Engine for POLYMEROS | |
Provides explainable predictions with uncertainty quantification and hypothesis generation | |
""" | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from typing import Dict, List, Any, Tuple, Optional | |
from dataclasses import dataclass | |
import warnings | |
try: | |
import shap | |
SHAP_AVAILABLE = True | |
except ImportError: | |
SHAP_AVAILABLE = False | |
warnings.warn("SHAP not available. Install with: pip install shap") | |
class PredictionExplanation: | |
"""Comprehensive explanation for a model prediction""" | |
prediction: int | |
confidence: float | |
confidence_level: str | |
probabilities: np.ndarray | |
feature_importance: Dict[str, float] | |
reasoning_chain: List[str] | |
uncertainty_sources: List[str] | |
similar_cases: List[Dict[str, Any]] | |
confidence_intervals: Dict[str, Tuple[float, float]] | |
class Hypothesis: | |
"""AI-generated scientific hypothesis""" | |
statement: str | |
confidence: float | |
supporting_evidence: List[str] | |
testable_predictions: List[str] | |
suggested_experiments: List[str] | |
related_literature: List[str] | |
class UncertaintyEstimator: | |
"""Bayesian uncertainty estimation for model predictions""" | |
def __init__(self, model, n_samples: int = 100): | |
self.model = model | |
self.n_samples = n_samples | |
self.epistemic_uncertainty = None | |
self.aleatoric_uncertainty = None | |
def estimate_uncertainty(self, x: torch.Tensor) -> Dict[str, float]: | |
"""Estimate prediction uncertainty using Monte Carlo dropout""" | |
self.model.train() # Enable dropout | |
predictions = [] | |
with torch.no_grad(): | |
for _ in range(self.n_samples): | |
pred = F.softmax(self.model(x), dim=1) | |
predictions.append(pred.cpu().numpy()) | |
predictions = np.array(predictions) | |
# Calculate uncertainties | |
mean_pred = np.mean(predictions, axis=0) | |
epistemic = np.var(predictions, axis=0) # Model uncertainty | |
aleatoric = np.mean(predictions * (1 - predictions), axis=0) # Data uncertainty | |
total_uncertainty = epistemic + aleatoric | |
return { | |
"epistemic": float(np.mean(epistemic)), | |
"aleatoric": float(np.mean(aleatoric)), | |
"total": float(np.mean(total_uncertainty)), | |
"prediction_variance": float(np.var(mean_pred)), | |
} | |
def confidence_intervals( | |
self, x: torch.Tensor, confidence_level: float = 0.95 | |
) -> Dict[str, Tuple[float, float]]: | |
"""Calculate confidence intervals for predictions""" | |
self.model.train() | |
predictions = [] | |
with torch.no_grad(): | |
for _ in range(self.n_samples): | |
pred = F.softmax(self.model(x), dim=1) | |
predictions.append(pred.cpu().numpy().flatten()) | |
predictions = np.array(predictions) | |
alpha = 1 - confidence_level | |
lower_percentile = (alpha / 2) * 100 | |
upper_percentile = (1 - alpha / 2) * 100 | |
intervals = {} | |
for i in range(predictions.shape[1]): | |
lower = np.percentile(predictions[:, i], lower_percentile) | |
upper = np.percentile(predictions[:, i], upper_percentile) | |
intervals[f"class_{i}"] = (lower, upper) | |
return intervals | |
class FeatureImportanceAnalyzer: | |
"""Advanced feature importance analysis for spectral data""" | |
def __init__(self, model): | |
self.model = model | |
self.shap_explainer = None | |
if SHAP_AVAILABLE: | |
try: | |
# Initialize SHAP explainer for the model | |
if SHAP_AVAILABLE: | |
if SHAP_AVAILABLE: | |
self.shap_explainer = shap.DeepExplainer( # type: ignore | |
model, torch.zeros(1, 500) | |
) | |
else: | |
self.shap_explainer = None | |
else: | |
self.shap_explainer = None | |
except (ValueError, RuntimeError) as e: | |
warnings.warn(f"Could not initialize SHAP explainer: {e}") | |
def analyze_feature_importance( | |
self, x: torch.Tensor, wavenumbers: Optional[np.ndarray] = None | |
) -> Dict[str, Any]: | |
"""Comprehensive feature importance analysis""" | |
importance_data = {} | |
# SHAP analysis (if available) | |
if self.shap_explainer is not None: | |
try: | |
shap_values = self.shap_explainer.shap_values(x) | |
importance_data["shap_values"] = shap_values | |
importance_data["shap_available"] = True | |
except (ValueError, RuntimeError) as e: | |
warnings.warn(f"SHAP analysis failed: {e}") | |
importance_data["shap_available"] = False | |
else: | |
importance_data["shap_available"] = False | |
# Gradient-based importance | |
x.requires_grad_(True) | |
self.model.eval() | |
output = self.model(x) | |
predicted_class = torch.argmax(output, dim=1) | |
# Calculate gradients | |
self.model.zero_grad() | |
output[0, predicted_class].backward() | |
if x.grad is not None: | |
gradients = x.grad.detach().abs().cpu().numpy().flatten() | |
else: | |
raise RuntimeError( | |
"Gradients were not computed. Ensure x.requires_grad_(True) is set correctly." | |
) | |
importance_data["gradient_importance"] = gradients | |
# Integrated gradients approximation | |
integrated_grads = self._integrated_gradients(x, predicted_class) | |
importance_data["integrated_gradients"] = integrated_grads | |
# Spectral region importance | |
if wavenumbers is not None: | |
region_importance = self._analyze_spectral_regions(gradients, wavenumbers) | |
importance_data["spectral_regions"] = region_importance | |
return importance_data | |
def _integrated_gradients( | |
self, x: torch.Tensor, target_class: torch.Tensor, steps: int = 50 | |
) -> np.ndarray: | |
"""Calculate integrated gradients for feature importance""" | |
baseline = torch.zeros_like(x) | |
integrated_grads = np.zeros(x.shape[1]) | |
for i in range(steps): | |
alpha = i / steps | |
interpolated = baseline + alpha * (x - baseline) | |
interpolated.requires_grad_(True) | |
output = self.model(interpolated) | |
self.model.zero_grad() | |
output[0, target_class].backward(retain_graph=True) | |
if interpolated.grad is not None: | |
grads = interpolated.grad.cpu().numpy().flatten() | |
integrated_grads += grads | |
integrated_grads = ( | |
integrated_grads * (x - baseline).detach().cpu().numpy().flatten() / steps | |
) | |
return integrated_grads | |
def _analyze_spectral_regions( | |
self, importance: np.ndarray, wavenumbers: np.ndarray | |
) -> Dict[str, float]: | |
"""Analyze importance by common spectral regions""" | |
regions = { | |
"fingerprint": (400, 1500), | |
"ch_stretch": (2800, 3100), | |
"oh_stretch": (3200, 3700), | |
"carbonyl": (1600, 1800), | |
"aromatic": (1450, 1650), | |
} | |
region_importance = {} | |
for region_name, (low, high) in regions.items(): | |
mask = (wavenumbers >= low) & (wavenumbers <= high) | |
if np.any(mask): | |
region_importance[region_name] = float(np.mean(importance[mask])) | |
else: | |
region_importance[region_name] = 0.0 | |
return region_importance | |
class HypothesisGenerator: | |
"""AI-driven scientific hypothesis generation""" | |
def __init__(self): | |
self.hypothesis_templates = [ | |
"The spectral differences in the {region} region suggest {mechanism} as a primary degradation pathway", | |
"Enhanced intensity at {wavenumber} cm⁻¹ indicates {chemical_change} in weathered samples", | |
"The correlation between {feature1} and {feature2} suggests {relationship}", | |
"Baseline shifts in {region} region may indicate {structural_change}", | |
] | |
def generate_hypotheses( | |
self, explanation: PredictionExplanation | |
) -> List[Hypothesis]: | |
"""Generate testable hypotheses based on model predictions and explanations""" | |
hypotheses = [] | |
# Analyze feature importance for hypothesis generation | |
important_features = self._identify_key_features(explanation.feature_importance) | |
for feature_info in important_features: | |
hypothesis = self._generate_single_hypothesis(feature_info, explanation) | |
if hypothesis: | |
hypotheses.append(hypothesis) | |
return hypotheses | |
def _identify_key_features( | |
self, feature_importance: Dict[str, float] | |
) -> List[Dict[str, Any]]: | |
"""Identify key features for hypothesis generation""" | |
# Sort features by importance | |
sorted_features = sorted( | |
feature_importance.items(), key=lambda x: abs(x[1]), reverse=True | |
) | |
key_features = [] | |
for feature_name, importance in sorted_features[:5]: # Top 5 features | |
feature_info = { | |
"name": feature_name, | |
"importance": importance, | |
"type": self._classify_feature_type(feature_name), | |
"chemical_significance": self._get_chemical_significance(feature_name), | |
} | |
key_features.append(feature_info) | |
return key_features | |
def _classify_feature_type(self, feature_name: str) -> str: | |
"""Classify spectral feature type""" | |
if "fingerprint" in feature_name.lower(): | |
return "fingerprint" | |
elif "stretch" in feature_name.lower(): | |
return "vibrational" | |
elif "carbonyl" in feature_name.lower(): | |
return "functional_group" | |
else: | |
return "general" | |
def _get_chemical_significance(self, feature_name: str) -> str: | |
"""Get chemical significance of spectral feature""" | |
significance_map = { | |
"fingerprint": "molecular backbone structure", | |
"ch_stretch": "aliphatic chain integrity", | |
"oh_stretch": "hydrogen bonding and hydration", | |
"carbonyl": "oxidative degradation products", | |
"aromatic": "aromatic ring preservation", | |
} | |
for key, significance in significance_map.items(): | |
if key in feature_name.lower(): | |
return significance | |
return "structural changes" | |
def _generate_single_hypothesis( | |
self, feature_info: Dict[str, Any], explanation: PredictionExplanation | |
) -> Optional[Hypothesis]: | |
"""Generate a single hypothesis from feature information""" | |
if feature_info["importance"] < 0.1: # Skip low-importance features | |
return None | |
# Create hypothesis statement | |
statement = f"Changes in {feature_info['name']} region indicate {feature_info['chemical_significance']} during polymer weathering" | |
# Generate supporting evidence | |
evidence = [ | |
f"Feature importance score: {feature_info['importance']:.3f}", | |
f"Classification confidence: {explanation.confidence:.3f}", | |
f"Chemical significance: {feature_info['chemical_significance']}", | |
] | |
# Generate testable predictions | |
predictions = [ | |
f"Controlled weathering experiments should show progressive changes in {feature_info['name']} region", | |
f"Different polymer types should exhibit varying {feature_info['name']} responses to weathering", | |
] | |
# Suggest experiments | |
experiments = [ | |
f"Time-series weathering study monitoring {feature_info['name']} region", | |
f"Comparative analysis across polymer types focusing on {feature_info['chemical_significance']}", | |
"Cross-validation with other analytical techniques (DSC, GPC, etc.)", | |
] | |
return Hypothesis( | |
statement=statement, | |
confidence=min(0.9, feature_info["importance"] * explanation.confidence), | |
supporting_evidence=evidence, | |
testable_predictions=predictions, | |
suggested_experiments=experiments, | |
related_literature=[], # Could be populated with literature search | |
) | |
class TransparentAIEngine: | |
"""Main transparent AI engine combining all reasoning components""" | |
def __init__(self, model): | |
self.model = model | |
self.uncertainty_estimator = UncertaintyEstimator(model) | |
self.feature_analyzer = FeatureImportanceAnalyzer(model) | |
self.hypothesis_generator = HypothesisGenerator() | |
def predict_with_explanation( | |
self, x: torch.Tensor, wavenumbers: Optional[np.ndarray] = None | |
) -> PredictionExplanation: | |
"""Generate comprehensive prediction with full explanation""" | |
self.model.eval() | |
# Get basic prediction | |
with torch.no_grad(): | |
logits = self.model(x) | |
probabilities = F.softmax(logits, dim=1).cpu().numpy().flatten() | |
prediction = int(torch.argmax(logits, dim=1).item()) | |
confidence = float(np.max(probabilities)) | |
# Determine confidence level | |
if confidence >= 0.80: | |
confidence_level = "HIGH" | |
elif confidence >= 0.60: | |
confidence_level = "MEDIUM" | |
else: | |
confidence_level = "LOW" | |
# Get uncertainty estimation | |
uncertainties = self.uncertainty_estimator.estimate_uncertainty(x) | |
confidence_intervals = self.uncertainty_estimator.confidence_intervals(x) | |
# Analyze feature importance | |
importance_data = self.feature_analyzer.analyze_feature_importance( | |
x, wavenumbers | |
) | |
# Create feature importance dictionary | |
if wavenumbers is not None and "spectral_regions" in importance_data: | |
feature_importance = importance_data["spectral_regions"] | |
else: | |
# Use gradient importance | |
gradients = importance_data.get("gradient_importance", []) | |
feature_importance = { | |
f"feature_{i}": float(val) for i, val in enumerate(gradients[:10]) | |
} | |
# Generate reasoning chain | |
reasoning_chain = self._generate_reasoning_chain( | |
prediction, confidence, feature_importance, uncertainties | |
) | |
# Identify uncertainty sources | |
uncertainty_sources = self._identify_uncertainty_sources(uncertainties) | |
# Create explanation object | |
explanation = PredictionExplanation( | |
prediction=prediction, | |
confidence=confidence, | |
confidence_level=confidence_level, | |
probabilities=probabilities, | |
feature_importance=feature_importance, | |
reasoning_chain=reasoning_chain, | |
uncertainty_sources=uncertainty_sources, | |
similar_cases=[], # Could be populated with case-based reasoning | |
confidence_intervals=confidence_intervals, | |
) | |
return explanation | |
def generate_hypotheses( | |
self, explanation: PredictionExplanation | |
) -> List[Hypothesis]: | |
"""Generate scientific hypotheses based on prediction explanation""" | |
return self.hypothesis_generator.generate_hypotheses(explanation) | |
def _generate_reasoning_chain( | |
self, | |
prediction: int, | |
confidence: float, | |
feature_importance: Dict[str, float], | |
uncertainties: Dict[str, float], | |
) -> List[str]: | |
"""Generate human-readable reasoning chain""" | |
reasoning = [] | |
# Start with prediction | |
class_names = ["Stable", "Weathered"] | |
reasoning.append( | |
f"Model predicts: {class_names[prediction]} (confidence: {confidence:.3f})" | |
) | |
# Add feature analysis | |
top_features = sorted( | |
feature_importance.items(), key=lambda x: abs(x[1]), reverse=True | |
)[:3] | |
for feature, importance in top_features: | |
reasoning.append( | |
f"Key evidence: {feature} region shows importance score {importance:.3f}" | |
) | |
# Add uncertainty analysis | |
total_uncertainty = uncertainties.get("total", 0) | |
if total_uncertainty > 0.1: | |
reasoning.append( | |
f"High uncertainty detected ({total_uncertainty:.3f}) - suggests ambiguous case" | |
) | |
# Add confidence assessment | |
if confidence > 0.8: | |
reasoning.append( | |
"High confidence: Strong spectral signature for classification" | |
) | |
elif confidence > 0.6: | |
reasoning.append("Medium confidence: Some ambiguity in spectral features") | |
else: | |
reasoning.append("Low confidence: Weak or conflicting spectral evidence") | |
return reasoning | |
def _identify_uncertainty_sources( | |
self, uncertainties: Dict[str, float] | |
) -> List[str]: | |
"""Identify sources of prediction uncertainty""" | |
sources = [] | |
epistemic = uncertainties.get("epistemic", 0) | |
aleatoric = uncertainties.get("aleatoric", 0) | |
if epistemic > 0.05: | |
sources.append( | |
"Model uncertainty: Limited training data for this type of spectrum" | |
) | |
if aleatoric > 0.05: | |
sources.append("Data uncertainty: Noisy or degraded spectral quality") | |
if uncertainties.get("prediction_variance", 0) > 0.1: | |
sources.append("Prediction instability: Multiple possible interpretations") | |
if not sources: | |
sources.append("Low uncertainty: Clear and unambiguous classification") | |
return sources | |