Source code for graphax.examples.deep_learning

import jax
import jax.nn as jnn
import jax.numpy as jnp


[docs] def sigmoid(x): return 1. / (1. + jnp.exp(-x))
[docs] def SiLU(x): return x*sigmoid(x)
[docs] def softmax_cross_entropy(logits, labels): return -jnp.sum(labels * jnp.log(jnn.softmax(logits, axis=-1)), axis=-1)
[docs] def variance(x, mu, axis=0): n = x.shape[axis] return jnp.sum((x - mu)**2, axis=axis)/n
[docs] def layer_norm(x, gamma, beta): mu = jnp.mean(x, axis=-1) sigma = variance(x, mu, axis=-1) return (x - mu)/jnp.sqrt(sigma + 1e-6) * gamma + beta
[docs] def attn(q, k, v): a = q.T @ k z = jnn.softmax(a, axis=1) return z @ v
[docs] def Perceptron(x, y, W1, b1, W2, b2, gamma, beta): out = jnp.tanh(W1 @ x + b1) out = layer_norm(out, gamma, beta) out = jnp.tanh(W2 @ out + b2) return softmax_cross_entropy(out, y)
[docs] def encoder_block(x, WQ, WK, WV, W, b, gamma, beta): q = WQ @ x k = WK @ x v = WV @ x a = x + attn(q, k, v) c = layer_norm(a, gamma, beta) return SiLU(W @ c + b)
[docs] def decoder_block(x, q, k, WQ1, WK1, WV1, WQ2, WK2, WV2, W, b, gamma0, gamma1, beta0, beta1): q1 = WQ1 @ x k1 = WK1 @ x v1 = WV1 @ x a1 = x + attn(q1, k1, v1) c1 = layer_norm(a1, gamma0, beta0) q2 = WQ2 @ q k2 = WK2 @ k v2 = WV2 @ c1 a2 = c1 + attn(q2, k2, v2) c2 = layer_norm(a2, gamma1, beta1) return SiLU(W @ c2 + b)
[docs] def Encoder(x, y, WQ1, WQ2, WK1, WK2, WV1, WV2, W1, W2, b1, b2, gamma0, beta0, gamma1, beta1): z1 = encoder_block(x, WQ1, WK1, WV1, W1, b1, gamma0, beta0) z2 = encoder_block(z1, WQ2, WK2, WV2, W2, b2, gamma1, beta1) return softmax_cross_entropy(z2, y)
[docs] def EncoderDecoder(x, y, WQ1, WQ2, WQ3, WK1, WK2, WK3, WV1, WV2, WV3, W1, W2, b1, b2, gamma0, beta0, gamma1, beta1, gamma2, beta2): z1 = encoder_block(x, WQ1, WK1, WV1, W1, b1, gamma0, beta0) z2 = decoder_block(x, z1, z1, WQ2, WQ3, WK2, WK3, WV2, WV3, W2, b2, gamma1, gamma2, beta1, beta2) return .5*(z2 - y)**2