Skip to content

Useful JAX functions¤

cryojax provides a collection of useful functions not found in JAX that tend to be important for tasks in cryo-EM. Depending on developments with core JAX/Equinox and other factors, these functions could be removed in future releases of cryojax. Use with caution!

cryojax.jax_util.filter_bmap(f: Callable[[PyTree[Shaped[Array, '_ ...'], 'X']], PyTree[Shaped[Array, '_ ...'], 'Y']], xs: PyTree[Shaped[Array, '_ ...'], 'X'], *, batch_size: int = 1) -> PyTree[Shaped[Array, '_ ...'], 'Y'] ¤

Like jax.lax.map(..., batch_size=...), but accepts x with the same rank as xs. xs is filtered in the usual equinox.filter_* way.

Arguments:

  • f: As jax.lax.map with format f(x), except vmapped over the first axis of the arrays of x.
  • xs: As jax.lax.map.
  • batch_size: Compute a loop of vmaps over xs in chunks of batch_size.

Returns:

As jax.lax.map.

cryojax.jax_util.filter_bscan(f: Callable[[~Carry, ~X], tuple[~Carry, ~Y]], init: ~Carry, xs: ~X, length: int | None = None, unroll: int | bool = 1, *, batch_size: int = 1) -> tuple[~Carry, ~Y] ¤

Like jax.lax.map(..., batch_size=...), except adding a batch_size to jax.lax.scan. Additionally, unlike jax.lax.map, f(carry, x) accepts x with the same rank as xs (e.g. perhaps it is vmapped over x). xs and carry are filtered in the usual equinox.filter_* way.

Arguments:

  • f: As jax.lax.scan with format f(carry, x).
  • init: As jax.lax.scan.
  • xs: As jax.lax.scan.
  • length: As jax.lax.scan.
  • unroll: As jax.lax.scan.
  • batch_size: Compute a loop of vmaps over xs in chunks of batch_size.

Returns:

As jax.lax.scan.