import jax
import jax.numpy as jnp
import jax.lax as lax
[docs]
def f(v0,v1,v2,v4):
v5 = jnp.power(v2,3)
v6 = lax.dot_general(v4,v5,dimension_numbers=(((0,), (0,)), ((), ())))
v7 = lax.dot_general(v5,v6,dimension_numbers=(((0,), (1,)), ((), ())))
v8 = jnp.power(v6,2)
v9 = lax.dot_general(v0,v8,dimension_numbers=(((0,), (1,)), ((), ())))
v10 = jnp.sqrt(v0)
v11 = jnp.arccos(v7)
v12 = jnp.transpose(v1,axes=[0, 1])
v13 = lax.dot_general(v11,v8,dimension_numbers=(((1,), (0,)), ((), ())))
_v9 = v9.reshape([1, 1])
v14 = jnp.add(v11,_v9)
_v9 = v9.reshape([1, 1])
v15 = jnp.subtract(v14,_v9)
v16 = jnp.cos(v9)
_v10 = jnp.reshape(v10, [4, 1])
v17 = jnp.add(_v10,v13)
v18 = jnp.arcsinh(v17)
_v16 = v16.reshape([1, 1])
v19 = jnp.add(v15,_v16)
v20 = jnp.negative(v12)
v21 = jnp.sqrt(v10)
v22 = lax.dot_general(v21,v18,dimension_numbers=(((0,), (1,)), ((), ())))
v23 = jnp.tanh(v12)
v24 = jnp.arctan(v19)
v25 = jnp.square(v24)
v26 = jnp.cos(v22)
v27 = jnp.subtract(_v16,v25)
v28 = jnp.cosh(v23)
v29 = jnp.sum(v20, axis=0)
v30 = jnp.transpose(v28,axes=[1, 0])
v31 = jnp.sum(v16, axis=0)
v32 = jnp.squeeze(v26)
_v29 = jnp.reshape(v29, [3, 1])
v33 = jnp.subtract(_v29,v30)
v34 = jnp.divide(v31,v33)
_v10 = v10.reshape([4, 1])
v35 = jnp.add(v13,_v10)
_v32 = jnp.reshape(v32, [4, 1])
v36 = jnp.multiply(_v32,v35)
v37 = jnp.power(v32,2)
v38 = jnp.amax(v34, axis=1)
v39 = jnp.sinh(v36)
v40 = jnp.exp(v27)
v41 = lax.slice(v39, start_indices=[3, 3], limit_indices=[4, 4])
v42 = jnp.sum(v37, axis=0)
v43 = jnp.add(v42,v38)
v44 = lax.dot_general(v32,v40,dimension_numbers=(((0,), (0,)), ((), ())))
v45 = jnp.arcsinh(v43)
v46 = lax.dot_general(v24,v15,dimension_numbers=(((1,), (1,)), ((), ())))
v47 = lax.stop_gradient(v41)
v48 = lax.logistic(v45)
v49 = lax.dot_general(v29,v48,dimension_numbers=(((0,), (0,)), ((), ())))
v50 = lax.dot_general(v22,v21,dimension_numbers=(((0,), (0,)), ((), ())))
v51 = lax.dot_general(v44,v47,dimension_numbers=(((0,), (1,)), ((), ())))
v52 = lax.dot_general(v30,v33,dimension_numbers=(((0,), (0,)), ((), ())))
v53 = jnp.multiply(v46,v50)
v54 = jnp.sqrt(v37)
v55 = jnp.sqrt(jnp.abs(v51))
v56 = jnp.subtract(v52,v49)
v57 = jnp.cosh(v54)
v58 = jnp.negative(v37)
_v57 = jnp.reshape(v57, [4, 1])
v59 = jnp.add(_v57,v53)
v60 = jnp.square(v57)
v61 = jnp.sin(v58)
v62 = jnp.arctan(v59)
v63 = jnp.sum(v62, axis=0)
v64 = jnp.amin(v63, axis=0)
v65 = jnp.subtract(v64,v55)
v66 = jnp.sqrt(v2)
v67 = jnp.sum(v56, axis=1)
_v61 = jnp.reshape(v61, [4, 1])
v68 = jnp.subtract(_v61,v66)
v69 = lax.dot_general(v60,v61,dimension_numbers=(((0,), (0,)), ((), ())))
v70 = lax.dot_general(v68,v62,dimension_numbers=(((0,), (0,)), ((), ())))
_v22 = jnp.reshape(v22, [4, 1])
v71 = jnp.add(_v22,v2)
v72 = lax.logistic(v70)
v73 = lax.dot_general(v72,v71,dimension_numbers=(((0,), (0,)), ((), ())))
_v37 = v37.reshape([1, 4])
v74 = jnp.divide(v6,_v37)
return v74,v65,v67,v69,v73
# fwd: 632, rev: 566, mM: 451
[docs]
def g(v0,v1,v2,v3,v4,v5,v6,v7,v8,v9,v10,v11,v12,v13,v14):
v15 = jnp.ones(())
v16 = jnp.ones(())
v17 = jnp.arctan2(v15,v3)
v18 = jnp.power(v15,v14)
v19 = jnp.arctan2(v13,v17)
v20 = jnp.power(v17,v5)
v21 = jnp.subtract(v9,v8)
v22 = jnp.arctan(v10)
v23 = jnp.multiply(v21,v0)
v24 = jnp.subtract(v22,v18)
v25 = jnp.power(v24,v11)
v26 = jnp.add(v21,v12)
v27 = jnp.exp(v2)
v28 = jnp.add(v21,v26)
v29 = jnp.power(v16,v17)
v30 = jnp.sin(v3)
v31 = jnp.divide(v2,v23)
v32 = jnp.arctan2(v24,v26)
v33 = jnp.subtract(v29,v25)
v34 = jnp.arctan(v33)
v35 = jnp.multiply(v7,v27)
v36 = jnp.tan(v13)
v37 = jnp.add(v27,v19)
v38 = jnp.divide(v7,v6)
v39 = jnp.subtract(v25,v34)
v40 = jnp.subtract(v29,v6)
v41 = jnp.power(v35,v4)
v42 = jnp.cos(v0)
v43 = jnp.subtract(v40,v19)
v44 = jnp.divide(v41,v7)
v45 = jnp.subtract(v40,v8)
v46 = jnp.arctan2(v16,v43)
v47 = jnp.divide(v8,v30)
v48 = jnp.power(v37,v36)
v49 = jnp.power(v47,v44)
v50 = jnp.arctan2(v45,v40)
v51 = jnp.arctan2(v44,v13)
v52 = jnp.divide(v4,v34)
v53 = jnp.arctan2(v41,v1)
v54 = jnp.arctan2(v25,v53)
v55 = jnp.add(v12,v51)
v56 = jnp.power(v3,v48)
v57 = jnp.subtract(v24,v32)
v58 = jnp.arctanh(v48)
v59 = jnp.arctanh(v54)
v60 = jnp.arctan2(v2,v26)
v61 = jnp.subtract(v52,v18)
v62 = jnp.cos(v38)
v63 = jnp.divide(v8,v32)
v64 = jnp.arctan2(v42,v46)
v65 = jnp.sinh(v36)
v66 = jnp.subtract(v61,v50)
v67 = jnp.power(v13,v39)
v68 = jnp.power(v37,v26)
v69 = jnp.subtract(v68,v43)
v70 = jnp.log(v65)
v71 = jnp.power(v23,v58)
v72 = jnp.arctan2(v50,v69)
v73 = jnp.divide(v20,v72)
v74 = jnp.multiply(v41,v56)
v75 = jnp.multiply(v39,v33)
v76 = jnp.multiply(v61,v15)
v77 = jnp.power(v66,v64)
v78 = jnp.arctan2(v53,v25)
v79 = jnp.subtract(v59,v60)
v80 = jnp.arctan2(v73,v41)
v81 = jnp.multiply(v74,v46)
v82 = jnp.square(v5)
v83 = jnp.arctan2(v28,v62)
v84 = jnp.arctan2(v39,v9)
v85 = jnp.multiply(v4,v12)
v86 = jnp.divide(v57,v56)
v87 = jnp.arctan2(v61,v63)
v88 = jnp.arcsinh(v14)
v89 = jnp.power(v86,v83)
v90 = jnp.arcsin(v66)
v91 = jnp.subtract(v70,v79)
v92 = jnp.arctan2(v37,v65)
v93 = jnp.multiply(v67,v77)
v94 = jnp.power(v87,v55)
v95 = jnp.square(v49)
v96 = jnp.divide(v92,v94)
v97 = jnp.add(v15,v95)
v98 = jnp.divide(v91,v76)
v99 = jnp.arctanh(v46)
v100 = jnp.multiply(v75,v45)
v101 = jnp.divide(v100,v28)
v102 = jnp.add(v90,v29)
v103 = jnp.arctan2(v80,v84)
v104 = jnp.subtract(v20,v91)
v105 = jnp.add(v82,v90)
v106 = jnp.divide(v102,v103)
v107 = jnp.power(v75,v42)
v108 = jnp.multiply(v68,v62)
v109 = jnp.divide(v102,v1)
v110 = jnp.power(v98,v101)
v111 = jnp.negative(v100)
v112 = jnp.divide(v31,v89)
v113 = jnp.divide(v105,v108)
v114 = jnp.subtract(v22,v70)
v115 = jnp.multiply(v106,v44)
v116 = jnp.subtract(v107,v81)
v117 = jnp.power(v58,v115)
v118 = jnp.subtract(v78,v108)
v119 = jnp.tan(v87)
v120 = jnp.power(v87,v48)
v121 = jnp.add(v105,v38)
return v121,v71,v85,v88,v93,v96,v97,v99,v104,v109,v110,v111,v112,v113,v114,v116,v117,v118,v119,v120