graphax package
Subpackages
- graphax.examples package
- Submodules
- graphax.examples.deep_learning module
- graphax.examples.differential_kinematics module
- graphax.examples.easy module
- graphax.examples.economics module
- graphax.examples.minpack module
- graphax.examples.neuromorphic module
- graphax.examples.randoms module
- graphax.examples.roe module
- Module contents
- graphax.sparse package
Submodules
graphax.core module
- graphax.core.jacve(fun: Callable, order: Sequence[int] | str, argnums: Sequence[int] = (0,), has_aux: bool = False, count_ops: bool = False, sparse_representation: bool = False) Callable[source]
Jacobian fun with respect to the argnums using the vertex elimination method. The vertex elimination order can be specified as a sequence of integers or as a string “forward” or “fwd” for forward elimination and “reverse” or “rev” for reverse elimination. The forward order basically corresponds to the elimination order [1, 2, 3, …] while the reverse order corresponds to […, 3, 2, 1]. For custom orders, just pass the sequence of integers in the desired order. Additionally, the count_ops flag can be set to True to count the number of multiplications and additions during the elimination process, i.e. the Jacobian accumulation. The sparsee_representation flag can be set to True to return the Jacobian in a sparse representation using the SparseTensor class.
- Parameters:
fun (Callable) – Function to differentiate.
order (Union[Sequence[int], str]) – Vertex elimination order. Either pass the desired order directly or specify a string. Allows options are “forward”, “fwd”, “reverse” and “rev”.
argnums (Sequence[int], optional) – Argument numbers to differentiate with respect to. Defaults to (0,).
has_aux (bool) – _description_
count_ops (bool, optional) – Count the number of operations during the elimination process. Defaults to False.
sparse_representation (bool, optional) – Return the Jacobian in a sparse representation. Defaults to False.
- Returns:
The function that returns the Jacobian of fun.
- Return type:
Callable
- graphax.core.vertex_elimination_jaxpr(jaxpr: Jaxpr, order: Sequence[int] | str, consts: Sequence[Literal], *args, has_aux: bool = False, argnums: Sequence[int] = (0,), count_ops: bool = False, sparse_representation: bool = False) Sequence[Sequence[Array]][source]
Function that generates a new vertex elimination jaxpression based on the vertex elimination jaxpression jaxpr found by JAX through tracing the function fun we intend to differentiate. The function operates in three stages:
1.) It creates a computational graph representation amenable to the vertex elimination rule. This is mainly facilitated through _build_graph.
2.) It applies the vertex elimination rule to every vertex following the given order using _eliminate_vertex.
3.) It performs post processing. This includes the application of several Jacobian transformation, densifying sparse tensors and reordering output values.
- Parameters:
jaxpr (core.Jaxpr) – The jaxpr we want to differentiate.
order (Union[Sequence[int], str]) – Vertex elimination order. Either pass the desired order directly or specify a string. Allows options are “forward”, “fwd”, “reverse” and “rev”.
consts (Sequence[core.Literal]) – The constant arguments of the function.
*args (Any) – The input arguments of the function as a flattened PyTree.
argnums (Sequence[int], optional) – Argument numbers to differentiate with respect to. Defaults to (0,).
has_aux (bool) – _description_
count_ops (bool, optional) – Count the number of operations during the elimination process. Defaults to False.
sparse_representation (bool, optional) – Return the Jacobian in a sparse representation. Defaults to False.
- Returns:
- The Jacobian of the function fun.
The output is a list of lists which corresponds to a flattened PyTree of the actual input parameters and will be reassambled into the correct PyTree by jacve.
- Return type:
Sequence[Sequence[jnp.ndarray]]
graphax.equinox_bindings module
graphax.primitives module
- class graphax.primitives.JacobianTransform(transform: Callable[[SparseTensor, SparseTensor, Array], SparseTensor], inverse_transform: Callable[[SparseTensor, SparseTensor, Array], SparseTensor] = None)[source]
Bases:
object- apply(tensor: SparseTensor, iota: Array) SparseTensor[source]
- apply_inverse(tensor: SparseTensor, iota: Array) SparseTensor[source]
- inverse_transform: Callable[[SparseTensor, SparseTensor, Array], SparseTensor]
- transform: Callable[[SparseTensor, SparseTensor, Array], SparseTensor]