"""
Sparse tensor algebra implementation
"""
import copy
from dataclasses import dataclass
from typing import Callable, Generator, Sequence, Union, Tuple
import jax
import jax.lax as lax
import jax.numpy as jnp
from jax._src.core import ShapedArray
from jax.tree_util import register_pytree_node_class
from chex import Array
from .utils import eye_like_copy, eye_like
# NOTE: a val_dim of None means that we have a possible replication of the tensor
# along the respective dimension `d.size` times to manage broadcasting
# operations such as broadcasted additions or multiplications.
# TODO: what do we do when we have a tensor that consists only of DenseDimensions
# with val_dim=None?
[docs]
@dataclass
class DenseDimension:
id: int
size: int
val_dim: int
# NOTE: a val_dim of None means that we have a factored Kronecker delta in
# our tensor at the respective dimensions.
# Also we can have unmatching `size` and `val.shape[d.val_dim]` for SparseDimensions
# if the size is 1. This is necessary to enable broadcasting operations.
[docs]
@dataclass
class SparseDimension:
id: int
size: int
val_dim: int
other_id: int
Dimension = Union[DenseDimension, SparseDimension]
[docs]
class SparseTensor:
"""
The `SparseTensor object enables` the representation of sparse tensors
that if out_dims or primal_dims is empty, this implies a scalar dependent or
independent variable. if both are empty, then we have a scalar value and
everything becomes trivial and the `val` field contains the value of the
singleton partial
"""
out_dims: Tuple[Dimension]
primal_dims: Tuple[Dimension] # input dimensions
dims: Tuple[Dimension]
shape: Tuple[int] # True shape of the tensor
val: Array
pre_transforms: Sequence[Callable]
post_transforms: Sequence[Callable]
# TODO: Document pre_transforms and post_transforms. What about addition?
# NOTE: We always assume that the dimensions are ordered in ascending order
def __init__(self,
out_dims: Sequence[Dimension],
primal_dims: Sequence[Dimension],
val: Array,
pre_transforms: Sequence[Callable] = None,
post_transforms: Sequence[Callable] = None) -> None:
if pre_transforms is None:
pre_transforms = []
if post_transforms is None:
post_transforms = []
self.out_dims = out_dims if isinstance(out_dims, tuple) else tuple(out_dims)
self.primal_dims = primal_dims if isinstance(primal_dims, tuple) else tuple(primal_dims)
self.dims = self.out_dims + self.primal_dims
self.out_shape = [d.size for d in out_dims]
self.primal_shape = [d.size for d in primal_dims]
self.shape = tuple(self.out_shape + self.primal_shape)
self.val = val
self.pre_transforms = pre_transforms
self.post_transforms = post_transforms
_assert_sparse_tensor_consistency(self)
def __repr__(self) -> str:
def map_str(a: Sequence) -> Generator:
return (str(s) for s in a)
def multiline_seq(s: Sequence, brackets: str) -> str:
lb, rb, *_ = brackets
if s:
return f"{lb}\n " + ",\n ".join(map_str(s)) + f",\n {rb}"
else:
return lb + rb
str_out_shape = ", ".join(map_str(self.out_shape))
str_primal_shape = ", ".join(map_str(self.primal_shape))
multiline_out_dims = multiline_seq(self.out_dims, "()")
multiline_primal_dims = multiline_seq(self.primal_dims, "()")
multiline_pre_transform = multiline_seq(self.pre_transforms, "[]")
multiline_post_transform = multiline_seq(self.post_transforms, "[]")
return f"SparseTensor(\n" \
f" shape = ({str_out_shape} | {str_primal_shape}),\n"\
f" out_dims = {multiline_out_dims},\n"\
f" primal_dims = {multiline_primal_dims},\n"\
f" val = {self.val},\n"\
f" pre_transforms = {multiline_pre_transform},\n"\
f" post_transforms = {multiline_post_transform}\n"\
f")"\
def __add__(self, _tensor):
return _add(self, _tensor)
def __mul__(self, _tensor):
return _mul(self, _tensor)
# TODO: add the case where `val_dim = None` for a `DenseDimension` by
# replicating the tensor `d.size` times using `jnp.tile`.
[docs]
def dense(self, iota: Array) -> Array:
"""
Materializes tensor to actual dense shape.
Args:
iota (Array): The Kronecker matrix/tensor that is used to
materialize the tensor.
Returns:
Array: Dense representation of the sparse tensor.
"""
# Compute shape of the multidimensional eye with which the `val` tensor
# will get multiplied to manifest the sparse dimensions
# If tensor contains SparseDimensions, we have to materialize them
def eye_dim_fn(d: Dimension) -> int:
if isinstance(d, SparseDimension):
return d.size
else:
return 1
eye_shape = [eye_dim_fn(d) for d in self.dims]
eye = eye_like_copy(eye_shape, len(self.out_dims), iota)
# If tensor consists only out of Kronecker Delta's, we can just reshape
# the eye matrix to the shape of the tensor and return it
if self.val is None:
return eye
if self.val.shape == self.shape:
return self.val
# Catching some corner cases
if not self.out_dims and not self.primal_dims:
return self.val
shape = _get_fully_materialized_shape(self)
val = self.val.reshape(shape) * eye
# Get the tiling for DenseDimensions with `val_dim = None`, i.e. replicating
# dimensions
def tile_dim_fn(d: Dimension) -> int:
if isinstance(d, DenseDimension) and d.val_dim is None:
return d.size
else:
return 1
tiling = [tile_dim_fn(d) for d in self.dims]
index_map = eye_like_copy(eye_shape, len(self.out_dims), iota)
return jnp.tile(index_map*val, tiling)
[docs]
def copy(self, val: Array = None):
"""
Function that copies the given sparse tensor object entirely except for
the `val` property which can be replaced by a new value.
Args:
val (Array, optional): The new value of the `val` property. Defaults to None.
Returns:
SparseTensor: A copy of the original SparseTensor object.
"""
out_dims = copy.deepcopy(self.out_dims)
primal_dims = copy.deepcopy(self.primal_dims)
val = self.val if val is None else val
return SparseTensor(out_dims, primal_dims, val, self.pre_transforms, self.post_transforms)
[docs]
def sparse_tensor_zeros_like(st: SparseTensor) -> SparseTensor:
"""
Function that generates a new `SparseTensor` with only zeros with the
same shape as the given `SparseTensor` `st`.
Args:
st (SparseTensor): `SparseTensor` whose shape we want to use to initialize
the new `SparseTensor` object.
Returns:
SparseTensor: A copy of a given `SparseTensor` object with only zeros.
"""
return st.copy(jnp.zeros_like(st.val))
def _assert_sparse_tensor_consistency(st: SparseTensor):
"""
Function that validates the consistency of a `SparseTensor` object,
i.e. checks if the `val` property has the correct shape and if the dimensions
are ordered correctly and sizes match the shape of `val`.
Args:
st (SparseTensor): SparseTensor object we want to validate.
Returns:
bool: True if the `SparseTensor` object is consistent.
"""
# Check if d.size matches val.shape[d.val_dim] for all d
matching_sparse_sizes = all(
d.size == st.val.shape[d.val_dim]
or d.size == 1 # NOTE: required to enable broadcasting operations
if isinstance(d, SparseDimension)
and d.val_dim is not None else True
for d in st.dims
)
matching_dense_sizes = all(
d.size == st.val.shape[d.val_dim]
if isinstance(d, DenseDimension)
and d.val_dim is not None else True
for d in st.dims
)
matching_sizes = matching_sparse_sizes or matching_dense_sizes
unique_out_dims = [d.val_dim for d in st.out_dims if d.val_dim is not None]
unique_primal_dims = [d.val_dim for d in st.primal_dims if d.val_dim is not None]
is_uniqe_out_dims = len(unique_out_dims) == len(set(unique_out_dims))
is_uniqe_primal_dims = len(unique_primal_dims) == len(set(unique_primal_dims))
has_uniqe_dims = is_uniqe_out_dims and is_uniqe_primal_dims
# Check if IDs in out_dims and primal_dims match their index positions
matching_id = all(
od.id == i and pd.id == i + len(st.out_dims)
for i, (od, pd) in enumerate(zip(st.out_dims, st.primal_dims))
)
# Check sparse dimension pairing consistency
matching_sparse_ids = all(
st.primal_dims[d.other_id - len(st.out_dims)].other_id == d.id
if isinstance(d, SparseDimension) else True
for d in st.out_dims
)
assert (matching_sizes
and has_uniqe_dims
and matching_id
and matching_sparse_ids
), f"{st} is not self-consistent!"
# TODO: check if val is consistent with tensor structure?
def _get_fully_materialized_shape(st: SparseTensor) -> Sequence[int]:
"""
Function that returns the shape of a `SparseTensor` object if its `val`
property would be fully materialized. Dimensions of size one are inserted
for one of the two dimensions corresponding to a pair of type `SparseDimension`.
If the `SparseDimension` has val == None, then both are set to one.
This corresponds to a
Args:
st (SparseTensor): The input tensor we want to materialize
swap_sparse_dims (bool, optional): Decides which of the pairs of
SparseDimensions gets the val property. Defaults to False.
Returns:
tuple[int]: The fully materialized shape.
"""
# Compute out_dims full shape-mul
def out_dim_fn(d: Dimension) -> int:
# NOTE we need the case `d.size != st.val.shape[d.val_dim]` because SparseDimensions can be matrialized without
# the correct d.size property
if d.val_dim is None or d.size != st.val.shape[d.val_dim]:
return 1
else:
return d.size
out_shape = tuple(out_dim_fn(d) for d in st.out_dims)
# Compute primal_dims full shape
def primal_dim_fn(d: Dimension) -> int:
if isinstance(d, SparseDimension) or d.val_dim is None:
return 1
else:
return d.size
primal_shape = tuple(primal_dim_fn(d) for d in st.primal_dims)
return out_shape + primal_shape
def _is_pure_dot_product_mul(lhs: SparseTensor, rhs: SparseTensor) -> bool:
"""
Function that checks if two `SparseTensor` objects are compatible for a
dot product multiplication.
Args:
lhs (SparseTensor): The left-hand side `SparseTensor` object.
rhs (SparseTensor): The right-hand side `SparseTensor` object.
Returns:
bool: Are the tensors compatible for multiplication?
"""
return all(isinstance(r, DenseDimension) and isinstance(l, DenseDimension)
for r, l in zip(lhs.primal_dims, rhs.out_dims))
def _is_pure_broadcast_mul(lhs: SparseTensor, rhs: SparseTensor) -> bool:
"""
Function that checks if two `SparseTensor` objects are compatible for a
broadcast multiplication.
Args:
lhs (SparseTensor): The left-hand side `SparseTensor` object.
rhs (SparseTensor): The right-hand side `SparseTensor` object.
Returns:
bool: Are the tensors compatible for multiplication?
"""
return all(isinstance(l, SparseDimension) or isinstance(r, SparseDimension)
for l, r in zip(lhs.primal_dims, rhs.out_dims))
def _mul(lhs: SparseTensor, rhs: SparseTensor) -> SparseTensor:
"""
Function that multiplies two `SparseTensor` objects together. The function
first performs a sequence of checks to guarantee the integrity of both
`SparseTensor` objects. It then proceeds to check if the two tensors are
compatible for multiplication. If they are, it performs the right
multiplicationn type (dot product, broadcast, mixed)
and returns the resulting `SparseTensor` object.
Args:
lhs (SparseTensor): The left-hand side `SparseTensor` object.
rhs (SparseTensor): The right-hand side `SparseTensor` object.
Returns:
SparseTensor: The resulting `SparseTensor` object.
"""
l = len(lhs.out_dims)
r = len(rhs.out_dims)
assert lhs.shape[l:] == rhs.shape[:r], \
f"{lhs.shape} and {rhs.shape} not compatible for multiplication!"
if lhs.shape == () and rhs.shape == ():
# If both tensors are scalars, we can just multiply them directly
res = SparseTensor((), (), lhs.val*rhs.val)
else:
_lhs = lhs.copy()
_rhs = rhs.copy()
if _is_pure_dot_product_mul(_lhs, _rhs):
res = _pure_dot_product_mul(_lhs, _rhs)
elif _is_pure_broadcast_mul(_lhs, _rhs):
res = _pure_broadcast_mul(_lhs, _rhs)
else:
res = _mixed_mul(_lhs, _rhs)
_assert_sparse_tensor_consistency(res)
return res
def _add(lhs: SparseTensor, rhs: SparseTensor) -> SparseTensor:
"""
Function that multiplies two `SparseTensor` objects together. The function
first performs a sequence of checks to guarantee the integrity of both
`SparseTensor` objects. It then proceeds to add both tensors, thereby
possibly materializing certain sparse dimensions.
Args:
lhs (SparseTensor): The left-hand side `SparseTensor` object.
rhs (SparseTensor): The right-hand side `SparseTensor` object.
Returns:
SparseTensor: The resulting `SparseTensor` object.
"""
assert lhs.shape == rhs.shape, \
f"{lhs.shape} and {rhs.shape} not compatible for addition!"
res = _sparse_add(lhs, rhs)
_assert_sparse_tensor_consistency(res)
return res
def _get_new_val_dim(d: Dimension, st: SparseTensor) -> int:
"""
Function that computes the new `val_dim` of a `SparseDimension` object
so that it's position within the `val` property matches the relative position
of the corresponding `SparseDimension` object in the `primal_dims` list.
Args:
d (Dimension): Dimension object whose new `val_dim` we want to compute.
st (SparseTensor): SparseTensor object that contains the `d` object.
Returns:
int: The new `val_dim` of the `d` object.
"""
if d.id < d.other_id:
dims = st.dims[:d.other_id]
else:
dims = st.dims[:d.id]
other_val_dims = [_d.val_dim for _d in dims if _d.val_dim is not None]
if other_val_dims:
return max(other_val_dims)
else:
return None
def _get_padding(lhs_out_dims: Sequence[Dimension],
rhs_primal_dims: Sequence[Dimension]) -> Tuple[Sequence[int], Sequence[int]]:
"""
Function that calculates how many dimensions have to be prepended/appended
to the `val` property of a `SparseTensor` to make it compatible for broadcast
multiplication with another `SparseTensor`.
Removes excess dimensions which are artifacts of `SparseTensor` objects.
Args:
lhs_out_dims (SparseDimension):
SparseDimension object whose `val` property we want to multiply with `rhs.val`.
rhs_primal_dims (SparseDimension):
SparseDimension object whose `val` property we want to multiply with `lhs.val`.
Returns:
Tuple[Sequence[int, ...], Sequence[int, ...]]:
A tuple of integers that tells us how many dimensions we have to
append/prepend to the `val` property of `lhs` and `rhs`.
"""
# Calculate where we have to add additional dimensions to rhs.val
# due to DenseDimensions in lhs.out_dims
lhs_pad = tuple(1 for d in rhs_primal_dims if isinstance(d, DenseDimension) and d.val_dim is not None)
rhs_pad = tuple(1 for d in lhs_out_dims if isinstance(d, DenseDimension) and d.val_dim is not None)
return lhs_pad, rhs_pad
def _assert_broadcast_compatibility(lhs_val: Array, rhs_val: Array):
"""
Function that checks if two arrays are compatible for broadcast multiplication.
Args:
lhs_val (Array): Array that we want to multiply with `rhs_val`.
rhs_val (Array): Array that we want to multiply with `lhs_val`.
"""
assert (
len(lhs_val.shape) == len(rhs_val.shape)
and all(
(ls == rs or ls == 1 or rs == 1)
for (ls, rs) in zip(lhs_val.shape, rhs_val.shape)
)
), f"Shapes {lhs_val.shape} and {rhs_val.shape} not compatible for broadcast multiplication!"
def _get_permutation_from_tensor(st: SparseTensor,
shape: Sequence[int] = None) -> Sequence[int]:
"""
Function that calculates the permutation of the axes of the `val` property
so as that `st.val.shape` matches `shape`. This is necessary to enable proper
broadcasting multiplication.s
Args:
st (SparseTensor): SparseTensor object whose `val` property we want to
compute the permutation for.
shape (Sequence[int], optional): The shape we want to permute the `val`
property of `st` to. Defaults to None.
Returns:
Sequence[int]: Permutation of the axes of `st.val` so that it matches
`shape`.
"""
shape = shape if shape else st.val.shape
permutation = [0]*len(st.val.shape)
i = 0
for d in st.dims:
if d.val_dim is not None and (isinstance(d, DenseDimension) or d.id < d.other_id):
permutation[d.val_dim] = i
i += 1
return permutation
def _get_val_shape(st: SparseTensor) -> Sequence[int]:
"""
Function that computes the shape of the `val` property of a `SparseTensor`
from its corresponding `Dimension` objects.
We assume that for `SparseDimensions` the corresponding dimension is at the
relative position that relates to the entry of the `SparseDimension` in
the `out_dims` list.
Args:
st (SparseTensor): SparseTensor object whose `val` property we want to
compute the shape of.
Returns:
Sequence[int]: Shape of the `val` property of the `SparseTensor` object.
"""
shape = [0]*st.val.ndim
for d in st.out_dims:
if d.val_dim is not None:
shape[d.val_dim] = d.size
for d in st.primal_dims:
if d.val_dim is not None and isinstance(d, DenseDimension):
shape[d.val_dim] = d.size
return shape
def _swap_axes(st: SparseTensor) -> SparseTensor:
"""Function that swaps the axes of the `val` property of a `SparseTensor`
so that the `val_dim`s of SparseDimension objects conincide with the position
in the `primal_dims` list. This is necessary to enable proper broadcasting
multiplication.
Example:
The tensor with `out_dims=(SparseDimension(0, 2, 0, 3), DenseDimension(1, 3, 1))`
and `primal_dims=(DenseDimension(2, 4, 2), SparseDimension(3, 2, 0, 0))` and
`val.shape = (2, 3, 4)` have it's `val` array turned into shape (3, 4, 2).
Args:
st (SparseTensor): SparseTensor object whose `val` property we want to
swap around for broadcasting multiplication.
Returns:
SparseTensor: SparseTensor object with appropriately swapped `val` property.
"""
transposed_shape = [d.size for d in st.out_dims if isinstance(d, DenseDimension)]
transposed_shape += [d.size for d in st.primal_dims if d.val_dim is not None]
l = len(st.out_dims)
for ld in st.out_dims:
# NOTE: not sure if this is a good solution to the problem here:
if transposed_shape == _get_val_shape(st):
break
if isinstance(ld, SparseDimension) and ld.val_dim is not None:
new_val_dim = _get_new_val_dim(ld, st)
for d in st.dims:
if (d.id != ld.id
and d.id != ld.other_id
and d.val_dim is not None
and ld.val_dim <= d.val_dim <= new_val_dim):
d.val_dim -= 1
ld.val_dim = new_val_dim
st.primal_dims[ld.other_id-l].val_dim = new_val_dim
permutation = _get_permutation_from_tensor(st)
st.val = jnp.transpose(st.val, permutation)
return st
def _pad_tensors(lhs: SparseTensor, rhs: SparseTensor):
"""
Function that pads the `val` properties of two `SparseTensor` objects for
proper broadcast multiplication. It does the following three things:
1. It appends new axes to the `lhs` tensor for every `DenseDimension` in
the `rhs.primal_dims` list.
2. It prepends new axes to the `rhs` tensor for every `DenseDimension` in
the `lhs.out_dims` list.
3. It adds new axes to the `lhs.val` for every `SparseDimension` with
`val_dim = None` in the `lhs` tensor where `rhs` tensor has a
`Dimension` object with `val_dim != None` at the same position in
`rhs.out_dims`.
4. It does the same as in 3. for the `rhs` tensor.
5. It checks the `val_dim` property of `lhs` and rhs` at the same index.
If both are None, it does not insert a new axis.
NOTE: The `val_dim` properties of the `Dimension` objects are changed accordingly.
Example:
The `lhs` tensor with
`out_dims=(SparseDimension(0, 2, 0, 3), DenseDimension(1, 3, 1))`
`primal_dims=(DenseDimension(2, 4, 2), SparseDimension(3, 2, 0, 0))`
`val.shape = (2, 3, 4)`
and the `rhs` tensor with
`out_dims=(SparseDimension(0, 4, None, 2), DenseDimension(1, 2, 1))`
`primal_dims=(SparseDimension(2, 4, None, 0), DenseDimension(3, 3, 0))`
`val.shape = (2, 3)`
have their `val` properties turned into shapes (3, 4, 2, 1)
and (1, 1, 2, 3) respectively so that they can be broadcast multiplied.
NOTE: This functions assumes that `_swap_axes` has been applied to the
`lhs` tensor before calling this function.
Args:
lhs (SparseTensor): SparseTensor object whose `val` property we want to
pad for broadcasting multiplication.
rhs (SparseTensor): SparseTensor object whose `val` property we want to
pad for broadcasting multiplication.
Returns:
tuple[SparseTensor, SparseTensor]: tuple of SparseTensor objects with
appropriately padded `val` properties
and corresponding changes to the
`val_dim` properties.
"""
lhs_shape, rhs_shape = list(lhs.val.shape), list(rhs.val.shape)
r = len(rhs.out_dims)
lhs_pad, rhs_pad = _get_padding(lhs.out_dims, rhs.primal_dims)
### Update dimension numbers
for rd in rhs.dims:
if rd.val_dim is not None:
if isinstance(rd, DenseDimension):
rd.val_dim += len(rhs_pad)
elif rd.id < rd.other_id:
rd.val_dim += len(rhs_pad)
primal_dim = rhs.primal_dims[rd.other_id-r]
primal_dim.val_dim += len(rhs_pad)
lhs_shape = list(lhs_shape) + list(lhs_pad)
rhs_shape = list(rhs_pad) + list(rhs_shape)
### Add dimensions where things are sparse
for ld, rd in zip(lhs.primal_dims, rhs.out_dims):
if ld.val_dim is None and rd.val_dim is None:
continue
# ld is sparse
if ld.val_dim is None and isinstance(ld, SparseDimension):
other_val_dim = _get_new_val_dim(ld, lhs)
if other_val_dim is not None:
other_val_dim += 1
else:
other_val_dim = 0
lhs_shape.insert(other_val_dim, 1)
ld.val_dim = other_val_dim
lhs.out_dims[ld.other_id].val_dim = other_val_dim
for d in lhs.dims:
if (d.id != ld.id
and d.id != ld.other_id
and d.val_dim is not None
and d.val_dim >= other_val_dim):
d.val_dim += 1
# rd is sparse
elif rd.val_dim is None and isinstance(rd, SparseDimension):
dims = [d.val_dim for d in rhs.out_dims[:rd.id] if d.val_dim is not None]
new_val_dim = 0
if dims:
new_val_dim = max(dims) + 1
rhs_shape.insert(new_val_dim, 1)
rd.val_dim = new_val_dim
rhs.primal_dims[rd.other_id-r].val_dim = new_val_dim
for d in rhs.dims:
if (d.id != rd.id
and d.id != rd.other_id
and d.val_dim is not None
and d.val_dim >= new_val_dim):
d.val_dim += 1
# ld is replicating
elif ld.val_dim is None and isinstance(ld, DenseDimension):
new_val_dim = _get_val_dim_when_swapped(lhs, ld.id) # cannot use this here!
lhs_shape.insert(new_val_dim, 1)
ld.val_dim = new_val_dim
for d in lhs.dims:
if d.id != ld.id:
if d.val_dim is not None and d.val_dim >= new_val_dim:
d.val_dim += 1
# rd is replicating
elif rd.val_dim is None and isinstance(rd, DenseDimension):
new_val_dim = _get_val_dim(rhs, rd.id)
# dims = [d.val_dim for d in rhs.out_dims[:rd.id] if d.val_dim is not None]
# new_val_dim = 0
# if len(dims) > 0:
# new_val_dim = max(dims) + 1
rhs_shape.insert(new_val_dim, 1)
rd.val_dim = new_val_dim
for d in rhs.dims:
if (d.id != rd.id
and d.val_dim is not None
and d.val_dim >= new_val_dim):
d.val_dim += 1
# Only do a reshape if the shape differs from the unmodified one
if lhs_shape != lhs.val.shape:
lhs.val = lhs.val.reshape(lhs_shape)
if rhs_shape != rhs.val.shape:
rhs.val = rhs.val.reshape(rhs_shape)
return lhs, rhs
def _swap_back_axes(st: SparseTensor) -> SparseTensor:
"""
After two `SparseTensor` objects have been broadcast multiplied, the
resulting tensor usually has the `val` not reshaped so that the dimensions
of it are sorted in ascending order according to the order in which the
corresponding dimensions appear. This function does this.
Example:
We might end up with a `SparseTensor` object that looks like
`out_dims=(SparseDimension(0, 2, 1, 3), DenseDimension(1, 3, 2))`
`primal_dims=(DenseDimension(2, 4, 0), SparseDimension(3, 2, 1, 0))`
`val.shape = (4, 2, 3)`
but we want to have `val.shape = (2, 3, 4)`.
This function computes the necessary permutation and applies it as a
`jnp.transpose` to the `val` property.
Args:
st (SparseTensor): SparseTensor object whose `val` property we want to
swap back around after broadcasting multiplication.
Returns:
SparseTensor: SparseTensor object with `val` property with dimensions
sorted in ascending order.
"""
l = len(st.out_dims)
i = 0
permutation = [0]*len(st.val.shape)
for d in st.dims:
if d.val_dim is not None and (isinstance(d, DenseDimension) or d.id < d.other_id):
permutation[i] = d.val_dim
i += 1
st.val = jnp.transpose(st.val, permutation)
i = 0
for d in st.dims:
if d.val_dim is not None:
if isinstance(d, DenseDimension):
d.val_dim = i
i += 1
else:
if d.id < d.other_id:
d.val_dim = i
primal_dim = st.primal_dims[d.other_id-l]
primal_dim.val_dim = i if primal_dim.val_dim is not None else None
i += 1
return st
def _get_output_tensor(lhs: SparseTensor,
rhs: SparseTensor,
val: Array) -> SparseTensor:
"""Function that computes the `out_dims` and `primal_dims` properties
of a `SparseTensor` object of a broadcast multiplication of two `SparseTensor`
objects. This is separated from the actual multiplication and broadcasting
of the `val` properties to make the code more readable. Also in several
corner cases we actually just need to reassign some `val_dims` and not
perform any actual calculations. The approach here also takes care of this
and saves multiplications by just storing the meta data of some trivial
multiplications.
Example:
TODO put an appropriate example here!
Args:
lhs (SparseTensor): SparseTensor object whose `val` property we want to
pad for broadcasting multiplication.
rhs (SparseTensor): SparseTensor object whose `val` property we want to
pad for broadcasting multiplication.
val (Array): The `val` property of the resulting `SparseTensor` object.
Returns:
SparseTensor: SparseTensor object with `val` property resulting from
broadcasting multiplication of `lhs.val` and `rhs.val`.
"""
new_out_dims, new_primal_dims = [], []
l, r = len(lhs.out_dims), len(rhs.out_dims)
for ld in lhs.out_dims:
if isinstance(ld, DenseDimension):
new_out_dims.append(DenseDimension(ld.id, ld.size, ld.val_dim))
else:
# `d` is a SparseDimension and we know it has a corresponding friend
# in lhs.primal_dims. We now check with what dimension in rhs.out_dims
# it will get contracted.
idx = ld.other_id - l
rd = rhs.out_dims[idx]
if isinstance(rd, DenseDimension):
new_out_dims.append(DenseDimension(ld.id, ld.size, ld.val_dim))
else:
other_id = rd.other_id-r+l
new_out_dims.append(SparseDimension(ld.id, ld.size, ld.val_dim, other_id))
new_primal_dims.insert(other_id-l, SparseDimension(other_id, ld.size, ld.val_dim, ld.id))
# Calculate where we have to add additional dimensions to lhs.val
# due to DenseDimensions in rhs.primal_dims
new_dense_dims = []
for rd in rhs.primal_dims:
if isinstance(rd, DenseDimension):
# shift = sum([1 for dim in new_dense_dims if dim <= rd.val_dim])
new_primal_dims.insert(rd.id-r, DenseDimension(rd.id-r+l, rd.size, rd.val_dim))
else:
idx = rd.other_id - r
ld = lhs.primal_dims[idx]
if isinstance(ld, DenseDimension):
new_dense_dims.append(ld.val_dim)
new_primal_dims.insert(rd.id-r, DenseDimension(rd.id-r+l, ld.size, ld.val_dim))
return SparseTensor(new_out_dims, new_primal_dims, val)
def _pure_broadcast_mul(lhs: SparseTensor, rhs: SparseTensor) -> SparseTensor:
"""Function that executes a pure broadcast multiplication of two `SparseTensor`
objects. This occurs only if for every pair of `Dimension` objects in
`lhs.primal_dims` and `rhs.out_dims` we have at least one `SparseDimension`.
In these cases we do not have to perform a full matrix multiplication and
get away with simple elementwise multiplication given we broadcast the
`val` properties of the `lhs` and `rhs` tensors to the right shape.
This function takes care of this by swapping axes and padding the `val`
properties accordingly.
NOTE: This happens a lot actually!
Args:
lhs (SparseTensor): SparseTensor object whose `val` property we want to
pad for broadcasting multiplication.
rhs (SparseTensor): SparseTensor object whose `val` property we want to
pad for broadcasting multiplication.
Returns:
SparseTensor: SparseTensor object with `val` property resulting from
broadcasting multiplication of `lhs.val` and `rhs.val`.
"""
### Calculate output tensor
if lhs.val is None and rhs.val is None:
return _get_output_tensor(lhs, rhs, None)
elif lhs.val is None:
return _get_output_tensor(lhs, rhs, rhs.val)
elif rhs.val is None:
return _get_output_tensor(lhs, rhs, lhs.val)
else:
# Swap left axes if sparse
lhs = _swap_axes(lhs)
# Add padding
lhs, rhs = _pad_tensors(lhs, rhs)
_assert_broadcast_compatibility(lhs.val, rhs.val)
new_val = lhs.val * rhs.val
out = _get_output_tensor(lhs, rhs, new_val)
res = _swap_back_axes(out)
return res
def _get_val_dim(st: SparseTensor, id: int) -> int:
"""Function to get the `val_dim` of both `SparseDimension` objects in the
`out_dims` and `primal_dims` lists of a `SparseTensor` object.
Args:
st (SparseTensor): SparseTensor object whose `val_dim` we want to know.
id (int): `SparseDimension` object with id `id` we want to compute the
`val_dim` for.
Returns:
int: The `val_dim` of the `SparseDimension` object with id `id`.
"""
dims = st.dims
i = 0
for d in dims[:id]:
if d.val_dim is not None:
if isinstance(d, DenseDimension):
i += 1
elif d.id < d.other_id:
i += 1
return i
def _get_val_dim_when_swapped(st: SparseTensor, id: int) -> int:
""" Function to get the `val_dim` of both `SparseDimension` objects in the
`out_dims` and `primal_dims` lists of a `SparseTensor` object where the axes
have been swapped.
Args:
st (SparseTensor): SparseTensor object whose `val_dim` we want to know.
id (int): `SparseDimension` object with id `id` we want to compute the
`val_dim` for.
Returns:
int: The `val_dim` of the `SparseDimension` object with id `id`.
"""
dims = st.dims
i = 0
for d in dims[:id]:
if d.val_dim is not None:
if isinstance(d, DenseDimension):
i += 1
elif d.id > d.other_id:
i += 1
return i
def _replicate_along_axis(st: SparseTensor, ids: Sequence[int]) -> SparseTensor:
"""Function that replicates the `val` property of a `SparseTensor` object
along a given axis. This is necessary to enable broadcasting multiplication
of two `SparseTensor` objects where one of them has a `DenseDimension` object
in its `out_dims` list and the other one has a `SparseDimension` object in
its `out_dims` list.
Args:
st (SparseTensor): SparseTensor object whose `val` property we want to
replicate along a given axis.
axes (Sequence[int]): Axes along which we want to replicate the `val`
property of `st`.
Returns:
SparseTensor: SparseTensor object with `val` property resulting from
replication of `st.val` along `axis`.
"""
# Expand the dimensions
dims = st.dims
new_dims = []
for id in ids:
d = dims[id]
new_val_dim = _get_val_dim(st, id)
d.val_dim = new_val_dim
new_dims.append(new_val_dim)
for _d in dims[id+1:]:
if _d.val_dim is not None:
if _d.val_dim >= new_val_dim:
_d.val_dim += 1
st.val = jnp.expand_dims(st.val, axis=new_dims)
# Do the tiling
tiling = []
for d in dims:
if d.val_dim is not None:
if isinstance(d, DenseDimension):
if st.val.shape[d.val_dim] != d.size:
tiling.append(d.size)
else:
tiling.append(1)
else:
if d.id < d.other_id:
if st.val.shape[d.val_dim] != d.size:
tiling.append(d.size)
else:
tiling.append(1)
st.val = jnp.tile(st.val, tiling)
return st
def _get_contracting_axes(lhs: SparseTensor, rhs: SparseTensor) -> Tuple[Sequence[int], Sequence[int]]:
"""
Function that computes the axes along which the `val` properties of two
`SparseTensor` objects will get contracted. This is necessary to enable
broadcasting multiplication of two `SparseTensor` objects where both of them
have a `DenseDimension` object in their `out_dims` list.
Args:
lhs (SparseTensor): SparseTensor object whose `val` property we want to
replicate along a given axis.
rhs (SparseTensor): SparseTensor object whose `val` property we want to
replicate along a given axis.
Returns:
Tuple[Sequence[int], Sequence[int]]: A tuple of sequences of integers that
tell us along which axes the `val`
properties of `lhs` and `rhs` will
get contracted.
"""
lcontracting_axes, rcontracting_axes = [], []
for ld, rd in zip(lhs.primal_dims, rhs.out_dims):
if isinstance(ld, DenseDimension) and isinstance(rd, DenseDimension):
# TODO this causes a bug if both axes are replicating axes
if ld.val_dim is not None and rd.val_dim is not None:
lcontracting_axes.append(ld.val_dim)
rcontracting_axes.append(rd.val_dim)
return lcontracting_axes, rcontracting_axes
def _pure_dot_product_mul(lhs: SparseTensor, rhs: SparseTensor) -> SparseTensor:
"""
This function takes care of the cases where all dimensions in
`lhs.primal_dims` and `rhs.out_dims` are of type `DenseDimension`.
Then we only need to do `lax.dot_general` with the right dimension numbers
to get the result.
Args:
lhs (SparseTensor): SparseTensor object whose `val` property we want to
multiply with `rhs.val`.
rhs (SparseTensor): SparseTensor object whose `val` property we want to
multiply with `lhs.val`.
Returns:
SparseTensor: SparseTensor object with `val` property resulting from
the dense dot-product multiplication of `lhs.val` and `rhs.val`.
"""
lcontracting_axes, rcontracting_axes = [], []
lreplication_ids, rreplication_ids = [], []
new_out_dims = lhs.out_dims
l = len(lhs.out_dims)
r = len(rhs.out_dims)
new_primal_dims = []
i = 0
for d in rhs.primal_dims:
if d.val_dim is not None:
new_primal_dims.append(DenseDimension(d.id-r+l, d.size, l+i))
i += 1
else:
new_primal_dims.append(DenseDimension(d.id-r+l, d.size, None))
# Handling contracting variables
for ld, rd in zip(lhs.primal_dims, rhs.out_dims):
if ld.val_dim is None and rd.val_dim is None:
lreplication_ids.append(ld.id-l+len(lhs.out_dims))
rreplication_ids.append(rd.id)
elif ld.val_dim is None:
lreplication_ids.append(ld.id-l+len(lhs.out_dims))
elif rd.val_dim is None:
rreplication_ids.append(rd.id)
else:
lcontracting_axes.append(ld.val_dim)
rcontracting_axes.append(rd.val_dim)
# Reshape lhs.val and rhs.val for tiling and replicate along the
# respective dimensions for dpt_product
if lreplication_ids:
lhs = _replicate_along_axis(lhs, lreplication_ids)
if rreplication_ids:
rhs = _replicate_along_axis(rhs, rreplication_ids)
# Get the contracting axes after the tiling
if lreplication_ids or rreplication_ids:
lcontracting_axes, rcontracting_axes = _get_contracting_axes(lhs, rhs)
# Do the math using dot_general
dimension_numbers = (tuple(lcontracting_axes), tuple(rcontracting_axes))
dimension_numbers = (dimension_numbers, ((), ()))
new_val = lax.dot_general(lhs.val, rhs.val, dimension_numbers)
return SparseTensor(new_out_dims, new_primal_dims, new_val)
def _mixed_mul(lhs: SparseTensor, rhs: SparseTensor) -> SparseTensor:
"""
This is the general case where we have dot-product multiplications as
well as broadcast multiplications. We first do the dot-product multiplications
and then the broadcast multiplications by extracting the diagonal of the
corresponding axes of the resulting tensor of the dot-product contraction.
TODO modularize this!
TODO write docstring
TODO write code for DenseDimension with val_dim = None
Args:
lhs (SparseTensor): SparseTensor object whose `val` property we want to
multiply with `rhs.val`.
rhs (SparseTensor): SparseTensor object whose `val` property we want to
multiply with `lhs.val`.
Returns:
SparseTensor: SparseTensor object with `val` property resulting from
the mixed multiplication of `lhs.val` and `rhs.val`.
"""
new_out_dims, new_primal_dims = [], []
l, r = len(lhs.out_dims), len(rhs.out_dims)
lcontracting_axes, rcontracting_axes = [], []
lreplication_ids, rreplication_ids = [], []
# We do contractions first
for ld, rd in zip(lhs.primal_dims, rhs.out_dims):
if isinstance(ld, DenseDimension) and isinstance(rd, DenseDimension):
if ld.val_dim is None and rd.val_dim is None:
lreplication_ids.append(ld.id-l+len(lhs.out_dims))
rreplication_ids.append(rd.id)
elif ld.val_dim is None:
lreplication_ids.append(ld.id-l+len(lhs.out_dims))
elif rd.val_dim is None:
rreplication_ids.append(rd.id)
else:
lcontracting_axes.append(ld.val_dim)
rcontracting_axes.append(rd.val_dim)
if lreplication_ids:
lhs = _replicate_along_axis(lhs, lreplication_ids)
if rreplication_ids:
rhs = _replicate_along_axis(rhs, rreplication_ids)
# Get the contracting axes after the tiling
if lreplication_ids or rreplication_ids:
lcontracting_axes, rcontracting_axes = _get_contracting_axes(lhs, rhs)
lbroadcasting_axes, rbroadcasting_axes, pos = [], [], []
# Then we do broadcasting by extracting diagonals from the contracted tensor
# TODO: split calculation of Dimension objects and val property!
# NOTE: SparseDimension and a DenseDimension with val_dim = None basically get rid
# of a single jnp.diagonal call!
for ld, rd in zip(lhs.primal_dims, rhs.out_dims):
if isinstance(ld, SparseDimension) or isinstance(rd, SparseDimension):
# Here, we have a broadcasting over two tensors that are not just
# Kronecker deltas
if ld.val_dim is not None and rd.val_dim is not None \
and lhs.val.shape[ld.val_dim] == ld.size \
and rhs.val.shape[rd.val_dim] == rd.size:
lval_dim = ld.val_dim - sum([1 for lc in lcontracting_axes if lc < ld.val_dim])
pos.append(lval_dim)
lbroadcasting_axes.append(ld.val_dim)
rbroadcasting_axes.append(rd.val_dim)
# The following cases cover ...
if isinstance(rd, DenseDimension):
new_out_dims.insert(ld.other_id, DenseDimension(ld.other_id, ld.size, lval_dim))
elif isinstance(ld, DenseDimension):
new_primal_dims.insert(rd.other_id-r, DenseDimension(rd.other_id-r+l, ld.size, lval_dim))
else:
new_out_dims.insert(ld.other_id, SparseDimension(ld.other_id, ld.size, lval_dim, rd.other_id-r+l))
new_primal_dims.insert(rd.other_id-r, SparseDimension(rd.other_id-r+l, ld.size, lval_dim, ld.other_id))
else:
# In this case, one of the two tensors we contract is just a
# Kronecker delta so we can spare ourselves the contraction
# and just reflag the dimension of the new tensor
# NOTE: here we also cover the cases where we have replicating
# dimensions
# The following cases cover ...
# TODO simplify this piece of code.
if isinstance(rd, DenseDimension):
# ld sparse
val_dim = None
if ld.val_dim is not None:
val_dim = ld.val_dim
else:
val_dim = rd.val_dim \
- sum([1 for rc in rcontracting_axes if rc < rd.val_dim]) \
+ lhs.val.ndim \
- sum([1 for lc in lcontracting_axes]) \
- sum([1 for lb in lbroadcasting_axes])
new_out_dims.insert(ld.id, DenseDimension(ld.other_id, ld.size, val_dim))
elif isinstance(ld, DenseDimension):
# rd sparse
val_dim = None
if rd.val_dim is not None:
# TODO This thing fails in many cases!
val_dim = rd.val_dim \
- sum([1 for rc in rcontracting_axes if rc < rd.val_dim]) \
+ sum([1 for ld in lhs.primal_dims if ld.val_dim]) \
- sum([1 for lc in lcontracting_axes])
else:
val_dim = ld.val_dim \
- sum([1 for lc in lcontracting_axes if lc < ld.val_dim])
new_primal_dims.insert(rd.other_id-r, DenseDimension(rd.other_id-r+l, ld.size, val_dim))
else:
# Sparse-Sparse case where either ld.val_dim is None or rd.val_dim is None
val_dim = None
if ld.val_dim is not None:
val_dim = ld.val_dim
elif rd.val_dim is not None:
val_dim = rd.val_dim \
- sum([1 for rc in rcontracting_axes if rc < rd.val_dim]) \
+ lhs.val.ndim - sum([1 for lc in lcontracting_axes]) \
- sum([1 for lb in lbroadcasting_axes])
new_out_dims.insert(ld.other_id, SparseDimension(ld.other_id, ld.size, val_dim, rd.other_id-r+l))
new_primal_dims.insert(rd.other_id-r, SparseDimension(rd.other_id-r+l, ld.size, val_dim, ld.other_id))
if lhs.val is None and rhs.val is None:
new_val = None
elif lhs.val is None:
new_val = rhs.val
elif rhs.val is None:
new_val = lhs.val
else:
dim_numbers = (tuple(lcontracting_axes), tuple(rcontracting_axes))
batch_dimensions = (tuple(lbroadcasting_axes), tuple(rbroadcasting_axes)) # we abuse these guys here to handle the SparseDimensions
dimension_numbers = (dim_numbers, batch_dimensions)
new_val = lax.dot_general(lhs.val, rhs.val, dimension_numbers)
permutation = [None]*new_val.ndim
j = 0
for i in range(new_val.ndim):
if i < len(pos):
permutation[pos[i]] = i
else:
while permutation[j] is not None:
j += 1
permutation[j] = i
new_val = jnp.transpose(new_val, permutation)
# Take care of the old dimensions
for ld in lhs.out_dims:
if isinstance(ld, DenseDimension):
val_dim = None
if ld.val_dim is not None:
val_dim = sum([1 for d in new_out_dims[:ld.id] if d.val_dim is not None])
new_out_dims.insert(ld.id, DenseDimension(ld.id, ld.size, ld.val_dim))
for rd in rhs.primal_dims:
if isinstance(rd, DenseDimension):
val_dim = None
if rd.val_dim is not None:
# TODO add documentation here
num_old_lhs_out_dims = sum([1 for ld in lhs.out_dims
if isinstance(ld, DenseDimension) \
and ld.val_dim is not None])
num_old_rhs_out_dims = sum([1 for rd in rhs.out_dims \
if rd.val_dim is not None])
num_sparse_dims = sum([1 for ld, rd in zip(lhs.primal_dims, rhs.out_dims)
if (isinstance(ld, SparseDimension) \
or isinstance(rd, SparseDimension)) \
and (ld.val_dim is not None \
or rd.val_dim is not None)])
val_dim = rd.val_dim + num_old_lhs_out_dims + num_sparse_dims - num_old_rhs_out_dims
new_primal_dims.insert(rd.id-r, DenseDimension(rd.id-r+l, rd.size, val_dim))
return _swap_back_axes(SparseTensor(new_out_dims, new_primal_dims, new_val))
def _materialize_dimensions(st: SparseTensor, dims: Sequence[int]) -> Array:
"""
Function that materializes the `val` property of a `SparseTensor` object
along a given set of axes. This is necessary to enable broadcasting multiplication
of two `SparseTensor` objects where one of them has a `DenseDimension` object
in its `out_dims` list and the other one has a `SparseDimension` object in
the corresponding `primal_dims` list or vice versa.
Args:
st (SparseTensor): The `SparseTensor` object whose `val` property we want
to materialize along the axes given in `dims`.
dims (Sequence[int]): The axes along which we want to materialize the `val`
property of `st`.
Returns:
Array: The `val` property of `st` materialized along the axes given in `dims`.
"""
if len(dims) == 0:
return st.val
dims = sorted(dims) # reverse=True
# dims = [d if d <= st.val.ndim else -1 for d in dims]
_dims, counter = [], st.val.ndim
for d in dims:
if d <= st.val.ndim:
_dims.append(d)
counter += 1
else:
_dims.append(counter)
counter += 1
return jnp.expand_dims(st.val, axis=_dims)
def _sparse_add(lhs: SparseTensor, rhs: SparseTensor) -> SparseTensor:
"""
TODO write a function that does the addition of two SparseTensor objects
and break it down into several functions that do the different steps of the
process. This is a mess right now!
Args:
lhs (SparseTensor): SparseTensor object whose `val` property we want to
add to `rhs.val`.
rhs (SparseTensor): SparseTensor object whose `val` property we want to
add to `lhs.val`.
Returns:
SparseTensor: SparseTensor object with `val` property resulting from
the sparse addition of `lhs.val` and `rhs.val`.
"""
assert lhs.shape == rhs.shape, \
f"Incompatible shapes {lhs.shape} and {rhs.shape} for addition!"
ldims, rdims = [], []
new_out_dims, new_primal_dims = [], []
_lshape, _rshape = [], []
count = 0
# Check the dimensionality of the `out_dims` of both tensors
for ld, rd in zip(lhs.out_dims, rhs.out_dims):
if ld.val_dim is None and rd.val_dim is None:
dim = count
ldims.append(count)
rdims.append(count)
count += 1
elif ld.val_dim is not None and rd.val_dim is None:
dim = count
rdims.append(count)
count += 1
elif rd.val_dim is not None and ld.val_dim is None:
dim = count
ldims.append(count)
count += 1
else:
dim = count
count += 1
if isinstance(ld, SparseDimension) and isinstance(rd, SparseDimension) \
and ld.other_id == rd.other_id:
new_out_dims.append(SparseDimension(ld.id, ld.size, dim, ld.other_id))
_lshape.append(1)
_rshape.append(1)
elif isinstance(ld, SparseDimension) and isinstance(rd, SparseDimension) \
and ld.other_id != rd.other_id:
new_out_dims.append(DenseDimension(ld.id, ld.size, dim))
_lshape.append(ld.size)
_rshape.append(ld.size)
else:
if isinstance(ld, SparseDimension):
_rshape.append(1)
_lshape.append(ld.size)
elif isinstance(rd, SparseDimension):
_rshape.append(rd.size)
_lshape.append(1)
else:
_lshape.append(1)
_rshape.append(1)
new_out_dims.append(DenseDimension(ld.id, ld.size, dim))
# Check the dimensionality of the `primal_dims` of both tensors
for ld, rd in zip(lhs.primal_dims, rhs.primal_dims):
if isinstance(ld, SparseDimension) and isinstance(rd, SparseDimension) \
and ld.other_id == rd.other_id:
dim = new_out_dims[ld.other_id].val_dim
new_primal_dims.append(SparseDimension(ld.id, ld.size, dim, ld.other_id))
# _lshape.append(1)
# _rshape.append(1)
else:
if ld.val_dim is None and rd.val_dim is None:
dim = count
ldims.append(count)
rdims.append(count)
count += 1
elif ld.val_dim is not None and rd.val_dim is None:
dim = count
rdims.append(count)
count += 1
elif rd.val_dim is not None and ld.val_dim is None:
dim = count
ldims.append(count)
count += 1
else:
dim = count
count += 1
# TODO something here is not right, there is an apparent asymetry
# between the cases for ld and rd!
if isinstance(ld, SparseDimension) and isinstance(rd, SparseDimension) \
and ld.other_id != rd.other_id:
_lshape.append(ld.size)
_rshape.append(ld.size)
elif isinstance(ld, SparseDimension):
_lshape.append(ld.size)
if ld.val_dim is not None:
ldims.append(dim)
_rshape.append(1)
elif isinstance(rd, SparseDimension):
_rshape.append(rd.size)
if rd.val_dim is not None:
rdims.append(dim)
_lshape.append(1)
else:
_lshape.append(1)
_rshape.append(1)
new_primal_dims.append(DenseDimension(ld.id, ld.size, dim))
lhs_val = _materialize_dimensions(lhs, ldims)
rhs_val = _materialize_dimensions(rhs, rdims)
ltiling = [1]*len(lhs_val.shape)
rtiling = [1]*len(rhs_val.shape)
_ldims = lhs.dims
_rdims = rhs.dims
i = 0
for ld, rd in zip(_ldims, _rdims):
if isinstance(ld, DenseDimension) \
and ld.val_dim is None and rd.val_dim is not None:
ltiling[i] = ld.size
i+= 1
elif isinstance(rd, DenseDimension) \
and rd.val_dim is None and ld.val_dim is not None:
rtiling[i] = ld.size
i += 1
if sum(ltiling) > len(ltiling):
lhs_val = jnp.tile(lhs_val, ltiling)
if sum(rtiling) > len(rtiling):
rhs_val = jnp.tile(rhs_val, rtiling)
# We need to materialize sparse dimensions for addition
if sum(_lshape) > len(_lshape):
iota = eye_like(_lshape, len(lhs.out_dims))
lhs_val = iota * lhs_val
if sum(_rshape) > len(_rshape):
iota = eye_like(_rshape, len(rhs.out_dims))
rhs_val = iota*rhs_val
new_val = lhs_val + rhs_val
return SparseTensor(new_out_dims, new_primal_dims, new_val)
[docs]
def get_num_muls(lhs: SparseTensor, rhs: SparseTensor) -> int:
# Function that counts the number of multiplications done by multiplication
# of two SparseTensor objects
num_muls = 1
for d in lhs.out_dims:
if isinstance(d, DenseDimension):
if d.val_dim is not None:
num_muls *= d.size
for ld, rd in zip(lhs.primal_dims, rhs.out_dims):
if isinstance(ld, DenseDimension) and isinstance(rd, DenseDimension):
num_muls *= ld.size
elif isinstance(ld, DenseDimension) and isinstance(rd, SparseDimension):
num_muls *= ld.size
elif isinstance(ld, SparseDimension) and isinstance(rd, DenseDimension):
num_muls *= rd.size
elif isinstance(ld, SparseDimension) and isinstance(rd, SparseDimension):
if ld.val_dim is not None and rd.val_dim is not None:
m = max([lhs.val.shape[ld.val_dim], rhs.val.shape[rd.val_dim]])
num_muls *= m
elif ld.val_dim is not None:
num_muls *= lhs.val.shape[ld.val_dim]
elif rd.val_dim is not None:
num_muls *= rhs.val.shape[rd.val_dim]
else:
# Handle multiplications with a multiple of a Kronecker matrix
if lhs.val is not None and rhs.val is not None:
if lhs.val.size == 1 and rhs.val.size == 1:
num_muls *= 1 # ld.size
elif lhs.val.size == 1:
num_muls *= rd.size
elif rhs.val.size == 1:
num_muls *= ld.size
for d in rhs.primal_dims:
if isinstance(d, DenseDimension):
if d.val_dim is not None:
num_muls *= d.size
return num_muls
# TODO fix this, algorithm might not be correct
[docs]
def get_num_adds(lhs: SparseTensor, rhs: SparseTensor) -> int:
"""Function that counts the number of multiplications done by addition
of two `SparseTensor` objects.
Args:
lhs (SparseTensor): SparseTensor object whose `val` property we want to
add to `rhs.val`.
rhs (SparseTensor): SparseTensor object whose `val` property we want to
add to `lhs.val`.
Returns:
int: The number of additions done by addition of `lhs.val` and `rhs.val`.
"""
num_adds = 1
for ld, rd in zip(lhs.out_dims, rhs.out_dims):
if isinstance(ld, DenseDimension) and isinstance(rd, DenseDimension):
num_adds *= ld.size
elif isinstance(ld, DenseDimension) and isinstance(rd, SparseDimension):
num_adds *= rd.size
elif isinstance(ld, SparseDimension) and isinstance(rd, DenseDimension):
num_adds *= ld.size
elif isinstance(ld, SparseDimension) and isinstance(rd, SparseDimension):
if ld.val_dim is not None and rd.val_dim is not None:
num_adds *= ld.size
for ld, rd in zip(lhs.primal_dims, rhs.primal_dims):
if isinstance(ld, DenseDimension) and isinstance(rd, DenseDimension):
num_adds *= ld.size
elif isinstance(ld, DenseDimension) and isinstance(rd, SparseDimension):
num_adds *= rd.size
elif isinstance(ld, SparseDimension) and isinstance(rd, DenseDimension):
num_adds *= ld.size
return num_adds