graphax package

Subpackages

Submodules

graphax.core module

graphax.core.append_pre_transforms(pre, out, iota)[source]
graphax.core.jacve(fun: Callable, order: Sequence[int] | str, argnums: Sequence[int] = (0,), has_aux: bool = False, count_ops: bool = False, sparse_representation: bool = False) Callable[source]

Jacobian fun with respect to the argnums using the vertex elimination method. The vertex elimination order can be specified as a sequence of integers or as a string “forward” or “fwd” for forward elimination and “reverse” or “rev” for reverse elimination. The forward order basically corresponds to the elimination order [1, 2, 3, …] while the reverse order corresponds to […, 3, 2, 1]. For custom orders, just pass the sequence of integers in the desired order. Additionally, the count_ops flag can be set to True to count the number of multiplications and additions during the elimination process, i.e. the Jacobian accumulation. The sparsee_representation flag can be set to True to return the Jacobian in a sparse representation using the SparseTensor class.

Parameters:
  • fun (Callable) – Function to differentiate.

  • order (Union[Sequence[int], str]) – Vertex elimination order. Either pass the desired order directly or specify a string. Allows options are “forward”, “fwd”, “reverse” and “rev”.

  • argnums (Sequence[int], optional) – Argument numbers to differentiate with respect to. Defaults to (0,).

  • has_aux (bool) – _description_

  • count_ops (bool, optional) – Count the number of operations during the elimination process. Defaults to False.

  • sparse_representation (bool, optional) – Return the Jacobian in a sparse representation. Defaults to False.

Returns:

The function that returns the Jacobian of fun.

Return type:

Callable

graphax.core.prepend_post_transforms(post, out, iota)[source]
graphax.core.tree_allclose(tree1, tree2, equal_nan: bool = False) bool[source]
graphax.core.unload_post_transforms(post, pre, iota)[source]
graphax.core.unload_pre_transforms(post, pre, iota)[source]
graphax.core.vertex_elimination_jaxpr(jaxpr: Jaxpr, order: Sequence[int] | str, consts: Sequence[Literal], *args, has_aux: bool = False, argnums: Sequence[int] = (0,), count_ops: bool = False, sparse_representation: bool = False) Sequence[Sequence[Array]][source]

Function that generates a new vertex elimination jaxpression based on the vertex elimination jaxpression jaxpr found by JAX through tracing the function fun we intend to differentiate. The function operates in three stages:

1.) It creates a computational graph representation amenable to the vertex elimination rule. This is mainly facilitated through _build_graph.

2.) It applies the vertex elimination rule to every vertex following the given order using _eliminate_vertex.

3.) It performs post processing. This includes the application of several Jacobian transformation, densifying sparse tensors and reordering output values.

Parameters:
  • jaxpr (core.Jaxpr) – The jaxpr we want to differentiate.

  • order (Union[Sequence[int], str]) – Vertex elimination order. Either pass the desired order directly or specify a string. Allows options are “forward”, “fwd”, “reverse” and “rev”.

  • consts (Sequence[core.Literal]) – The constant arguments of the function.

  • *args (Any) – The input arguments of the function as a flattened PyTree.

  • argnums (Sequence[int], optional) – Argument numbers to differentiate with respect to. Defaults to (0,).

  • has_aux (bool) – _description_

  • count_ops (bool, optional) – Count the number of operations during the elimination process. Defaults to False.

  • sparse_representation (bool, optional) – Return the Jacobian in a sparse representation. Defaults to False.

Returns:

The Jacobian of the function fun.

The output is a list of lists which corresponds to a flattened PyTree of the actual input parameters and will be reassambled into the correct PyTree by jacve.

Return type:

Sequence[Sequence[jnp.ndarray]]

graphax.equinox_bindings module

graphax.equinox_bindings.eqx_jacve(fun: Callable, order: Sequence[int] | str, argnums: Sequence[int] = (0,), count_ops: bool = False, sparse_representation: bool = False) Callable[source]
graphax.equinox_bindings.filter_jacve(fun=<object object>, **gradkwargs) Callable[source]

TODO docstring

graphax.primitives module

class graphax.primitives.JacobianTransform(transform: Callable[[SparseTensor, SparseTensor, Array], SparseTensor], inverse_transform: Callable[[SparseTensor, SparseTensor, Array], SparseTensor] = None)[source]

Bases: object

apply(tensor: SparseTensor, iota: Array) SparseTensor[source]
apply_inverse(tensor: SparseTensor, iota: Array) SparseTensor[source]
inverse_transform: Callable[[SparseTensor, SparseTensor, Array], SparseTensor]
transform: Callable[[SparseTensor, SparseTensor, Array], SparseTensor]
graphax.primitives.add_elemental_rule(*operands, **params) tuple[Array, ...][source]
graphax.primitives.atan2_elemental_rule(*operands, **params) tuple[Array, ...][source]
graphax.primitives.broadcast_elemental_rule(primals, **params)[source]
graphax.primitives.concatenate_elemental_rule(primals, **params)[source]
graphax.primitives.convert_element_type_rule(primals, **params)[source]
graphax.primitives.defelemental(primitive, elementalrule)[source]
graphax.primitives.defelemental2(primitive, elementalrule)[source]
graphax.primitives.device_put_elemental_rule(primals, **params)[source]
graphax.primitives.div_elemental_rule(*operands, **params) tuple[Array, ...][source]
graphax.primitives.dot_general_elemental_rule(primals, **params)[source]
graphax.primitives.eq_elemental_rule(*operands, **params) tuple[Array, ...][source]
graphax.primitives.get_aval_shape(val)[source]
graphax.primitives.get_ndim(arr)[source]
graphax.primitives.get_shape(arr)[source]
graphax.primitives.iota_elemental_rule(primals, **params)[source]
graphax.primitives.make_parallel_jacobian(i, primals, val_out, elemental)[source]
graphax.primitives.max_elemental_rule(*operands, **params) tuple[Array, ...][source]
graphax.primitives.min_elemental_rule(*operands, **params) tuple[Array, ...][source]
graphax.primitives.mul_elemental_rule(*operands, **params) tuple[Array, ...][source]
graphax.primitives.pjit_elemental_rule(primals, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline)[source]
graphax.primitives.pow_elemental_rule(*operands, **params) tuple[Array, ...][source]
graphax.primitives.reduce_elemental_rule(primals, agg, **params)[source]
graphax.primitives.reduce_max_elemental_rule(primals, **params)[source]
graphax.primitives.reduce_min_elemental_rule(primals, **params)[source]
graphax.primitives.reduce_sum_elemental_rule(primals, **params)[source]
graphax.primitives.reshape_elemental_rule(primals, **params)[source]
graphax.primitives.select_elemental_rule(primals, **params)[source]
graphax.primitives.slice_elemental_rule(primals, **params)[source]
graphax.primitives.squeeze_elemental_rule(primals, **params)[source]
graphax.primitives.standard_elemental(elementalrule, primitive, primals, **params)[source]
graphax.primitives.standard_elemental2(elementalrule, primitive, primals, **params)[source]
graphax.primitives.stop_gradient_elemental_rule(primals, **params)[source]
graphax.primitives.sub_elemental_rule(*operands, **params) tuple[Array, ...][source]
graphax.primitives.transpose_elemental_rule(primals, **params)[source]
graphax.primitives.with_type_promotion(fn: Callable) Callable[source]

graphax.utils module

Module contents