import copy
from functools import reduce
from typing import Sequence
import jax.lax as lax
import jax.numpy as jnp
from jax._src.core import ClosedJaxpr, JaxprEqn, ShapedArray
[docs]
def zeros_like(invar: ShapedArray, outvar: ShapedArray) -> jnp.ndarray:
"""
Function that creates an array of zeros. The shape of the array is the
concatenation of the shapes of the input and output dimensions.
Args:
invar (ShapedArray): The input variable.
outvar (ShapedArray): The output variable.
Returns:
jnp.ndarray: An array of zeros with the shape of the concatenation of the
shapes of the input and output dimensions.
"""
in_shape = invar.aval.shape
out_shape = outvar.aval.shape
if in_shape == () and out_shape == ():
return 0.0
else:
shape = (*in_shape, *out_shape)
return jnp.zeros(shape)
[docs]
def eye_like(shape: Sequence[int], out_len: int) -> jnp.ndarray:
"""
Function that creates a higher order tensor that is a product of Kronecker deltas.
Args:
shape (Sequence[int]): The shape of the higher order tensor that we want to create.
out_len (int): The length of the output tensor, i.e. the number of
output dimensions.
Returns:
jnp.ndarray: The higher order tensor that is a product of Kronecker deltas.
"""
primal_shape = shape[out_len:]
out_shape = shape[:out_len]
if any([primal_shape == out_shape]):
primal_size = reduce((lambda x, y: x * y), primal_shape, 1)
out_size = reduce((lambda x, y: x * y), out_shape, 1)
if out_size == 1:
return jnp.ones((1,) + tuple(primal_shape))
elif primal_size == 1:
return jnp.ones(tuple(out_shape) + (1,))
else:
return jnp.eye(out_size, primal_size).reshape(*out_shape, *primal_shape)
else:
out_size = reduce((lambda x, y: x * y), out_shape, 1)
val = jnp.eye(out_size).reshape(*out_shape, *primal_shape)
return val
[docs]
def eye_like_copy(shape: Sequence[int], out_len: int, iota: jnp.ndarray) -> jnp.ndarray:
"""
Function that creates a higher order tensor that is a product of Kronecker deltas.
It tries to reuse the identity matrix `iota` as much as possible to create the
higher order tensor. If `iota` is too small, it creates a new identity matrix
of the appropriate size.
Args:
shape (Sequence[int]): The shape of the higher order tensor that we want to create.
out_len (int): The length of the output tensor, i.e. the number of
output dimensions.
iota (jnp.ndarray): The identity matrix that we use to create the higher
order tensor.
Returns:
jnp.ndarray: The higher order tensor that is a product of Kronecker deltas.
"""
primal_shape = shape[out_len:]
out_shape = shape[:out_len]
if any([primal_shape == out_shape]):
primal_size = reduce((lambda x, y: x * y), primal_shape, 1)
out_size = reduce((lambda x, y: x * y), out_shape, 1)
if out_size == 1:
return jnp.ones((1,) + tuple(primal_shape))
elif primal_size == 1:
return jnp.ones(tuple(out_shape) + (1,))
else:
if iota.shape[0] < out_size or iota.shape[1] < primal_size:
iota = jnp.eye(max(out_size, primal_size))
else:
iota = lax.slice(iota, (0, 0), (out_size, primal_size))
sub_iota = lax.slice(iota, (0, 0), (out_size, primal_size))
return sub_iota.reshape(*shape)
else:
# This piece of code creates a proper higher order tensor as that is a
# product of Kronecker deltas
# It does so by creating 2d tensors of the appropriate shape and then
# reshaping them to the correct higher order shape and then multiplying
# them together
l = len(out_shape)
val = 1.0
_primal_shape = copy.copy(primal_shape)
for i, o in enumerate(out_shape):
if o in primal_shape:
j = primal_shape.index(o)
_j = _primal_shape.index(o)
primal_shape.pop(j)
_shape = [1] * len(shape)
_shape[i] = o
_shape[l + _j] = o
if o == 1:
kronecker = jnp.ones((1, 1)).reshape(_shape)
else:
if iota.shape[0] < o or iota.shape[1] < o:
sub_iota = jnp.eye(o)
kronecker = sub_iota.reshape(_shape)
else:
sub_iota = lax.slice(iota, (0, 0), (o, o))
kronecker = sub_iota.reshape(_shape)
val *= kronecker # NOTE: This thing is crazy expensive to compute and not always necessary?
return val
[docs]
def get_largest_tensor(tensors: Sequence[ShapedArray]) -> int:
"""
Function that computes the size of the largest tensor in a list of tensors.
Args:
tensors (Sequence): A list of tensors for which we want to know the size
of the largest tensor.
Returns:
int: The size of the largest tensor in the list of tensors.
"""
sizes = [t.aval.size for t in tensors]
return max(sizes)
[docs]
def count_muls(eqn: JaxprEqn) -> int:
"""
Function that counts the number of multiplications done by a jaxpr equation.
The implementation treats every primitive as zero multiplications except for
the `lax.dot_general` and `lax.mul` primitives. For these, simple algorithms
for counting the number of multiplications are implemented.
Args:
eqn (core.JaxprEqn): The `JaxprEqn` of which we want to know how many
multiplications are happening.
Returns:
int: The number of multiplications done inside the jaxpr equation.
"""
if eqn.primitive is lax.dot_general_p:
contraction_dims = eqn.params["dimension_numbers"][0]
batch_dims = eqn.params["dimension_numbers"][1]
var0, var1 = eqn.invars
var0_shape = list(var0.aval.shape)
var1_shape = list(var1.aval.shape)
for d in contraction_dims[1] + batch_dims[1]:
var1_shape[d] = 1
return reduce((lambda x, y: x * y), var0_shape, 1) * reduce(
(lambda x, y: x * y), var1_shape, 1
)
elif eqn.primitive is lax.mul_p:
return reduce((lambda x, y: x * y), eqn.outvars[0].aval.shape, 1)
else:
return 0
[docs]
def count_muls_jaxpr(jaxpr: ClosedJaxpr) -> int:
"""
Function that counts the number of multiplications done within a jaxpr.
Args:
jaxpr (core.ClosedJaxpr): The `ClosedJaxpr` of which we want to know
how many multiplications are performed.
Returns:
int: The number of multiplications done within the jax
"""
return sum([count_muls(eqn) for eqn in jaxpr.eqns])