Extra JAX/Equinox¤
cryojax.jax_util supports downstream applications with helpers for common patterns using cryoJAX for cryo-EM data analysis.
Equinox extensions¤
Helpers for filtering¤
To make use of the full power of JAX, it is highly recommended to learn about equinox. Using equinox, cryoJAX implements its models as pytrees using equinox.Modules. These pytrees can be operated on similarly to any pytree with JAX (e.g. with jax.tree.map). Complementary to the equinox.Module interface, equinox introduces the idea of filtering in order to separate pytree leaves into different groups, a core task in using JAX (e.g. separating traced and static arguments to jax.jit). In particular, this grouping is achieved with the functions eqx.partition and eqx.combine. This documentation describes utilities in cryojax for working with equinox.partition and equinox.combine.
cryojax.jax_util.make_filter_spec
¤
make_filter_spec(pytree: PyTree, where: Callable[[PyTree], Any | Sequence[Any]], *, inverse: bool = False, is_leaf: Callable[[Any], bool] | None = None) -> PyTree[bool]
A lightweight wrapper around equinox for creating a "filter specification".
A filter specification, or filter_spec, is a pytree whose
leaves are either True or False. These are commonly used with
equinox filtering.
In cryojax, it is a common pattern to need to finely specify which
leaves we would like to take JAX transformations with respect to. This is done with a
pointer to individual leaves, which is referred to as a where function. See
here
in the equinox documentation for an example.
Returns:
The filter specification. This is a pytree of the same structure as pytree with
True where the where function points to, and False where it does not
(or the opposite, if inverse = True).
Batched loops¤
cryojax provides a collection of useful functions not found in JAX that tend to be important for tasks in cryo-EM. Depending on developments with core JAX/Equinox and other factors, these functions could be removed in future releases of cryojax. Use with caution!
cryojax.jax_util.filter_bmap
¤
filter_bmap(f: Callable[[PyTree[Shaped[Array, '_ ...'], 'X']], PyTree[Shaped[Array, '_ ...'], 'Y']], xs: PyTree[Shaped[Array, '_ ...'], 'X'], *, batch_size: int = 1) -> PyTree[Shaped[Array, '_ ...'], 'Y']
Like jax.lax.map(..., batch_size=...), but accepts x
with the same rank as xs. xs is filtered in the usual
equinox.filter_* way.
Arguments:
f: Asjax.lax.mapwith formatf(x), except vmapped over the first axis of the arrays ofx.xs: Asjax.lax.map.batch_size: Compute a loop of vmaps overxsin chunks ofbatch_size.
Returns:
As jax.lax.map.
cryojax.jax_util.filter_bscan
¤
filter_bscan(f: Callable[[~Carry, ~X], tuple[~Carry, ~Y]], init: ~Carry, xs: ~X, length: int | None = None, unroll: int | bool = 1, *, batch_size: int = 1) -> tuple[~Carry, ~Y]
Like jax.lax.map(..., batch_size=...), except adding
a batch_size to jax.lax.scan. Additionally, unlike
jax.lax.map, f(carry, x) accepts x with the same
rank as xs (e.g. perhaps it is vmapped over x).
xs and carry are filtered in the usual equinox.filter_* way.
Arguments:
f: Asjax.lax.scanwith formatf(carry, x).init: Asjax.lax.scan.xs: Asjax.lax.scan.length: Asjax.lax.scan.unroll: Asjax.lax.scan.batch_size: Compute a loop of vmaps overxsin chunks ofbatch_size.
Returns:
As jax.lax.scan.
Grid search interface¤
Reparameterizing pytrees¤
Extra type hints using jaxtyping¤
cryojax.jax_util.NDArrayLike
¤
A type hint for a JAX or numpy array
cryojax.jax_util.FloatLike
¤
A type hint for a python float or a
JAX/numpy float scalar.
cryojax.jax_util.ComplexLike
¤
A type hint for a python complex or a
JAX/numpy complex scalar.
cryojax.jax_util.InexactLike
¤
A type hint for a python float / complex, or a
JAX/numpy float / complex scalar.
cryojax.jax_util.IntLike
¤
A type hint for a python int, or a
JAX/numpy integer scalar.
cryojax.jax_util.BoolLike
¤
A type hint for a python bool or a
JAX/numpy boolean scalar.