Source code for graphax.examples.neuromorphic

import jax
import jax.nn as jnn
import jax.numpy as jnp


[docs] def superspike_surrogate(beta=10.): @jax.custom_jvp def heaviside_with_super_spike_surrogate(x): return jnp.heaviside(x, 1) @heaviside_with_super_spike_surrogate.defjvp def f_jvp(primals, tangents): x, = primals x_dot, = tangents primal_out = heaviside_with_super_spike_surrogate(x) tangent_out = 1./(beta*jnp.abs(x)+1.) * x_dot return primal_out, tangent_out return heaviside_with_super_spike_surrogate
surrogate = superspike_surrogate()
[docs] def lif(U, I, S, a, b, threshold): U_next = a*U + (1.-a)*I I_next = b*I + (1.-b)*S S_next = surrogate(U_next - threshold) return U_next, I_next, S_next
# From Bellec et al. e-prop paper
[docs] def ada_lif(U, a, S, alpha, beta, rho, threshold): U_next = alpha*U + S A_th = threshold + beta*a S_next = jnn.sigmoid(U_next - A_th) # this needs to have spiking behavior jnp.heaviside(U_next - A_th, 1) # a_next = rho*a - S_next return U_next, a_next, S_next
# Single SNN forward pass
[docs] def LIF_SNN(S_in, S_target, U1, U2, U3, I1, I2, I3, W1, W2, W3, alpha, beta, thresh): i1 = W1 @ S_in U1, a1, s1 = lif(U1, I1, i1, alpha, beta, thresh) i2 = W2 @ s1 U2, a2, s2 = lif(U2, I2, i2, alpha, beta, thresh) i3 = W3 @ s2 U3, a3, s3 = lif(U3, I3, i3, alpha, beta, thresh) return .5*(s3 - S_target)**2, U1, U2, U3, a1, a2, a3
# Single SNN forward pass
[docs] def ADALIF_SNN(S_in, S_target, U1, U2, U3, a1, a2, a3, W1, W2, W3, alpha, beta, rho, thresh): i1 = W1 @ S_in U1, a1, s1 = ada_lif(U1, a1, i1, alpha, beta, rho, thresh) i2 = W2 @ s1 U2, a2, s2 = ada_lif(U2, a2, i2, alpha, beta, rho, thresh) i3 = W3 @ s2 U3, a3, s3 = ada_lif(U3, a3, i3, alpha, beta, rho, thresh) return .5*(s3 - S_target)**2, U1, U2, U3, a1, a2, a3