DeepEpilepsy V2

Spectrogram-based Vision Transformer (SimpleViT) for EEG representation learning. Produces 512-dimensional embeddings from 60-second, 19-channel EEG segments in the 10-20 referential montage.

Checkpoints

File Params Layers Heads Training data
deepepilepsy_v2_large.pt 14M 4 4 CHUM epilepsy (supervised)
deepepilepsy_v2_huge.pt 23M 6 8 CHUM epilepsy (supervised)
deepepilepsy_v2_large_tuh.pt 14M 4 4 TUH abnormal (pretrained)
deepepilepsy_v2_huge_tuh.pt 27M 8 8 TUH abnormal (pretrained)

Note: The TUH huge checkpoint uses 8 encoder layers (not 6), so it does not load into the default DeepEpilepsyV2Huge adapter.

Usage with eegzoo

pip install eegzoo[spectral]  # MNE required for spectrogram preprocessing
from eegzoo import get

model = get("deepepilepsy_v2_large", weights="emilelemoine/DeepEpilepsy_V2")

tensor = model.preprocess(segment, sfreq=256, channel_names=ch_names)  # (19, 32, 3000)
embeddings = model.embed(tensor.unsqueeze(0))  # (1, 512)

To load a non-default checkpoint (e.g. TUH large):

model = get(
    "deepepilepsy_v2_large",
    weights="emilelemoine/DeepEpilepsy_V2/deepepilepsy_v2_large_tuh.pt",
)

Architecture

Conv2dTokenizer  (19, 32, 3000) -> (batch, 60, 512)
  3x strided Conv2d(freq_kernel=3, time_kernel=11, stride=(2,1))
  + Conv2d(4, 1, stride=(1, 50))       # collapse freq, stride time
Sinusoidal positional encoding (1D, not learned)
TransformerEncoder
  N x [LayerNorm -> MultiheadAttention -> LayerNorm -> MLP]
  Final LayerNorm
Mean pooling -> 512-dim embedding
Linear(512, 2) classification head (not used for embeddings)

Intermediate tokenizer channel widths are geometrically interpolated from dim_internal to dim_hidden:

Variant dim_internal Channels dim_mlp
Large 128 128 -> 256 -> 512 768
Huge 256 256 -> 362 -> 512 1024

Preprocessing

  1. Select and reorder 19 channels (Fp1, Fp2, F3, F4, C3, C4, P3, P4, F7, F8, T3, T4, T5, T6, O1, O2, Fz, Cz, Pz)
  2. Resample to 200 Hz
  3. Crop to 12,000 samples (60s)
  4. Morlet wavelet spectrogram:
    • 60 Hz notch filter (FIR)
    • Morlet CWT: 32 log-spaced frequencies (0.2--100 Hz), 7 cycles, decim=4
    • Log transform: log(power + 1e-9)
    • Robust normalization: (x - median) / (Q75 - Q25 + 1e-9) (global)
  5. Output shape: (19, 32, 3000)

State dict format

Clean state_dict (no module. prefix, no optimizer state). Keys follow torchvision naming:

stem.layers.conv_stem_{0,1,2}.{0,2}.{weight,bias}
stem.layers.conv_last.{weight,bias}
encoder.layers.encoder_layer_{N}.{ln_1,self_attention,ln_2,mlp}.*
encoder.ln.{weight,bias}
mlp_head.{weight,bias}

pos_embedding is not stored (deterministic sinusoidal, recomputed at init).

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support