Source code for graphax.examples.easy

import jax
import jax.numpy as jnp


### Examples from "Evaluting Derivatives" book by Andreas Griewank and Andrea Walther

[docs] def Simple(x, y): z = x * y w = jnp.sin(z) return w + z, jnp.log(w)
[docs] def Lighthouse(nu, gamma, omega, t): y1 = nu*jnp.tan(omega*t)/(gamma-jnp.tan(omega*t)) y2 = gamma*y1 return y1, y2
[docs] def Hole(x, y, z, w): a = y * z b = a + x c = a + w d = jnp.cos(b) e = jnp.exp(c) f = d - e g = d / e h = d * e return f, g, h
### General Relativity a = .5 b = .9 M = 1.
[docs] def KerrSenn_metric(t, r, theta, phi): sintheta2 = jnp.sin(theta)**2 sigma = r**2 + 2.*b*r + a**2 * jnp.cos(theta)**2 k = r**2 + 2*b*r - 2.*M*r + a**2 gtt = -(1. - 2.*M*r/sigma) grr = sigma/k gthetatheta = sigma gphiphi = (sigma + a**2*sintheta2 + \ 2.*M*r*a**2*sintheta2/sigma)*sintheta2 gphit = -2.*M*r*a/sintheta2 return gtt, grr, gthetatheta, gphiphi, gphit
[docs] def KerrSenn_Jacobian(t, r, theta, phi): return jax.jacfwd(KerrSenn_metric)(t, r, theta, phi)
### Thermodynamics and Statistical Mechanics
[docs] def Helmholtz(x): return x * jnp.log(x / (1. + -jnp.sum(x)))
[docs] def FreeEnergy(x): return jnp.sum(Helmholtz(x))
### Meterology qc, qr, qv = 1., 1., 1. c = 1. S = 1. B = 0. gTw = 1.
[docs] def condensation(qc): return c*S*qc**0.33333
[docs] def accretion(a2, bc, br, qc, qr): return a2*qc**bc*qr**br
[docs] def autoconversion(a1, gamma, qc): return a1*qc**gamma
[docs] def evaporation(e1, d1, e2, d2, qr): return (e1*qr**d1 + e2*qr**d2)
# Taken from https://gmd.copernicus.org/preprints/gmd-2019-140/gmd-2019-140.pdf
[docs] def CloudSchemes_step(a1, a2, e1, e2, delta, gamma, bc, br, d1, d2, chi): dqc = condensation(qc) - autoconversion(a1, gamma, qc) - accretion(a2, bc, br, qc, qr) dqr = autoconversion(a1, gamma, qc) + evaporation(e1, d1, e2, d2, qr) + B - delta*qr**chi dqv = -condensation(qc) - evaporation(e1, d1, e2, d2, qr) return dqc, dqr, dqv