import jax
import jax.nn as jnn
import jax.numpy as jnp
[docs]
def sigmoid(x):
return 1. / (1. + jnp.exp(-x))
[docs]
def SiLU(x):
return x*sigmoid(x)
[docs]
def softmax_cross_entropy(logits, labels):
return -jnp.sum(labels * jnp.log(jnn.softmax(logits, axis=-1)), axis=-1)
[docs]
def variance(x, mu, axis=0):
n = x.shape[axis]
return jnp.sum((x - mu)**2, axis=axis)/n
[docs]
def layer_norm(x, gamma, beta):
mu = jnp.mean(x, axis=-1)
sigma = variance(x, mu, axis=-1)
return (x - mu)/jnp.sqrt(sigma + 1e-6) * gamma + beta
[docs]
def attn(q, k, v):
a = q.T @ k
z = jnn.softmax(a, axis=1)
return z @ v
[docs]
def Perceptron(x, y, W1, b1, W2, b2, gamma, beta):
out = jnp.tanh(W1 @ x + b1)
out = layer_norm(out, gamma, beta)
out = jnp.tanh(W2 @ out + b2)
return softmax_cross_entropy(out, y)
[docs]
def encoder_block(x, WQ, WK, WV, W, b, gamma, beta):
q = WQ @ x
k = WK @ x
v = WV @ x
a = x + attn(q, k, v)
c = layer_norm(a, gamma, beta)
return SiLU(W @ c + b)
[docs]
def decoder_block(x, q, k, WQ1, WK1, WV1, WQ2, WK2, WV2, W, b, gamma0, gamma1, beta0, beta1):
q1 = WQ1 @ x
k1 = WK1 @ x
v1 = WV1 @ x
a1 = x + attn(q1, k1, v1)
c1 = layer_norm(a1, gamma0, beta0)
q2 = WQ2 @ q
k2 = WK2 @ k
v2 = WV2 @ c1
a2 = c1 + attn(q2, k2, v2)
c2 = layer_norm(a2, gamma1, beta1)
return SiLU(W @ c2 + b)
[docs]
def Encoder(x, y, WQ1, WQ2, WK1, WK2, WV1, WV2, W1, W2, b1, b2, gamma0, beta0, gamma1, beta1):
z1 = encoder_block(x, WQ1, WK1, WV1, W1, b1, gamma0, beta0)
z2 = encoder_block(z1, WQ2, WK2, WV2, W2, b2, gamma1, beta1)
return softmax_cross_entropy(z2, y)
[docs]
def EncoderDecoder(x, y, WQ1, WQ2, WQ3, WK1, WK2, WK3, WV1, WV2, WV3, W1, W2, b1, b2, gamma0, beta0, gamma1, beta1, gamma2, beta2):
z1 = encoder_block(x, WQ1, WK1, WV1, W1, b1, gamma0, beta0)
z2 = decoder_block(x, z1, z1, WQ2, WQ3, WK2, WK3, WV2, WV3, W2, b2, gamma1, gamma2, beta1, beta2)
return .5*(z2 - y)**2