Source code for graphax.examples.randoms

import jax
import jax.numpy as jnp
import jax.lax as lax


[docs] def f(v0,v1,v2,v4): v5 = jnp.power(v2,3) v6 = lax.dot_general(v4,v5,dimension_numbers=(((0,), (0,)), ((), ()))) v7 = lax.dot_general(v5,v6,dimension_numbers=(((0,), (1,)), ((), ()))) v8 = jnp.power(v6,2) v9 = lax.dot_general(v0,v8,dimension_numbers=(((0,), (1,)), ((), ()))) v10 = jnp.sqrt(v0) v11 = jnp.arccos(v7) v12 = jnp.transpose(v1,axes=[0, 1]) v13 = lax.dot_general(v11,v8,dimension_numbers=(((1,), (0,)), ((), ()))) _v9 = v9.reshape([1, 1]) v14 = jnp.add(v11,_v9) _v9 = v9.reshape([1, 1]) v15 = jnp.subtract(v14,_v9) v16 = jnp.cos(v9) _v10 = jnp.reshape(v10, [4, 1]) v17 = jnp.add(_v10,v13) v18 = jnp.arcsinh(v17) _v16 = v16.reshape([1, 1]) v19 = jnp.add(v15,_v16) v20 = jnp.negative(v12) v21 = jnp.sqrt(v10) v22 = lax.dot_general(v21,v18,dimension_numbers=(((0,), (1,)), ((), ()))) v23 = jnp.tanh(v12) v24 = jnp.arctan(v19) v25 = jnp.square(v24) v26 = jnp.cos(v22) v27 = jnp.subtract(_v16,v25) v28 = jnp.cosh(v23) v29 = jnp.sum(v20, axis=0) v30 = jnp.transpose(v28,axes=[1, 0]) v31 = jnp.sum(v16, axis=0) v32 = jnp.squeeze(v26) _v29 = jnp.reshape(v29, [3, 1]) v33 = jnp.subtract(_v29,v30) v34 = jnp.divide(v31,v33) _v10 = v10.reshape([4, 1]) v35 = jnp.add(v13,_v10) _v32 = jnp.reshape(v32, [4, 1]) v36 = jnp.multiply(_v32,v35) v37 = jnp.power(v32,2) v38 = jnp.amax(v34, axis=1) v39 = jnp.sinh(v36) v40 = jnp.exp(v27) v41 = lax.slice(v39, start_indices=[3, 3], limit_indices=[4, 4]) v42 = jnp.sum(v37, axis=0) v43 = jnp.add(v42,v38) v44 = lax.dot_general(v32,v40,dimension_numbers=(((0,), (0,)), ((), ()))) v45 = jnp.arcsinh(v43) v46 = lax.dot_general(v24,v15,dimension_numbers=(((1,), (1,)), ((), ()))) v47 = lax.stop_gradient(v41) v48 = lax.logistic(v45) v49 = lax.dot_general(v29,v48,dimension_numbers=(((0,), (0,)), ((), ()))) v50 = lax.dot_general(v22,v21,dimension_numbers=(((0,), (0,)), ((), ()))) v51 = lax.dot_general(v44,v47,dimension_numbers=(((0,), (1,)), ((), ()))) v52 = lax.dot_general(v30,v33,dimension_numbers=(((0,), (0,)), ((), ()))) v53 = jnp.multiply(v46,v50) v54 = jnp.sqrt(v37) v55 = jnp.sqrt(jnp.abs(v51)) v56 = jnp.subtract(v52,v49) v57 = jnp.cosh(v54) v58 = jnp.negative(v37) _v57 = jnp.reshape(v57, [4, 1]) v59 = jnp.add(_v57,v53) v60 = jnp.square(v57) v61 = jnp.sin(v58) v62 = jnp.arctan(v59) v63 = jnp.sum(v62, axis=0) v64 = jnp.amin(v63, axis=0) v65 = jnp.subtract(v64,v55) v66 = jnp.sqrt(v2) v67 = jnp.sum(v56, axis=1) _v61 = jnp.reshape(v61, [4, 1]) v68 = jnp.subtract(_v61,v66) v69 = lax.dot_general(v60,v61,dimension_numbers=(((0,), (0,)), ((), ()))) v70 = lax.dot_general(v68,v62,dimension_numbers=(((0,), (0,)), ((), ()))) _v22 = jnp.reshape(v22, [4, 1]) v71 = jnp.add(_v22,v2) v72 = lax.logistic(v70) v73 = lax.dot_general(v72,v71,dimension_numbers=(((0,), (0,)), ((), ()))) _v37 = v37.reshape([1, 4]) v74 = jnp.divide(v6,_v37) return v74,v65,v67,v69,v73
# fwd: 632, rev: 566, mM: 451
[docs] def g(v0,v1,v2,v3,v4,v5,v6,v7,v8,v9,v10,v11,v12,v13,v14): v15 = jnp.ones(()) v16 = jnp.ones(()) v17 = jnp.arctan2(v15,v3) v18 = jnp.power(v15,v14) v19 = jnp.arctan2(v13,v17) v20 = jnp.power(v17,v5) v21 = jnp.subtract(v9,v8) v22 = jnp.arctan(v10) v23 = jnp.multiply(v21,v0) v24 = jnp.subtract(v22,v18) v25 = jnp.power(v24,v11) v26 = jnp.add(v21,v12) v27 = jnp.exp(v2) v28 = jnp.add(v21,v26) v29 = jnp.power(v16,v17) v30 = jnp.sin(v3) v31 = jnp.divide(v2,v23) v32 = jnp.arctan2(v24,v26) v33 = jnp.subtract(v29,v25) v34 = jnp.arctan(v33) v35 = jnp.multiply(v7,v27) v36 = jnp.tan(v13) v37 = jnp.add(v27,v19) v38 = jnp.divide(v7,v6) v39 = jnp.subtract(v25,v34) v40 = jnp.subtract(v29,v6) v41 = jnp.power(v35,v4) v42 = jnp.cos(v0) v43 = jnp.subtract(v40,v19) v44 = jnp.divide(v41,v7) v45 = jnp.subtract(v40,v8) v46 = jnp.arctan2(v16,v43) v47 = jnp.divide(v8,v30) v48 = jnp.power(v37,v36) v49 = jnp.power(v47,v44) v50 = jnp.arctan2(v45,v40) v51 = jnp.arctan2(v44,v13) v52 = jnp.divide(v4,v34) v53 = jnp.arctan2(v41,v1) v54 = jnp.arctan2(v25,v53) v55 = jnp.add(v12,v51) v56 = jnp.power(v3,v48) v57 = jnp.subtract(v24,v32) v58 = jnp.arctanh(v48) v59 = jnp.arctanh(v54) v60 = jnp.arctan2(v2,v26) v61 = jnp.subtract(v52,v18) v62 = jnp.cos(v38) v63 = jnp.divide(v8,v32) v64 = jnp.arctan2(v42,v46) v65 = jnp.sinh(v36) v66 = jnp.subtract(v61,v50) v67 = jnp.power(v13,v39) v68 = jnp.power(v37,v26) v69 = jnp.subtract(v68,v43) v70 = jnp.log(v65) v71 = jnp.power(v23,v58) v72 = jnp.arctan2(v50,v69) v73 = jnp.divide(v20,v72) v74 = jnp.multiply(v41,v56) v75 = jnp.multiply(v39,v33) v76 = jnp.multiply(v61,v15) v77 = jnp.power(v66,v64) v78 = jnp.arctan2(v53,v25) v79 = jnp.subtract(v59,v60) v80 = jnp.arctan2(v73,v41) v81 = jnp.multiply(v74,v46) v82 = jnp.square(v5) v83 = jnp.arctan2(v28,v62) v84 = jnp.arctan2(v39,v9) v85 = jnp.multiply(v4,v12) v86 = jnp.divide(v57,v56) v87 = jnp.arctan2(v61,v63) v88 = jnp.arcsinh(v14) v89 = jnp.power(v86,v83) v90 = jnp.arcsin(v66) v91 = jnp.subtract(v70,v79) v92 = jnp.arctan2(v37,v65) v93 = jnp.multiply(v67,v77) v94 = jnp.power(v87,v55) v95 = jnp.square(v49) v96 = jnp.divide(v92,v94) v97 = jnp.add(v15,v95) v98 = jnp.divide(v91,v76) v99 = jnp.arctanh(v46) v100 = jnp.multiply(v75,v45) v101 = jnp.divide(v100,v28) v102 = jnp.add(v90,v29) v103 = jnp.arctan2(v80,v84) v104 = jnp.subtract(v20,v91) v105 = jnp.add(v82,v90) v106 = jnp.divide(v102,v103) v107 = jnp.power(v75,v42) v108 = jnp.multiply(v68,v62) v109 = jnp.divide(v102,v1) v110 = jnp.power(v98,v101) v111 = jnp.negative(v100) v112 = jnp.divide(v31,v89) v113 = jnp.divide(v105,v108) v114 = jnp.subtract(v22,v70) v115 = jnp.multiply(v106,v44) v116 = jnp.subtract(v107,v81) v117 = jnp.power(v58,v115) v118 = jnp.subtract(v78,v108) v119 = jnp.tan(v87) v120 = jnp.power(v87,v48) v121 = jnp.add(v105,v38) return v121,v71,v85,v88,v93,v96,v97,v99,v104,v109,v110,v111,v112,v113,v114,v116,v117,v118,v119,v120