Skip to content

Extra JAX/Equinox¤

cryojax.jax_util provides a collection of useful functions not found in JAX that tend to be important for using cryoJAX in practice. Depending on developments with core JAX/Equinox and other factors, these functions could be deprecated in future releases of cryojax. Use with caution!

Equinox extensions¤

Helpers for filtering¤

To make use of the full power of JAX, it is highly recommended to learn about equinox. Using equinox, cryoJAX implements its models as pytrees using equinox.Modules. These pytrees can be operated on similarly to any pytree with JAX (e.g. with jax.tree.map). Complementary to the equinox.Module interface, equinox introduces the idea of filtering in order to separate pytree leaves into different groups, a core task in using JAX (e.g. separating traced and static arguments to jax.jit). In particular, this grouping is achieved with the functions eqx.partition and eqx.combine. This documentation describes utilities in cryojax for working with equinox.partition and equinox.combine.

cryojax.jax_util.make_filter_spec(pytree: PyTree, where: Callable[[PyTree], Any | Sequence[Any]], *, inverse: bool = False, is_leaf: Callable[[Any], bool] | None = None) -> PyTree[bool] ¤

A lightweight wrapper around equinox for creating a "filter specification".

A filter specification, or filter_spec, is a pytree whose leaves are either True or False. These are commonly used with equinox filtering.

In cryojax, it is a common pattern to need to finely specify which leaves we would like to take JAX transformations with respect to. This is done with a pointer to individual leaves, which is referred to as a where function. See here in the equinox documentation for an example.

Returns:

The filter specification. This is a pytree of the same structure as pytree with True where the where function points to, and False where it does not (or the opposite, if inverse = True).

Batched loops¤

cryojax.jax_util.filter_bmap(f: Callable[[PyTree[Shaped[Array, '_ ...'], 'X']], PyTree[Shaped[Array, '_ ...'], 'Y']], xs: PyTree[Shaped[Array, '_ ...'], 'X'], *, batch_size: int = 1) -> PyTree[Shaped[Array, '_ ...'], 'Y'] ¤

Like jax.lax.map(..., batch_size=...), but accepts x with the same rank as xs. xs is filtered in the usual equinox.filter_* way.

Arguments:

  • f: As jax.lax.map with format f(x), except vmapped over the first axis of the arrays of x.
  • xs: As jax.lax.map.
  • batch_size: Compute a loop of vmaps over xs in chunks of batch_size.

Returns:

As jax.lax.map.

cryojax.jax_util.filter_bscan(f: Callable[[~Carry, ~X], tuple[~Carry, ~Y]], init: ~Carry, xs: ~X, length: int | None = None, unroll: int | bool = 1, *, batch_size: int = 1) -> tuple[~Carry, ~Y] ¤

Like jax.lax.map(..., batch_size=...), except adding a batch_size to jax.lax.scan. Additionally, unlike jax.lax.map, f(carry, x) accepts x with the same rank as xs (e.g. perhaps it is vmapped over x). xs and carry are filtered in the usual equinox.filter_* way.

Arguments:

  • f: As jax.lax.scan with format f(carry, x).
  • init: As jax.lax.scan.
  • xs: As jax.lax.scan.
  • length: As jax.lax.scan.
  • unroll: As jax.lax.scan.
  • batch_size: Compute a loop of vmaps over xs in chunks of batch_size.

Returns:

As jax.lax.scan.

Debugging and runtime errors¤

cryojax.jax_util.maybe_error_if(x: ~T, pred_fn: Callable[[~T], Bool[ArrayLike, '...']], msg: str) -> PyTree ¤

Applies equinox.error_if depending on the value of the environmental variable CRYOJAX_ENABLE_CHECKS.

  • If CRYOJAX_ENABLE_CHECKS=true: This function is equivalent to equinox.error_if, with the replacement of the input pred with pred_fn, where pred = pred_fn(x). This way, pred is only evaluated if checks are enabled.
  • If CRYOJAX_ENABLE_CHECKS=false: This function is the identity, i.e. lambda x: x.

By default, CRYOJAX_ENABLE_CHECKS=false because checks may cause slowdowns, particularly on GPU.

This function is used to achieve a similar idea as JAX_ENABLE_CHECKS in cryojax and is exposed as public API for development downstream.

Interoperability with lineax¤

cryojax.jax_util.make_linear_operator(fn: Callable[[~Args], Array], args: ~Args, where_vector: Callable[[~Args], Any], *, tags: object | Iterable[object] = ()) -> tuple[lineax.FunctionLinearOperator, ~Args] ¤

Instantiate a lineax.FunctionLinearOperator from a function that takes an arbitrary pytree as input.

This is useful for converting from the cryoJAX abstraction for image simulation to a lineax matrix-vector multiplication abstraction. It is easy to get backprojection operators using lineax, which calls jax.linear_transpose under the hood.

Backprojection with lineax

import cryojax.simulator as cxs
import cryojax.jax_util as jxu
import lineax as lx

# Instantiate a linear operator
volume = cxs.FourierVoxelGridVolume.from_real_voxel_grid(...)
image_model = cxs.make_image_model(volume, ...)
where_vector = lambda x: x.volume.fourier_voxel_grid
operator, vector = jxu.make_linear_operator(
    fn=lambda x: x.simulate(),
    args=image_model,
    where_vector=where_vector,
)
# Simulate an image
image = operator.mv(vector)
# Compute backprojection
adjoint = lx.conj(operator.T)
backprojection = where_vector(adjoint.mv(image))

Warning

This function promises that fn can be expressed as a linear operator with respect to the input arguments at where_vector. CryoJAX does not explicitly check if this is the case, so JAX will throw errors downstream.

Arguments:

  • fn: A function with signature out = fn(args)
  • args: Input arguments to fn
  • where_vector: A pointer to where the arguments for the volume input space are in args.
  • tags: See lineax.FunctionLinearOperator for documentation.

Returns:

A tuple with first element lineax.FunctionLinearOperator and second element a pytree with the same structure as pytree, partitioned to only include the arguments at where_vector.

Extra type hints using jaxtyping¤

cryojax.jax_util.NDArrayLike ¤

A type hint for a JAX or numpy array

cryojax.jax_util.FloatLike ¤

A type hint for a python float or a JAX/numpy float scalar.

cryojax.jax_util.ComplexLike ¤

A type hint for a python complex or a JAX/numpy complex scalar.

cryojax.jax_util.InexactLike ¤

A type hint for a python float / complex, or a JAX/numpy float / complex scalar.

cryojax.jax_util.IntLike ¤

A type hint for a python int, or a JAX/numpy integer scalar.

cryojax.jax_util.BoolLike ¤

A type hint for a python bool or a JAX/numpy boolean scalar.