Exact Gradients for Stochastic Spiking Neural Networks Driven by Rough Signals
Abstract
A framework using rough path theory models stochastic spiking neural networks as event-driven stochastic differential equations, enabling gradient-based training despite noise affecting spike timing and dynamics.
We introduce a mathematically rigorous framework based on rough path theory to model stochastic spiking neural networks (SSNNs) as stochastic differential equations with event discontinuities (Event SDEs) and driven by c\`adl\`ag rough paths. Our formalism is general enough to allow for potential jumps to be present both in the solution trajectories as well as in the driving noise. We then identify a set of sufficient conditions ensuring the existence of pathwise gradients of solution trajectories and event times with respect to the network's parameters and show how these gradients satisfy a recursive relation. Furthermore, we introduce a general-purpose loss function defined by means of a new class of signature kernels indexed on c\`adl\`ag rough paths and use it to train SSNNs as generative models. We provide an end-to-end autodifferentiable solver for Event SDEs and make its implementation available as part of the diffrax library. Our framework is, to our knowledge, the first enabling gradient-based training of SSNNs with noise affecting both the spike timing and the network's dynamics.
Models citing this paper 0
No model linking this paper
Datasets citing this paper 1
Spaces citing this paper 0
No Space linking this paper
Collections including this paper 0
No Collection including this paper