Spaces:
Sleeping
Sleeping
devjas1
commited on
Commit
·
6373c5a
1
Parent(s):
8013c07
(SYNC): bring parity backend (utils/ scripts/ models/ tests/) from feat/ui-parity-rebuild; no UI changes
Browse files- models/registry.py +12 -1
- scripts/preprocess_dataset.py +59 -91
- scripts/run_inference.py +142 -123
- tests/conftest.py +8 -0
- tests/test_preprocessing.py +17 -0
- utils/__init__.py +0 -4
- utils/audit.py +56 -0
- utils/preprocessing.py +75 -98
models/registry.py
CHANGED
@@ -21,4 +21,15 @@ def build(name: str, input_length: int):
|
|
21 |
raise ValueError(f"Unknown model '{name}'. Choices: {choices()}")
|
22 |
return _REGISTRY[name](input_length)
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
raise ValueError(f"Unknown model '{name}'. Choices: {choices()}")
|
22 |
return _REGISTRY[name](input_length)
|
23 |
|
24 |
+
def spec(name: str):
|
25 |
+
"""Return expected input length and number of classes for a model key."""
|
26 |
+
if name == "figure2":
|
27 |
+
return {"input_length": 500, "num_classes": 2}
|
28 |
+
if name == "resnet":
|
29 |
+
return {"input_length": 500, "num_classes": 2}
|
30 |
+
if name == "resnet18vision":
|
31 |
+
return {"input_length": 500, "num_classes": 2}
|
32 |
+
raise KeyError(f"Unknown model '{name}'")
|
33 |
+
|
34 |
+
|
35 |
+
__all__ = ["choices", "build"]
|
scripts/preprocess_dataset.py
CHANGED
@@ -1,86 +1,42 @@
|
|
1 |
-
"""
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
- preprocess_dataset(...): Loads, resamples, and applies optional preprocessing steps:
|
7 |
-
- baseline correction
|
8 |
-
- Savitzky-Golay smoothing
|
9 |
-
- min-max normalization
|
10 |
-
|
11 |
-
The script expects the dataset directory to contain text files representing spectra.
|
12 |
-
Each file is:
|
13 |
-
1. Listed using `list_txt_files()`
|
14 |
-
2. Labeled using `label_file()`
|
15 |
-
3. Loaded using `load_spectrum()`
|
16 |
-
4. Resampled and optionally cleaned
|
17 |
-
5. Returned as arrays suitable for ML training
|
18 |
-
|
19 |
-
Dependencies:
|
20 |
-
- numpy
|
21 |
-
- scipy.interpolate, scipy.signal
|
22 |
-
- sklearn.preprocessing
|
23 |
-
- list_spectra (custom)
|
24 |
-
- plot_spectrum (custom)
|
25 |
"""
|
26 |
|
27 |
import os
|
28 |
import sys
|
29 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
30 |
import numpy as np
|
31 |
-
from scipy.interpolate import interp1d
|
32 |
-
from scipy.signal import savgol_filter
|
33 |
-
from sklearn.preprocessing import minmax_scale
|
34 |
-
from scripts.discover_raman_files import list_txt_files, label_file
|
35 |
-
from scripts.plot_spectrum import load_spectrum
|
36 |
-
|
37 |
-
# Default resample target
|
38 |
-
TARGET_LENGTH = 500
|
39 |
-
|
40 |
-
# Optional preprocessing steps
|
41 |
-
def remove_baseline(y):
|
42 |
-
"""Simple baseline correction using polynomial fitting (order 2)"""
|
43 |
-
x = np.arange(len(y))
|
44 |
-
coeffs = np.polyfit(x, y, deg=2)
|
45 |
-
baseline = np.polyval(coeffs, x)
|
46 |
-
return y - baseline
|
47 |
-
|
48 |
-
def normalize_spectrum(y):
|
49 |
-
"""Min-max normalization to [0, 1]"""
|
50 |
-
return minmax_scale(y)
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
f_interp = interp1d(x, y, kind='linear', fill_value='extrapolate')
|
59 |
-
x_uniform = np.linspace(min(x), max(x), target_len)
|
60 |
-
y_uniform = f_interp(x_uniform)
|
61 |
-
return y_uniform
|
62 |
|
63 |
def preprocess_dataset(
|
64 |
-
dataset_dir,
|
65 |
-
target_len=
|
66 |
-
baseline_correction=
|
67 |
-
apply_smoothing=
|
68 |
-
normalize=
|
|
|
69 |
):
|
70 |
"""
|
71 |
-
Load,
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
Returns:
|
81 |
-
X (np.ndarray): Preprocessed spectra
|
82 |
-
y (np.ndarray): Corresponding labels
|
83 |
"""
|
|
|
84 |
txt_paths = list_txt_files(dataset_dir)
|
85 |
X, y_labels = [], []
|
86 |
|
@@ -93,29 +49,41 @@ def preprocess_dataset(
|
|
93 |
if len(x_raw) < 10:
|
94 |
continue # Skip files with too few points
|
95 |
|
96 |
-
# Resample
|
97 |
-
y_processed = resample_spectrum(x_raw, y_raw, target_len=target_len)
|
98 |
-
|
99 |
-
# Optional preprocessing
|
100 |
-
if baseline_correction:
|
101 |
-
y_processed = remove_baseline(y_processed)
|
102 |
-
if apply_smoothing:
|
103 |
-
y_processed = smooth_spectrum(y_processed)
|
104 |
-
if normalize:
|
105 |
-
y_processed = normalize_spectrum(y_processed)
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
X.append(y_processed)
|
108 |
-
y_labels.append(label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
-
return
|
111 |
|
112 |
-
# Optional: Run directly for
|
113 |
if __name__ == "__main__":
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
X
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
1 |
+
"""preprocess_dataset.py
|
2 |
+
|
3 |
+
Canonical Raman preprocessing for dataset splits.
|
4 |
+
Uses the single source of truth in utils.preprocessing:
|
5 |
+
resample → baseline (deg=2) → smooth (w=11,o=2) → normalize.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
"""
|
7 |
|
8 |
import os
|
9 |
import sys
|
10 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
11 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
from utils.preprocessing import (
|
14 |
+
TARGET_LENGTH,
|
15 |
+
preprocess_spectrum
|
16 |
+
)
|
17 |
|
18 |
+
from scripts.discover_raman_files import list_txt_files, label_file
|
19 |
+
from scripts.plot_spectrum import load_spectrum
|
|
|
|
|
|
|
|
|
20 |
|
21 |
def preprocess_dataset(
|
22 |
+
dataset_dir: str,
|
23 |
+
target_len: int = TARGET_LENGTH,
|
24 |
+
baseline_correction: bool = True,
|
25 |
+
apply_smoothing: bool = True,
|
26 |
+
normalize: bool = True,
|
27 |
+
out_dtype: str = "float32",
|
28 |
):
|
29 |
"""
|
30 |
+
Load, preprocess, and label Raman spectra in dataset_dir.
|
31 |
+
|
32 |
+
Returns
|
33 |
+
-------
|
34 |
+
X : np.ndarray, shape (N, target_len), dtype=out_dtype
|
35 |
+
Preprocessed spectra (resampled and transformed).
|
36 |
+
y : np.ndarray, shape (N,), dtype=int64
|
37 |
+
Integer labels (e.g., 0 = stable, 1 = weathered).
|
|
|
|
|
|
|
|
|
38 |
"""
|
39 |
+
|
40 |
txt_paths = list_txt_files(dataset_dir)
|
41 |
X, y_labels = [], []
|
42 |
|
|
|
49 |
if len(x_raw) < 10:
|
50 |
continue # Skip files with too few points
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
+
# === Single-source-of-truth path ===
|
54 |
+
_, y_processed = preprocess_spectrum(
|
55 |
+
np.asarray(x_raw),
|
56 |
+
np.asarray(y_raw),
|
57 |
+
target_len=target_len,
|
58 |
+
do_baseline=baseline_correction,
|
59 |
+
do_smooth=apply_smoothing,
|
60 |
+
do_normalize=normalize,
|
61 |
+
out_dtype=out_dtype # str is OK (DTypeLike),
|
62 |
+
)
|
63 |
+
|
64 |
+
# === Collect ===
|
65 |
X.append(y_processed)
|
66 |
+
y_labels.append(int(label))
|
67 |
+
|
68 |
+
if not X:
|
69 |
+
# === No valid samples ===
|
70 |
+
return np.empty((0, target_len), dtype=out_dtype), np.empty((0,), dtype=np.int64)
|
71 |
+
|
72 |
+
X_arr = np.asarray(X, dtype=np.dtype(out_dtype))
|
73 |
+
Y_arr = np.asarray(y_labels, dtype=np.int64)
|
74 |
|
75 |
+
return X_arr, Y_arr
|
76 |
|
77 |
+
# === Optional: Run directly for quick smoke test ===
|
78 |
if __name__ == "__main__":
|
79 |
+
test_dataset_dir = os.path.join("datasets", "rdwp")
|
80 |
+
X, y = preprocess_dataset(test_dataset_dir)
|
81 |
+
|
82 |
+
print(f"X shape: {X.shape} dtype={X.dtype}")
|
83 |
+
print(f"y shape: {y.shape} dtype={y.dtype}")
|
84 |
+
if y.size:
|
85 |
+
try:
|
86 |
+
counts = np.bincount(y, minlength=2)
|
87 |
+
print(f"Label distribution: {counts} (stable, weathered)")
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Could not compute label distribution {e}")
|
scripts/run_inference.py
CHANGED
@@ -1,142 +1,161 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import os
|
|
|
3 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
4 |
-
from pathlib import Path
|
5 |
|
6 |
import argparse
|
7 |
-
import
|
8 |
import logging
|
|
|
|
|
|
|
9 |
|
10 |
import numpy as np
|
11 |
import torch
|
12 |
-
|
13 |
|
14 |
-
from
|
15 |
-
from
|
|
|
|
|
16 |
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
#
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
# See: @raman-pipeline-focus-milestone
|
24 |
-
# =============================================
|
25 |
|
|
|
|
|
|
|
26 |
|
27 |
-
warnings.filterwarnings(
|
28 |
-
"ignore",
|
29 |
-
message=".*weights_only=False.*",
|
30 |
-
category=FutureWarning
|
31 |
-
)
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
for
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
80 |
)
|
81 |
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
# configure logging
|
85 |
-
level = logging.INFO
|
86 |
-
if args.verbose:
|
87 |
-
level = logging.DEBUG
|
88 |
-
elif args.quiet:
|
89 |
-
level = logging.WARNING
|
90 |
-
logging.basicConfig(level=level, format="%(levelname)s: %(message)s")
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
if os.path.isdir(args.input):
|
95 |
-
parser.error(f"Input must be a single Raman .txt file, got a directory: {args.input}")
|
96 |
-
|
97 |
-
x_raw, y_raw = load_raman_spectrum(args.input)
|
98 |
-
if len(x_raw) < 10:
|
99 |
-
parser.error("Spectrum too short for inference.")
|
100 |
-
|
101 |
-
data = resample_spectrum(x_raw, y_raw, target_len=args.target_len)
|
102 |
-
# Shape = (1, 1, target_len) — valid input for Raman inference
|
103 |
-
input_tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE)
|
104 |
-
|
105 |
-
|
106 |
-
# 2. Load Model (via shared model registry)
|
107 |
-
model = build_model(args.arch, args.target_len).to(DEVICE)
|
108 |
-
if args.model != "random":
|
109 |
-
state = torch.load(args.model, map_location="cpu") # broad compatibility
|
110 |
-
model.load_state_dict(state)
|
111 |
-
model.eval()
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
# 3. Inference
|
116 |
-
with torch.no_grad():
|
117 |
-
logits = model(input_tensor)
|
118 |
-
pred = torch.argmax(logits, dim=1).item()
|
119 |
-
|
120 |
-
# 4. True Label
|
121 |
-
try:
|
122 |
-
true_label = label_file(args.input)
|
123 |
-
label_str = f"True Label: {true_label}"
|
124 |
-
except FileNotFoundError:
|
125 |
-
label_str = "True Label: Unknown"
|
126 |
-
|
127 |
-
result = f"Predicted Label: {pred} {label_str}\nRaw Logits: {logits.tolist()}"
|
128 |
-
logging.info(result)
|
129 |
-
|
130 |
-
# 5. Save or stdout
|
131 |
-
if args.output:
|
132 |
-
# ensure parent dir exists (e.g., outputs/inference/)
|
133 |
-
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
|
134 |
-
with open(args.output, "w", encoding="utf-8") as fout:
|
135 |
-
fout.write(result)
|
136 |
-
logging.info("Result saved to %s", args.output)
|
137 |
-
|
138 |
-
sys.exit(0)
|
139 |
-
|
140 |
-
except Exception as e:
|
141 |
-
logging.error(e)
|
142 |
-
sys.exit(1)
|
|
|
1 |
+
# scripts/run_inference.py
|
2 |
+
"""
|
3 |
+
CLI inference with preprocessing parity.
|
4 |
+
Applies: resample → baseline (deg=2) → smooth (w=11,o=2) → normalize
|
5 |
+
unless explicitly disabled via flags.
|
6 |
+
|
7 |
+
Usage (examples):
|
8 |
+
python scripts/run_inference.py \
|
9 |
+
--input datasets/rdwp/sta-1.txt \
|
10 |
+
--arch figure2 \
|
11 |
+
--weights outputs/figure2_model.pth \
|
12 |
+
--target-len 500
|
13 |
+
|
14 |
+
# Disable smoothing only:
|
15 |
+
python scripts/run_inference.py --input ... --arch resnet --weights ... --disable-smooth
|
16 |
+
"""
|
17 |
+
|
18 |
import os
|
19 |
+
import sys
|
20 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|
|
21 |
|
22 |
import argparse
|
23 |
+
import json
|
24 |
import logging
|
25 |
+
from pathlib import Path
|
26 |
+
from typing import cast
|
27 |
+
from torch import nn
|
28 |
|
29 |
import numpy as np
|
30 |
import torch
|
31 |
+
import torch.nn.functional as F
|
32 |
|
33 |
+
from models.registry import build, choices
|
34 |
+
from utils.preprocessing import preprocess_spectrum, TARGET_LENGTH
|
35 |
+
from scripts.plot_spectrum import load_spectrum
|
36 |
+
from scripts.discover_raman_files import label_file
|
37 |
|
38 |
|
39 |
+
def parse_args():
|
40 |
+
p = argparse.ArgumentParser(description="Raman spectrum inference (parity with CLI preprocessing).")
|
41 |
+
p.add_argument("--input", required=True, help="Path to a single Raman .txt file (2 columns: x, y).")
|
42 |
+
p.add_argument("--arch", required=True, choices=choices(), help="Model architecture key.")
|
43 |
+
p.add_argument("--weights", required=True, help="Path to model weights (.pth).")
|
44 |
+
p.add_argument("--target-len", type=int, default=TARGET_LENGTH, help="Resample length (default: 500).")
|
45 |
|
46 |
+
# Default = ON; use disable- flags to turn steps off explicitly.
|
47 |
+
p.add_argument("--disable-baseline", action="store_true", help="Disable baseline correction.")
|
48 |
+
p.add_argument("--disable-smooth", action="store_true", help="Disable Savitzky–Golay smoothing.")
|
49 |
+
p.add_argument("--disable-normalize", action="store_true", help="Disable min-max normalization.")
|
|
|
|
|
50 |
|
51 |
+
p.add_argument("--output", default=None, help="Optional output JSON path (defaults to outputs/inference/<name>.json).")
|
52 |
+
p.add_argument("--device", default="cpu", choices=["cpu", "cuda"], help="Compute device (default: cpu).")
|
53 |
+
return p.parse_args()
|
54 |
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
+
def _load_state_dict_safe(path: str):
|
57 |
+
"""Load a state dict safely across torch versions & checkpoint formats."""
|
58 |
+
try:
|
59 |
+
obj = torch.load(path, map_location="cpu", weights_only=True) # newer torch
|
60 |
+
except TypeError:
|
61 |
+
obj = torch.load(path, map_location="cpu") # fallback for older torch
|
62 |
+
|
63 |
+
# Accept either a plain state_dict or a checkpoint dict that contains one
|
64 |
+
if isinstance(obj, dict):
|
65 |
+
for k in ("state_dict", "model_state_dict", "model"):
|
66 |
+
if k in obj and isinstance(obj[k], dict):
|
67 |
+
obj = obj[k]
|
68 |
+
break
|
69 |
+
|
70 |
+
if not isinstance(obj, dict):
|
71 |
+
raise ValueError(
|
72 |
+
"Loaded object is not a state_dict or checkpoint with a state_dict. "
|
73 |
+
f"Type={type(obj)} from file={path}"
|
74 |
+
)
|
75 |
+
|
76 |
+
# Strip DataParallel 'module.' prefixes if present
|
77 |
+
if any(key.startswith("module.") for key in obj.keys()):
|
78 |
+
obj = {key.replace("module.", "", 1): val for key, val in obj.items()}
|
79 |
+
|
80 |
+
return obj
|
81 |
+
|
82 |
+
|
83 |
+
def main():
|
84 |
+
logging.basicConfig(level=logging.INFO, format="INFO: %(message)s")
|
85 |
+
args = parse_args()
|
86 |
+
|
87 |
+
in_path = Path(args.input)
|
88 |
+
if not in_path.exists():
|
89 |
+
raise FileNotFoundError(f"Input file not found: {in_path}")
|
90 |
+
|
91 |
+
# --- Load raw spectrum
|
92 |
+
x_raw, y_raw = load_spectrum(str(in_path))
|
93 |
+
if len(x_raw) < 10:
|
94 |
+
raise ValueError("Input spectrum has too few points (<10).")
|
95 |
+
|
96 |
+
# --- Preprocess (single source of truth)
|
97 |
+
_, y_proc = preprocess_spectrum(
|
98 |
+
np.array(x_raw),
|
99 |
+
np.array(y_raw),
|
100 |
+
target_len=args.target_len,
|
101 |
+
do_baseline=not args.disable_baseline,
|
102 |
+
do_smooth=not args.disable_smooth,
|
103 |
+
do_normalize=not args.disable_normalize,
|
104 |
+
out_dtype="float32",
|
105 |
)
|
106 |
|
107 |
+
# --- Build model & load weights (safe)
|
108 |
+
device = torch.device(args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu")
|
109 |
+
model = cast(nn.Module, build(args.arch, args.target_len)).to(device)
|
110 |
+
state = _load_state_dict_safe(args.weights)
|
111 |
+
missing, unexpected = model.load_state_dict(state, strict=False)
|
112 |
+
if missing or unexpected:
|
113 |
+
logging.info("Loaded with non-strict keys. missing=%d unexpected=%d", len(missing), len(unexpected))
|
114 |
+
|
115 |
+
model.eval()
|
116 |
+
|
117 |
+
# Shape: (B, C, L) = (1, 1, target_len)
|
118 |
+
x_tensor = torch.from_numpy(y_proc[None, None, :]).to(device)
|
119 |
+
|
120 |
+
with torch.no_grad():
|
121 |
+
logits = model(x_tensor).float().cpu() # shape (1, num_classes)
|
122 |
+
probs = F.softmax(logits, dim=1)
|
123 |
+
|
124 |
+
probs_np = probs.numpy().ravel().tolist()
|
125 |
+
logits_np = logits.numpy().ravel().tolist()
|
126 |
+
pred_label = int(np.argmax(probs_np))
|
127 |
+
|
128 |
+
# Optional ground-truth from filename (if encoded)
|
129 |
+
true_label = label_file(str(in_path))
|
130 |
+
|
131 |
+
# --- Prepare output
|
132 |
+
out_dir = Path("outputs") / "inference"
|
133 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
134 |
+
out_path = Path(args.output) if args.output else (out_dir / f"{in_path.stem}_{args.arch}.json")
|
135 |
+
|
136 |
+
result = {
|
137 |
+
"input_file": str(in_path),
|
138 |
+
"arch": args.arch,
|
139 |
+
"weights": str(args.weights),
|
140 |
+
"target_len": args.target_len,
|
141 |
+
"preprocessing": {
|
142 |
+
"baseline": not args.disable_baseline,
|
143 |
+
"smooth": not args.disable_smooth,
|
144 |
+
"normalize": not args.disable_normalize,
|
145 |
+
},
|
146 |
+
"predicted_label": pred_label,
|
147 |
+
"true_label": true_label,
|
148 |
+
"probs": probs_np,
|
149 |
+
"logits": logits_np,
|
150 |
+
}
|
151 |
+
|
152 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
153 |
+
json.dump(result, f, indent=2)
|
154 |
+
|
155 |
+
logging.info("Predicted Label: %d True Label: %s", pred_label, true_label)
|
156 |
+
logging.info("Raw Logits: %s", logits_np)
|
157 |
+
logging.info("Result saved to %s", out_path)
|
158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
+
if __name__ == "__main__":
|
161 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/conftest.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# tests/conftest.py
|
2 |
+
import sys
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
# Add repo root to sys.path so "utils", "models", "scripts" are importable in tests
|
6 |
+
ROOT = Path(__file__).resolve().parents[1]
|
7 |
+
if str(ROOT) not in sys.path:
|
8 |
+
sys.path.insert(0, str(ROOT))
|
tests/test_preprocessing.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from utils.preprocessing import preprocess_spectrum, TARGET_LENGTH
|
3 |
+
|
4 |
+
def test_shapes_and_monotonicity():
|
5 |
+
x = np.linspace(100, 200, 300)
|
6 |
+
y = np.sin(x/10.0) + 0.01*(x - 100)
|
7 |
+
x2, y2 = preprocess_spectrum(x, y, target_len=TARGET_LENGTH)
|
8 |
+
assert x2.shape == (TARGET_LENGTH,)
|
9 |
+
assert y2.shape == (TARGET_LENGTH,)
|
10 |
+
assert np.all(np.diff(x2) > 0)
|
11 |
+
|
12 |
+
def test_idempotency():
|
13 |
+
x = np.linspace(0, 100, 400)
|
14 |
+
y = np.cos(x/7.0) + 0.002*x
|
15 |
+
_, y1 = preprocess_spectrum(x, y, target_len=TARGET_LENGTH)
|
16 |
+
_, y2 = preprocess_spectrum(np.linspace(x.min(), x.max(), TARGET_LENGTH), y1, target_len=TARGET_LENGTH)
|
17 |
+
np.testing.assert_allclose(y1, y2, rtol=1e-6, atol=1e-7)
|
utils/__init__.py
CHANGED
@@ -1,4 +0,0 @@
|
|
1 |
-
"""Utility functions for the polymer classification app"""
|
2 |
-
from .preprocessing import resample_spectrum
|
3 |
-
|
4 |
-
__all__ = ['resample_spectrum']
|
|
|
|
|
|
|
|
|
|
utils/audit.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
audit.py - quick audit tool for preprocessing baseline
|
4 |
+
|
5 |
+
Searches for relevant keywords in the ml-polymer-recycling repo
|
6 |
+
to confirm what preprocessing steps (resample, baseline, smooth,
|
7 |
+
normalize, etc.) are actually implemented in code/docs.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import re
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
# ||== KEYWORDS TO TRACE ==||
|
14 |
+
KEYWORDS = [
|
15 |
+
"resample", "baseline", "smooth", "Savitz",
|
16 |
+
"normalize", "minmax" "TARGET_LENGTH", "WINDOW_LENGTH",
|
17 |
+
"POLYORDER", "DEGREE", "input_length", "target_len", "Figure2CNN", "ResNet"
|
18 |
+
]
|
19 |
+
|
20 |
+
# ||==== DIRECTORIES/FILES TO SCAN ====||
|
21 |
+
TARGETS = [
|
22 |
+
"scripts/preprocess_dataset.py",
|
23 |
+
"scripts/run_inferece.py",
|
24 |
+
"models/",
|
25 |
+
"utils/",
|
26 |
+
"README.md",
|
27 |
+
"GROUND_TRUTH_PIPELINE.md",
|
28 |
+
"docs/"
|
29 |
+
]
|
30 |
+
|
31 |
+
# ||==== COMPILE REGEX FOR KEYWORDS ====||
|
32 |
+
pattern = re.compile("|".join(KEYWORDS), re.IGNORECASE)
|
33 |
+
|
34 |
+
def scan_file(path: Path):
|
35 |
+
try:
|
36 |
+
with path.open(encoding="utf-8", errors="ignore") as f:
|
37 |
+
for i, line in enumerate(f, 1):
|
38 |
+
if pattern.search(line):
|
39 |
+
print(f"{path}:{i}: {line.strip()}")
|
40 |
+
except Exception as e:
|
41 |
+
print(f"[ERR] Could not read {path}: {e}")
|
42 |
+
|
43 |
+
def main():
|
44 |
+
root = Path(".").resolve()
|
45 |
+
for target in TARGETS:
|
46 |
+
p = root / target
|
47 |
+
if p.is_file():
|
48 |
+
scan_file(p)
|
49 |
+
elif p.is_dir():
|
50 |
+
for sub in p.rglob("*.py"):
|
51 |
+
scan_file(sub)
|
52 |
+
for sub in p.rglob("*.md"):
|
53 |
+
scan_file(sub)
|
54 |
+
|
55 |
+
if __name__ == "__main__":
|
56 |
+
main()
|
utils/preprocessing.py
CHANGED
@@ -3,107 +3,84 @@ Preprocessing utilities for polymer classification app.
|
|
3 |
Adapted from the original scripts/preprocess_dataset.py for Hugging Face Spaces deployment.
|
4 |
"""
|
5 |
|
|
|
6 |
import numpy as np
|
|
|
7 |
from scipy.interpolate import interp1d
|
8 |
from scipy.signal import savgol_filter
|
9 |
-
from
|
10 |
-
|
11 |
-
# Default resample target
|
12 |
-
TARGET_LENGTH = 500
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
return y - baseline
|
20 |
|
21 |
-
def
|
22 |
-
"""
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
if
|
28 |
-
window_length
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
y_unique[i] = np.mean(y_sorted[mask])
|
69 |
-
x_sorted, y_sorted = x_unique, y_unique
|
70 |
-
|
71 |
-
# Create interpolation function
|
72 |
-
f_interp = interp1d(x_sorted, y_sorted, kind='linear', bounds_error=False, fill_value=np.nan)
|
73 |
-
|
74 |
-
# Generate uniform grid
|
75 |
-
x_uniform = np.linspace(min(x_sorted), max(x_sorted), target_len)
|
76 |
-
y_uniform = f_interp(x_uniform)
|
77 |
-
|
78 |
-
return y_uniform
|
79 |
-
|
80 |
-
def preprocess_spectrum(x, y, target_len=500, baseline_correction=False,
|
81 |
-
apply_smoothing=False, normalize=False):
|
82 |
-
"""
|
83 |
-
Complete preprocessing pipeline for a single spectrum.
|
84 |
-
|
85 |
-
Args:
|
86 |
-
x (array-like): Wavenumber values
|
87 |
-
y (array-like): Intensity values
|
88 |
-
target_len (int): Number of points to resample to
|
89 |
-
baseline_correction (bool): Whether to apply baseline removal
|
90 |
-
apply_smoothing (bool): Whether to apply Savitzky-Golay smoothing
|
91 |
-
normalize (bool): Whether to apply min-max normalization
|
92 |
-
|
93 |
-
Returns:
|
94 |
-
np.ndarray: Preprocessed spectrum
|
95 |
-
"""
|
96 |
-
# Resample first
|
97 |
-
y_processed = resample_spectrum(x, y, target_len=target_len)
|
98 |
-
|
99 |
-
# Optional preprocessing steps
|
100 |
-
if baseline_correction:
|
101 |
-
y_processed = remove_baseline(y_processed)
|
102 |
-
|
103 |
-
if apply_smoothing:
|
104 |
-
y_processed = smooth_spectrum(y_processed)
|
105 |
-
|
106 |
-
if normalize:
|
107 |
-
y_processed = normalize_spectrum(y_processed)
|
108 |
-
|
109 |
-
return y_processed
|
|
|
3 |
Adapted from the original scripts/preprocess_dataset.py for Hugging Face Spaces deployment.
|
4 |
"""
|
5 |
|
6 |
+
from __future__ import annotations
|
7 |
import numpy as np
|
8 |
+
from numpy.typing import DTypeLike
|
9 |
from scipy.interpolate import interp1d
|
10 |
from scipy.signal import savgol_filter
|
11 |
+
from scipy.interpolate import interp1d
|
|
|
|
|
|
|
12 |
|
13 |
+
TARGET_LENGTH = 500 # Frozen default per PREPROCESSING_BASELINE
|
14 |
+
|
15 |
+
def __ensure_1d_equal(x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
16 |
+
x = np.asarray(x, dtype=float)
|
17 |
+
y = np.asarray(x, dtype=float)
|
18 |
+
if x.ndim != 1 or y.ndim != 1 or x.size != y.size or x.size < 2:
|
19 |
+
raise ValueError("x and y must be 1D arrays of equal length >= 2")
|
20 |
+
return x, y
|
21 |
+
|
22 |
+
def resample_spectrum(x: np.ndarray, y: np.ndarray, target_len: int = TARGET_LENGTH) -> tuple[np.ndarray, np.ndarray]:
|
23 |
+
"""Linear re-sampling onto a uniform grid of length target_len."""
|
24 |
+
x, y = __ensure_1d_equal(x, y)
|
25 |
+
order = np.argsort(x)
|
26 |
+
x_sorted, y_sorted = x[order], y[order]
|
27 |
+
x_new = np.linspace(x_sorted[0], x_sorted[-1], int(target_len))
|
28 |
+
f = interp1d(x_sorted, y_sorted, kind="linear", assume_sorted=True)
|
29 |
+
y_new = f(x_new)
|
30 |
+
return x_new, y_new
|
31 |
+
|
32 |
+
def remove_baseline(y: np.ndarray, degree: int = 2) -> np.ndarray:
|
33 |
+
"""Polynomial baseline subtraction (degree=2 default)"""
|
34 |
+
y = np.asarray(y, dtype=float)
|
35 |
+
x_idx = np.arange(y.size, dtype=float)
|
36 |
+
coeffs = np.polyfit(x_idx, y, deg=int(degree))
|
37 |
+
baseline = np.polyval(coeffs, x_idx)
|
38 |
return y - baseline
|
39 |
|
40 |
+
def smooth_spectrum(y: np.ndarray, window_length: int = 11, polyorder: int = 2) -> np.ndarray:
|
41 |
+
"""Savitzky-Golay smoothing with safe/odd window enforcement"""
|
42 |
+
y = np.asarray(y, dtype=float)
|
43 |
+
window_length = int(window_length)
|
44 |
+
polyorder = int(polyorder)
|
45 |
+
# === window must be odd and >= polyorder+1 ===
|
46 |
+
if window_length % 2 == 0:
|
47 |
+
window_length += 1
|
48 |
+
min_win = polyorder + 1
|
49 |
+
if min_win % 2 == 0:
|
50 |
+
min_win += 1
|
51 |
+
window_length = max(window_length, min_win)
|
52 |
+
return savgol_filter(y, window_length=window_length, polyorder=polyorder, mode="interp")
|
53 |
+
|
54 |
+
def normalize_spectrum(y: np.ndarray) -> np.ndarray:
|
55 |
+
"""Min-max normalization to [0, 1] with constant-signal guard."""
|
56 |
+
y = np.asarray(y, dtype=float)
|
57 |
+
y_min = float(np.min(y))
|
58 |
+
y_max = float(np.max(y))
|
59 |
+
if np.isclose(y_max - y_min, 0.0):
|
60 |
+
return np.zeros_like(y)
|
61 |
+
return (y - y_min) / (y_max - y_min)
|
62 |
+
|
63 |
+
def preprocess_spectrum(
|
64 |
+
x: np.ndarray,
|
65 |
+
y: np.ndarray,
|
66 |
+
*,
|
67 |
+
target_len: int = TARGET_LENGTH,
|
68 |
+
do_baseline: bool = True,
|
69 |
+
degree: int = 2,
|
70 |
+
do_smooth: bool = True,
|
71 |
+
window_length: int = 11,
|
72 |
+
polyorder: int = 2,
|
73 |
+
do_normalize: bool = True,
|
74 |
+
out_dtype: DTypeLike = np.float32,
|
75 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
76 |
+
"""Exact CLI baseline: resample -> baseline -> smooth -> normalize"""
|
77 |
+
x_rs, y_rs = resample_spectrum(x, y, target_len=target_len)
|
78 |
+
if do_baseline:
|
79 |
+
y_rs = remove_baseline(y_rs, degree=degree)
|
80 |
+
if do_smooth:
|
81 |
+
y_rs = smooth_spectrum(y_rs, window_length=window_length, polyorder=polyorder)
|
82 |
+
if do_normalize:
|
83 |
+
y_rs = normalize_spectrum(y_rs)
|
84 |
+
# === Coerce to a real dtype to satisfy static checkers & runtime ===
|
85 |
+
out_dt = np.dtype(out_dtype)
|
86 |
+
return x_rs.astype(out_dt, copy=False), y_rs.astype(out_dt, copy=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|