Skip to content

Modeling cryo-EM volumes¤

There are many different volume representations of biological structures for cryo-EM, including atomic models, voxel maps, and neural network representations. Further, there are many ways to generate these volumes, such as from protein generative modeling and molecular dynamics. The optimal implementation to use depends on the user's needs. Therefore, CryoJAX supports a variety of these representations as well as a modeling interface for usage downstream. This page discusses how to use this interface and documents the volumes included in the library.

Core base classes¤

cryojax.simulator.AbstractVolumeParametrization

cryojax.simulator.AbstractVolumeParametrization ¤

Abstract interface for a parametrization of a volume. Specifically, the cryo-EM image formation process typically starts with a scattering potential. "Volumes" and "scattering potentials" in cryoJAX are synonymous.

Info

In, cryojax, potentials should be built in units of inverse length squared, \([L]^{-2}\). This rescaled potential is defined to be

\[U(\mathbf{r}) = \frac{m_0 e}{2 \pi \hbar^2} V(\mathbf{r}),\]

where \(V\) is the electrostatic potential energy, \(\mathbf{r}\) is a positional coordinate, \(m_0\) is the electron rest mass, and \(e\) is the electron charge.

For a single atom, this rescaled potential has the advantage that under usual scattering approximations (i.e. the first-born approximation), the fourier transform of this quantity is closely related to tabulated electron scattering factors. In particular, for a single atom with scattering factor \(f^{(e)}(\mathbf{q})\) and scattering vector \(\mathbf{q}\), its rescaled potential is equal to

\[U(\mathbf{r}) = \mathcal{F}^{-1}[f^{(e)}(\boldsymbol{\xi} / 2)](\mathbf{r}),\]

where \(\boldsymbol{\xi} = 2 \mathbf{q}\) is the wave vector coordinate and \(\mathcal{F}^{-1}\) is the inverse fourier transform operator in the convention

\[\mathcal{F}[f](\boldsymbol{\xi}) = \int d^3\mathbf{r} \ \exp(2\pi i \boldsymbol{\xi}\cdot\mathbf{r}) f(\mathbf{r}).\]

The rescaled potential \(U\) gives the following time-independent schrodinger equation for the scattering problem,

\[(\nabla^2 + k^2) \psi(\mathbf{r}) = - 4 \pi U(\mathbf{r}) \psi(\mathbf{r}),\]

where \(k\) is the incident wavenumber of the electron beam.

References:

  • For the definition of the rescaled potential, see Chapter 69, Page 2003, Equation 69.6 from Hawkes, Peter W., and Erwin Kasper. Principles of Electron Optics, Volume 4: Advanced Wave Optics. Academic Press, 2022.
  • To work out the correspondence between the rescaled potential and the electron scattering factors, see the supplementary information from Vulović, Miloš, et al. "Image formation modeling in cryo-electron microscopy." Journal of structural biology 183.1 (2013): 19-32.
get_representation(rng_key: PRNGKeyArray | None = None) -> cryojax.simulator.AbstractVolumeRepresentation ¤

Core interface for computing the representation of the volume.

Arguments:

  • rng_key: An optional RNG key for including noise / stochastic elements to volume simulation.
cryojax.simulator.AbstractVolumeRepresentation

cryojax.simulator.AbstractVolumeRepresentation(cryojax.simulator.AbstractVolumeParametrization) ¤

Abstract interface for the representation of a volume, such as atomic coordinates, voxels, or a neural network.

Volume representations contain information of coordinates and may be passed to AbstractVolumeIntegrator classes for imaging.

rotate_to_pose(pose: AbstractPose, inverse: bool = False) -> typing.Self ¤

Rotate the coordinate system of the volume.

Volume representations¤

Atom-based volumes¤

cryojax.simulator.AbstractAtomVolume

cryojax.simulator.AbstractAtomVolume(cryojax.simulator.AbstractVolumeRepresentation) ¤

Abstract interface for a volume represented as a point-cloud.

translate_to_pose(pose: cryojax.simulator.AbstractPose) -> typing.Self ¤

cryojax.simulator.GaussianMixtureVolume(cryojax.simulator.AbstractAtomVolume) ¤

A representation of a volume as a mixture of gaussians, with multiple gaussians used per position.

The convention of allowing multiple gaussians per position follows "Robust Parameterization of Elastic and Absorptive Electron Atomic Scattering Factors" by Peng et al. (1996). The \(a\) and \(b\) parameters in this work correspond to amplitudes = a and variances = b / 8\pi^2.

Info

Use the following to load a GaussianMixtureVolume from these tabulated electron scattering factors.

from cryojax.constants import PengScatteringFactorParameters
from cryojax.io import read_atoms_from_pdb
from cryojax.simulator import GaussianMixtureVolume

# Load positions of atoms and one-hot encoded atom names
atom_positions, atom_types = read_atoms_from_pdb(...)
parameters = PengScatteringFactorParameters(atom_types)
potential = GaussianMixtureVolume.from_tabulated_parameters(
    atom_positions, parameters
)
__init__(positions: Float[Array, 'n_positions 3'] | Float[ndarray, 'n_positions 3'], amplitudes: float | Float[Array, ''] | Float[ndarray, ''] | Float[Array, 'n_positions'] | Float[ndarray, 'n_positions'] | Float[Array, 'n_positions n_gaussians'] | Float[ndarray, 'n_positions n_gaussians'], variances: float | Float[Array, ''] | Float[ndarray, ''] | Float[Array, 'n_positions'] | Float[ndarray, 'n_positions'] | Float[Array, 'n_positions n_gaussians'] | Float[ndarray, 'n_positions n_gaussians']) ¤

Arguments:

  • positions: The coordinates of the gaussians in units of angstroms.
  • amplitudes: The amplitude for each gaussian. To simulate in physical units of a scattering potential, this should have units of angstroms.
  • variances: The variance for each gaussian. This has units of angstroms squared.
from_tabulated_parameters(atom_positions: Float[Array, 'n_atoms 3'] | Float[ndarray, 'n_atoms 3'], parameters: cryojax.constants.PengScatteringFactorParameters, extra_b_factors: float | Float[Array, ''] | Float[ndarray, ''] | Float[Array, 'n_atoms'] | Float[ndarray, 'n_atoms'] | None = None) -> typing.Self classmethod ¤

Initialize a GaussianMixtureVolume from tabulated electron scattering factor parameters (Peng et al. 1996). This treats the scattering potential as a mixture of five gaussians per atom.

References:

  • Peng, L-M. "Electron atomic scattering factors and scattering potentials of crystals." Micron 30.6 (1999): 625-648.
  • Peng, L-M., et al. "Robust parameterization of elastic and absorptive electron atomic scattering factors." Acta Crystallographica Section A: Foundations of Crystallography 52.2 (1996): 257-276.

Arguments:

  • atom_positions: The coordinates of the atoms in units of angstroms.
  • parameters: A pytree for the scattering factor parameters from Peng et al. (1996).
  • extra_b_factors: Additional per-atom B-factors that are added to the values in scattering_parameters.b.
get_representation(rng_key: PRNGKeyArray | None = None) -> typing.Self ¤

Since this class is itself an AbstractVolumeRepresentation, this function maps to the identity.

Arguments:

  • rng_key: Not used in this implementation.
rotate_to_pose(pose: cryojax.simulator.AbstractPose, inverse: bool = False) -> typing.Self ¤

Return a new potential with rotated positions.

translate_to_pose(pose: cryojax.simulator.AbstractPose) -> typing.Self ¤

Return a new potential with rotated positions.


cryojax.simulator.IndependentAtomVolume(cryojax.simulator.AbstractAtomVolume) ¤

IndependentAtomVolume(position_pytree: PyTree[Float[Array, ' 3'] | Float[ndarray, ' 3'], 'T'], scattering_factor_pytree: PyTree[cryojax.ndimage.operators._fourier_operator.AbstractFourierOperator, 'T'])

__init__(position_pytree: PyTree[Float[Array, '_ 3'] | Float[ndarray, '_ 3'], 'T'], scattering_factor_pytree: PyTree[cryojax.ndimage.operators.AbstractFourierOperator, 'T']) ¤
from_tabulated_parameters(positions_by_element: tuple[Float[Array, '_ 3'] | Float[ndarray, '_ 3'], ...], parameters: cryojax.constants.PengScatteringFactorParameters, *, b_factor_by_element: tuple[float | Float[Array, ''] | Float[ndarray, ''], ...] | None = None) -> typing.Self classmethod ¤
get_representation(rng_key: PRNGKeyArray | None = None) -> typing.Self ¤

Since this class is itself an AbstractVolumeRepresentation, this function maps to the identity.

Arguments:

  • rng_key: Not used in this implementation.
rotate_to_pose(pose: cryojax.simulator.AbstractPose, inverse: bool = False) -> typing.Self ¤

Return a new potential with rotated positions.

translate_to_pose(pose: cryojax.simulator.AbstractPose) -> typing.Self ¤

Return a new potential with rotated positions.

Voxel-based volumes¤

Fourier-space¤

Fourier-space conventions

  • The fourier_voxel_grid and frequency_slice arguments to FourierVoxelGridVolume.__init__ should be loaded with the zero frequency component in the center of the box. This is returned by the
  • The parameters in an AbstractPose represent a rotation in real-space. This means that when calling FourierVoxelGridVolume.rotate_to_pose, frequencies are rotated by the inverse rotation as stored in the pose.

cryojax.simulator.FourierVoxelGridVolume(cryojax.simulator.AbstractVolumeRepresentation) ¤

A 3D voxel grid in fourier-space.

shape property ¤

The shape of the fourier_voxel_grid.

__init__(fourier_voxel_grid: Complex[Array, 'dim dim dim'] | Complex[ndarray, 'dim dim dim'], frequency_slice_in_pixels: Float[Array, '1 dim dim 3'] | Float[ndarray, '1 dim dim 3']) ¤

Arguments:

  • fourier_voxel_grid: The cubic voxel grid in fourier space.
  • frequency_slice_in_pixels: The frequency slice coordinate system.
from_real_voxel_grid(real_voxel_grid: Float[Array, 'dim dim dim'] | Float[ndarray, 'dim dim dim'], *, pad_scale: float = 1.0, pad_mode: str = 'constant', filter: cryojax.ndimage.transforms.AbstractFilter | None = None) -> typing.Self classmethod ¤

Load from a real-valued 3D voxel grid.

Arguments:

  • real_voxel_grid: A voxel grid in real space.
  • pad_scale: Scale factor at which to pad real_voxel_grid before fourier transform. Must be a value greater than 1.0.
  • pad_mode: Padding method. See jax.numpy.pad for documentation.
  • filter: A filter to apply to the result of the fourier transform of real_voxel_grid, i.e. fftn(real_voxel_grid). Note that the zero frequency component is assumed to be in the corner.
get_representation(rng_key: PRNGKeyArray | None = None) -> typing.Self ¤

Since this class is itself an AbstractVolumeRepresentation, this function maps to the identity.

Arguments:

  • rng_key: Not used in this implementation.
rotate_to_pose(pose: cryojax.simulator.AbstractPose, inverse: bool = False) -> typing.Self ¤

Return a new volume with a rotated frequency_slice_in_pixels.


cryojax.simulator.FourierVoxelSplineVolume(cryojax.simulator.AbstractVolumeRepresentation) ¤

A 3D voxel grid in fourier-space, represented by spline coefficients.

shape property ¤

The shape of the original fourier_voxel_grid from which coefficients were computed.

__init__(spline_coefficients: Complex[Array, 'coeff_dim coeff_dim coeff_dim'] | Complex[ndarray, 'coeff_dim coeff_dim coeff_dim'], frequency_slice_in_pixels: Float[Array, '1 dim dim 3'] | Float[ndarray, '1 dim dim 3']) ¤

Arguments:

  • spline_coefficients: The spline coefficents computed from the cubic voxel grid in fourier space. See cryojax.ndimage.compute_spline_coefficients.
  • frequency_slice_in_pixels: Frequency slice coordinate system. See cryojax.coordinates.make_frequency_slice.
from_real_voxel_grid(real_voxel_grid: Float[Array, 'dim dim dim'] | Float[ndarray, 'dim dim dim'], *, pad_scale: float = 1.0, pad_mode: str = 'constant', filter: cryojax.ndimage.transforms.AbstractFilter | None = None) -> typing.Self classmethod ¤

Load from a real-valued 3D voxel grid.

Arguments:

  • real_voxel_grid: A voxel grid in real space.
  • pad_scale: Scale factor at which to pad real_voxel_grid before fourier transform. Must be a value greater than 1.0.
  • pad_mode: Padding method. See jax.numpy.pad for documentation.
  • filter: A filter to apply to the result of the fourier transform of real_voxel_grid, i.e. fftn(real_voxel_grid). Note that the zero frequency component is assumed to be in the corner.
get_representation(rng_key: PRNGKeyArray | None = None) -> typing.Self ¤

Since this class is itself an AbstractVolumeRepresentation, this function maps to the identity.

Arguments:

  • rng_key: Not used in this implementation.
rotate_to_pose(pose: cryojax.simulator.AbstractPose, inverse: bool = False) -> typing.Self ¤

Return a new volume with a rotated frequency_slice_in_pixels.

Real-space¤

cryojax.simulator.RealVoxelGridVolume(cryojax.simulator.AbstractVolumeRepresentation) ¤

A 3D voxel grid in real-space.

shape property ¤

The shape of the real_voxel_grid.

__init__(real_voxel_grid: Float[Array, 'dim dim dim'] | Float[ndarray, 'dim dim dim'], coordinate_grid_in_pixels: Float[Array, 'dim dim dim 3'] | Float[ndarray, 'dim dim dim 3']) ¤

Arguments:

  • real_voxel_grid: The voxel grid in fourier space.
  • coordinate_grid_in_pixels: A coordinate grid.
from_real_voxel_grid(real_voxel_grid: Float[Array, 'dim dim dim'] | Float[ndarray, 'dim dim dim'], *, coordinate_grid_in_pixels: Float[Array, 'dim dim dim 3'] | None = None, crop_scale: float | None = None) -> typing.Self classmethod ¤

Load a RealVoxelGridVolume from a real-valued 3D voxel grid.

Arguments:

  • real_voxel_grid: A voxel grid in real space.
  • crop_scale: Scale factor at which to crop real_voxel_grid. Must be a value greater than 1.
get_representation(rng_key: PRNGKeyArray | None = None) -> typing.Self ¤

Since this class is itself an AbstractVolumeRepresentation, this function maps to the identity.

Arguments:

  • rng_key: Not used in this implementation.
rotate_to_pose(pose: cryojax.simulator.AbstractPose, inverse: bool = False) -> typing.Self ¤

Return a new volume with a rotated coordinate_grid_in_pixels.

Volume rendering¤

cryojax.simulator.AbstractVolumeRenderFn

cryojax.simulator.AbstractVolumeRenderFn ¤

Base class for rendering a volume onto voxels.

__call__(volume_representation: ~VolRepT, *, outputs_real_space: bool = True, outputs_rfft: bool = False, fftshifted: bool = False) -> Array ¤

cryojax.simulator.GaussianMixtureRenderFn(cryojax.simulator.AbstractVolumeRenderFn) ¤

Render a voxel grid from the GaussianMixtureVolume.

If GaussianMixtureVolume is instantiated from electron scattering factors via from_tabulated_parameters, this renders an electrostatic potential as tabulated in Peng et al. 1996. The elastic electron scattering factors defined in this work are

\[f^{(e)}(\mathbf{q}) = \sum\limits_{i = 1}^5 a_i \exp(- b_i |\mathbf{q}|^2),\]

where \(a_i\) is stored as GaussianMixtureVolume.amplitudes, \(b_i / 8 \pi^2\) are the GaussianMixtureVolume.variances, and \(\mathbf{q}\) is the scattering vector.

Under usual scattering approximations (i.e. the first-born approximation), the rescaled electrostatic potential energy \(U(\mathbf{r})\) for a given atom type is \(\mathcal{F}^{-1}[f^{(e)}(\boldsymbol{\xi} / 2)](\mathbf{r})\), which is computed analytically as

\[U(\mathbf{r}) = \sum\limits_{i = 1}^5 \frac{a_i}{(2\pi (b_i / 8 \pi^2))^{3/2}} \exp(- \frac{|\mathbf{r} - \mathbf{r}'|^2}{2 (b_i / 8 \pi^2)}),\]

where \(\mathbf{r}'\) is the position of the atom. Including an additional B-factor (denoted by \(B\)) gives the expression for the potential \(U(\mathbf{r})\) of a single atom type and its fourier transform pair \(\tilde{U}(\boldsymbol{\xi}) \equiv \mathcal{F}[U](\boldsymbol{\xi})\),

\[U(\mathbf{r}) = \sum\limits_{i = 1}^5 \frac{a_i}{(2\pi ((b_i + B) / 8 \pi^2))^{3/2}} \exp(- \frac{|\mathbf{r} - \mathbf{r}'|^2}{2 ((b_i + B) / 8 \pi^2)}),\]
\[\tilde{U}(\boldsymbol{\xi}) = \sum\limits_{i = 1}^5 a_i \exp(- (b_i + B) |\boldsymbol{\xi}|^2 / 4) \exp(2 \pi i \boldsymbol{\xi}\cdot\mathbf{r}'),\]

where \(\mathbf{q} = \boldsymbol{\xi} / 2\) gives the relationship between the wave vector and the scattering vector.

In practice, for a discretization on a grid with voxel size \(\Delta r\) and grid point \(\mathbf{r}_{\ell}\), the potential is evaluated as the average value inside the voxel

\[U_{\ell} = \frac{1}{\Delta r^3} \sum\limits_{i = 1}^5 a_i \prod\limits_{j = 1}^3 \int_{r^{\ell}_j-\Delta r/2}^{r^{\ell}_j+\Delta r/2} dr_j \ \frac{1}{{\sqrt{2\pi ((b_i + B) / 8 \pi^2)}}} \exp(- \frac{(r_j - r'_j)^2}{2 ((b_i + B) / 8 \pi^2)}),\]

where \(j\) indexes the components of the spatial coordinate vector \(\mathbf{r}\). The above expression is evaluated using the error function as

\[U_{\ell} = \frac{1}{(2 \Delta r)^3} \sum\limits_{i = 1}^5 a_i \prod\limits_{j = 1}^3 \textrm{erf}(\frac{r_j^{\ell} - r'_j + \Delta r / 2}{\sqrt{2 ((b_i + B) / 8\pi^2)}}) - \textrm{erf}(\frac{r_j^{\ell} - r'_j - \Delta r / 2}{\sqrt{2 ((b_i + B) / 8\pi^2)}}).\]
__init__(shape: tuple[int, int, int], voxel_size: Float[Array, ''] | Float[ndarray, ''] | float, *, batch_options: dict[str, Any] = {}) ¤

Arguments:

  • shape: The shape of the resulting voxel grid.
  • voxel_size: The voxel size of the resulting voxel grid.
  • batch_options: Advanced options for controlling batching. This is a dictionary with the following keys:
    • "batch_size": The number of z-planes to evaluate in parallel with jax.vmap. By default, 1.
    • "n_batches": The number of iterations used to evaluate the volume, where the iteration is taken over groups of atoms. This is useful if batch_size = 1 and GPU memory is exhausted. By default, 1.
__call__(volume_representation: cryojax.simulator.GaussianMixtureVolume, *, outputs_real_space: bool = True, outputs_rfft: bool = False, fftshifted: bool = False) -> Float[Array, '{self.shape[0]} {self.shape[1]} {self.shape[2]}'] ¤

Arguments:

  • volume_representation: The GaussianMixtureVolume.
  • outputs_real_space: If True, return a voxel grid in real-space.
  • outputs_rfft: If True, return a fourier-space voxel grid transformed with cryojax.ndimage.rfftn. Otherwise, use fftn. Does nothing if outputs_real_space = True.
  • fftshifted: If True, return a fourier-space voxel grid with the zero frequency component in the center of the grid via jax.numpy.fft.fftshift. Otherwise, the zero frequency component is in the corner. Does nothing if outputs_real_space = True.

cryojax.simulator.FFTAtomRenderFn(cryojax.simulator.AbstractVolumeRenderFn) ¤

Render a voxel grid using non-uniform FFTs and convoluton.

__init__(shape: tuple[int, int, int], voxel_size: Float[Array, ''] | Float[ndarray, ''] | float, *, frequency_grid: Float[Array, '_ _ _ 3'] | None = None, antialias: bool = True, eps: float = 1e-06, opts: Any = None) ¤

Arguments:

  • shape: The shape of the resulting voxel grid.
  • voxel_size: The voxel size of the resulting voxel grid.
  • frequency_grid: An optional frequency grid for rendering the volume. If None, compute on the fly. The grid should be in inverse angstroms and have the zero frequency component in the center, i.e.

    frequency_grid = jnp.fft.fftshift(
        make_frequency_grid(shape, voxel_size, outputs_rfft=False),
        axes=(0, 1, 2),
    )
    
  • antialias: If True, apply an anti-aliasing filter to more accurately sample the volume.

  • eps: See jax-finufft for documentation.
  • opts: A jax_finufft.options.Opts or jax_finufft.options.NestedOpts dataclass. See jax-finufft for documentation.
__call__(volume_representation: cryojax.simulator.IndependentAtomVolume, *, outputs_real_space: bool = True, outputs_rfft: bool = False, fftshifted: bool = False) -> Float[Array, '{self.shape[0]} {self.shape[1]} {self.shape[2]}'] ¤

Arguments:

  • volume_representation: The GaussianMixtureVolume.
  • outputs_real_space: If True, return a voxel grid in real-space.
  • outputs_rfft: If True, return a fourier-space voxel grid transformed with cryojax.ndimage.rfftn. Otherwise, use fftn. Does nothing if outputs_real_space = True.
  • fftshifted: If True, return a fourier-space voxel grid with the zero frequency component in the center of the grid via jax.numpy.fft.fftshift. Otherwise, the zero frequency component is in the corner. Does nothing if outputs_real_space = True.