graphax.sparse package
Submodules
graphax.sparse.tensor module
Sparse tensor algebra implementation
- class graphax.sparse.tensor.SparseDimension(id: int, size: int, val_dim: int, other_id: int)[source]
Bases:
object
- class graphax.sparse.tensor.SparseTensor(out_dims: Sequence[DenseDimension | SparseDimension], primal_dims: Sequence[DenseDimension | SparseDimension], val: Array | ndarray | bool | number, pre_transforms: Sequence[Callable] = None, post_transforms: Sequence[Callable] = None)[source]
Bases:
objectThe SparseTensor object enables the representation of sparse tensors that if out_dims or primal_dims is empty, this implies a scalar dependent or independent variable. if both are empty, then we have a scalar value and everything becomes trivial and the val field contains the value of the singleton partial
- copy(val: Array | ndarray | bool | number = None)[source]
Function that copies the given sparse tensor object entirely except for the val property which can be replaced by a new value.
- Parameters:
val (Array, optional) – The new value of the val property. Defaults to None.
- Returns:
A copy of the original SparseTensor object.
- Return type:
- dense(iota: Array | ndarray | bool | number) Array | ndarray | bool | number[source]
Materializes tensor to actual dense shape.
- Parameters:
iota (Array) – The Kronecker matrix/tensor that is used to materialize the tensor.
- Returns:
Dense representation of the sparse tensor.
- Return type:
Array
- dims: Tuple[DenseDimension | SparseDimension]
- out_dims: Tuple[DenseDimension | SparseDimension]
- primal_dims: Tuple[DenseDimension | SparseDimension]
- graphax.sparse.tensor.get_num_adds(lhs: SparseTensor, rhs: SparseTensor) int[source]
Function that counts the number of multiplications done by addition of two SparseTensor objects.
- Parameters:
lhs (SparseTensor) – SparseTensor object whose val property we want to add to rhs.val.
rhs (SparseTensor) – SparseTensor object whose val property we want to add to lhs.val.
- Returns:
The number of additions done by addition of lhs.val and rhs.val.
- Return type:
- graphax.sparse.tensor.get_num_muls(lhs: SparseTensor, rhs: SparseTensor) int[source]
- graphax.sparse.tensor.sparse_tensor_zeros_like(st: SparseTensor) SparseTensor[source]
Function that generates a new SparseTensor with only zeros with the same shape as the given SparseTensor st.
- Parameters:
st (SparseTensor) – SparseTensor whose shape we want to use to initialize the new SparseTensor object.
- Returns:
A copy of a given SparseTensor object with only zeros.
- Return type:
graphax.sparse.utils module
- graphax.sparse.utils.count_muls(eqn: JaxprEqn) int[source]
Function that counts the number of multiplications done by a jaxpr equation. The implementation treats every primitive as zero multiplications except for the lax.dot_general and lax.mul primitives. For these, simple algorithms for counting the number of multiplications are implemented.
- Parameters:
eqn (core.JaxprEqn) – The JaxprEqn of which we want to know how many multiplications are happening.
- Returns:
The number of multiplications done inside the jaxpr equation.
- Return type:
- graphax.sparse.utils.count_muls_jaxpr(jaxpr: ClosedJaxpr) int[source]
Function that counts the number of multiplications done within a jaxpr.
- Parameters:
jaxpr (core.ClosedJaxpr) – The ClosedJaxpr of which we want to know how many multiplications are performed.
- Returns:
The number of multiplications done within the jax
- Return type:
- graphax.sparse.utils.eye_like(shape: Sequence[int], out_len: int) Array[source]
Function that creates a higher order tensor that is a product of Kronecker deltas.
- graphax.sparse.utils.eye_like_copy(shape: Sequence[int], out_len: int, iota: Array) Array[source]
Function that creates a higher order tensor that is a product of Kronecker deltas. It tries to reuse the identity matrix iota as much as possible to create the higher order tensor. If iota is too small, it creates a new identity matrix of the appropriate size.
- Parameters:
- Returns:
The higher order tensor that is a product of Kronecker deltas.
- Return type:
jnp.ndarray
- graphax.sparse.utils.get_largest_tensor(tensors: Sequence[ShapedArray]) int[source]
Function that computes the size of the largest tensor in a list of tensors.
- Parameters:
tensors (Sequence) – A list of tensors for which we want to know the size of the largest tensor.
- Returns:
The size of the largest tensor in the list of tensors.
- Return type:
- graphax.sparse.utils.zeros_like(invar: ShapedArray, outvar: ShapedArray) Array[source]
Function that creates an array of zeros. The shape of the array is the concatenation of the shapes of the input and output dimensions.
- Parameters:
invar (ShapedArray) – The input variable.
outvar (ShapedArray) – The output variable.
- Returns:
- An array of zeros with the shape of the concatenation of the
shapes of the input and output dimensions.
- Return type:
jnp.ndarray