Source code for graphax.primitives

import copy
import inspect

from typing import Callable
from collections import defaultdict
from functools import partial
from jax._src.util import safe_map

import numpy as np

import jax
import jax.lax as lax
import jax.numpy as jnp

import jax._src.core as core
from jax._src.pjit import jit_p

from .sparse.tensor import (
    DenseDimension,
    SparseDimension,
    SparseTensor,
    _materialize_dimensions,
    _swap_back_axes,
)

Array = jax.Array


[docs] def get_ndim(arr): if isinstance(arr, (float, int, jax._src.literals.TypedFloat)): return 0 else: return arr.ndim
[docs] def get_shape(arr): if isinstance(arr, (float, int, jax._src.literals.TypedFloat)): return () else: return arr.shape
[docs] def get_aval_shape(val): if isinstance(val, np.ndarray): return val.shape else: return ()
[docs] def make_parallel_jacobian(i, primals, val_out, elemental): if len(primals) > 2: raise NotImplementedError(f"Parallel Jacobians with {len(primals)} inputs not yet supported!") primal = primals[i] primal_size = get_ndim(primal) out_size = get_ndim(val_out) out_shape = get_shape(val_out) if primal_size == 0 and out_size == 0: # Singletons return SparseTensor([], [], elemental) if primal_size == 0: # Broadcast singleton out_dims = [DenseDimension(j, e, j) for j, e in enumerate(val_out.aval.shape)] return SparseTensor(out_dims, [], elemental) if len(primals) == 2 and get_shape(primal) != get_shape(val_out): # Broadcasting case out_dims, primal_dims = [], [] for j, (os, ps) in enumerate(zip(val_out.aval.shape, primal.aval.shape)): n_out, n_primal = len(out_dims), len(primal_dims) if ps != os: val_dim = sum(1 for d in out_dims if d.val_dim is not None) out_dims.append(DenseDimension(j, os, val_dim)) primal_dims.append( DenseDimension(n_out + n_primal + 1, ps, None) ) else: val_dim = sum(1 for d in out_dims if d.size is not None) out_dims.append( SparseDimension(j, os, val_dim, n_out + n_primal + 1) ) primal_dims.append( SparseDimension(n_out + n_primal + 1, os, val_dim, j) ) for d in primal_dims[:-1]: d.id += 1 if isinstance(d, SparseDimension): out_dims[d.other_id].other_id += 1 return _swap_back_axes(SparseTensor(out_dims, primal_dims, elemental)) if len(primals) == 2 and (type(elemental) is float or elemental.size == 1): if type(elemental) is not float: # TODO dirty quick fix that needs to be properly addressed elemental = jnp.squeeze(elemental) val_dim_fn = lambda _j: None else: val_dim_fn = lambda j: j shape = primal.aval.shape out_dims = [SparseDimension(j, e, val_dim_fn(j), out_size + j) for j, e in enumerate(shape)] primal_dims = [SparseDimension(out_size + j, e, val_dim_fn(j), j) for j, e in enumerate(shape)] return SparseTensor(out_dims, primal_dims, elemental)
elemental_rules = {} def _filter_params(fn, params): """Filter params to only those accepted by fn, to handle new JAX params gracefully.""" try: sig = inspect.signature(fn) if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()): return params valid = set(sig.parameters.keys()) return {k: v for k, v in params.items() if k in valid} except (ValueError, TypeError): return params
[docs] def defelemental(primitive, elementalrule): assert isinstance(primitive, core.Primitive) assert not primitive.multiple_results elemental_rules[primitive] = partial(standard_elemental, elementalrule, primitive)
[docs] def standard_elemental(elementalrule, primitive, primals, **params): assert elementalrule is not None, f"Elemental rule does exist for {primitive}!" val_out = primitive.bind(*primals, **params) print(primitive, params) elementals = elementalrule(*primals, **_filter_params(elementalrule, params)) elementals = elementals if isinstance(elementals, tuple) else (elementals,) elementals_out = [ make_parallel_jacobian(i, primals, val_out, elemental) for i, elemental in enumerate(elementals) if not isinstance(primals[i], (float, np.ndarray, np.float32)) ] return val_out, elementals_out
# NOTE: Useful for stuff such as exp_p
[docs] def defelemental2(primitive, elementalrule): assert isinstance(primitive, core.Primitive) assert not primitive.multiple_results elemental_rules[primitive] = partial(standard_elemental2, elementalrule, primitive)
[docs] def standard_elemental2(elementalrule, primitive, primals, **params): assert elementalrule is not None print(primitive, params) val_out = primitive.bind(*primals, **params) _filtered_params = _filter_params(elementalrule, params) elementals = elementalrule(val_out, *primals, **_filtered_params) elementals = elementals if isinstance(elementals, tuple) else (elementals,) elementals_out = [ make_parallel_jacobian(i, primals, val_out, elemental) for i, elemental in enumerate(elementals) if not isinstance(primals[i], (float, np.ndarray, np.float32)) ] return val_out, elementals_out
# Define elemental partials defelemental(lax.neg_p, lambda x: -jnp.ones_like(x)) defelemental2( lax.abs_p, lambda out, primal: primal / out ) # NOTE: not differentiable here! defelemental(lax.integer_pow_p, lambda x, n: n * lax.pow(x, n - 1)) defelemental2(lax.exp_p, lambda out, primal: out) defelemental(lax.log_p, lambda x, accuracy: 1.0 / x) defelemental2(lax.sqrt_p, lambda out, primal: 0.5 / out) defelemental(lax.square_p, lambda x: 2.0 * x) defelemental2(lax.logistic_p, lambda out, primal: out * (1.0 - out)) defelemental(lax.log1p_p, lambda x: 1.0 / (1.0 + x)) defelemental(lax.sin_p, lax.cos) defelemental(lax.asin_p, lambda x, accuracy: 1.0 / lax.sqrt(1.0 - x**2, accuracy)) defelemental(lax.cos_p, lambda x, accuracy: -lax.sin(x)) defelemental(lax.acos_p, lambda x, accuracy: -1.0 / lax.sqrt(1.0 - x**2, accuracy)) defelemental2(lax.tan_p, lambda out, primal: 1.0 + out**2) defelemental(lax.atan_p, lambda x: 1.0 / (1.0 + x**2)) defelemental(lax.sinh_p, lax.cosh) defelemental(lax.asinh_p, lambda x, accuracy: lax.sqrt(1.0 + x**2, accuracy)) defelemental(lax.cosh_p, lax.sinh) defelemental(lax.acosh_p, lambda x, accuracy: 1.0 / lax.sqrt(x**2 - 1.0, accuracy)) defelemental2(lax.tanh_p, lambda out, primal, accuracy: 1.0 - out**2) defelemental(lax.atanh_p, lambda x: 1.0 / (1.0 - x**2)) defelemental(lax.erf_p, lambda x: 2.0 * lax.exp(-(x**2)) / lax.sqrt(jnp.pi))
[docs] def with_type_promotion(fn: Callable) -> Callable: def promoted_fn(*operands, **params) -> tuple[Array, ...]: res = fn(*operands, **params) type = jnp.promote_types(*(op.dtype for op in operands)) return tuple(lax.convert_element_type(el, type) for el in res) return promoted_fn
# TODO this can be significantly optimized # Currently we are creating a new array of ones everytime. Not smart!
[docs] @with_type_promotion def add_elemental_rule(x, y): return (jnp.ones_like(y), jnp.ones_like(x))
defelemental(lax.add_p, add_elemental_rule) # TODO this can also be optimized significantly
[docs] @with_type_promotion def sub_elemental_rule(x, y): return (jnp.ones_like(y), -jnp.ones_like(x))
defelemental(lax.sub_p, sub_elemental_rule)
[docs] @with_type_promotion def mul_elemental_rule(x, y): return (y, x)
defelemental(lax.mul_p, mul_elemental_rule)
[docs] @with_type_promotion def div_elemental_rule(x, y): return (1.0 / y, -x / y**2)
defelemental(lax.div_p, div_elemental_rule)
[docs] @with_type_promotion def atan2_elemental_rule(x, y): abs2 = x**2 + y**2 return (y / abs2, -x / abs2)
defelemental(lax.atan2_p, atan2_elemental_rule)
[docs] @with_type_promotion def max_elemental_rule(x, y): return (x < y, x >= y)
defelemental(lax.max_p, max_elemental_rule)
[docs] @with_type_promotion def min_elemental_rule(x, y): return (jnp.where(x < y, 1, 0), jnp.where(x < y, 0, 1))
defelemental(lax.min_p, min_elemental_rule)
[docs] @with_type_promotion def eq_elemental_rule(x, y): return (jnp.zeros_like(y), jnp.zeros_like(x))
defelemental(lax.eq_p, eq_elemental_rule) defelemental(lax.gt_p, eq_elemental_rule) defelemental(lax.lt_p, eq_elemental_rule)
[docs] def select_elemental_rule(primals, **params): val_out = lax.select_n_p.bind(*primals, **params) size = primals[0].size jacsize = (size, size) num_cases = len(primals) - 1 new_out_dims = [SparseDimension(0, 1, size, 1)] new_primal_dims = [SparseDimension(1, 1, size, 0)] jacval = jnp.zeros(jacsize) return val_out, [ SparseTensor(new_out_dims, new_primal_dims, jacval) for _ in range(num_cases) ]
elemental_rules[lax.select_n_p] = select_elemental_rule
[docs] @with_type_promotion def pow_elemental_rule(out, x, y): return (y * x ** (y - 1), jnp.log(x) * out)
defelemental2(lax.pow_p, pow_elemental_rule) # TODO Create a general reduce rule with a custom derivative!
[docs] def reduce_sum_elemental_rule(primals, **params): val_out = lax.reduce_sum_p.bind(*primals, **params) primal = primals[0] axes = params["axes"] if axes is None: axes = tuple(range(primal.ndim)) new_out_dims.append(DenseDimension(0, 1, 0)) elif isinstance(axes, int): axes = (axes,) new_out_dims, new_primal_dims, shape = [], [], [] l = get_ndim(val_out) # TODO rename l, bad name... count = 0 for i, size in enumerate(get_shape(primal)): if i in axes: # idx = len(new_out_dims) + len(new_primal_dims) # idx = max(idx, 1) if val_out.ndim > 0 else idx new_primal_dims.append(DenseDimension(l + i, size, count)) shape.append(size) count += 1 else: ll = len(new_out_dims) new_out_dims.append(SparseDimension(ll, size, None, l + i)) new_primal_dims.append(SparseDimension(l + i, size, None, ll)) val = jnp.ones(shape, dtype=jnp.float32) return val_out, [SparseTensor(new_out_dims, new_primal_dims, val)]
elemental_rules[lax.reduce_sum_p] = reduce_sum_elemental_rule
[docs] def reduce_max_elemental_rule(primals, **params): val_out = lax.reduce_max_p.bind(*primals, **params) primal = primals[0] axes = params["axes"] shape = list(get_shape(val_out)) new_out_dims, new_primal_dims, _shape = [], [], [] if axes is None: axes = tuple(range(primal.ndim)) new_out_dims.append(DenseDimension(0, 1, 0, True)) elif isinstance(axes, int): axes = (axes,) l = get_ndim(val_out) # TODO rename l, bad name ... for i, size in enumerate(get_shape(primal)): if i in axes: shape.insert(i, 1) idx = len(new_out_dims) + len(new_primal_dims) idx = max(idx, 1) if val_out.ndim > 0 else idx new_primal_dims.append(DenseDimension(idx, size, i)) _shape.append(size) else: ll = len(new_out_dims) new_out_dims.append(SparseDimension(ll, size, i, l + i)) new_primal_dims.append(SparseDimension(l + i, size, i, ll)) _val_out = val_out.reshape(shape) new_val = jnp.where(primal == _val_out, 1, 0) # NOTE: Normalization is important if the maximum is not unique norm = jnp.sum(new_val, axis=axes, keepdims=True) new_val = new_val / norm return val_out, [ _swap_back_axes(SparseTensor(new_out_dims, new_primal_dims, new_val)) ]
elemental_rules[lax.reduce_max_p] = reduce_max_elemental_rule
[docs] def reduce_min_elemental_rule(primals, **params): val_out = lax.reduce_min_p.bind(*primals, **params) primal = primals[0] axes = params["axes"] new_out_dims, new_primal_dims, _shape = [], [], [] if axes is None: axes = tuple(range(primal.ndim)) new_out_dims.append(DenseDimension(0, 1, 0, True)) elif isinstance(axes, int): axes = (axes,) l = get_ndim(val_out) count = 0 for i, size in enumerate(get_shape(primal)): if i in axes: idx = len(new_out_dims) + len(new_primal_dims) idx = max(idx, 1) if val_out.ndim > 0 else idx new_primal_dims.append(DenseDimension(idx, size, i)) _shape.append(size) count += 1 else: ll = len(new_out_dims) new_out_dims.append(SparseDimension(ll, size, i, l + i)) new_primal_dims.append(SparseDimension(l + i, size, i, ll)) new_val = jnp.where(primal == val_out, 1, 0) # NOTE: Normalization is important if the minimum is not unique norm = jnp.sum(new_val, axis=axes, keepdims=True) new_val = new_val / norm return val_out, [ _swap_back_axes(SparseTensor(new_out_dims, new_primal_dims, new_val)) ]
elemental_rules[lax.reduce_min_p] = reduce_min_elemental_rule # first draft unified reduce, TODO: test!
[docs] def reduce_elemental_rule(primals, agg, **params): assert agg in {"sum", "min", "max"}, ( f"{agg} is not one of the valid aggregate functions `sum`, `min`, `max`" ) val_out = getattr(lax, f"reduce_{agg}_p").bind(*primals, **params) shape = list(get_shape(val_out)) primal = primals[0] axes = params["axes"] new_out_dims, new_primal_dims, _shape = [], [], [] if axes is None: axes = tuple(range(primal.ndim)) new_out_dims.append(DenseDimension(0, 1, 0)) elif isinstance(axes, int): axes = (axes,) l = get_ndim(val_out) for i, size in enumerate(get_shape(primal)): if i in axes: if agg == "sum": idx = l + i else: shape.insert(i, 1) idx = len(new_out_dims) + len(new_primal_dims) idx = max(idx, 1) if val_out.ndim > 0 else idx new_primal_dims.append(DenseDimension(idx, size, i)) _shape.append(size) else: ll = len(new_out_dims) val = None if "sum" else i new_out_dims.append(SparseDimension(ll, size, val, l + i)) new_primal_dims.append(SparseDimension(l + i, size, val, ll)) if agg == "sum": new_val = jnp.ones(_shape, dtype=jnp.float32) else: _val_out = val_out.reshape(shape) new_val = jnp.where(primal == _val_out, 1, 0) norm = jnp.sum(new_val, axis=axes, keepdims=True) new_val /= norm return val_out, [ _swap_back_axes(SparseTensor(new_out_dims, new_primal_dims, new_val)) ]
# elemental_rules[lax.reduce_sum_p] = partial(reduce_elemental_rule, agg="sum") # elemental_rules[lax.reduce_min_p] = partial(reduce_elemental_rule, agg="min") # elemental_rules[lax.reduce_max_p] = partial(reduce_elemental_rule, agg="max")
[docs] def dot_general_elemental_rule(primals, **params): val_out = lax.dot_general_p.bind(*primals, **params) lhs, rhs = primals # Which dimensions of the tensors are contracted dimension_numbers = params["dimension_numbers"][0] batch_dims = params["dimension_numbers"][1] # NOTE: Batch dimensions are just treated as SparseDimensions. lhs_contracting_dims = dimension_numbers[0] rhs_contracting_dims = dimension_numbers[1] lhs_batch_dims = batch_dims[0] rhs_batch_dims = batch_dims[1] lhs_shape = list(get_shape(lhs)) rhs_shape = list(get_shape(rhs)) out_shape = list(get_shape(val_out)) lhs_out_dims, rhs_out_dims = [], [] lhs_primal_dims, rhs_primal_dims = [], [] num_out_dims = len(out_shape) i, ii = 0, 0 batch_dim_counter = 0 for lid, ld in enumerate(lhs_shape): other_lid = lid + len(out_shape) if lid in lhs_contracting_dims: # Contracting dimension dim = rhs_contracting_dims[i] lhs_primal_dims.append(DenseDimension(other_lid, rhs_shape[dim], dim)) i += 1 else: if lid in lhs_batch_dims: # If it is a batch dimension, we need to treat it as a SparseDimension # with a valid `val_dim` dim = rhs_batch_dims[ii] ii += 1 lhs_out_dims.insert( batch_dim_counter, SparseDimension(batch_dim_counter, ld, dim, other_lid) ) lhs_primal_dims.append( SparseDimension(other_lid, ld, dim, batch_dim_counter) ) batch_dim_counter += 1 for d in lhs_out_dims[batch_dim_counter:]: d.id += 1 if isinstance(d, SparseDimension): _d = lhs_primal_dims[d.other_id - num_out_dims] _d.other_id += 1 else: # Otherwise, we can just set `val_dim` to None _lid = len(lhs_out_dims) lhs_out_dims.append(SparseDimension(_lid, ld, None, other_lid)) lhs_primal_dims.append(SparseDimension(other_lid, ld, None, _lid)) rhs_out_dims.append(DenseDimension(len(rhs_out_dims), ld, lid)) j, jj = 0, 0 batch_dim_counter = 0 for rid, rd in enumerate(rhs_shape): other_rid = rid + len(out_shape) if rid in rhs_contracting_dims: # Contracting dimension dim = lhs_contracting_dims[j] rhs_primal_dims.append(DenseDimension(other_rid, lhs_shape[dim], dim)) j += 1 else: if rid in rhs_batch_dims: # If it is a batch dimension, we need to treat it as a # SparseDimension with a valid `val_dim` dim = lhs_batch_dims[jj] jj += 1 rhs_out_dims.insert( batch_dim_counter, SparseDimension(batch_dim_counter, rd, dim, other_rid) ) rhs_primal_dims.append( SparseDimension(other_rid, rd, dim, batch_dim_counter) ) batch_dim_counter += 1 for d in rhs_out_dims[batch_dim_counter:]: d.id += 1 if isinstance(d, SparseDimension): _d = rhs_primal_dims[d.other_id - num_out_dims] _d.other_id += 1 else: # Otherwise, we can just set `val_dim` to None _rid = len(rhs_out_dims) rhs_out_dims.append(SparseDimension(_rid, rd, None, other_rid)) rhs_primal_dims.append(SparseDimension(other_rid, rd, None, _rid)) lhs_out_dims.append(DenseDimension(len(lhs_out_dims), rd, rid)) lhs_tensor = SparseTensor(lhs_out_dims, lhs_primal_dims, rhs) rhs_tensor = SparseTensor(rhs_out_dims, rhs_primal_dims, lhs) lhs_tensor = _swap_back_axes(lhs_tensor) rhs_tensor = _swap_back_axes(rhs_tensor) return val_out, [lhs_tensor, rhs_tensor]
elemental_rules[lax.dot_general_p] = dot_general_elemental_rule
[docs] def iota_elemental_rule(primals, **params): val_out = lax.iota_p.bind(*primals, **params) return val_out, []
elemental_rules[lax.iota_p] = iota_elemental_rule
[docs] def device_put_elemental_rule(primals, **params): val_out = lax.device_put_p.bind(*primals, **params) return val_out, []
elemental_rules[lax.device_put_p] = device_put_elemental_rule
[docs] def stop_gradient_elemental_rule(primals, **params): val_out = lax.stop_gradient_p.bind(*primals, **params) return val_out, []
elemental_rules[lax.stop_gradient_p] = stop_gradient_elemental_rule ### Transforms Transform = Callable[[SparseTensor, SparseTensor, jnp.ndarray], SparseTensor]
[docs] class JacobianTransform: transform: Transform inverse_transform: Transform def __init__( self, transform: Transform, inverse_transform: Transform = None ) -> None: self.transform = transform self.inverse_transform = inverse_transform def __repr__(self) -> str: return ( f"JacobianTransform(transform={self.transform}, " f"inverse_transform={self.inverse_transform})" )
[docs] def apply(self, tensor: SparseTensor, iota: jnp.ndarray) -> SparseTensor: if self.transform is None: raise NotImplementedError("Transform not implemented!") return self.transform(tensor, iota)
[docs] def apply_inverse(self, tensor: SparseTensor, iota: jnp.ndarray) -> SparseTensor: if self.inverse_transform is None: raise NotImplementedError("Inverse transform not implemented!") return self.inverse_transform(tensor, iota)
def _inverse_permutation(permutation): inverse = [0] * len(permutation) for i, p in enumerate(permutation): inverse[p] = i return inverse # NOTE: Proper pjit and custom grad implementation only possible with a proper tracing system def _trace_subjaxpr(jaxpr, args, consts): env = {} # env stores the primal value associated with the core.Var object graph = defaultdict(lambda: defaultdict()) # Input connectivity transpose_graph = defaultdict(lambda: defaultdict()) # Output connectivity vo_vertices = set() # contains all intermediate and output vertices counter = 1 # vertex id counter var_id = {} # associates every application of a JaxprEqn with a unique integer # identifier that is later used when using the vertex elimination order. # NOTE: This only works well if the output is a single value. # It is ill-defined when having functions with more than one output!. # Reads variable and corresponding traced shaped array def read(var): if isinstance(var, core.Literal): return var.val return env[var] # Adds new variable and corresponding traced shaped array def write(var, val): env[var] = val # Writes a new elemental partial to the graph and transpose_graph def write_elemental(outvar, invar, val): # _checkify_tensor(val) if isinstance(invar, core.Var): graph[invar][outvar] = val transpose_graph[outvar][invar] = val safe_map(write, jaxpr.invars, args) safe_map(write, jaxpr.constvars, consts) # NOTE: this is essentially the tracing part. Probably should write a proper # tracing system with lift etc. for better compatibility with JAX # Loop though elemental partials and create an abstract representation of # the computational graph for eqn in jaxpr.eqns: # Treatment of intermediate variables that are also output variables for outvar in eqn.outvars: if isinstance(outvar, core.Var) and outvar not in var_id.keys(): var_id[outvar] = counter counter += 1 for invar in eqn.invars: if invar in jaxpr._outvars: vertex = var_id[invar] vo_vertices.add(vertex) invals = safe_map(read, eqn.invars) if eqn.primitive not in elemental_rules: raise NotImplementedError( f"{eqn.primitive} does not have registered elemental partial." ) cce = elemental_rules.get(eqn.primitive) primal_outvals, elemental_outvals = cce(invals, **eqn.params) if eqn.primitive.multiple_results: safe_map(write, eqn.outvars, primal_outvals) else: safe_map(write, eqn.outvars, [primal_outvals]) invars = [invar for invar in eqn.invars if isinstance(invar, core.Var)] # NOTE: Currently only able to treat one output variable _write_elemental = partial(write_elemental, eqn.outvars[0]) if len(invars) == len(elemental_outvals): safe_map(_write_elemental, invars, elemental_outvals) return eqn.outvars, graph, transpose_graph, vo_vertices # TODO: this is a very ugly hack that treats pjit as a normal primitive with a stop_grad
[docs] def pjit_elemental_rule( primals, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, ): # TODO Jamie: How do we handle the gradients here? # jaxpr_cce = cce_core.cce_jaxpr(jaxpr) # print("pjit primals", primals) # print("pjit zero", zero_elementals) # print("pjit jaxpr", jaxpr) # outs, elementals, subgraph, transpose_subgraph, vo_vertices = _trace_subjaxpr(jaxpr.jaxpr, primals, ()) # print("### pjit outs", outs) # print("### pjit elementals", elementals) # print("### pjit jaxpr", jaxpr) outputs = jit_p.bind( *primals, jaxpr=jaxpr, in_shardings=(*in_shardings,), out_shardings=(*out_shardings,), in_layouts=(*in_layouts,), out_layouts=(*out_layouts,), resource_env=resource_env, donated_invars=(*donated_invars,), name=name, keep_unused=keep_unused, inline=inline, ) # print("pjit val_out:", outputs) out_primals = outputs return out_primals, []
elemental_rules[jit_p] = pjit_elemental_rule # Should work for high-dimensional stuff
[docs] def transpose_elemental_rule(primals, **params): # This primitive is written such that it applies the transpose to the out_dims # of the pre_tensor val_out = lax.transpose_p.bind(*primals, **params) permutation = params["permutation"] def transpose_transform(pre, iota): new_out_dims = [] new_primal_dims = pre.primal_dims counter = 0 l = len(pre.out_dims) for p in permutation: new_out_dims.append(pre.out_dims[p]) new_out_dims[-1].id = counter if isinstance(new_out_dims[-1], SparseDimension): other_id = new_out_dims[-1].other_id new_primal_dims[other_id - l].other_id = counter counter += 1 return _swap_back_axes(SparseTensor(new_out_dims, new_primal_dims, pre.val)) return _swap_back_axes( SparseTensor(new_out_dims, new_primal_dims, pre.val) ) def inverse_transpose_transform(post, iota): new_out_dims = post.out_dims new_primal_dims = [] counter = len(post.out_dims) # This implementation is faulty! inv_permutation = _inverse_permutation(permutation) for p in inv_permutation: new_primal_dims.append(post.primal_dims[p]) new_primal_dims[-1].id = counter if isinstance(new_primal_dims[-1], SparseDimension): other_id = new_primal_dims[-1].other_id new_out_dims[other_id].other_id = counter counter += 1 return _swap_back_axes( SparseTensor(new_out_dims, new_primal_dims, post.val) ) transform = JacobianTransform(transpose_transform, inverse_transpose_transform) return val_out, [SparseTensor([], [], None, [transform])]
elemental_rules[lax.transpose_p] = transpose_elemental_rule
[docs] def reshape_elemental_rule(primals, **params): val_out = lax.reshape_p.bind(*primals, **params) # TODO: dimensional collapse is not covered here! # Implement sparsity-aware version for significant speedup! def reshape_transform(pre, iota): # NOTE array is not correctly materialized sometimes! full_val = pre.dense(iota) new_shape = [] new_out_dims = [] new_primal_dims = [] counter = 0 for s in val_out.shape: new_out_dims.append(DenseDimension(counter, s, counter)) new_shape.append(s) counter += 1 for d in pre.primal_dims: new_primal_dims.append(DenseDimension(counter, d.size, counter)) new_shape.append(d.size) counter += 1 full_val = full_val.reshape(new_shape) return SparseTensor(new_out_dims, new_primal_dims, full_val) def inverse_reshape_transform(post, iota): full_val = post.dense(iota) new_shape = [] new_out_dims = [] new_primal_dims = [] counter = 0 for d in post.out_dims: new_out_dims.append(DenseDimension(counter, d.size, counter)) new_shape.append(d.size) counter += 1 for s in primals[0].shape: new_primal_dims.append(DenseDimension(counter, s, counter)) new_shape.append(s) counter += 1 full_val = full_val.reshape(new_shape) return SparseTensor(new_out_dims, new_primal_dims, full_val) transform = JacobianTransform(reshape_transform, inverse_reshape_transform) return val_out, [SparseTensor([], [], None, [transform])]
elemental_rules[lax.reshape_p] = reshape_elemental_rule
[docs] def slice_elemental_rule(primals, **params): # The slice primitive is written in such a way that it just densifies the # Jacobian and then slices it. This is not efficient and there might be ways # to make this more efficient by checking if sparse dimensions are untouched # how this changes the Jacobian. val_out = lax.slice_p.bind(*primals, **params) def slice_transform(pre, iota): start_indices = list(params["start_indices"]) limit_indices = list(params["limit_indices"]) full_val = pre.dense(iota) new_out_dims = [] new_primal_dims = [] counter = 0 for s in val_out.shape: new_out_dims.append(DenseDimension(counter, s, counter)) counter += 1 for d in pre.primal_dims: new_primal_dims.append(DenseDimension(counter, d.size, counter)) start_indices.append(0) limit_indices.append(d.size) counter += 1 new_val = lax.slice(full_val, start_indices, limit_indices) return SparseTensor(new_out_dims, new_primal_dims, new_val) def inverse_slice_transform(post, iota): start_indices = list(params["start_indices"]) limit_indices = list(params["limit_indices"]) full_val = post.dense(iota) new_shape = [] new_out_dims = [] new_primal_dims = [] counter = 0 for d in post.out_dims: new_out_dims.append(DenseDimension(counter, d.size, counter)) new_shape.append(d.size) counter += 1 scatter_zeros = jnp.zeros(counter, dtype=jnp.int32) for s in primals[0].shape: new_primal_dims.append(DenseDimension(counter, s, counter)) new_shape.append(s) counter += 1 zeros = jnp.zeros(new_shape) dims = tuple(range(zeros.ndim)) scatter_dims = lax.ScatterDimensionNumbers(dims, (), dims) _scatter_indices = jnp.array(start_indices, dtype=jnp.int32) scatter_indices = jnp.concatenate([scatter_zeros, _scatter_indices]) new_val = lax.scatter(zeros, scatter_indices, full_val, scatter_dims) return SparseTensor(new_out_dims, new_primal_dims, new_val) transform = JacobianTransform(slice_transform, inverse_slice_transform) return val_out, [SparseTensor([], [], None, [transform])]
elemental_rules[lax.slice_p] = slice_elemental_rule
[docs] def broadcast_elemental_rule(primals, **params): val_out = lax.broadcast_in_dim_p.bind(*primals, **params) dims = sorted(params["broadcast_dimensions"]) shape = params["shape"] def broadcast_transform(pre, iota): new_out_dims = list(copy.deepcopy(pre.out_dims)) new_primal_dims = list(copy.deepcopy(pre.primal_dims)) non_broadcast_dims = [] counter = 0 l = len(pre.out_dims) insert_dims = [i for i, s in enumerate(shape) if i not in dims] for dim in insert_dims: val_dim = sum(1 for d in new_out_dims[:dim+counter] if d.val_dim is not None) non_broadcast_dims.append(val_dim) new_out_dims.insert(dim, DenseDimension(dim + counter, 1, val_dim)) counter += 1 for d in new_out_dims[dim + counter :]: d.id += 1 if d.val_dim is not None: d.val_dim += 1 if isinstance(d, SparseDimension): _d = new_primal_dims[d.other_id - l] # _d.id += 1 d.other_id += 1 _d.other_id += 1 if _d.val_dim is not None: _d.val_dim += 1 for d in new_primal_dims: d.id += 1 if isinstance(d, DenseDimension): if d.val_dim is not None: d.val_dim += 1 else: _d = new_out_dims[d.other_id] # if _d.id > dim + counter: # _d.id += 1 if d.other_id < dim: _d.other_id += 1 broadcast_shape = [d.size for d in new_out_dims if d.val_dim is not None] broadcast_shape += [ d.size for d in new_primal_dims if d.val_dim is not None and isinstance(d, DenseDimension) ] broadcast_dims = [ d.val_dim for d in new_out_dims if d.val_dim not in non_broadcast_dims ] broadcast_dims += [ d.val_dim for d in new_primal_dims if d.val_dim not in non_broadcast_dims and isinstance(d, DenseDimension) ] broadcast_dims = [d for d in broadcast_dims if d is not None] # TODO check this quick hack in the second argument of the or! if len(broadcast_dims) > 0 or pre.val.shape == (): new_val = lax.broadcast_in_dim( pre.val, shape=broadcast_shape, broadcast_dimensions=broadcast_dims ) else: new_val = pre.val return SparseTensor(new_out_dims, new_primal_dims, new_val) def inverse_broadcast_transform(post, iota): rm_dims = [d for d in range(val_out.ndim) if d not in dims] new_out_dims = list(copy.deepcopy(post.out_dims)) new_primal_dims = list(copy.deepcopy(post.primal_dims)) primal_shape = [d.size for d in post.primal_dims] _rm_dims = [] counter = 0 for dim in rm_dims: if new_primal_dims[dim-counter].val_dim is not None: _rm_dims.append(new_primal_dims[dim-counter].val_dim) if isinstance(new_primal_dims[dim-counter], DenseDimension): has_smaller_dims = sum(1 for d in new_primal_dims[:dim+1] if d.val_dim is not None) > 0 old_val_dim = new_primal_dims[dim-counter].val_dim del new_primal_dims[dim-counter] for d in new_primal_dims[dim-counter:]: d.id -= 1 if d.val_dim is not None and old_val_dim is not None: d.val_dim -= 1 if isinstance(d, SparseDimension): _d = new_out_dims[d.other_id] _d.other_id -= 1 else: id = new_primal_dims[dim - counter].id other_id = new_primal_dims[dim - counter].other_id old_dim = new_out_dims[other_id] new_out_dims[other_id] = DenseDimension(old_dim.id, old_dim.size, None) has_smaller_dims = ( sum( [1 for d in new_primal_dims[: dim + 1] if d.val_dim is not None] ) > 0 ) del new_primal_dims[dim] for d in new_out_dims + new_primal_dims: if d.id > id: d.id -= 1 if isinstance(d, SparseDimension): _d = new_out_dims[d.other_id] _d.other_id -= 1 if d.val_dim is not None and has_smaller_dims: d.val_dim -= 1 _d.val_dim -= 1 else: if d.val_dim is not None and has_smaller_dims: d.val_dim -= 1 counter += 1 new_out_dims = tuple(new_out_dims) new_primal_dims = tuple(new_primal_dims) if len(_rm_dims) > 0: if all([post.val.shape[d] == 1 for d in _rm_dims]): new_val = jnp.squeeze(post.val, axis=_rm_dims) else: new_val = jnp.sum(post.val, axis=_rm_dims) else: new_val = post.val return SparseTensor(new_out_dims, new_primal_dims, new_val) transform = JacobianTransform(broadcast_transform, inverse_broadcast_transform) return val_out, [SparseTensor([], [], None, [transform])]
elemental_rules[lax.broadcast_in_dim_p] = broadcast_elemental_rule
[docs] def squeeze_elemental_rule(primals, **params): # NOTE: squeeze is basically just the inverse operation to broadcast_in_dim # since it just adds a DenseDimension of size 1 val_out = lax.squeeze_p.bind(*primals, **params) def squeeze_transform(pre, iota): dims = sorted(params["dimensions"]) new_out_dims = list(copy.deepcopy(pre.out_dims)) new_primal_dims = list(copy.deepcopy(pre.primal_dims)) squeeze_dims = [] counter = 0 for id in dims: idx = [j for j, d in enumerate(new_out_dims) if d.id == id][0] val_dim = new_out_dims[idx].val_dim squeeze_dims.append(val_dim) if isinstance(new_out_dims[idx], SparseDimension): def _check(d, id): if isinstance(d, SparseDimension): return d.other_id == id else: return False other_idx = [j for j, d in enumerate(new_primal_dims) if _check(d, id)][0] other_dim = new_primal_dims[other_idx] new_primal_dims[other_idx] = DenseDimension( other_dim.id, other_dim.size, None ) del new_out_dims[idx] counter += 1 out_ids = [d.id for d in new_out_dims] primal_ids = [d.id for d in new_primal_dims] new_val_dims = [d.val_dim for d in new_out_dims if d.val_dim is not None] new_val_dims += [d.val_dim for d in new_primal_dims if isinstance(d, DenseDimension) and d.val_dim is not None] for d in new_out_dims: d.id = out_ids.index(d.id) if d.val_dim is not None: d.val_dim = new_val_dims.index(d.val_dim) if isinstance(d, SparseDimension): d.other_id = len(new_out_dims) + primal_ids.index(d.other_id) for d in new_primal_dims: d.id = len(new_out_dims) + primal_ids.index(d.id) if d.val_dim is not None: d.val_dim = new_val_dims.index(d.val_dim) if isinstance(d, SparseDimension): d.other_id = out_ids.index(d.other_id) squeeze_dims = [d for d in squeeze_dims if d is not None] if len(squeeze_dims) > 0: new_val = jnp.squeeze(pre.val, axis=squeeze_dims) else: new_val = pre.val return SparseTensor(new_out_dims, new_primal_dims, new_val) def inverse_squeeze_transform(post, iota): new_dims = params["dimensions"] new_out_dims = list(copy.deepcopy(post.out_dims)) new_primal_dims = list(copy.deepcopy(post.primal_dims)) for dim in new_dims: val_dim = sum(1 for d in new_out_dims if d.val_dim is not None) val_dim += sum(1 for d in new_primal_dims[:dim] if d.val_dim is not None and isinstance(d, DenseDimension)) new_primal_dims.insert(dim, DenseDimension(dim, 1, val_dim)) for d in new_primal_dims[dim:]: d.id += 1 if d.val_dim is not None: d.val_dim += 1 if isinstance(d, SparseDimension): _d = new_out_dims[d.other_id] _d.other_id += 1 if _d.val_dim is not None: _d.val_dim += 1 new_val = jnp.expand_dims(post.val, axis=new_dims) return SparseTensor(new_out_dims, new_primal_dims, new_val) transform = JacobianTransform(squeeze_transform, inverse_squeeze_transform) return val_out, [SparseTensor([], [], None, [transform])]
elemental_rules[lax.squeeze_p] = squeeze_elemental_rule
[docs] def concatenate_elemental_rule(primals, **params): # This gradient transformation is designed to take an post edge and # decompose it into the pre edges. This is done by densifying the post along # the respective axes and then use jnp.split to split the tensor. # TODO DynamicJaxprTracer is now a unhashable type, so we can no longer use # it as a key in the dict. We need to find another way of doing this. val_out = lax.concatenate_p.bind(*primals, **params) dim = params["dimension"] offset = primals[0].shape[dim] slices = {0: [0, offset]} for i, val in enumerate(primals[1:], start=1): slices[i] = [offset, offset + val.shape[dim]] offset += val.shape[dim] def concatenate_transform(primal, pre, iota): new_out_dims = list(copy.deepcopy(pre.out_dims)) new_primal_dims = list(copy.deepcopy(pre.primal_dims)) l = len(pre.out_dims) d = new_out_dims[dim] dim_id = d.id primal_idx = [idx for idx, p in enumerate(primals) if p is primal][0] idx, _idx = slices[primal_idx] if isinstance(d, DenseDimension): if d.val_dim is not None: lshape = list(pre.val.shape) rshape = list(pre.val.shape) lshape[d.val_dim] = idx rshape[d.val_dim] = val_out.shape[dim] - _idx lcat_zeros = jnp.zeros(lshape) rcat_zeros = jnp.zeros(rshape) new_val = jnp.concatenate( [lcat_zeros, pre.val, rcat_zeros], axis=d.val_dim ) new_out_dims[dim].size = new_val.shape[d.val_dim] else: # val_dim=None: this output dimension is a Kronecker factor not stored in val. # Materialize it: broadcast pre.val to the primal's slice size (zero-copy), # then pad with zeros in a single lax.pad pass. val_size = _idx - idx new_val_dim = sum(1 for dd in new_out_dims[:dim] if dd.val_dim is not None) new_val = jnp.expand_dims(pre.val, axis=new_val_dim) new_val = jnp.broadcast_to( new_val, (*pre.val.shape[:new_val_dim], val_size, *pre.val.shape[new_val_dim:]) ) pad_config = [(0, 0, 0)] * new_val.ndim pad_config[new_val_dim] = (idx, val_out.shape[dim] - _idx, 0) new_val = lax.pad( new_val, jnp.zeros((), dtype=new_val.dtype), pad_config ) # Inserting a new axis shifts all subsequent val_dims up by 1 for _dim in new_out_dims[dim + 1:]: if _dim.val_dim is not None: _dim.val_dim += 1 for _dim in new_primal_dims: if isinstance(_dim, DenseDimension) and _dim.val_dim is not None: _dim.val_dim += 1 new_out_dims[dim].val_dim = new_val_dim new_out_dims[dim].size = val_out.shape[dim] else: other_id = d.other_id if d.val_dim is not None: _d = new_primal_dims[other_id - l] # Calculate the new val_dim of the primal dimension val_dim = sum(1 for dd in new_out_dims if dd.val_dim is not None) val_dim += sum(1 for dd in new_primal_dims[:other_id - l] if dd.val_dim is not None and isinstance(dd, DenseDimension)) # Update the val_dim of all following dimensions for _dim in new_primal_dims[dim + 1:]: if isinstance(_dim, DenseDimension) and _dim.val_dim is not None: _dim.val_dim += 1 # Materialize the sparse dimensions related to the concatenation dimension new_val = _materialize_dimensions(pre, [d.id]) if iota.shape[0] < d.size or iota.shape[1] < d.size: sub_iota = jnp.eye(d.size, dtype=jnp.float32) else: sub_iota = lax.slice(iota, [0, 0], [d.size, d.size]) shape = [1 for _ in range(pre.val.ndim)] shape[_d.val_dim] = _d.size shape.insert(val_dim, d.size) sub_iota = sub_iota.reshape(shape) new_val = new_val * sub_iota # Make zeros for insertion _size = val_out.shape[dim] _shape = list(new_val.shape) _shape[d.val_dim] = _size _shape[val_dim] = d.size zeros = jnp.zeros(_shape, dtype=jnp.float32) # scatter_indices: where in `zeros` to place `new_val` scatter_indices = [0 for _ in _shape] scatter_indices[d.val_dim] = idx scatter_indices[val_dim] = 0 update_window_dims = tuple(range(len(_shape))) scatter_dims_to_operand_dims = tuple(range(len(_shape))) scatter_dims = lax.ScatterDimensionNumbers( update_window_dims, (), scatter_dims_to_operand_dims ) new_val = lax.scatter( zeros, jnp.array(scatter_indices), new_val, scatter_dims, indices_are_sorted=True, unique_indices=True ) new_out_dims[dim_id] = DenseDimension( dim_id, val_out.shape[dim], d.val_dim ) new_primal_dims[other_id - l] = DenseDimension( other_id, d.size, val_dim ) else: _d = new_primal_dims[other_id - l] _size = val_out.shape[dim] # Calculate the new val_dim of the out dimension out_val_dim = sum(1 for dd in new_out_dims[:dim] if dd.val_dim is not None) # Calculate the new val_dim of the primal dimension primal_val_dim = sum(1 for dd in new_out_dims if dd.val_dim is not None) primal_val_dim += sum(1 for dd in new_primal_dims[:other_id - l] if dd.val_dim is not None and isinstance(dd, DenseDimension)) primal_val_dim = max(1, primal_val_dim) # Update the val_dim of all following dimensions for _dim in new_primal_dims[dim + 1:]: if isinstance(_dim, DenseDimension) and _dim.val_dim is not None: _dim.val_dim += 1 # Materialize the sparse dimensions related to the concatenation dimension if pre.val.shape != (): new_val = _materialize_dimensions(pre, [d.id, d.other_id]) else: new_val = pre.val if iota.shape[0] < d.size or iota.shape[1] < d.size: sub_iota = jnp.eye(d.size, dtype=jnp.float32) else: sub_iota = lax.slice(iota, [0, 0], [d.size, d.size]) shape = [1 for _ in range(pre.val.ndim)] shape.insert(out_val_dim, _d.size) shape.insert(primal_val_dim, d.size) new_val = new_val * sub_iota # Make zeros for insertion _shape = list(pre.val.shape) _shape.insert(out_val_dim, _size) _shape.insert(primal_val_dim, _d.size) zeros = jnp.zeros(_shape, dtype=jnp.float32) scatter_dims = lax.ScatterDimensionNumbers( [out_val_dim, primal_val_dim], [], [out_val_dim, primal_val_dim] ) new_val = lax.scatter( zeros, jnp.array([idx, 0]), new_val, scatter_dims, indices_are_sorted=True, unique_indices=True ) new_out_dims[dim_id] = DenseDimension( dim_id, val_out.shape[dim], out_val_dim ) new_primal_dims[other_id - l] = DenseDimension( other_id, d.size, primal_val_dim ) return SparseTensor(new_out_dims, new_primal_dims, new_val) def inverse_concatenate_transform(primal, post, iota): new_out_dims = list(copy.deepcopy(post.out_dims)) new_primal_dims = list(copy.deepcopy(post.primal_dims)) primal_idx = next(idx for idx, p in enumerate(primals) if p is primal) d = None if len(new_primal_dims) > 0: d = new_primal_dims[dim] if isinstance(d, DenseDimension): if d.val_dim is not None: new_val = lax.slice_in_dim( post.val, *slices[primal_idx], axis=d.val_dim ) d.size = new_val.shape[d.val_dim] else: # val_dim=None: the primal dimension is a Kronecker factor not stored in val. # There is no axis to slice — just narrow the size to this primal's contribution. d.size = slices[primal_idx][1] - slices[primal_idx][0] new_val = post.val else: _d = new_out_dims[d.other_id] if d.val_dim is not None: new_out_dims[d.other_id] = DenseDimension(_d.id, _d.size, _d.val_dim) size = slices[primal_idx][1] - slices[primal_idx][0] # Calculate the new val_dim of the primal dimension val_dim = sum(1 for dd in new_out_dims if dd.val_dim is not None) val_dim += sum(1 for dd in new_primal_dims[:dim] if dd.val_dim is not None and type(dd) is DenseDimension) new_primal_dims[dim] = DenseDimension(_d.other_id, size, val_dim) # Update the val_dim of all following dimensions for _dim in new_primal_dims[dim + 1:]: if type(_dim) is DenseDimension and _dim.val_dim is not None: _dim.val_dim += 1 # Materialize the sparse dimensions related to the concatenation dimension new_val = _materialize_dimensions(post, [d.id]) if iota.shape[0] < d.size or iota.shape[1] < d.size: sub_iota = jnp.eye(d.size, dtype=jnp.float32) else: sub_iota = lax.slice(iota, [0, 0], [d.size, d.size]) shape = [1 for _ in range(post.val.ndim)] shape[_d.val_dim] = _d.size shape.insert(val_dim, d.size) sub_iota = sub_iota.reshape(shape) new_val = new_val * sub_iota new_val = lax.slice_in_dim( new_val, *slices[primal_idx], axis=val_dim ) d.size = new_val.shape[d.val_dim] _d.size = new_val.shape[d.val_dim] else: # TODO: complete the implementation here at some point raise NotImplementedError("Finish the implementation!") _d = new_out_dims[d.other_id] if d.val_dim is not None: size = slices[primal_idx][1] - slices[primal_idx][0] out_val_dim = sum(1 for d in new_out_dims[:d.other_id] if d.val_dim is not None) primal_val_dim = sum(1 for d in new_out_dims if d.val_dim is not None) primal_val_dim += sum(1 for d in new_primal_dims[:dim] if d.val_dim is not None and type(d) is DenseDimension) new_out_dims[d.other_id] = DenseDimension(_d.id, _d.size, out_val_dim) new_primal_dims[dim] = DenseDimension(_d.other_id, size, primal_val_dim) # TODO finish this! for d in new_out_dims[d.other_id :]: if type(d) is DenseDimension: if d.val_dim is not None: d.val_dim += 1 # increase the val_dim of all following dimensions for d in new_primal_dims[dim + 1 :]: if type(d) is DenseDimension: if d.val_dim is not None: d.val_dim += 1 # The following piece of code materialized the particular set # of sparse dimensions related to the concatenation dimension new_val = _materialize_dimensions(post, [d.id, d.other_id]) if iota.shape[0] < d.size or iota.shape[1] < d.size: sub_iota = jnp.eye(d.size, dtype=jnp.float32) else: sub_iota = lax.slice(iota, [0, 0], [d.size, d.size]) shape = [1 for _ in range(post.val.ndim)] shape.insert(out_val_dim, _d.size) shape.insert(primal_val_dim, size) sub_iota = sub_iota.reshape(shape) new_val = new_val * sub_iota new_val = lax.slice_in_dim( new_val, *slices[primal], axis=primal_val_dim ) d.size = new_val.shape[d.val_dim] _d.size = new_val.shape[d.val_dim] return SparseTensor(new_out_dims, new_primal_dims, new_val) return val_out, [ SparseTensor( [], [], None, [ JacobianTransform( partial(concatenate_transform, p), partial(inverse_concatenate_transform, p), ) ], ) for p in primals ]
elemental_rules[lax.concatenate_p] = concatenate_elemental_rule
[docs] def convert_element_type_rule(primals, **params): val_out = lax.convert_element_type_p.bind(*primals, **params) new_dtype = params["new_dtype"] def convert_element_type_transform(pre, iota): new_pre_val = lax.convert_element_type(pre.val, new_dtype) new_out_dims = copy.deepcopy(pre.out_dims) new_primal_dims = copy.deepcopy(pre.primal_dims) return SparseTensor(new_out_dims, new_primal_dims, new_pre_val) def inverse_convert_element_type_transform(post, iota): new_post_val = lax.convert_element_type(post.val, new_dtype) new_out_dims = copy.deepcopy(post.out_dims) new_primal_dims = copy.deepcopy(post.primal_dims) return SparseTensor(new_out_dims, new_primal_dims, new_post_val) transform = JacobianTransform( convert_element_type_transform, inverse_convert_element_type_transform ) return val_out, [SparseTensor([], [], None, [transform])]
elemental_rules[lax.convert_element_type_p] = convert_element_type_rule