polymer-aging-ml / modules /enhanced_data_pipeline.py
devjas1
FEAT(data_pipeline): implement enhanced data pipeline for polymer ML aging with spectroscopy integration and synthetic data augmentation
07fb119
"""
Enhanced Data Pipeline for Polymer ML Aging
Integrates with spectroscopy databases, synthetic data augmentation, and quality control
"""
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional, Union, Any
from dataclasses import dataclass, field
from pathlib import Path
import requests
import json
import sqlite3
from datetime import datetime
import hashlib
import warnings
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.decomposition import PCA
from sklearn.cluster import DBSCAN
import pickle
import io
import base64
@dataclass
class SpectralDatabase:
"""Configuration for spectroscopy databases"""
name: str
base_url: Optional[str] = None
api_key: Optional[str] = None
description: str = ""
supported_formats: List[str] = field(default_factory=list)
access_method: str = "api" # "api", "download", "local"
local_path: Optional[Path] = None
# -///////////////////////////////////////////////////
@dataclass
class PolymerSample:
"""Enhanced polymer sample information"""
sample_id: str
polymer_type: str
molecular_weight: Optional[float] = None
additives: List[str] = field(default_factory=list)
processing_conditions: Dict[str, Any] = field(default_factory=dict)
aging_condition: Dict[str, Any] = field(default_factory=dict)
aging_time: Optional[float] = None # Hours
degradation_level: Optional[float] = None # 0-1 Scale
spectral_data: Dict[str, np.ndarray] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
quality_score: Optional[float] = None
validation_status: str = "pending" # pending, validated, rejected
# -///////////////////////////////////////////////////
# Database configurations
SPECTROSCOPY_DATABASES = {
"FTIR_PLASTICS": SpectralDatabase(
name="FTIR Plastics Database",
description="Comprehensive FTIR spectra of plastic materials",
supported_formats=["FTIR", "ATR-FTIR"],
access_method="local",
local_path=Path("data/databases/ftir_plastics"),
),
"NIST_WEBBOOK": SpectralDatabase(
name="NIST Chemistry WebBook",
base_url="https://webbook.nist.gov/chemistry",
description="NIST spectroscopic database",
supported_formats=["FTIR", "Raman"],
access_method="api",
),
"POLYMER_DATABASE": SpectralDatabase(
name="Polymer Spectroscopy Database",
description="Curated polymer degradation spectra",
supported_formats=["FTIR", "ATR-FTIR", "Raman"],
access_method="local",
local_path=Path("data/databases/polymer_degradation"),
),
}
# -///////////////////////////////////////////////////
class DatabaseConnector:
"""Connector for spectroscopy databases"""
def __init__(self, database_config: SpectralDatabase):
self.config = database_config
self.connection = None
self.cache_dir = Path("data/cache") / database_config.name.lower().replace(
" ", "_"
)
self.cache_dir.mkdir(parents=True, exist_ok=True)
def connect(self) -> bool:
"""Establish connection to database"""
try:
if self.config.access_method == "local":
if self.config.local_path and self.config.local_path.exists():
return True
else:
print(f"Local database path not found: {self.config.local_path}")
return False
elif self.config.access_method == "api":
# Test API connection
if self.config.base_url:
response = requests.get(self.config.base_url, timeout=10)
return response.status_code == 200
return False
return True
except Exception as e:
print(f"Failed to connect to {self.config.name}: {e}")
return False
# -///////////////////////////////////////////////////
def search_by_polymer_type(self, polymer_type: str, limit: int = 100) -> List[Dict]:
"""Search database for spectra by polymer type"""
cache_key = f"search{hashlib.md5(polymer_type.encode()).hexdigest()}"
cache_file = self.cache_dir / f"{cache_key}.json"
# Check cache first
if cache_file.exists():
with open(cache_file, "r") as f:
return json.load(f)
results = []
if self.config.access_method == "local":
results = self._search_local_database(polymer_type, limit)
elif self.config.access_method == "api":
results = self._search_api_database(polymer_type, limit)
# Cache results
if results:
with open(cache_file, "w") as f:
json.dump(results, f)
return results
# -///////////////////////////////////////////////////
def _search_local_database(self, polymer_type: str, limit: int) -> List[Dict]:
"""Search local database files"""
results = []
if not self.config.local_path or not self.config.local_path.exists():
return results
# Look for CSV files with polymer data
for csv_file in self.config.local_path.glob("*.csv"):
try:
df = pd.read_csv(csv_file)
# Search for polymer type in columns
polymer_matches = df[
df.astype(str)
.apply(lambda x: x.str.contains(polymer_type, case=False))
.any(axis=1)
]
for _, row in polymer_matches.head(limit).iterrows():
result = {
"source_file": str(csv_file),
"polymer_type": polymer_type,
"data": row.to_dict(),
}
results.append(result)
except Exception as e:
print(f"Error reading {csv_file}: {e}")
continue
return results
# -///////////////////////////////////////////////////
def _search_api_database(self, polymer_type: str, limit: int) -> List[Dict]:
"""Search API-based database"""
results = []
try:
# TODO: Example API search (would need actual API endpoints)
search_params = {"query": polymer_type, "limit": limit, "format": "json"}
if self.config.api_key:
search_params["api_key"] = self.config.api_key
response = requests.get(
f"{self.config.base_url}/search", params=search_params, timeout=30
)
if response.status_code == 200:
results = response.json().get("results", [])
except Exception as e:
print(f"API search failed: {e}")
return results
# -///////////////////////////////////////////////////
def download_spectrum(self, spectrum_id: str) -> Optional[Dict]:
"""Download specific spectrum data"""
cache_file = self.cache_dir / f"spectrum_{spectrum_id}.json"
# Check cache
if cache_file.exists():
with open(cache_file, "r") as f:
return json.load(f)
spectrum_data = None
if self.config.access_method == "api":
try:
url = f"{self.config.base_url}/spectrum/{spectrum_id}"
response = requests.get(url, timeout=30)
if response.status_code == 200:
spectrum_data = response.json()
except Exception as e:
print(f"Failed to download spectrum {spectrum_id}: {e}")
# Cache results if successful
if spectrum_data:
with open(cache_file, "w") as f:
json.dump(spectrum_data, f)
return spectrum_data
# -///////////////////////////////////////////////////
class SyntheticDataAugmentation:
"""Advanced synthetic data augmentation for spectroscopy"""
def __init__(self):
self.augmentation_methods = [
"noise_addition",
"baseline_drift",
"intensity_scaling",
"wavenumber_shift",
"peak_broadening",
"atmospheric_effects",
"instrumental_response",
"sample_variations",
]
def augment_spectrum(
self,
wavenumbers: np.ndarray,
intensities: np.ndarray,
method: str = "random",
num_variations: int = 5,
intensity_factor: float = 0.1,
) -> List[Tuple[np.ndarray, np.ndarray]]:
"""
Generate augmented versions of a spectrum
Args:
wavenumbers: Original wavenumber array
intensities: Original intensity array
method: str = Augmentation method or 'random' for random selection
num_variations: Number of variations to generate
intensity_factor: Factor controlling augmentation intesity
Returns:
List of (wavenumbers, intensities) tuples
"""
augmented_spectra = []
for _ in range(num_variations):
if method == "random":
chosen_method = np.random.choice(self.augmentation_methods)
else:
chosen_method = method
aug_wavenumbers, aug_intensities = self._apply_augmentation(
wavenumbers, intensities, chosen_method, intensity_factor
)
augmented_spectra.append((aug_wavenumbers, aug_intensities))
return augmented_spectra
# -///////////////////////////////////////////////////
def _apply_augmentation(
self,
wavenumbers: np.ndarray,
intensities: np.ndarray,
method: str,
intensity: float,
) -> Tuple[np.ndarray, np.ndarray]:
"""Apply specific augmentation method"""
aug_wavenumbers = wavenumbers.copy()
aug_intensities = intensities.copy()
if method == "noise_addition":
# Add random noise
noise_level = intensity * np.std(intensities)
noise = np.random.normal(0, noise_level, len(intensities))
aug_intensities += noise
elif method == "baseline_drift":
# Add baseline drift
drift_amplitude = intensity * np.mean(np.abs(intensities))
drift = drift_amplitude * np.sin(
2 * np.pi * np.linspace(0, 2, len(intensities))
)
aug_intensities += drift
elif method == "intensity_scaling":
# Scale intensity uniformly
scale_factor = 1.0 + intensity * (2 * np.random.random() - 1)
aug_intensities *= scale_factor
elif method == "wavenumber_shift":
# Shift wavenumber axis
shift_range = intensity * 10 # cm-1
shift = shift_range * (2 * np.random.random() - 1)
aug_wavenumbers += shift
elif method == "peak_broadening":
# Broaden peaks using convolution
from scipy import signal
sigma = intensity * 2 # Broadening factor
kernel_size = int(sigma * 6) + 1
if kernel_size % 2 == 0:
kernel_size += 1
if kernel_size >= 3:
from scipy.signal.windows import gaussian
kernel = gaussian(kernel_size, sigma)
kernel = kernel / np.sum(kernel)
aug_intensities = signal.convolve(
aug_intensities, kernel, mode="same"
)
elif method == "atmospheric_effects":
# Simulate atmospheric absorption
co2_region = (wavenumbers >= 2320) & (wavenumbers <= 2380)
h2o_region = (wavenumbers >= 3200) & (wavenumbers <= 3800)
if np.any(co2_region):
aug_intensities[co2_region] *= 1 - intensity * 0.1
if np.any(h2o_region):
aug_intensities[h2o_region] *= 1 - intensity * 0.05
elif method == "instrumental_response":
# Simulate instrumental response variations
# Add slight frequency-dependent response
response_curve = 1.0 + intensity * 0.1 * np.sin(
2
* np.pi
* (wavenumbers - wavenumbers.min())
/ (wavenumbers.max() - wavenumbers.min())
)
aug_intensities *= response_curve
elif method == "sample_variations":
# Simulate sample-to-sample variations
# Random peak intensity variations
num_peaks = min(5, len(intensities) // 100)
for _ in range(num_peaks):
peak_center = np.random.randint(0, len(intensities))
peak_width = np.random.randint(5, 20)
peak_variation = intensity * (2 * np.random.random() - 1)
start_idx = max(0, peak_center - peak_width)
end_idx = min(len(intensities), peak_center + peak_width)
aug_intensities[start_idx:end_idx] *= 1 + peak_variation
return aug_wavenumbers, aug_intensities
# -///////////////////////////////////////////////////
def generate_synthetic_aging_series(
self,
base_spectrum: Tuple[np.ndarray, np.ndarray],
num_time_points: int = 10,
max_degradation: float = 0.8,
) -> List[Dict]:
"""
Generate synthetic aging series showing progressive degradation
Args:
base_spectrum: (wavenumbers, intensities) for fresh sample
num_time_points: Number of time points in series
max_degradation: Maximum degradation level (0-1)
Returns:
List of aging data points
"""
wavenumbers, intensities = base_spectrum
aging_series = []
# Define degradation-related spectral changes
degradation_features = {
"carbonyl_growth": {
"region": (1700, 1750), # C=0 stretch
"intensity_change": 2.0, # Factor increase
},
"oh_growth": {
"region": (3200, 3600), # OH stretch
"intensity_change": 1.5,
},
"ch_decrease": {
"region": (2800, 3000), # CH stretch
"intensity_change": 0.7, # Factor decrease
},
"crystrallinity_change": {
"region": (1000, 1200), # Various polymer backbone changes
"intensity_change": 0.9,
},
}
for i in range(num_time_points):
degradation_level = (i / (num_time_points - 1)) * max_degradation
aging_time = i * 100 # hours (arbitrary scale)
# Start with base spectrum
aged_intensities = intensities.copy()
# Apply degradation-related changes
for feature, params in degradation_features.items():
region_mask = (wavenumbers >= params["region"][0]) & (
wavenumbers <= params["region"][1]
)
if np.any(region_mask):
change_factor = 1.0 + degradation_level * (
params["intensity_change"] - 1.0
)
aged_intensities[region_mask] *= change_factor
# Add some random variations
aug_wavenumbers, aug_intensities = self._apply_augmentation(
wavenumbers, aged_intensities, "noise_addition", 0.02
)
aging_point = {
"aging_time": aging_time,
"degradation_level": degradation_level,
"wavenumbers": aug_wavenumbers,
"intensities": aug_intensities,
"spectral_changes": {
feature: degradation_level * params["intensity_change"] - 1.0
for feature, params in degradation_features.items()
},
}
aging_series.append(aging_point)
return aging_series
# -///////////////////////////////////////////////////
class DataQualityController:
"""Advanced data quality assessment and validation"""
def __init__(self):
self.quality_metrics = [
"signal_to_noise_ratio",
"baseline_stability",
"peak_resolution",
"spectral_range_coverage",
"instrumental_artifacts",
"data_completeness",
"metadata_completeness",
]
self.validation_rules = {
"min_str": 10.0,
"max_baseline_variation": 0.1,
"min_peak_count": 3,
"min_spectral_range": 1000.0, # cm-1
"max_missing_points": 0.05, # 5% max missing data
}
def assess_spectrum_quality(
self,
wavenumbers: np.ndarray,
intensities: np.ndarray,
metadata: Optional[Dict] = None,
) -> Dict[str, Any]:
"""
Comprehensive quality assessment of spectral data
Args:
wavenumbers: Array of wavenumbers
intensities: Array of intensities
metadata: Optional metadata dictionary
Returns:
Quality assessment results
"""
assessment = {
"overall_score": 0.0,
"individual_scores": {},
"issues_found": [],
"recommendations": [], # Ensure this is initialized as a list
"validation_status": "pending",
}
# Signal-to-noise
snr_score, snr_value = self._assess_snr(intensities)
assessment["individual_scores"]["snr"] = snr_score
assessment["recommendations"] = snr_value
if snr_value < self.validation_rules["min_snr"]:
assessment["issues_found"].append(
f"Low SNR: {snr_value:.1f} (min: {self.validation_rules['min_snr']})"
)
assessment["recommendations"].append(
"Consider noise reduction preprocessing"
)
# Baseline stability
baseline_score, baseline_variation = self._assess_baseline_stability(
intensities
)
assessment["individual_scores"]["baseline"] = baseline_score
assessment["baseline_variation"] = baseline_variation
if baseline_variation > self.validation_rules["max_baseline_variation"]:
assessment["issues_found"].append(
f"Unstable baseline: {baseline_variation:.3f}"
)
assessment["recommendations"].append("Apply baseline correction")
# Peak resolution and count
peak_score, peak_count = self._assess_peak_resolution(wavenumbers, intensities)
assessment["individual_scores"]["peaks"] = peak_score
assessment["peak_count"] = peak_count
if peak_count < self.validation_rules["min_peak_count"]:
assessment["issues_found"].append(f"Few peaks detected: {peak_count}")
assessment["recommendations"].append(
"Check sample quality or measurement conditions"
)
# Spectral range coverage
range_score, spectral_range = self._assess_spectral_range(wavenumbers)
assessment["individual_scores"]["range"] = range_score
assessment["spectral_range"] = spectral_range
if spectral_range < self.validation_rules["min_spectral_range"]:
assessment["issues_found"].append(
f"Limited spectral range: {spectral_range:.0f} cm-1"
)
# Data completeness
completeness_score, missing_fraction = self._assess_data_completeness(
intensities
)
assessment["individual_scores"]["completeness"] = completeness_score
assessment["missing_fraction"] = missing_fraction
if missing_fraction > self.validation_rules["max_missing_points"]:
assessment["issues_found"].append(
f"Missing data points: {missing_fraction:.1f}%"
)
assessment["recommendations"].append(
"Interpolate missing points or re-measure"
)
# Instrumental artifacts
artifact_score, artifacts = self._detect_instrumental_artifacts(
wavenumbers, intensities
)
assessment["individual_scores"]["artifacts"] = artifact_score
assessment["artifacts_detected"] = artifacts
if artifacts:
assessment["issues_found"].extend(
[f"Artifact detected {artifact}" for artifact in artifacts]
)
assessment["recommendations"].append("Apply artifact correction")
# Metadata completeness
metadata_score = self._assess_metadata_completeness(metadata)
assessment["individual_scores"]["metadata"] = metadata_score
# Calculate overall score
scores = list(assessment["individual_scores"].values())
assessment["overall_score"] = np.mean(scores) if scores else 0.0
# Determine validation status
if assessment["overall_score"] >= 0.8 and len(assessment["issues_found"]) == 0:
assessment["validation_status"] = "validated"
elif assessment["overall_score"] >= 0.6:
assessment["validation_status"] = "conditional"
else:
assessment["validation_status"] = "rejected"
return assessment
# -///////////////////////////////////////////////////
def _assess_snr(self, intensities: np.ndarray) -> Tuple[float, float]:
"""Assess signal-to-noise ratio"""
try:
# Estimate noise from high-frequency components
diff_signal = np.diff(intensities)
noise_std = np.std(diff_signal)
signal_power = np.var(intensities)
snr = np.sqrt(signal_power) / noise_std if noise_std > 0 else float("inf")
# Score based on SNR values
score = min(
1.0, max(0.0, (np.log10(snr) - 1) / 2)
) # Log scale, 10-1000 range
return score, snr
except:
return 0.5, 1.0
# -///////////////////////////////////////////////////
def _assess_baseline_stability(
self, intensities: np.ndarray
) -> Tuple[float, float]:
"""Assess baseline stability"""
try:
# Estimate baseline from endpoints and low-frequency components
baseline_points = np.concatenate([intensities[:10], intensities[-10]])
baseline_variation = np.std(baseline_points) / np.mean(abs(intensities))
score = max(0.0, 1.0 - baseline_variation * 10) # Penalty for variation
return score, baseline_variation
except:
return 0.5, 1.0
# -///////////////////////////////////////////////////
def _assess_peak_resolution(
self, wavenumbers: np.ndarray, intensities: np.ndarray
) -> Tuple[float, int]:
"""Assess peak resolution and count"""
try:
from scipy.signal import find_peaks
# Find peaks with minimum prominence
prominence_threshold = 0.1 * np.std(intensities)
peaks, properties = find_peaks(
intensities, prominence=prominence_threshold, distance=5
)
peak_count = len(peaks)
# Score based on peak count and prominence
if peak_count > 0:
avg_prominence = np.mean(properties["prominences"])
prominence_score = min(
1.0, avg_prominence / (0.2 * np.std(intensities))
)
count_score = min(1.0, peak_count / 10) # Normalize to ~10 peaks
score = 0.5 * prominence_score + 0.5 * count_score
else:
score = 0.0
return score, peak_count
except:
return 0.5, 0
# -///////////////////////////////////////////////////
def _assess_spectral_range(self, wavenumbers: np.ndarray) -> Tuple[float, float]:
"""Assess spectral range coverage"""
try:
spectral_range = wavenumbers.max() - wavenumbers.min()
# Score based on typical FTIR range (4000 cm-1)
score = min(1.0, spectral_range / 4000)
return score, spectral_range
except:
return 0.5, 1000
# -///////////////////////////////////////////////////
def _assess_data_completeness(self, intensities: np.ndarray) -> Tuple[float, float]:
"""Assess data completion"""
try:
# Check for NaN, or zero values
invalid_mask = (
np.isnan(intensities) | np.isinf(intensities) | (intensities == 0)
)
missing_fraction = np.sum(invalid_mask) / len(intensities)
score = max(
0.0, 1.0 - missing_fraction * 10
) # Heavy penalty for missing data
return score, missing_fraction
except:
return 0.5, 0.0
# -///////////////////////////////////////////////////
def _detect_instrumental_artifacts(
self, wavenumbers: np.ndarray, intensities: np.ndarray
) -> Tuple[float, List[str]]:
"""Detect common instrumental artifacts"""
artifacts = []
try:
# Check for spike artifacts (cosmic rays, electrical interference)
diff_threshold = 5 * np.std(np.diff(intensities))
spikes = np.where(np.abs(np.diff(intensities)) > diff_threshold)[0]
if len(spikes) > len(intensities) * 0.01: # More than 1% spikes
artifacts.append("excessive_spikes")
# Check for saturation (flat regions at max/min)
if np.std(intensities) > 0:
max_val = np.max(intensities)
min_val = np.min(intensities)
saturation_high = np.sum(intensities >= 0.99 * max_val) / len(
intensities
)
saturation_low = np.sum(intensities <= 1.01 * min_val) / len(
intensities
)
if saturation_high > 0.05:
artifacts.append("high_saturation")
if saturation_low > 0.05:
artifacts.append("low_saturation")
# Check for periodic noise (electrical interference)
fft = np.fft.fft(intensities - np.mean(intensities))
freq_domain = np.abs(fft[: len(fft) // 2])
# Look for strong periodic components
if len(freq_domain) > 10:
mean_amplitude = np.mean(freq_domain)
strong_frequencies = np.sum(freq_domain > 3 * mean_amplitude)
if strong_frequencies > len(freq_domain) * 0.1:
artifacts.append("periodic_noise")
# Score inversely related to number of artifacts
score = max(0.0, 1.0 - len(artifacts) * 0.3)
return score, artifacts
except:
return 0.5, []
# -///////////////////////////////////////////////////
def _assess_metadata_completeness(self, metadata: Optional[Dict]) -> float:
"""Assess completeness of metadata"""
if metadata is None:
return 0.0
required_fields = [
"sample_id",
"measurement_date",
"instrument_type",
"resolution",
"number_of_scans",
"sample_type",
]
present_fields = sum(
1
for field in required_fields
if field in metadata and metadata[field] is not None
)
score = present_fields / len(required_fields)
return score
# -///////////////////////////////////////////////////
class EnhancedDataPipeline:
"""Complete enhanced data pipeline integrating all components"""
def __init__(self):
self.database_connector = {}
self.augmentation_engine = SyntheticDataAugmentation()
self.quality_controller = DataQualityController()
self.local_database_path = Path("data/enhanced_data")
self.local_database_path.mkdir(parents=True, exist_ok=True)
self._init_local_database()
def _init_local_database(self):
"""Initialize local SQLite database"""
db_path = self.local_database_path / "polymer_spectra.db"
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
# Create main spectra table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS spectra (
id INTEGER PRIMARY KEY AUTOINCREMENT,
sample_id TEXT UNIQUE NOT NULL,
polymer_type TEXT NOT NULL,
technique TEXT NOT NULL,
wavenumbers BLOB,
intensities BLOB,
metadata TEXT,
quality_score REAL,
validation_status TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
source_database TEXT
)
"""
)
# Create aging data table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS aging_data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
sample_id TEXT,
aging_time REAL,
degradation_level REAL,
spectral_changes TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (sample_id) REFERENCES spectra (sample_id)
)
"""
)
conn.commit()
# -///////////////////////////////////////////////////
def connect_to_databases(self) -> Dict[str, bool]:
"""Connect to all configured databases"""
connection_status = {}
for db_name, db_config in SPECTROSCOPY_DATABASES.items():
connector = DatabaseConnector(db_config)
self.database_connector[db_name] = connector.connect()
return connection_status
# -///////////////////////////////////////////////////
def search_and_import_spectra(
self, polymer_type: str, max_per_database: int = 50
) -> Dict[str, int]:
"""Search and import spectra from all connected databases"""
import_counts = {}
for db_name, connector in self.database_connector.items():
try:
search_results = connector.search_by_polymer_type(
polymer_type, max_per_database
)
imported_count = 0
for result in search_results:
if self._import_spectrum_to_local(result, db_name):
imported_count += 1
import_counts[db_name] = imported_count
except Exception as e:
print(f"Error importing from {db_name}: {e}")
import_counts[db_name] = 0
return import_counts
# -///////////////////////////////////////////////////]
def _import_spectrum_to_local(self, spectrum_data: Dict, source_db: str) -> bool:
"""Import spectrum data to local database"""
try:
# Extract or generate sample ID
sample_id = spectrum_data.get(
"sample_id", f"{source_db}_{hash(str(spectrum_data))}"
)
# Convert spectrum data format
if "wavenumbers" in spectrum_data and "intensities" in spectrum_data:
wavenumbers = np.array(spectrum_data["wavenumbers"])
intensities = np.array(spectrum_data["intensities"])
else:
# Try to extract from other formats
return False
# Quality assessment
metadata = spectrum_data.get("metadata", {})
quality_assessment = self.quality_controller.assess_spectrum_quality(
wavenumbers, intensities, metadata
)
# Only import if quality is acceptable
if quality_assessment["validation_status"] == "rejected":
return False
# Serialize arrays
wavenumbers_blob = pickle.dumps(wavenumbers)
intensities_blob = pickle.dumps(intensities)
metadata_json = json.dumps(metadata)
# Insert into database
db_path = self.local_database_path / "polymer_spectra.db"
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT OR REPLACE INTO spectra(
sample_id, polymer_type, technique,
wavenumbers, intensities, metadata,
quality_score, validation_status,
source_database)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
sample_id,
spectrum_data.get("polymer_type", "unknown"),
spectrum_data.get("technique", "FTIR"),
wavenumbers_blob,
intensities_blob,
metadata_json,
quality_assessment["overall_score"],
quality_assessment["validation_status"],
source_db,
),
)
conn.commit()
return True
except Exception as e:
print(f"Error importing spectrum: {e}")
return False
# -///////////////////////////////////////////////////
def generate_synthetic_aging_dataset(
self,
base_polymer_type: str,
num_samples: int = 50,
aging_conditions: Optional[List[Dict]] = None,
) -> int:
"""
Generate synthetic aging dataset for training
Args:
base_polymer_type: Base polymer type to use
num_samples: Number of synthetic samples to generate
aging_conditions: List of aging condition dictionaries
Returns:
Number of samples generated
"""
if aging_conditions is None:
aging_conditions = [
{"temperature": 60, "humidity": 75, "uv_exposure": True},
{"temperature": 80, "humidity": 85, "uv_exposure": True},
{"temperature": 40, "humidity": 95, "uv_exposure": False},
{"temperature": 100, "humidity": 50, "uv_exposure": True},
]
# Get base spectra from database
base_spectra = self.spectra_by_type(base_polymer_type, limit=10)
if not base_spectra:
print(f"No base spectra found for {base_polymer_type}")
return 0
generated_count = 0
synthetic_id = None # Initialize synthetic_id to avoid unbound error
aging_series = [] # Initialize aging_series to avoid unbound error
for base_spectrum in base_spectra:
wavenumbers = pickle.loads(base_spectrum["wavenumbers"])
intensities = pickle.loads(base_spectrum["intensities"])
# Generate aging series for each condition
for condition in aging_conditions:
aging_series = self.augmentation_engine.generate_synthetic_aging_series(
(wavenumbers, intensities),
num_time_points=min(
10, num_samples // len(aging_conditions) // len(base_spectra)
),
)
if "aging_series" in locals() and aging_series:
for aging_point in aging_series:
synthetic_id = f"synthetic_{base_polymer_type}_{generated_count}"
# Ensure condition is properly passed into the loop
metadata = {
"synthetic": True,
"aging_condition": aging_conditions[
0
], # Use the first condition or adjust as needed
"aging_time": aging_point["aging_time"],
"degradation_level": aging_point["degradation_level"],
}
# Store synthetic spectrum
if self._store_synthetic_spectrum(
synthetic_id, base_polymer_type, aging_point, metadata
):
generated_count += 1
return generated_count
def _store_synthetic_spectrum(
self, sample_id: str, polymer_type: str, aging_point: Dict, metadata: Dict
) -> bool:
"""Store synthetic spectrum in local database"""
try:
quality_assessment = self.quality_controller.assess_spectrum_quality(
aging_point["wavenumbers"], aging_point["intensities"], metadata
)
# Serialize data
wavenumbers_blob = pickle.dumps(aging_point["wavenumbers"])
intensities_blob = pickle.dumps(aging_point["intensities"])
metadata_json = json.dumps(metadata)
# Insert spectrum
db_path = self.local_database_path / "polymer_spectra.db"
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO spectra
(sample_id, polymer_type, technique, wavenumbers, intensities,
metadata, quality_score, validation_status, source_database)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
sample_id,
polymer_type,
"FTIR_synthetic",
wavenumbers_blob,
intensities_blob,
metadata_json,
quality_assessment["overall_score"],
"validated", # Synthetic data is pre-validated
"synthetic",
),
)
# Insert aging data
cursor.execute(
"""
INSERT INTO aging_data
(sample_id, aging_time, degradation_level, aging_conditions, spectral_changes)
VALUES (?, ?, ?, ?, ?)
""",
(
sample_id,
aging_point["aging_time"],
aging_point["degradation_level"],
json.dumps(metadata["aging_conditions"]),
json.dumps(aging_point.get("spectral_changes", {})),
),
)
conn.commit()
return True
except Exception as e:
print(f"Error storing synthetic spectrum: {e}")
return False
# -///////////////////////////////////////////////////]
def spectra_by_type(self, polymer_type: str, limit: int = 100) -> List[Dict]:
"""Retrieve spectra by polymer type from local database"""
db_path = self.local_database_path / "polymer_spectra.db"
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT * FROM spectra
WHERE polymer_type LIKE ? AND validation_status != 'rejected'
ORDER BY quality_score DESC
LIMIT ?
""",
(f"%{polymer_type}%", limit),
)
columns = [description[0] for description in cursor.description]
results = [dict(zip(columns, row)) for row in cursor.fetchall()]
return results
# -///////////////////////////////////////////////////]
def get_weathered_samples(self, polymer_type: Optional[str] = None) -> List[Dict]:
"""Get samples with aging/weathering data"""
db_path = self.local_database_path / "polymer_spectra.db"
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
query = """
SELECT s.*, a.aging_time, a.degradation_level, a.aging_conditions
FROM spectra s
JOIN aging_data a ON s.sample_id = a.sample_id
WHERE s.validation_status != 'rejected'
"""
params = []
if polymer_type:
query += " AND s.polymer_type LIKE ?"
params.append(f"%{polymer_type}%")
query += " ORDER BY a.degradation_level"
cursor.execute(query, params)
columns = [description[0] for description in cursor.description]
results = [dict(zip(columns, row)) for row in cursor.fetchall()]
return results
# -////////////////////////////////
def get_database_statistics(self) -> Dict[str, Any]:
"""Get statistics about the local database"""
db_path = self.local_database_path / "polymer_spectra.db"
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
# Total spectra count
cursor.execute("SELECT COUNT(*) FROM spectra")
total_spectra = cursor.fetchone()[0]
# By polymer type
cursor.execute(
"""
SELECT polymer_type, COUNT(*) as count
FROM spectra
GROUP BY polymer_type
ORDER BY count DESC
"""
)
by_polymer_type = dict(cursor.fetchall())
# By technique
cursor.execute(
"""
SELECT technique, COUNT(*) as count
FROM spectra
GROUP BY technique
ORDER BY count DESC
"""
)
by_technique = dict(cursor.fetchall())
# By validation status
cursor.execute(
"""
SELECT validation_status, COUNT(*) as count
FROM spectra
GROUP BY validation_status
"""
)
by_validation = dict(cursor.fetchall())
# Average quality score
cursor.execute(
"SELECT AVG(quality_score) FROM spectra WHERE quality_score IS NOT NULL"
)
avg_quality = cursor.fetchone()[0] or 0.0
# Aging data count
cursor.execute("SELECT COUNT(*) FROM aging_data")
aging_samples = cursor.fetchone()[0]
return {
"total_spectra": total_spectra,
"by_polymer_type": by_polymer_type,
"by_technique": by_technique,
"by_validation_status": by_validation,
"average_quality_score": avg_quality,
"aging_samples": aging_samples,
}