Useful JAX functions¤
cryojax provides a collection of useful functions not found in JAX that tend to be important for tasks in cryo-EM. Depending on developments with core JAX/Equinox and other factors, these functions could be removed in future releases of cryojax. Use with caution!
cryojax.jax_util.filter_bmap(f: Callable[[PyTree[Shaped[Array, '_ ...'], 'X']], PyTree[Shaped[Array, '_ ...'], 'Y']], xs: PyTree[Shaped[Array, '_ ...'], 'X'], *, batch_size: int = 1) -> PyTree[Shaped[Array, '_ ...'], 'Y']
¤
Like jax.lax.map(..., batch_size=...), but accepts x
with the same rank as xs. xs is filtered in the usual
equinox.filter_* way.
Arguments:
f: Asjax.lax.mapwith formatf(x), except vmapped over the first axis of the arrays ofx.xs: Asjax.lax.map.batch_size: Compute a loop of vmaps overxsin chunks ofbatch_size.
Returns:
As jax.lax.map.
cryojax.jax_util.filter_bscan(f: Callable[[~Carry, ~X], tuple[~Carry, ~Y]], init: ~Carry, xs: ~X, length: int | None = None, unroll: int | bool = 1, *, batch_size: int = 1) -> tuple[~Carry, ~Y]
¤
Like jax.lax.map(..., batch_size=...), except adding
a batch_size to jax.lax.scan. Additionally, unlike
jax.lax.map, f(carry, x) accepts x with the same
rank as xs (e.g. perhaps it is vmapped over x).
xs and carry are filtered in the usual equinox.filter_* way.
Arguments:
f: Asjax.lax.scanwith formatf(carry, x).init: Asjax.lax.scan.xs: Asjax.lax.scan.length: Asjax.lax.scan.unroll: Asjax.lax.scan.batch_size: Compute a loop of vmaps overxsin chunks ofbatch_size.
Returns:
As jax.lax.scan.