import jax
import jax.numpy as jnp
### Examples from "Evaluting Derivatives" book by Andreas Griewank and Andrea Walther
[docs]
def Simple(x, y):
z = x * y
w = jnp.sin(z)
return w + z, jnp.log(w)
[docs]
def Lighthouse(nu, gamma, omega, t):
y1 = nu*jnp.tan(omega*t)/(gamma-jnp.tan(omega*t))
y2 = gamma*y1
return y1, y2
[docs]
def Hole(x, y, z, w):
a = y * z
b = a + x
c = a + w
d = jnp.cos(b)
e = jnp.exp(c)
f = d - e
g = d / e
h = d * e
return f, g, h
### General Relativity
a = .5
b = .9
M = 1.
[docs]
def KerrSenn_metric(t, r, theta, phi):
sintheta2 = jnp.sin(theta)**2
sigma = r**2 + 2.*b*r + a**2 * jnp.cos(theta)**2
k = r**2 + 2*b*r - 2.*M*r + a**2
gtt = -(1. - 2.*M*r/sigma)
grr = sigma/k
gthetatheta = sigma
gphiphi = (sigma + a**2*sintheta2 + \
2.*M*r*a**2*sintheta2/sigma)*sintheta2
gphit = -2.*M*r*a/sintheta2
return gtt, grr, gthetatheta, gphiphi, gphit
[docs]
def KerrSenn_Jacobian(t, r, theta, phi):
return jax.jacfwd(KerrSenn_metric)(t, r, theta, phi)
### Thermodynamics and Statistical Mechanics
[docs]
def Helmholtz(x):
return x * jnp.log(x / (1. + -jnp.sum(x)))
[docs]
def FreeEnergy(x):
return jnp.sum(Helmholtz(x))
### Meterology
qc, qr, qv = 1., 1., 1.
c = 1.
S = 1.
B = 0.
gTw = 1.
[docs]
def condensation(qc):
return c*S*qc**0.33333
[docs]
def accretion(a2, bc, br, qc, qr):
return a2*qc**bc*qr**br
[docs]
def autoconversion(a1, gamma, qc):
return a1*qc**gamma
[docs]
def evaporation(e1, d1, e2, d2, qr):
return (e1*qr**d1 + e2*qr**d2)
# Taken from https://gmd.copernicus.org/preprints/gmd-2019-140/gmd-2019-140.pdf
[docs]
def CloudSchemes_step(a1, a2, e1, e2, delta, gamma, bc, br, d1, d2, chi):
dqc = condensation(qc) - autoconversion(a1, gamma, qc) - accretion(a2, bc, br, qc, qr)
dqr = autoconversion(a1, gamma, qc) + evaporation(e1, d1, e2, d2, qr) + B - delta*qr**chi
dqv = -condensation(qc) - evaporation(e1, d1, e2, d2, qr)
return dqc, dqr, dqv