devjas1 commited on
Commit
0fec4be
·
1 Parent(s): a427341

(FEAT): Add utilities for multi-file batch processing and inference

Browse files

- Implemented `parse_spectrum_data` to parse spectrum data from text files, handling comments and malformed lines.
- Added `process_single_file` to handle the complete pipeline for a single file, including parsing, resampling, inference, confidence calculation, and ground truth extraction.
- Developed `process_multiple_files` to process multiple uploaded files iteratively, with support for progress tracking and error handling.
- Integrated `ResultsManager` to store successful inference results in session state.
- Added `display_batch_results` to present batch processing results in the Streamlit UI, including success and failure summaries.
- Created `create_batch_uploader` to provide a Streamlit widget for uploading multiple spectrum files.
- Enhanced error handling and logging using `ErrorHandler` for better debugging and user feedback.
- Ensured compatibility with custom model loading, inference, and labeling functions for flexibility."

Files changed (1) hide show
  1. utils/multifile.py +333 -0
utils/multifile.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-file processing utiltities for batch inference.
2
+ Handles multiple file uploads and iterative processing."""
3
+
4
+ from typing import List, Dict, Any, Tuple, Optional
5
+ import time
6
+ import streamlit as st
7
+ import numpy as np
8
+
9
+ from .preprocessing import resample_spectrum
10
+ from .errors import ErrorHandler, safe_execute
11
+ from .results_manager import ResultsManager
12
+ from .confidence import calculate_softmax_confidence
13
+
14
+ def parse_spectrum_data(text_content: str, filename: str = "unknown") -> Tuple[np.ndarray, np.ndarray]:
15
+ """
16
+ Parse spectrum data from text content
17
+
18
+ Args:
19
+ text_content: Raw text content of the spectrum file
20
+ filename: Name of the file for error reporting
21
+
22
+ Returns:
23
+ Tuple of (x_values, y_values) as numpy arrays
24
+
25
+ Raises:
26
+ ValueError: If the data cannot be parsed
27
+ """
28
+ try:
29
+ lines = text_content.strip().split('\n')
30
+
31
+ #==Remove empty lines and comments==
32
+ data_lines = []
33
+ for line in lines:
34
+ line = line.strip()
35
+ if line and not line.startswith('#') and not line.startswith('%'):
36
+ data_lines.append(line)
37
+
38
+ if not data_lines:
39
+ raise ValueError("No data lines found in file")
40
+
41
+ #==Try to parse==
42
+ x_vals, y_vals = [], []
43
+
44
+ for i, line in enumerate(data_lines):
45
+ try:
46
+ #=Try comma separation first, then space=
47
+ if ',' in line:
48
+ parts = line.split(',')
49
+ else:
50
+ parts = line.split()
51
+
52
+ if len(parts) < 2:
53
+ ErrorHandler.log_warning(f"Line {i+1} has fewer than 2 columns, skipping", f"Parsing {filename}")
54
+ continue
55
+
56
+ x_val = float(parts[0].strip())
57
+ y_val = float(parts[1].split())
58
+
59
+ x_vals.append(x_val)
60
+ y_vals.append(y_val)
61
+
62
+ except (ValueError, IndexError) as e:
63
+ ErrorHandler.log_warning(f"Could not parse line {i+1}: {line}", f"Parsing {filename}")
64
+ continue
65
+
66
+ if len(x_vals) < 10: #==Need minimum points for interpolation==
67
+ raise ValueError(f"Insufficient data points ({len(x_vals)}). Need at least 10 points.")
68
+
69
+ return np.array(x_vals), np.array(y_vals)
70
+
71
+ except Exception as e:
72
+ raise ValueError(f"Failed to parse spectrum data: {str(e)}")
73
+
74
+ def process_single_file(
75
+ filename: str,
76
+ text_content: str,
77
+ model_choice: str,
78
+ load_model_func,
79
+ run_inference_func,
80
+ label_file_func
81
+ ) -> Optional[Dict[str, Any]]:
82
+ """
83
+ Process a single spectrum file
84
+
85
+ Args:
86
+ filename: Name of the file
87
+ text_content: Raw text content
88
+ model_choice: Selected model name
89
+ load_model_func: Function to load the model
90
+ run_inference_func: Function to run inference
91
+ label_file_func: Function to extract ground truth label
92
+
93
+ Returns:
94
+ Dictionary with processing results or None if failed
95
+ """
96
+ start_time = time.time()
97
+
98
+ try:
99
+ #==Parse spectrum data==
100
+ x_raw, y_raw, success = safe_execute(
101
+ parse_spectrum_data,
102
+ text_content,
103
+ filename,
104
+ error_context=f"parsing {filename}",
105
+ show_error=False
106
+ )
107
+
108
+ if not success:
109
+ return None
110
+
111
+ #==Resample spectrum==
112
+ x_resampled, y_resampled, success = safe_execute(
113
+ resample_spectrum,
114
+ x_raw,
115
+ y_raw,
116
+ 500, # TARGET_LEN
117
+ error_context=f"resampling {filename}",
118
+ show_error=False
119
+ )
120
+
121
+ if not success:
122
+ return None
123
+
124
+ #==Run inference==
125
+ prediction, logits_list, probs, inference_time, logits, success = safe_execute(
126
+ run_inference_func,
127
+ y_resampled,
128
+ model_choice,
129
+ error_context=f"inference on {filename}",
130
+ show_error=False
131
+ )
132
+
133
+ if not success or prediction is None:
134
+ ErrorHandler.log_error(Exception("Inference failed"), f"processing {filename}")
135
+ return None
136
+
137
+ #==Calculate confidence==
138
+ if logits is not None:
139
+ probs_np, max_confidence, confidence_level, confidence_emoji = calculate_softmax_confidence(logits)
140
+ else:
141
+ probs_np = np.array([])
142
+ max_confidence = 0.0
143
+ confidence_level = "LOW"
144
+ confidence_emoji = "🔴"
145
+
146
+ #==Get ground truth==
147
+ try:
148
+ ground_truth = label_file_func(filename)
149
+ ground_truth = ground_truth if ground_truth >= 0 else None
150
+ except Exception:
151
+ ground_truth = None
152
+
153
+ #==Get predicted class==
154
+ label_map = {0: "Stable (Unweathered)", 1: "Weathered (Degraded)"}
155
+ predicted_class = label_map.get(prediction, f"Unknown ({prediction})")
156
+
157
+ processing_time = time.time() - start_time
158
+
159
+ return {
160
+ "filename": filename,
161
+ "success": True,
162
+ "prediction": prediction,
163
+ "predicted_class": predicted_class,
164
+ "confidence": max_confidence,
165
+ "confidence_level": confidence_level,
166
+ "confidence_emoji": confidence_emoji,
167
+ "logits": logits_list if logits_list else [],
168
+ "probabilities": probs_np.tolist() if len(probs_np) > 0 else [],
169
+ "ground_truth": ground_truth,
170
+ "processing_time": processing_time,
171
+ "x_raw": x_raw,
172
+ "y_raw": y_raw,
173
+ "x_resampled": x_resampled,
174
+ "y_resampled": y_resampled,
175
+ }
176
+
177
+ except Exception as e:
178
+ ErrorHandler.log_error(e, f"processing {filename}")
179
+ return {
180
+ "filename": filename,
181
+ "success": False,
182
+ "error": str(e),
183
+ "processing_time": time.time() - start_time
184
+ }
185
+
186
+ def process_multiple_files(
187
+ uploaded_files: List,
188
+ model_choice: str,
189
+ load_model_func,
190
+ run_inference_func,
191
+ label_file_func,
192
+ progress_callback=None
193
+ ) -> List[Dict[str, Any]]:
194
+ """
195
+ Process multiple uploaded files
196
+
197
+ Args:
198
+ uploaded_files: List of uploaded file objects
199
+ model_choice: Selected model name
200
+ load_model_func: Function to load the model
201
+ run_inference_func: Function to run inference
202
+ label_file_func: Function to extract ground truth label
203
+ progress_callback: Optional callback to update progress
204
+
205
+ Returns:
206
+ List of processing results
207
+ """
208
+ results = []
209
+ total_files = len(uploaded_files)
210
+
211
+ ErrorHandler.log_info(f"Starting batch processing of {total_files} files")
212
+
213
+ for i, uploaded_file in enumerate(uploaded_files):
214
+ if progress_callback:
215
+ progress_callback(i, total_files, uploaded_file.name)
216
+
217
+ try:
218
+ #==Read file content==
219
+ raw = uploaded_file.read()
220
+ text_content = raw.decode('utf-8') if isinstance(raw, bytes) else raw
221
+
222
+ #==Process the file==
223
+ result = process_single_file(
224
+ uploaded_file.name,
225
+ text_content,
226
+ model_choice,
227
+ load_model_func,
228
+ run_inference_func,
229
+ label_file_func
230
+ )
231
+
232
+ if result:
233
+ results.append(result)
234
+
235
+ #==Add successful results to the results manager==
236
+ if result.get("success", False):
237
+ ResultsManager.add_results(
238
+ filename=result["filename"],
239
+ model_name=model_choice,
240
+ prediction=result["prediction"],
241
+ predicted_class=result["predicted_class"],
242
+ confidence=result["confidence"],
243
+ logits=result["logits"],
244
+ ground_truth=result["ground_truth"],
245
+ processing_time=result["processing_time"],
246
+ metadata={
247
+ "confidence_level": result["confidence_level"],
248
+ "confidence_emoji": result["confidence_emoji"]
249
+ }
250
+ )
251
+
252
+ except Exception as e:
253
+ ErrorHandler.log_error(e, f"reading file {uploaded_file.name}")
254
+ results.append({
255
+ "filename": uploaded_file.name,
256
+ "success": False,
257
+ "error": f"Failed to read file: {str(e)}"
258
+ })
259
+
260
+ if progress_callback:
261
+ progress_callback(total_files, total_files, "Complete")
262
+
263
+ ErrorHandler.log_info(f"Completed batch processing: {sum(1 for r in results if r.get('success', False))}/{total_files} successful")
264
+
265
+ return results
266
+
267
+ def display_batch_results(results: List[Dict[str, Any]]) -> None:
268
+ """
269
+ Display batch processing results in the UI
270
+
271
+ Args:
272
+ results: List of processing results
273
+ """
274
+ if not results:
275
+ st.warning("No results to display")
276
+ return
277
+
278
+ successful = [r for r in results if r.get("success", False)]
279
+ failed = [r for r in results if not r.get("success", False)]
280
+
281
+ #==Summary==
282
+ col1, col2, col3 = st.columns(3)
283
+ with col1:
284
+ st.metric("Total Files", len(results))
285
+ with col2:
286
+ st.metric("Successful", len(successful), delta=f"{len(successful)/len(results)*100:.1f}%")
287
+ with col3:
288
+ st.metric("Failed", len(failed), delta=f"-{len(failed)/len(results)*100:.1f}%" if failed else "0%")
289
+
290
+ #==Results tabs==
291
+ tab1, tab2 = st.tabs(["✅Successful", "❌ Failed"])
292
+
293
+ with tab1:
294
+ if successful:
295
+ for result in successful:
296
+ with st.expander(f"{result['filename']}", expanded=False):
297
+ col1, col2 = st.columns(2)
298
+ with col1:
299
+ st.write(f"**Prediction:** {result['predicted_class']}")
300
+ st.write(f"**Confidence:** {result['confidence_emoji']} {result['confidence_level']} ({result['confidence']:.3f})")
301
+ with col2:
302
+ st.write(f"**Processing Time:** {result['processing_time']:.3f}s")
303
+ if result['ground_truth'] is not None:
304
+ gt_label = {0: "Stable", 1: "Weathered"}.get(result['ground_truth'], "Unknown")
305
+ correct = "✅" if result['prediction'] == result['ground_truth'] else "❌"
306
+ st.write(f"**Ground Truth:** {gt_label} {correct}")
307
+ else:
308
+ st.info("No successful results")
309
+
310
+ with tab2:
311
+ if failed:
312
+ for result in failed:
313
+ with st.expander(f"❌ {result['filename']}", expanded=False):
314
+ st.error(f"Error: {result.get('error', 'Unknown error')}")
315
+ else:
316
+ st.success("No failed files!")
317
+
318
+ def create_batch_uploader() -> List:
319
+ """
320
+ Create multi-file uploader widget
321
+
322
+ Returns:
323
+ List of uploaded files
324
+ """
325
+ uploaded_files = st.file_uploader(
326
+ "Upload multiple Raman spectrum files (.txt)",
327
+ type="txt",
328
+ accept_multiple_files=True,
329
+ help="Select multiple .txt files with wavenumber and intensity columns",
330
+ key="batch_uploader"
331
+ )
332
+
333
+ return uploaded_files if uploaded_files else []