graphax.sparse package

Submodules

graphax.sparse.tensor module

Sparse tensor algebra implementation

class graphax.sparse.tensor.DenseDimension(id: int, size: int, val_dim: int)[source]

Bases: object

id: int
size: int
val_dim: int
class graphax.sparse.tensor.SparseDimension(id: int, size: int, val_dim: int, other_id: int)[source]

Bases: object

id: int
other_id: int
size: int
val_dim: int
class graphax.sparse.tensor.SparseTensor(out_dims: Sequence[DenseDimension | SparseDimension], primal_dims: Sequence[DenseDimension | SparseDimension], val: Array | ndarray | bool | number, pre_transforms: Sequence[Callable] = None, post_transforms: Sequence[Callable] = None)[source]

Bases: object

The SparseTensor object enables the representation of sparse tensors that if out_dims or primal_dims is empty, this implies a scalar dependent or independent variable. if both are empty, then we have a scalar value and everything becomes trivial and the val field contains the value of the singleton partial

copy(val: Array | ndarray | bool | number = None)[source]

Function that copies the given sparse tensor object entirely except for the val property which can be replaced by a new value.

Parameters:

val (Array, optional) – The new value of the val property. Defaults to None.

Returns:

A copy of the original SparseTensor object.

Return type:

SparseTensor

dense(iota: Array | ndarray | bool | number) Array | ndarray | bool | number[source]

Materializes tensor to actual dense shape.

Parameters:

iota (Array) – The Kronecker matrix/tensor that is used to materialize the tensor.

Returns:

Dense representation of the sparse tensor.

Return type:

Array

dims: Tuple[DenseDimension | SparseDimension]
out_dims: Tuple[DenseDimension | SparseDimension]
post_transforms: Sequence[Callable]
pre_transforms: Sequence[Callable]
primal_dims: Tuple[DenseDimension | SparseDimension]
shape: Tuple[int]
val: Array | ndarray | bool | number
graphax.sparse.tensor.get_num_adds(lhs: SparseTensor, rhs: SparseTensor) int[source]

Function that counts the number of multiplications done by addition of two SparseTensor objects.

Parameters:
  • lhs (SparseTensor) – SparseTensor object whose val property we want to add to rhs.val.

  • rhs (SparseTensor) – SparseTensor object whose val property we want to add to lhs.val.

Returns:

The number of additions done by addition of lhs.val and rhs.val.

Return type:

int

graphax.sparse.tensor.get_num_muls(lhs: SparseTensor, rhs: SparseTensor) int[source]
graphax.sparse.tensor.sparse_tensor_zeros_like(st: SparseTensor) SparseTensor[source]

Function that generates a new SparseTensor with only zeros with the same shape as the given SparseTensor st.

Parameters:

st (SparseTensor) – SparseTensor whose shape we want to use to initialize the new SparseTensor object.

Returns:

A copy of a given SparseTensor object with only zeros.

Return type:

SparseTensor

graphax.sparse.utils module

graphax.sparse.utils.count_muls(eqn: JaxprEqn) int[source]

Function that counts the number of multiplications done by a jaxpr equation. The implementation treats every primitive as zero multiplications except for the lax.dot_general and lax.mul primitives. For these, simple algorithms for counting the number of multiplications are implemented.

Parameters:

eqn (core.JaxprEqn) – The JaxprEqn of which we want to know how many multiplications are happening.

Returns:

The number of multiplications done inside the jaxpr equation.

Return type:

int

graphax.sparse.utils.count_muls_jaxpr(jaxpr: ClosedJaxpr) int[source]

Function that counts the number of multiplications done within a jaxpr.

Parameters:

jaxpr (core.ClosedJaxpr) – The ClosedJaxpr of which we want to know how many multiplications are performed.

Returns:

The number of multiplications done within the jax

Return type:

int

graphax.sparse.utils.eye_like(shape: Sequence[int], out_len: int) Array[source]

Function that creates a higher order tensor that is a product of Kronecker deltas.

Parameters:
  • shape (Sequence[int]) – The shape of the higher order tensor that we want to create.

  • out_len (int) – The length of the output tensor, i.e. the number of output dimensions.

Returns:

The higher order tensor that is a product of Kronecker deltas.

Return type:

jnp.ndarray

graphax.sparse.utils.eye_like_copy(shape: Sequence[int], out_len: int, iota: Array) Array[source]

Function that creates a higher order tensor that is a product of Kronecker deltas. It tries to reuse the identity matrix iota as much as possible to create the higher order tensor. If iota is too small, it creates a new identity matrix of the appropriate size.

Parameters:
  • shape (Sequence[int]) – The shape of the higher order tensor that we want to create.

  • out_len (int) – The length of the output tensor, i.e. the number of output dimensions.

  • iota (jnp.ndarray) – The identity matrix that we use to create the higher order tensor.

Returns:

The higher order tensor that is a product of Kronecker deltas.

Return type:

jnp.ndarray

graphax.sparse.utils.get_largest_tensor(tensors: Sequence[ShapedArray]) int[source]

Function that computes the size of the largest tensor in a list of tensors.

Parameters:

tensors (Sequence) – A list of tensors for which we want to know the size of the largest tensor.

Returns:

The size of the largest tensor in the list of tensors.

Return type:

int

graphax.sparse.utils.zeros_like(invar: ShapedArray, outvar: ShapedArray) Array[source]

Function that creates an array of zeros. The shape of the array is the concatenation of the shapes of the input and output dimensions.

Parameters:
  • invar (ShapedArray) – The input variable.

  • outvar (ShapedArray) – The output variable.

Returns:

An array of zeros with the shape of the concatenation of the

shapes of the input and output dimensions.

Return type:

jnp.ndarray

Module contents