Skip to content

Extra JAX/Equinox¤

cryojax.jax_util supports downstream applications with helpers for common patterns using cryoJAX for cryo-EM data analysis.

Equinox extensions¤

Helpers for filtering¤

To make use of the full power of JAX, it is highly recommended to learn about equinox. Using equinox, cryoJAX implements its models as pytrees using equinox.Modules. These pytrees can be operated on similarly to any pytree with JAX (e.g. with jax.tree.map). Complementary to the equinox.Module interface, equinox introduces the idea of filtering in order to separate pytree leaves into different groups, a core task in using JAX (e.g. separating traced and static arguments to jax.jit). In particular, this grouping is achieved with the functions eqx.partition and eqx.combine. This documentation describes utilities in cryojax for working with equinox.partition and equinox.combine.

cryojax.jax_util.make_filter_spec ¤

make_filter_spec(pytree: PyTree, where: Callable[[PyTree], Any | Sequence[Any]], *, inverse: bool = False, is_leaf: Callable[[Any], bool] | None = None) -> PyTree[bool]

A lightweight wrapper around equinox for creating a "filter specification".

A filter specification, or filter_spec, is a pytree whose leaves are either True or False. These are commonly used with equinox filtering.

In cryojax, it is a common pattern to need to finely specify which leaves we would like to take JAX transformations with respect to. This is done with a pointer to individual leaves, which is referred to as a where function. See here in the equinox documentation for an example.

Returns:

The filter specification. This is a pytree of the same structure as pytree with True where the where function points to, and False where it does not (or the opposite, if inverse = True).

Batched loops¤

cryojax provides a collection of useful functions not found in JAX that tend to be important for tasks in cryo-EM. Depending on developments with core JAX/Equinox and other factors, these functions could be removed in future releases of cryojax. Use with caution!

cryojax.jax_util.filter_bmap ¤

filter_bmap(f: Callable[[PyTree[Shaped[Array, '_ ...'], 'X']], PyTree[Shaped[Array, '_ ...'], 'Y']], xs: PyTree[Shaped[Array, '_ ...'], 'X'], *, batch_size: int = 1) -> PyTree[Shaped[Array, '_ ...'], 'Y']

Like jax.lax.map(..., batch_size=...), but accepts x with the same rank as xs. xs is filtered in the usual equinox.filter_* way.

Arguments:

  • f: As jax.lax.map with format f(x), except vmapped over the first axis of the arrays of x.
  • xs: As jax.lax.map.
  • batch_size: Compute a loop of vmaps over xs in chunks of batch_size.

Returns:

As jax.lax.map.

cryojax.jax_util.filter_bscan ¤

filter_bscan(f: Callable[[~Carry, ~X], tuple[~Carry, ~Y]], init: ~Carry, xs: ~X, length: int | None = None, unroll: int | bool = 1, *, batch_size: int = 1) -> tuple[~Carry, ~Y]

Like jax.lax.map(..., batch_size=...), except adding a batch_size to jax.lax.scan. Additionally, unlike jax.lax.map, f(carry, x) accepts x with the same rank as xs (e.g. perhaps it is vmapped over x). xs and carry are filtered in the usual equinox.filter_* way.

Arguments:

  • f: As jax.lax.scan with format f(carry, x).
  • init: As jax.lax.scan.
  • xs: As jax.lax.scan.
  • length: As jax.lax.scan.
  • unroll: As jax.lax.scan.
  • batch_size: Compute a loop of vmaps over xs in chunks of batch_size.

Returns:

As jax.lax.scan.

Grid search interface¤

Reparameterizing pytrees¤

Extra type hints using jaxtyping¤

cryojax.jax_util.NDArrayLike ¤

A type hint for a JAX or numpy array

cryojax.jax_util.FloatLike ¤

A type hint for a python float or a JAX/numpy float scalar.

cryojax.jax_util.ComplexLike ¤

A type hint for a python complex or a JAX/numpy complex scalar.

cryojax.jax_util.InexactLike ¤

A type hint for a python float / complex, or a JAX/numpy float / complex scalar.

cryojax.jax_util.IntLike ¤

A type hint for a python int, or a JAX/numpy integer scalar.

cryojax.jax_util.BoolLike ¤

A type hint for a python bool or a JAX/numpy boolean scalar.