DeepEpilepsy V2
Spectrogram-based Vision Transformer (SimpleViT) for EEG representation learning. Produces 512-dimensional embeddings from 60-second, 19-channel EEG segments in the 10-20 referential montage.
Checkpoints
| File | Params | Layers | Heads | Training data |
|---|---|---|---|---|
deepepilepsy_v2_large.pt |
14M | 4 | 4 | CHUM epilepsy (supervised) |
deepepilepsy_v2_huge.pt |
23M | 6 | 8 | CHUM epilepsy (supervised) |
deepepilepsy_v2_large_tuh.pt |
14M | 4 | 4 | TUH abnormal (pretrained) |
deepepilepsy_v2_huge_tuh.pt |
27M | 8 | 8 | TUH abnormal (pretrained) |
Note: The TUH huge checkpoint uses 8 encoder layers (not 6), so it does not load into the default
DeepEpilepsyV2Hugeadapter.
Usage with eegzoo
pip install eegzoo[spectral] # MNE required for spectrogram preprocessing
from eegzoo import get
model = get("deepepilepsy_v2_large", weights="emilelemoine/DeepEpilepsy_V2")
tensor = model.preprocess(segment, sfreq=256, channel_names=ch_names) # (19, 32, 3000)
embeddings = model.embed(tensor.unsqueeze(0)) # (1, 512)
To load a non-default checkpoint (e.g. TUH large):
model = get(
"deepepilepsy_v2_large",
weights="emilelemoine/DeepEpilepsy_V2/deepepilepsy_v2_large_tuh.pt",
)
Architecture
Conv2dTokenizer (19, 32, 3000) -> (batch, 60, 512)
3x strided Conv2d(freq_kernel=3, time_kernel=11, stride=(2,1))
+ Conv2d(4, 1, stride=(1, 50)) # collapse freq, stride time
Sinusoidal positional encoding (1D, not learned)
TransformerEncoder
N x [LayerNorm -> MultiheadAttention -> LayerNorm -> MLP]
Final LayerNorm
Mean pooling -> 512-dim embedding
Linear(512, 2) classification head (not used for embeddings)
Intermediate tokenizer channel widths are geometrically interpolated from dim_internal to dim_hidden:
| Variant | dim_internal | Channels | dim_mlp |
|---|---|---|---|
| Large | 128 | 128 -> 256 -> 512 | 768 |
| Huge | 256 | 256 -> 362 -> 512 | 1024 |
Preprocessing
- Select and reorder 19 channels (Fp1, Fp2, F3, F4, C3, C4, P3, P4, F7, F8, T3, T4, T5, T6, O1, O2, Fz, Cz, Pz)
- Resample to 200 Hz
- Crop to 12,000 samples (60s)
- Morlet wavelet spectrogram:
- 60 Hz notch filter (FIR)
- Morlet CWT: 32 log-spaced frequencies (0.2--100 Hz), 7 cycles, decim=4
- Log transform:
log(power + 1e-9) - Robust normalization:
(x - median) / (Q75 - Q25 + 1e-9)(global)
- Output shape:
(19, 32, 3000)
State dict format
Clean state_dict (no module. prefix, no optimizer state). Keys follow torchvision naming:
stem.layers.conv_stem_{0,1,2}.{0,2}.{weight,bias}
stem.layers.conv_last.{weight,bias}
encoder.layers.encoder_layer_{N}.{ln_1,self_attention,ln_2,mlp}.*
encoder.ln.{weight,bias}
mlp_head.{weight,bias}
pos_embedding is not stored (deterministic sinusoidal, recomputed at init).
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support