Spaces:
Sleeping
Sleeping
devjas1
(FIX/UX)[Update Image Upload Interface]: Change image display to use container width for better responsiveness + fix jitter bug on Image Analysis page
5e6ebd2
""" | |
Image loading and transformation utilities for polymer classification. | |
Supports conversion of spectral images to processable data. | |
""" | |
from typing import Tuple, Optional, List, Dict | |
import numpy as np | |
from PIL import Image, ImageEnhance, ImageFilter | |
import matplotlib.pyplot as plt | |
from matplotlib.figure import Figure | |
import streamlit as st | |
import pandas as pd | |
# Use existing inference pipeline | |
from utils.preprocessing import preprocess_spectrum | |
from core_logic import run_inference | |
class SpectralImageProcessor: | |
"""Handles loading and processing of spectral images.""" | |
def __init__(self): | |
self.support_formats = [".png", ".jpg", ".jpeg", ".tiff", ".bmp"] | |
self.default_target_size = (224, 224) | |
def load_image(self, image_source) -> Optional[np.ndarray]: | |
"""Load image from various sources.""" | |
try: | |
if isinstance(image_source, str): | |
# File path | |
img = Image.open(image_source) | |
elif hasattr(image_source, "read"): | |
# File-like object (Streamlit uploaded file) | |
img = Image.open(image_source) | |
elif isinstance(image_source, np.ndarray): | |
# NumPy array | |
return image_source | |
else: | |
raise ValueError("Unsupported image source type") | |
# Convert to RGB if needed | |
if img.mode != "RGB": | |
img = img.convert("RGB") | |
return np.array(img) | |
except (FileNotFoundError, IOError, ValueError) as e: | |
st.error(f"Error loading image: {e}") | |
return None | |
def preprocess_image( | |
self, | |
image: np.ndarray, | |
target_size: Optional[Tuple[int, int]] = None, | |
enhance_contrast: bool = True, | |
apply_gaussian_blur: bool = False, | |
normalize: bool = True, | |
) -> np.ndarray: | |
"""Preprocess image for analysis.""" | |
if target_size is None: | |
target_size = self.default_target_size | |
# Convert to PIL for processing | |
img = Image.fromarray(image.astype(np.uint8)) | |
# Resize | |
img = img.resize(target_size, Image.Resampling.LANCZOS) | |
# Enhance contrast if required | |
if enhance_contrast: | |
enhancer = ImageEnhance.Contrast(img) | |
img = enhancer.enhance(1.2) | |
# Apply Gaussian blur if requested | |
if apply_gaussian_blur: | |
img = img.filter(ImageFilter.GaussianBlur(radius=1)) | |
# Convert back to numpy | |
processed = np.array(img) | |
# Normalize to [0, 1] if requested | |
if normalize: | |
processed = processed.astype(np.float32) / 255.0 | |
return processed | |
def extract_spectral_profile( | |
self, | |
image: np.ndarray, | |
method: str = "average", | |
roi: Optional[Tuple[int, int, int, int]] = None, | |
) -> np.ndarray: | |
""" | |
Extract 1D spectral profile from 2D image. | |
Args: | |
image: Input image array | |
method: 'average', 'center_line', 'max_intensity' | |
roi: Region of interest (x1, y1, x2, y2) | |
""" | |
if roi: | |
x1, y1, x2, y2 = roi | |
image_roi = image[y1:y2, x1:x2] | |
else: | |
image_roi = image | |
if len(image_roi.shape) == 3: | |
# Convert to grayscale if color | |
image_roi = np.mean(image_roi, axis=2) | |
if method == "average": | |
# Average along one axis | |
profile = np.mean(image_roi, axis=0) | |
elif method == "center_line": | |
# Extract center line | |
center_y = image_roi.shape[0] // 2 | |
profile = image_roi[center_y, :] | |
elif method == "max_intensity": | |
# Maximum intensity projection | |
profile = np.max(image_roi, axis=0) | |
else: | |
raise ValueError(f"Unknown method: {method}") | |
return profile | |
def image_to_spectrum( | |
self, | |
image: np.ndarray, | |
wavenumber_range: Tuple[float, float] = (400, 4000), | |
method: str = "average", | |
) -> Tuple[np.ndarray, np.ndarray]: | |
"""Convert image to spectrum-like data.""" | |
# Extract 1D profile | |
profile = self.extract_spectral_profile(image, method=method) | |
# Create wavenumber axis | |
wavenumbers = np.linspace( | |
wavenumber_range[0], wavenumber_range[1], len(profile) | |
) | |
return wavenumbers, profile | |
def detect_spectral_peaks( | |
self, | |
spectrum: np.ndarray, | |
wavenumbers: np.ndarray, | |
prominence: float = 0.1, | |
height: float = 0.1, | |
) -> List[Dict[str, float]]: | |
"""Detect peaks in spectral data.""" | |
from scipy.signal import find_peaks | |
peaks, properties = find_peaks(spectrum, prominence=prominence, height=height) | |
peak_info = [] | |
for i, peak_idx in enumerate(peaks): | |
peak_info.append( | |
{ | |
"wavenumber": wavenumbers[peak_idx], | |
"intensity": spectrum[peak_idx], | |
"prominence": properties["prominences"][i], | |
"width": ( | |
properties.get("widths", [None])[i] | |
if "widths" in properties | |
else None | |
), | |
} | |
) | |
return peak_info | |
def create_visualization( | |
self, | |
image: np.ndarray, | |
spectrum_x: np.ndarray, | |
spectrum_y: np.ndarray, | |
peaks: Optional[List[Dict]] = None, | |
) -> Figure: | |
"""Create visualization of image and extracted spectrum.""" | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) | |
# Display image | |
ax1.imshow(image, cmap="viridis" if len(image.shape) == 2 else None) | |
ax1.set_title("Input Image") | |
ax1.axis("off") | |
# Display spectrum | |
ax2.plot( | |
spectrum_x, spectrum_y, "b-", linewidth=1.5, label="Extracted Spectrum" | |
) | |
# Mark peaks if provided | |
if peaks: | |
peak_wavenumbers = [p["wavenumber"] for p in peaks] | |
peak_intensities = [p["intensity"] for p in peaks] | |
ax2.plot( | |
peak_wavenumbers, | |
peak_intensities, | |
"ro", | |
markersize=6, | |
label="Detected Peaks", | |
) | |
ax2.set_xlabel("Wavenumber (cm⁻¹)") | |
ax2.set_ylabel("Intensity") | |
ax2.set_title("Extracted Spectral Profile") | |
ax2.grid(True, alpha=0.3) | |
ax2.legend() | |
plt.tight_layout() | |
return fig | |
def render_image_upload_interface(): | |
"""Render UI for image upload and processing.""" | |
st.markdown("#### Image-Based Spectral Analysis") | |
st.markdown( | |
"Upload spectral images for analysis and conversion to spectroscopic data." | |
) | |
processor = SpectralImageProcessor() | |
# Image upload | |
uploaded_image = st.file_uploader( | |
"Upload spectral image", | |
type=["png", "jpg", "jpeg", "tiff", "bmp"], | |
help="Upload an image containing spectral data", | |
) | |
if uploaded_image is not None: | |
# Load and display original image | |
image = processor.load_image(uploaded_image) | |
if image is not None: | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
st.markdown("##### Original Image") | |
st.image(image, use_container_width=True) | |
# Image info | |
st.write(f"**Dimensions**: {image.shape}") | |
st.write(f"**Size**: {uploaded_image.size} bytes") | |
with col2: | |
st.markdown("##### Processing Options") | |
# Processing parameters | |
target_width = st.slider("Target Width", 100, 1000, 500) | |
target_height = st.slider("Target Height", 100, 1000, 300) | |
enhance_contrast = st.checkbox("Enhance Contrast", value=True) | |
apply_blur = st.checkbox("Apply Gaussian Blur", value=False) | |
# Extraction method | |
extraction_method = st.selectbox( | |
"Spectrum Extraction Method", | |
["average", "center_line", "max_intensity"], | |
help="Method for converting 2D image to 1D spectrum", | |
) | |
# Wavenumber range | |
st.markdown("**Wavenumber Range (cm⁻¹)**") | |
wn_col1, wn_col2 = st.columns(2) | |
with wn_col1: | |
wn_min = st.number_input("Min", value=400.0, step=10.0) | |
with wn_col2: | |
wn_max = st.number_input("Max", value=4000.0, step=10.0) | |
# Process image | |
if st.button("Process Image", type="primary"): | |
with st.spinner("Processing image..."): | |
# Preprocess image | |
processed_image = processor.preprocess_image( | |
image, | |
target_size=(target_width, target_height), | |
enhance_contrast=enhance_contrast, | |
apply_gaussian_blur=apply_blur, | |
) | |
# Extract spectrum | |
wavenumbers, spectrum = processor.image_to_spectrum( | |
processed_image, | |
wavenumber_range=(wn_min, wn_max), | |
method=extraction_method, | |
) | |
# Detect peaks | |
peaks = processor.detect_spectral_peaks(spectrum, wavenumbers) | |
# Create visualization | |
fig = processor.create_visualization( | |
processed_image, wavenumbers, spectrum, peaks | |
) | |
# Display visualization | |
st.pyplot(fig) | |
# Display peaks information | |
if peaks: | |
st.markdown("##### Detected Peaks") | |
peak_df = pd.DataFrame(peaks) | |
peak_df["wavenumber"] = peak_df["wavenumber"].round(2) | |
peak_df["intensity"] = peak_df["intensity"].round(4) | |
st.dataframe(peak_df) | |
# Store in session state for further analysis | |
st.session_state["image_spectrum_x"] = wavenumbers | |
st.session_state["image_spectrum_y"] = spectrum | |
st.session_state["image_peaks"] = peaks | |
st.success( | |
"Image processing complete! You can now use this data for model inference." | |
) | |
# Option to run inference on extracted spectrum | |
if st.button("Run Inference on Extracted Spectrum"): | |
# Preprocess extracted spectrum | |
modality = st.session_state.get("modality_select", "raman") | |
_, y_processed = preprocess_spectrum( | |
wavenumbers, spectrum, modality=modality, target_len=500 | |
) | |
# Get selected model | |
model_choice = st.session_state.get("model_select", "figure2") | |
if " " in model_choice: | |
model_choice = model_choice.split(" ", 1)[1] | |
# Run inference | |
prediction, logits_list, probs, inference_time, logits = ( | |
run_inference(y_processed, model_choice) | |
) | |
if prediction is not None: | |
class_names = ["Stable", "Weathered"] | |
predicted_class = ( | |
class_names[int(prediction)] | |
if prediction < len(class_names) | |
else f"Class_{prediction}" | |
) | |
confidence = max(probs) if probs and len(probs) > 0 else 0.0 | |
# Display results | |
st.markdown("##### Inference Results") | |
result_col1, result_col2 = st.columns(2) | |
with result_col1: | |
st.metric("Prediction", predicted_class) | |
st.metric("Confidence", f"{confidence:.3f}") | |
with result_col2: | |
st.metric("Model Used", model_choice) | |
st.metric("Processing Time", f"{inference_time:.3f}s") | |
# Show class probabilities | |
if probs: | |
st.markdown("**Class Probabilities**") | |
for i, prob in enumerate(probs): | |
if i < len(class_names): | |
st.write(f"- {class_names[i]}: {prob:.4f}") | |
def image_to_spectrum_converter( | |
image_path: str, | |
wavenumber_range: Tuple[float, float] = (400, 4000), | |
method: str = "average", | |
) -> Tuple[np.ndarray, np.ndarray]: | |
"""Convert image file to spectrum data (utility function).""" | |
processor = SpectralImageProcessor() | |
# Load image | |
image = processor.load_image(image_path) | |
if image is None: | |
raise ValueError(f"Could not load image from {image_path}.") | |
# Convert to spectrum | |
return processor.image_to_spectrum(image, wavenumber_range, method) | |