devjas1 commited on
Commit
9d0759c
·
1 Parent(s): 2fb5cb5

(FEAT): Add batch processing utilities for multi-file uploads

Browse files

- Implement `create_batch_uploader` to handle batch file uploads.
- Add `process_multiple_files` for processing multiple files in batch mode.
- Include `display_batch_results` to render batch processing results in the UI.
- Enhance error handling for batch operations with `safe_execute`.
- Improve user experience with streamlined batch file management and result visualization.

Files changed (1) hide show
  1. utils/multifile.py +93 -60
utils/multifile.py CHANGED
@@ -3,32 +3,33 @@ Handles multiple file uploads and iterative processing."""
3
 
4
  from typing import List, Dict, Any, Tuple, Optional
5
  import time
6
- import streamlit as st
7
- import numpy as np
8
 
9
  from .preprocessing import resample_spectrum
10
  from .errors import ErrorHandler, safe_execute
11
  from .results_manager import ResultsManager
12
  from .confidence import calculate_softmax_confidence
13
 
 
14
  def parse_spectrum_data(text_content: str, filename: str = "unknown") -> Tuple[np.ndarray, np.ndarray]:
15
  """
16
  Parse spectrum data from text content
17
-
18
  Args:
19
  text_content: Raw text content of the spectrum file
20
  filename: Name of the file for error reporting
21
-
22
  Returns:
23
  Tuple of (x_values, y_values) as numpy arrays
24
-
25
  Raises:
26
  ValueError: If the data cannot be parsed
27
  """
28
  try:
29
  lines = text_content.strip().split('\n')
30
 
31
- #==Remove empty lines and comments==
32
  data_lines = []
33
  for line in lines:
34
  line = line.strip()
@@ -38,39 +39,52 @@ def parse_spectrum_data(text_content: str, filename: str = "unknown") -> Tuple[n
38
  if not data_lines:
39
  raise ValueError("No data lines found in file")
40
 
41
- #==Try to parse==
42
  x_vals, y_vals = [], []
43
 
44
  for i, line in enumerate(data_lines):
45
  try:
46
- #=Try comma separation first, then space=
47
- if ',' in line:
48
- parts = line.split(',')
49
- else:
50
- parts = line.split()
 
 
 
 
 
 
 
 
 
51
 
52
- if len(parts) < 2:
53
- ErrorHandler.log_warning(f"Line {i+1} has fewer than 2 columns, skipping", f"Parsing {filename}")
54
- continue
55
 
56
- x_val = float(parts[0].strip())
57
- y_val = float(parts[1].split())
58
 
59
- x_vals.append(x_val)
60
- y_vals.append(y_val)
 
61
 
62
- except (ValueError, IndexError) as e:
63
- ErrorHandler.log_warning(f"Could not parse line {i+1}: {line}", f"Parsing {filename}")
64
- continue
 
 
 
 
 
65
 
66
- if len(x_vals) < 10: #==Need minimum points for interpolation==
67
- raise ValueError(f"Insufficient data points ({len(x_vals)}). Need at least 10 points.")
68
 
69
- return np.array(x_vals), np.array(y_vals)
70
-
71
  except Exception as e:
72
  raise ValueError(f"Failed to parse spectrum data: {str(e)}")
73
 
 
74
  def process_single_file(
75
  filename: str,
76
  text_content: str,
@@ -81,7 +95,7 @@ def process_single_file(
81
  ) -> Optional[Dict[str, Any]]:
82
  """
83
  Process a single spectrum file
84
-
85
  Args:
86
  filename: Name of the file
87
  text_content: Raw text content
@@ -89,15 +103,15 @@ def process_single_file(
89
  load_model_func: Function to load the model
90
  run_inference_func: Function to run inference
91
  label_file_func: Function to extract ground truth label
92
-
93
  Returns:
94
  Dictionary with processing results or None if failed
95
  """
96
  start_time = time.time()
97
 
98
  try:
99
- #==Parse spectrum data==
100
- x_raw, y_raw, success = safe_execute(
101
  parse_spectrum_data,
102
  text_content,
103
  filename,
@@ -105,11 +119,13 @@ def process_single_file(
105
  show_error=False
106
  )
107
 
108
- if not success:
109
  return None
110
 
111
- #==Resample spectrum==
112
- x_resampled, y_resampled, success = safe_execute(
 
 
113
  resample_spectrum,
114
  x_raw,
115
  y_raw,
@@ -118,11 +134,13 @@ def process_single_file(
118
  show_error=False
119
  )
120
 
121
- if not success:
122
  return None
123
 
124
- #==Run inference==
125
- prediction, logits_list, probs, inference_time, logits, success = safe_execute(
 
 
126
  run_inference_func,
127
  y_resampled,
128
  model_choice,
@@ -130,27 +148,31 @@ def process_single_file(
130
  show_error=False
131
  )
132
 
133
- if not success or prediction is None:
134
- ErrorHandler.log_error(Exception("Inference failed"), f"processing {filename}")
 
135
  return None
136
 
137
- #==Calculate confidence==
 
 
138
  if logits is not None:
139
- probs_np, max_confidence, confidence_level, confidence_emoji = calculate_softmax_confidence(logits)
 
140
  else:
141
  probs_np = np.array([])
142
  max_confidence = 0.0
143
  confidence_level = "LOW"
144
  confidence_emoji = "🔴"
145
 
146
- #==Get ground truth==
147
  try:
148
  ground_truth = label_file_func(filename)
149
  ground_truth = ground_truth if ground_truth >= 0 else None
150
  except Exception:
151
  ground_truth = None
152
 
153
- #==Get predicted class==
154
  label_map = {0: "Stable (Unweathered)", 1: "Weathered (Degraded)"}
155
  predicted_class = label_map.get(prediction, f"Unknown ({prediction})")
156
 
@@ -183,6 +205,7 @@ def process_single_file(
183
  "processing_time": time.time() - start_time
184
  }
185
 
 
186
  def process_multiple_files(
187
  uploaded_files: List,
188
  model_choice: str,
@@ -193,7 +216,7 @@ def process_multiple_files(
193
  ) -> List[Dict[str, Any]]:
194
  """
195
  Process multiple uploaded files
196
-
197
  Args:
198
  uploaded_files: List of uploaded file objects
199
  model_choice: Selected model name
@@ -201,7 +224,7 @@ def process_multiple_files(
201
  run_inference_func: Function to run inference
202
  label_file_func: Function to extract ground truth label
203
  progress_callback: Optional callback to update progress
204
-
205
  Returns:
206
  List of processing results
207
  """
@@ -215,11 +238,12 @@ def process_multiple_files(
215
  progress_callback(i, total_files, uploaded_file.name)
216
 
217
  try:
218
- #==Read file content==
219
  raw = uploaded_file.read()
220
- text_content = raw.decode('utf-8') if isinstance(raw, bytes) else raw
 
221
 
222
- #==Process the file==
223
  result = process_single_file(
224
  uploaded_file.name,
225
  text_content,
@@ -232,7 +256,7 @@ def process_multiple_files(
232
  if result:
233
  results.append(result)
234
 
235
- #==Add successful results to the results manager==
236
  if result.get("success", False):
237
  ResultsManager.add_results(
238
  filename=result["filename"],
@@ -260,14 +284,16 @@ def process_multiple_files(
260
  if progress_callback:
261
  progress_callback(total_files, total_files, "Complete")
262
 
263
- ErrorHandler.log_info(f"Completed batch processing: {sum(1 for r in results if r.get('success', False))}/{total_files} successful")
 
264
 
265
  return results
266
 
 
267
  def display_batch_results(results: List[Dict[str, Any]]) -> None:
268
  """
269
  Display batch processing results in the UI
270
-
271
  Args:
272
  results: List of processing results
273
  """
@@ -278,16 +304,18 @@ def display_batch_results(results: List[Dict[str, Any]]) -> None:
278
  successful = [r for r in results if r.get("success", False)]
279
  failed = [r for r in results if not r.get("success", False)]
280
 
281
- #==Summary==
282
  col1, col2, col3 = st.columns(3)
283
  with col1:
284
  st.metric("Total Files", len(results))
285
  with col2:
286
- st.metric("Successful", len(successful), delta=f"{len(successful)/len(results)*100:.1f}%")
 
287
  with col3:
288
- st.metric("Failed", len(failed), delta=f"-{len(failed)/len(results)*100:.1f}%" if failed else "0%")
 
289
 
290
- #==Results tabs==
291
  tab1, tab2 = st.tabs(["✅Successful", "❌ Failed"])
292
 
293
  with tab1:
@@ -296,12 +324,16 @@ def display_batch_results(results: List[Dict[str, Any]]) -> None:
296
  with st.expander(f"{result['filename']}", expanded=False):
297
  col1, col2 = st.columns(2)
298
  with col1:
299
- st.write(f"**Prediction:** {result['predicted_class']}")
300
- st.write(f"**Confidence:** {result['confidence_emoji']} {result['confidence_level']} ({result['confidence']:.3f})")
 
 
301
  with col2:
302
- st.write(f"**Processing Time:** {result['processing_time']:.3f}s")
 
303
  if result['ground_truth'] is not None:
304
- gt_label = {0: "Stable", 1: "Weathered"}.get(result['ground_truth'], "Unknown")
 
305
  correct = "✅" if result['prediction'] == result['ground_truth'] else "❌"
306
  st.write(f"**Ground Truth:** {gt_label} {correct}")
307
  else:
@@ -315,10 +347,11 @@ def display_batch_results(results: List[Dict[str, Any]]) -> None:
315
  else:
316
  st.success("No failed files!")
317
 
 
318
  def create_batch_uploader() -> List:
319
  """
320
  Create multi-file uploader widget
321
-
322
  Returns:
323
  List of uploaded files
324
  """
@@ -330,4 +363,4 @@ def create_batch_uploader() -> List:
330
  key="batch_uploader"
331
  )
332
 
333
- return uploaded_files if uploaded_files else []
 
3
 
4
  from typing import List, Dict, Any, Tuple, Optional
5
  import time
6
+ import streamlit as st
7
+ import numpy as np
8
 
9
  from .preprocessing import resample_spectrum
10
  from .errors import ErrorHandler, safe_execute
11
  from .results_manager import ResultsManager
12
  from .confidence import calculate_softmax_confidence
13
 
14
+
15
  def parse_spectrum_data(text_content: str, filename: str = "unknown") -> Tuple[np.ndarray, np.ndarray]:
16
  """
17
  Parse spectrum data from text content
18
+
19
  Args:
20
  text_content: Raw text content of the spectrum file
21
  filename: Name of the file for error reporting
22
+
23
  Returns:
24
  Tuple of (x_values, y_values) as numpy arrays
25
+
26
  Raises:
27
  ValueError: If the data cannot be parsed
28
  """
29
  try:
30
  lines = text_content.strip().split('\n')
31
 
32
+ # ==Remove empty lines and comments==
33
  data_lines = []
34
  for line in lines:
35
  line = line.strip()
 
39
  if not data_lines:
40
  raise ValueError("No data lines found in file")
41
 
42
+ # ==Try to parse==
43
  x_vals, y_vals = [], []
44
 
45
  for i, line in enumerate(data_lines):
46
  try:
47
+ # Handle different separators
48
+ parts = line.replace(",", " ").split()
49
+ numbers = [p for p in parts if p.replace('.', '', 1).replace(
50
+ '-', '', 1).replace('+', '', 1).isdigit()]
51
+ if len(numbers) >= 2:
52
+ x_val = float(numbers[0])
53
+ y_val = float(numbers[1])
54
+ x_vals.append(x_val)
55
+ y_vals.append(y_val)
56
+
57
+ except ValueError:
58
+ ErrorHandler.log_warning(
59
+ f"Could not parse line {i+1}: {line}", f"Parsing {filename}")
60
+ continue
61
 
62
+ if len(x_vals) < 10: # ==Need minimum points for interpolation==
63
+ raise ValueError(
64
+ f"Insufficient data points ({len(x_vals)}). Need at least 10 points.")
65
 
66
+ x = np.array(x_vals)
67
+ y = np.array(y_vals)
68
 
69
+ # Check for NaNs
70
+ if np.any(np.isnan(x)) or np.any(np.isnan(y)):
71
+ raise ValueError("Input data contains NaN values")
72
 
73
+ # Check monotonic increasing x
74
+ if not np.all(np.diff(x) > 0):
75
+ raise ValueError("Wavenumbers must be strictly increasing")
76
+
77
+ # Check reasonable range for Raman spectroscopy
78
+ if min(x) < 0 or max(x) > 10000 or (max(x) - min(x)) < 100:
79
+ raise ValueError(
80
+ f"Invalid wavenumber range: {min(x)} - {max(x)}. Expected ~400-4000 cm⁻¹ with span >100")
81
 
82
+ return x, y
 
83
 
 
 
84
  except Exception as e:
85
  raise ValueError(f"Failed to parse spectrum data: {str(e)}")
86
 
87
+
88
  def process_single_file(
89
  filename: str,
90
  text_content: str,
 
95
  ) -> Optional[Dict[str, Any]]:
96
  """
97
  Process a single spectrum file
98
+
99
  Args:
100
  filename: Name of the file
101
  text_content: Raw text content
 
103
  load_model_func: Function to load the model
104
  run_inference_func: Function to run inference
105
  label_file_func: Function to extract ground truth label
106
+
107
  Returns:
108
  Dictionary with processing results or None if failed
109
  """
110
  start_time = time.time()
111
 
112
  try:
113
+ # ==Parse spectrum data==
114
+ result, success = safe_execute(
115
  parse_spectrum_data,
116
  text_content,
117
  filename,
 
119
  show_error=False
120
  )
121
 
122
+ if not success or result is None:
123
  return None
124
 
125
+ x_raw, y_raw = result
126
+
127
+ # ==Resample spectrum==
128
+ result, success = safe_execute(
129
  resample_spectrum,
130
  x_raw,
131
  y_raw,
 
134
  show_error=False
135
  )
136
 
137
+ if not success or result is None:
138
  return None
139
 
140
+ x_resampled, y_resampled = result
141
+
142
+ # ==Run inference==
143
+ result, success = safe_execute(
144
  run_inference_func,
145
  y_resampled,
146
  model_choice,
 
148
  show_error=False
149
  )
150
 
151
+ if not success or result is None:
152
+ ErrorHandler.log_error(
153
+ Exception("Inference failed"), f"processing {filename}")
154
  return None
155
 
156
+ prediction, logits_list, probs, inference_time, logits = result
157
+
158
+ # ==Calculate confidence==
159
  if logits is not None:
160
+ probs_np, max_confidence, confidence_level, confidence_emoji = calculate_softmax_confidence(
161
+ logits)
162
  else:
163
  probs_np = np.array([])
164
  max_confidence = 0.0
165
  confidence_level = "LOW"
166
  confidence_emoji = "🔴"
167
 
168
+ # ==Get ground truth==
169
  try:
170
  ground_truth = label_file_func(filename)
171
  ground_truth = ground_truth if ground_truth >= 0 else None
172
  except Exception:
173
  ground_truth = None
174
 
175
+ # ==Get predicted class==
176
  label_map = {0: "Stable (Unweathered)", 1: "Weathered (Degraded)"}
177
  predicted_class = label_map.get(prediction, f"Unknown ({prediction})")
178
 
 
205
  "processing_time": time.time() - start_time
206
  }
207
 
208
+
209
  def process_multiple_files(
210
  uploaded_files: List,
211
  model_choice: str,
 
216
  ) -> List[Dict[str, Any]]:
217
  """
218
  Process multiple uploaded files
219
+
220
  Args:
221
  uploaded_files: List of uploaded file objects
222
  model_choice: Selected model name
 
224
  run_inference_func: Function to run inference
225
  label_file_func: Function to extract ground truth label
226
  progress_callback: Optional callback to update progress
227
+
228
  Returns:
229
  List of processing results
230
  """
 
238
  progress_callback(i, total_files, uploaded_file.name)
239
 
240
  try:
241
+ # ==Read file content==
242
  raw = uploaded_file.read()
243
+ text_content = raw.decode(
244
+ 'utf-8') if isinstance(raw, bytes) else raw
245
 
246
+ # ==Process the file==
247
  result = process_single_file(
248
  uploaded_file.name,
249
  text_content,
 
256
  if result:
257
  results.append(result)
258
 
259
+ # ==Add successful results to the results manager==
260
  if result.get("success", False):
261
  ResultsManager.add_results(
262
  filename=result["filename"],
 
284
  if progress_callback:
285
  progress_callback(total_files, total_files, "Complete")
286
 
287
+ ErrorHandler.log_info(
288
+ f"Completed batch processing: {sum(1 for r in results if r.get('success', False))}/{total_files} successful")
289
 
290
  return results
291
 
292
+
293
  def display_batch_results(results: List[Dict[str, Any]]) -> None:
294
  """
295
  Display batch processing results in the UI
296
+
297
  Args:
298
  results: List of processing results
299
  """
 
304
  successful = [r for r in results if r.get("success", False)]
305
  failed = [r for r in results if not r.get("success", False)]
306
 
307
+ # ==Summary==
308
  col1, col2, col3 = st.columns(3)
309
  with col1:
310
  st.metric("Total Files", len(results))
311
  with col2:
312
+ st.metric("Successful", len(successful),
313
+ delta=f"{len(successful)/len(results)*100:.1f}%")
314
  with col3:
315
+ st.metric("Failed", len(
316
+ failed), delta=f"-{len(failed)/len(results)*100:.1f}%" if failed else "0%")
317
 
318
+ # ==Results tabs==
319
  tab1, tab2 = st.tabs(["✅Successful", "❌ Failed"])
320
 
321
  with tab1:
 
324
  with st.expander(f"{result['filename']}", expanded=False):
325
  col1, col2 = st.columns(2)
326
  with col1:
327
+ st.write(
328
+ f"**Prediction:** {result['predicted_class']}")
329
+ st.write(
330
+ f"**Confidence:** {result['confidence_emoji']} {result['confidence_level']} ({result['confidence']:.3f})")
331
  with col2:
332
+ st.write(
333
+ f"**Processing Time:** {result['processing_time']:.3f}s")
334
  if result['ground_truth'] is not None:
335
+ gt_label = {0: "Stable", 1: "Weathered"}.get(
336
+ result['ground_truth'], "Unknown")
337
  correct = "✅" if result['prediction'] == result['ground_truth'] else "❌"
338
  st.write(f"**Ground Truth:** {gt_label} {correct}")
339
  else:
 
347
  else:
348
  st.success("No failed files!")
349
 
350
+
351
  def create_batch_uploader() -> List:
352
  """
353
  Create multi-file uploader widget
354
+
355
  Returns:
356
  List of uploaded files
357
  """
 
363
  key="batch_uploader"
364
  )
365
 
366
+ return uploaded_files if uploaded_files else []