Extra JAX/Equinox¤
cryojax.jax_util provides a collection of useful functions not found in JAX that tend to be important for using cryoJAX in practice. Depending on developments with core JAX/Equinox and other factors, these functions could be deprecated in future releases of cryojax. Use with caution!
Equinox extensions¤
Helpers for filtering¤
To make use of the full power of JAX, it is highly recommended to learn about equinox. Using equinox, cryoJAX implements its models as pytrees using equinox.Modules. These pytrees can be operated on similarly to any pytree with JAX (e.g. with jax.tree.map). Complementary to the equinox.Module interface, equinox introduces the idea of filtering in order to separate pytree leaves into different groups, a core task in using JAX (e.g. separating traced and static arguments to jax.jit). In particular, this grouping is achieved with the functions eqx.partition and eqx.combine. This documentation describes utilities in cryojax for working with equinox.partition and equinox.combine.
cryojax.jax_util.make_filter_spec(pytree: PyTree, where: Callable[[PyTree], Any | Sequence[Any]], *, inverse: bool = False, is_leaf: Callable[[Any], bool] | None = None) -> PyTree[bool]
¤
A lightweight wrapper around equinox for creating a "filter specification".
A filter specification, or filter_spec, is a pytree whose
leaves are either True or False. These are commonly used with
equinox filtering.
In cryojax, it is a common pattern to need to finely specify which
leaves we would like to take JAX transformations with respect to. This is done with a
pointer to individual leaves, which is referred to as a where function. See
here
in the equinox documentation for an example.
Returns:
The filter specification. This is a pytree of the same structure as pytree with
True where the where function points to, and False where it does not
(or the opposite, if inverse = True).
Batched loops¤
cryojax.jax_util.filter_bmap(f: Callable[[PyTree[Shaped[Array, '_ ...'], 'X']], PyTree[Shaped[Array, '_ ...'], 'Y']], xs: PyTree[Shaped[Array, '_ ...'], 'X'], *, batch_size: int = 1) -> PyTree[Shaped[Array, '_ ...'], 'Y']
¤
Like jax.lax.map(..., batch_size=...), but accepts x
with the same rank as xs. xs is filtered in the usual
equinox.filter_* way.
Arguments:
f: Asjax.lax.mapwith formatf(x), except vmapped over the first axis of the arrays ofx.xs: Asjax.lax.map.batch_size: Compute a loop of vmaps overxsin chunks ofbatch_size.
Returns:
As jax.lax.map.
cryojax.jax_util.filter_bscan(f: Callable[[~Carry, ~X], tuple[~Carry, ~Y]], init: ~Carry, xs: ~X, length: int | None = None, unroll: int | bool = 1, *, batch_size: int = 1) -> tuple[~Carry, ~Y]
¤
Like jax.lax.map(..., batch_size=...), except adding
a batch_size to jax.lax.scan. Additionally, unlike
jax.lax.map, f(carry, x) accepts x with the same
rank as xs (e.g. perhaps it is vmapped over x).
xs and carry are filtered in the usual equinox.filter_* way.
Arguments:
f: Asjax.lax.scanwith formatf(carry, x).init: Asjax.lax.scan.xs: Asjax.lax.scan.length: Asjax.lax.scan.unroll: Asjax.lax.scan.batch_size: Compute a loop of vmaps overxsin chunks ofbatch_size.
Returns:
As jax.lax.scan.
Debugging and runtime errors¤
cryojax.jax_util.maybe_error_if(x: ~T, pred_fn: Callable[[~T], Bool[ArrayLike, '...']], msg: str) -> PyTree
¤
Applies equinox.error_if
depending on the value of the environmental variable CRYOJAX_ENABLE_CHECKS.
- If
CRYOJAX_ENABLE_CHECKS=true: This function is equivalent toequinox.error_if, with the replacement of the inputpredwithpred_fn, wherepred = pred_fn(x). This way,predis only evaluated if checks are enabled. - If
CRYOJAX_ENABLE_CHECKS=false: This function is the identity, i.e.lambda x: x.
By default, CRYOJAX_ENABLE_CHECKS=false because checks may cause slowdowns, particularly
on GPU.
This function is used to achieve a similar idea as
JAX_ENABLE_CHECKS
in cryojax and is exposed as public API for development downstream.
Interoperability with lineax¤
cryojax.jax_util.make_linear_operator(fn: Callable[[~Args], Array], args: ~Args, where_vector: Callable[[~Args], Any], *, tags: object | Iterable[object] = ()) -> tuple[lineax.FunctionLinearOperator, ~Args]
¤
Instantiate a lineax.FunctionLinearOperator
from a function that takes an arbitrary pytree as input.
This is useful for converting from the cryoJAX abstraction for image simulation
to a lineax matrix-vector multiplication
abstraction. It is easy to get backprojection operators using lineax, which
calls jax.linear_transpose
under the hood.
Backprojection with lineax
import cryojax.simulator as cxs
import cryojax.jax_util as jxu
import lineax as lx
# Instantiate a linear operator
volume = cxs.FourierVoxelGridVolume.from_real_voxel_grid(...)
image_model = cxs.make_image_model(volume, ...)
where_vector = lambda x: x.volume.fourier_voxel_grid
operator, vector = jxu.make_linear_operator(
fn=lambda x: x.simulate(),
args=image_model,
where_vector=where_vector,
)
# Simulate an image
image = operator.mv(vector)
# Compute backprojection
adjoint = lx.conj(operator.T)
backprojection = where_vector(adjoint.mv(image))
Warning
This function promises that fn can be expressed as a
linear operator with respect to the input arguments at where_vector.
CryoJAX does not explicitly check if this is the case, so JAX will
throw errors downstream.
Arguments:
fn: A function with signatureout = fn(args)args: Input arguments tofnwhere_vector: A pointer to where the arguments for the volume input space are inargs.tags: Seelineax.FunctionLinearOperatorfor documentation.
Returns:
A tuple with first element lineax.FunctionLinearOperator and second element
a pytree with the same structure as pytree, partitioned to only include the
arguments at where_vector.
Extra type hints using jaxtyping¤
cryojax.jax_util.NDArrayLike
¤
A type hint for a JAX or numpy array
cryojax.jax_util.FloatLike
¤
A type hint for a python float or a
JAX/numpy float scalar.
cryojax.jax_util.ComplexLike
¤
A type hint for a python complex or a
JAX/numpy complex scalar.
cryojax.jax_util.InexactLike
¤
A type hint for a python float / complex, or a
JAX/numpy float / complex scalar.
cryojax.jax_util.IntLike
¤
A type hint for a python int, or a
JAX/numpy integer scalar.
cryojax.jax_util.BoolLike
¤
A type hint for a python bool or a
JAX/numpy boolean scalar.