Source code for graphax.equinox_bindings

import functools as ft
from functools import wraps
from typing import Any, Callable, Dict, Union, Sequence

import jax
import jax.tree_util as jtu

from equinox import is_array
from equinox._filters import combine, partition, is_inexact_array
from equinox._module import Module, Partial, module_update_wrapper
from equinox._custom_types import sentinel
from equinox import filter_make_jaxpr

from .core import vertex_elimination_jaxpr


class _JacveWrapper(Module):
    _fun: Callable
    _gradkwargs: Dict[str, Any]
    
    @property
    def __wrapped__(self):
        return self._fun
    
    def __call__(self, *args, **kwargs):
        def fun_jacve(_diff_x, _nondiff_x, *_args, **_kwargs):
            _x = combine(_diff_x, _nondiff_x)
            flat_x = jtu.tree_flatten(_x)
            _argnums = [i for i, xs in enumerate(flat_x[0]) if is_inexact_array(xs)]

            return eqx_jacve(self._fun, argnums=_argnums, **self._gradkwargs)(_x, *_args, **_kwargs)
        
        x, *args = args
        diff_x, nondiff_x = partition(x, is_inexact_array)
        
        return fun_jacve(diff_x, nondiff_x, *args, **kwargs)
    
    def __get__(self, instance, owner):
        if instance is None:
            return self
        return Partial(self, instance)


[docs] def filter_jacve( fun=sentinel, **gradkwargs ) -> Callable: """ TODO docstring """ if fun is sentinel: return ft.partial(filter_jacve, **gradkwargs) argnums = gradkwargs.pop("argnums", None) if argnums is not None: raise ValueError( "`argnums` should not be passed. If you need to differentiate " "multiple objects then collect them into a tuple and pass that " "as the first argument." ) return module_update_wrapper(_JacveWrapper(fun, gradkwargs), fun)
# TODO pytree crap needs overhauling
[docs] def eqx_jacve(fun: Callable, order: Union[Sequence[int], str], argnums: Sequence[int] = (0,), count_ops: bool = False, sparse_representation: bool = False) -> Callable: @wraps(fun) def wrapped(*args, **kwargs): # TODO Make repackaging work properly with one input value only in_tree = jtu.tree_structure(args) closed_jaxpr, _, _ = filter_make_jaxpr(fun)(*args, **kwargs) print(closed_jaxpr.jaxpr) x, *args = args flattened_x, _ = jtu.tree_flatten(x) flattened_args, _ = jtu.tree_flatten(args) _x = [arg for arg in flattened_x if is_inexact_array(arg)] _args = [arg for arg in flattened_args if is_array(arg)] _args = _x + _args out = vertex_elimination_jaxpr(closed_jaxpr.jaxpr, order, closed_jaxpr.literals, *_args, argnums=argnums, count_ops=count_ops, sparse_representation=sparse_representation) if count_ops: out, op_counts = out out_tree = jtu.tree_structure(tuple(closed_jaxpr.jaxpr.outvars)) if len(closed_jaxpr.jaxpr.outvars) == 1: return out[0], op_counts return jtu.tree_unflatten(out_tree, out), op_counts else: out_tree = jtu.tree_structure(tuple(closed_jaxpr.jaxpr.outvars)) if len(closed_jaxpr.jaxpr.outvars) == 1: _out = [] i = 0 for j, arg in enumerate(flattened_x + flattened_args): if arg is not None and j in argnums: _out.append(out[0][i]) i += 1 else: _out.append(None) return jtu.tree_unflatten(in_tree, _out)[0] _out = jtu.tree_unflatten(in_tree, out) return _out return wrapped