Image and volume manipulation¤
cryojax.ndimage implements routines for image and volume arrays, such coordinate creation, downsampling, filters, and masks. This is a key submodule for supporting cryojax.simulator.
Coordinate systems¤
This documentation is a collection of functions used to work with coordinate systems in cryojax's conventions. The most important functions are make_coordinate_grid and make_frequency_grid.
Creating coordinate systems¤
cryojax.ndimage.make_coordinate_grid
¤
make_coordinate_grid(shape: tuple[int, ...], grid_spacing: float | Float[ndarray, ''] | Float[Array, ''] = 1.0) -> Float[Array, '*shape ndim']
Create a real-space cartesian coordinate system on a grid.
Arguments:
shape: Shape of the grid, withndim = len(shape).grid_spacing: The grid spacing (i.e. pixel/voxel size), in units of length.
Returns:
A cartesian coordinate system in real space.
cryojax.ndimage.make_frequency_grid
¤
make_frequency_grid(shape: tuple[int, ...], grid_spacing: float | Float[ndarray, ''] | Float[Array, ''] = 1.0, outputs_rfftfreqs: bool = True) -> Float[Array, '*shape ndim']
Create a fourier-space cartesian coordinate system on a grid. The zero-frequency component is in the corner.
Arguments:
shape: Shape of the grid, withndim = len(shape).grid_spacing: The grid spacing (i.e. pixel/voxel size), in units of length.outputs_rfftfreqs: Return a frequency grid for use withjax.numpy.fft.rfftn.shape[-1]is the axis on which the negative frequencies are omitted.
Returns:
A cartesian coordinate system in frequency space.
cryojax.ndimage.make_radial_coordinate_grid
¤
make_radial_coordinate_grid(shape: tuple[int, ...], grid_spacing: float | Float[ndarray, ''] | Float[Array, ''] = 1.0) -> Float[Array, '*shape']
Create a real-space radial coordinate system on a grid.
This wraps the function make_coordinate_grid to compute
the coordinate vector magnitude.
Arguments:
shape: Shape of the grid, withndim = len(shape).grid_spacing: The grid spacing (i.e. pixel/voxel size), in units of length.
Returns:
A radial coordinate system in real space.
cryojax.ndimage.make_radial_frequency_grid
¤
make_radial_frequency_grid(shape: tuple[int, ...], grid_spacing: float | Float[ndarray, ''] | Float[Array, ''] = 1.0, outputs_rfftfreqs: bool = True) -> Float[Array, '*shape']
Create a fourier-space radial coordinate system on a grid. The zero-frequency component is in the corner.
This wraps the function make_frequency_grid to compute
the frequency magnitude, which is a common use case for
things like computing fourier shell correlations and power spectrums.
Arguments:
shape: Shape of the grid, withndim = len(shape).grid_spacing: The grid spacing (i.e. pixel/voxel size), in units of length.outputs_rfftfreqs: Return a frequency grid for use withjax.numpy.fft.rfftn.shape[-1]is the axis on which the negative frequencies are omitted.
Returns:
A radial coordinate system in frequency space.
cryojax.ndimage.make_frequency_slice
¤
make_frequency_slice(shape: tuple[int, int], grid_spacing: float | Float[ndarray, ''] | Float[Array, ''] = 1.0, outputs_rfftfreqs: bool = False) -> Float[Array, '1 {shape[0]} {shape[1]} 3']
Create a fourier-space cartesian coordinate system on a grid, where zero-frequency component is in the center of the grid.
Warning
In the function make_frequency_grid, the convention is that
the grid is returned with the zero frequency component is in the
corner. In this function, as mentioned above, frequency slices are
returned with the zero frequency component in the center. To convert
between the two conventions, run
import jax.numpy as jnp
from cryojax.coordinates import make_frequency_slice
frequency_slice_with_zero_in_center = make_frequency_slice((100, 100)) # Shape (1, 100, 100, 3)
frequency_slice_with_zero_in_corner = jnp.fft.ifftshift(frequency_slice_with_zero_in_center, axes=(1, 2))
The reason for the difference is so that this function can be used to
directly pass a frequency_slice to the cryojax.simulator.FourierVoxelGridPotential,
which requires that the zero is in the center of the grid.
Arguments:
shape: Shape of the frequency slice, e.g.shape = (100, 100).grid_spacing: The grid spacing (i.e. voxel size), in units of length.outputs_rfftfreqs: Return a frequency grid for use withjax.numpy.fft.rfftn.shape[-1]is the axis on which the negative frequencies are omitted.
Returns:
The central, \(q_z = 0\) slice of a 3D frequency grid \((q_x, q_y, q_z)\), where zero-frequency component is in the center of the grid.
cryojax.ndimage.make_1d_coordinate_grid
¤
make_1d_coordinate_grid(size: int, grid_spacing: float | Float[ndarray, ''] | Float[Array, ''] = 1.0) -> Float[Array, '*shape ndim']
Create a 1D real-space cartesian coordinate array.
Arguments:
size: Size of the coordinate array.grid_spacing: The grid spacing (i.e. pixel/voxel size), in units of length.
Returns:
A 1D cartesian coordinate array in real space.
cryojax.ndimage.make_1d_frequency_grid
¤
make_1d_frequency_grid(size: int, grid_spacing: float | Float[ndarray, ''] | Float[Array, ''] = 1.0, outputs_rfftfreqs: bool = True) -> Float[Array, '*shape ndim']
Create a 1D fourier-space cartesian coordinate array.
If outputs_rfftfreqs = False, the zero-frequency component is in the beginning.
Arguments¤
size: Size of the coordinate array.grid_spacing: The grid spacing (i.e. pixel/voxel size), in units of length.outputs_rfftfreqs: Return a frequency grid for use withjax.numpy.fft.rfftn.shape[-1]is the axis on which the negative frequencies are omitted.
Returns:
A 1D cartesian coordinate array in frequency space.
Transforming coordinate systems¤
cryojax also provides functions that transform between coordinate conventions.
cryojax.ndimage.cartesian_to_polar
¤
cartesian_to_polar(coordinate_or_frequency_grid: Float[Array, 'y_dim x_dim 2'], square: bool = False) -> tuple[Inexact[Array, 'y_dim x_dim'], Inexact[Array, 'y_dim x_dim']]
Convert from cartesian to polar coordinates.
Arguments:
coordinate_or_frequency_grid: The cartesian coordinate system in real or fourier space.square: IfTrue, return the square of the radial coordinate \(|r|^2\). Otherwise, return \(|r|\).
Returns:
A tuple (r, theta), where r is the radial coordinate system and
theta is the angular coordinate system. If square=True, return a
tuple (r_squared, theta).
Image transforms (e.g. filters and masks)¤
cryojax.ndimage.AbstractImageTransform
cryojax.ndimage.AbstractImageTransform
¤
cryojax.ndimage.ScaleImage(cryojax.ndimage.AbstractImageTransform)
¤
ScaleImage(scale: cryojax.jax_util._typing.FloatLike = 1.0, offset: cryojax.jax_util._typing.FloatLike = 0.0)
Filters¤
cryojax.ndimage.AbstractFilter
cryojax.ndimage.AbstractFilter(cryojax.ndimage.AbstractImageTransform)
¤
Base class for computing and applying an image filter.
cryojax.ndimage.LowpassFilter(cryojax.ndimage.AbstractFilter)
¤
Apply a low-pass filter to an image or volume, with a cosine soft-edge.
__init__
¤
__init__(frequency_grid_in_angstroms_or_pixels: Float[Array, 'y_dim x_dim 2'] | Float[Array, 'z_dim y_dim x_dim 3'], grid_spacing: cryojax.jax_util.FloatLike = 1.0, frequency_cutoff_fraction: cryojax.jax_util.FloatLike = 0.95, rolloff_width_fraction: cryojax.jax_util.FloatLike = 0.05)
Arguments:
frequency_grid_in_angstroms_or_pixels: The frequency grid of the image or volume.grid_spacing: The pixel or voxel size offrequency_grid_in_angstroms_or_pixels.frequency_cutoff_fraction: The cutoff frequency as a fraction of the Nyquist frequency. By default,0.95.rolloff_width_fraction: The rolloff width as a fraction of the Nyquist frequency. By default,0.05.
__call__
¤
__call__(image: Complex[Array, 'y_dim x_dim'] | Complex[Array, 'z_dim y_dim x_dim']) -> Complex[Array, 'y_dim x_dim'] | Complex[Array, 'z_dim y_dim x_dim']
cryojax.ndimage.HighpassFilter(cryojax.ndimage.AbstractFilter)
¤
Apply a low-pass filter to an image or volume, with a cosine soft-edge.
__init__
¤
__init__(frequency_grid_in_angstroms_or_pixels: Float[Array, 'y_dim x_dim 2'] | Float[Array, 'z_dim y_dim x_dim 3'], grid_spacing: cryojax.jax_util.FloatLike = 1.0, frequency_cutoff_fraction: cryojax.jax_util.FloatLike = 0.95, rolloff_width_fraction: cryojax.jax_util.FloatLike = 0.05)
Arguments:
frequency_grid_in_angstroms_or_pixels: The frequency grid of the image or volume.grid_spacing: The pixel or voxel size offrequency_grid_in_angstroms_or_pixels.frequency_cutoff_fraction: The cutoff frequency as a fraction of the Nyquist frequency. By default,0.95.rolloff_width_fraction: The rolloff width as a fraction of the Nyquist frequency. By default,0.05.
__call__
¤
__call__(image: Complex[Array, 'y_dim x_dim'] | Complex[Array, 'z_dim y_dim x_dim']) -> Complex[Array, 'y_dim x_dim'] | Complex[Array, 'z_dim y_dim x_dim']
cryojax.ndimage.WhiteningFilter(cryojax.ndimage.AbstractFilter)
¤
Compute a whitening filter from an image. This is taken to be the inverse square root of the 2D radially averaged power spectrum.
This implementation follows the cisTEM whitening filter algorithm.
__init__
¤
__init__(images: Float[NDArrayLike, '_ _'] | Float[NDArrayLike, '_ _ _'], shape: tuple[int, int] | None = None, *, interpolation_mode: str = 'linear', outputs_squared: bool = False)
Arguments:
images: The image (or stack of images) from which to compute the power spectrum.shape: The shape of the resulting filter. This downsamples or upsamples the filter by cropping or padding in real space.interpolation_mode: The method of interpolating the binned, radially averaged power spectrum onto a 2D grid. Eithernearestorlinear.outputs_squared: IfFalse, the whitening filter is the inverse square root of the image power. IfTrue, the filter is the inverse of the image power.
__call__
¤
__call__(image: Complex[Array, 'y_dim x_dim'] | Complex[Array, 'z_dim y_dim x_dim']) -> Complex[Array, 'y_dim x_dim'] | Complex[Array, 'z_dim y_dim x_dim']
cryojax.ndimage.CustomFilter(cryojax.ndimage.AbstractFilter)
¤
Pass a custom filter as an array.
Masks¤
cryojax.ndimage.AbstractMask
cryojax.ndimage.AbstractMask(cryojax.ndimage.AbstractImageTransform)
¤
Base class for computing and applying an image mask.
cryojax.ndimage.CircularCosineMask(cryojax.ndimage.AbstractMask)
¤
Apply a circular mask to an image with a cosine soft-edge.
__init__
¤
__init__(coordinate_grid: Float[Array, 'y_dim x_dim 2'], radius: cryojax.jax_util.FloatLike, rolloff_width: cryojax.jax_util.FloatLike, xy_offset: tuple[float, float] | Float[NDArrayLike, '2'] = (0.0, 0.0))
Arguments:
coordinate_grid: The image coordinates.radius: The radius of the circular mask.rolloff_width: The rolloff width of the soft edge.
__call__
¤
__call__(image: Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']) -> Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']
cryojax.ndimage.SphericalCosineMask(cryojax.ndimage.AbstractMask)
¤
Apply a spherical mask to a volume with a cosine soft-edge.
__init__
¤
__init__(coordinate_grid: Float[Array, 'z_dim y_dim x_dim 3'], radius: cryojax.jax_util.FloatLike, rolloff_width: cryojax.jax_util.FloatLike)
Arguments:
coordinate_grid: The volume coordinates.radius: The radius of the spherical mask.rolloff_width: The rolloff width of the soft edge.
__call__
¤
__call__(image: Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']) -> Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']
cryojax.ndimage.SquareCosineMask(cryojax.ndimage.AbstractMask)
¤
Apply a square mask to an image with a cosine soft-edge.
__init__
¤
__init__(coordinate_grid: Float[Array, 'y_dim x_dim 2'], side_length: cryojax.jax_util.FloatLike, rolloff_width: cryojax.jax_util.FloatLike)
Arguments:
coordinate_grid: The image coordinates.side_length: The side length of the square.rolloff_width: The rolloff width of the soft edge.
__call__
¤
__call__(image: Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']) -> Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']
cryojax.ndimage.Rectangular2DCosineMask(cryojax.ndimage.AbstractMask)
¤
Apply a rectangular mask in 2D to an image with a cosine soft-edge. Optionally, rotate the rectangle by an angle.
__init__
¤
__init__(coordinate_grid: Float[Array, 'y_dim x_dim 2'], x_width: cryojax.jax_util.FloatLike, y_width: cryojax.jax_util.FloatLike, rolloff_width: cryojax.jax_util.FloatLike, in_plane_rotation_angle: cryojax.jax_util.FloatLike = 0.0)
Arguments:
coordinate_grid: The image coordinates.x_width: The width of the rectangle along the x-axis.y_width: The width of the rectangle along the y-axis.rolloff_width: The rolloff width of the soft edge.in_plane_rotation_angle: The in-plane rotation angle of the rectangle in degrees. By default,0.0.
__call__
¤
__call__(image: Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']) -> Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']
cryojax.ndimage.Rectangular3DCosineMask(cryojax.ndimage.AbstractMask)
¤
Apply a rectangular mask to a volume with a cosine soft-edge.
__init__
¤
__init__(coordinate_grid: Float[Array, 'z_dim y_dim x_dim 3'], x_width: cryojax.jax_util.FloatLike, y_width: cryojax.jax_util.FloatLike, z_width: cryojax.jax_util.FloatLike, rolloff_width: cryojax.jax_util.FloatLike)
Arguments:
coordinate_grid: The volume coordinates.x_width: The width of the rectangle along the x-axis.y_width: The width of the rectangle along the y-axis.z_width: The width of the rectangle along the z-axis.rolloff_width: The rolloff width of the soft edge.
__call__
¤
__call__(image: Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']) -> Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']
cryojax.ndimage.Cylindrical2DCosineMask(cryojax.ndimage.AbstractMask)
¤
Apply a cylindrical mask to an image with a cosine soft-edge. This implements an infinite in-plane cylinder, rotated at a given angle.
__init__
¤
__init__(coordinate_grid: Float[Array, 'y_dim x_dim 2'], radius: cryojax.jax_util.FloatLike, rolloff_width: cryojax.jax_util.FloatLike, in_plane_rotation_angle: cryojax.jax_util.FloatLike = 0.0, length: cryojax.jax_util.FloatLike | None = None)
Arguments:
coordinate_grid: The image coordinates.radius: The radius of the cylinder.rolloff_width: The rolloff width of the soft edge.in_plane_rotation_angle: The in-plane rotation angle of the cylinder in degrees. By default,0.0.length: The length of the cylinder. IfNone, do not mask the cylinder length-wise.
__call__
¤
__call__(image: Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']) -> Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']
cryojax.ndimage.CustomMask(cryojax.ndimage.AbstractMask)
¤
Pass a custom mask as an array.
__call__
¤
__call__(image: Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']) -> Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']
cryojax.ndimage.SincCorrectionMask(cryojax.ndimage.AbstractMask)
¤
Divide an image or volume by a 2D or 3D rectangular
squared sinc function computed on the unit box. This is used
for correcting scaling in cryojax.simulator.FourierSliceExtraction.
Linear interpolation in the Fourier domain can be thought of as a convolution with a triangular kernel, whose Fourier transform pair is the squared sinc. This mask acts to deconvolve this function with a division in real-space.
__init__
¤
__init__(coordinate_grid_in_pixels: Float[Array, 'y_dim x_dim 2'] | Float[Array, 'z_dim y_dim x_dim 3'])
Arguments:
coordinate_grid_in_pixels: The image or volume coordinates. This should be generated withmake_coordinate_grid(shape), notmake_coordinate_grid(shape, grid_spacing).
__call__
¤
__call__(image: Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']) -> Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']
Operators¤
Fourier-space¤
cryojax.ndimage.AbstractFourierOperator
cryojax.ndimage.AbstractFourierOperator
¤
The base class for all fourier-based operators.
By convention, operators should be defined to be dimensionless (up to a scale factor).
To create a subclass,
1) Include the necessary parameters in
the class definition.
2) Overrwrite the `__call__` method.
__call__
¤
__call__(frequency_grid: Float[Array, 'y_dim x_dim 2'] | Float[Array, 'z_dim y_dim x_dim 3']) -> Inexact[Array, 'y_dim x_dim'] | Inexact[Array, 'z_dim y_dim x_dim']
cryojax.ndimage.FourierGaussian(cryojax.ndimage.AbstractFourierOperator)
¤
This operator represents a simple gaussian. Specifically, this is
.. math:: P(k) = \kappa \exp(- \beta k^2 / 4),
where :math:k^2 = k_x^2 + k_y^2 is the length of the
wave vector. Here, :math:\beta has dimensions of length
squared.
__init__
¤
__init__(amplitude: cryojax.jax_util.FloatLike = 1.0, b_factor: cryojax.jax_util.FloatLike = 1.0)
Arguments:
amplitude: The amplitude of the operator, equal to \(\kappa\) in the above equation.b_factor: The B-factor of the gaussian, equal to \(\beta\) in the above equation.
__call__
¤
__call__(frequency_grid: Float[Array, 'y_dim x_dim 2'] | Float[Array, 'z_dim y_dim x_dim 3']) -> Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']
cryojax.ndimage.PeakedFourierGaussian(cryojax.ndimage.AbstractFourierOperator)
¤
This operator represents a gaussian with a peak at a given frequency shell.
__init__
¤
__init__(amplitude: cryojax.jax_util.FloatLike = 1.0, b_factor: cryojax.jax_util.FloatLike = 1.0, radial_peak: cryojax.jax_util.FloatLike = 0.0)
Arguments:
amplitude: The amplitude of the operator, equal to \(\kappa\) in the above equation.b_factor: The B-factor of the gaussian, equal to \(\beta\) in the above equation.radial_peak: The frequency shell of the gaussian peak.
__call__
¤
__call__(frequency_grid: Float[Array, 'y_dim x_dim 2'] | Float[Array, 'z_dim y_dim x_dim 3']) -> Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']
cryojax.ndimage.FourierConstant(cryojax.ndimage.AbstractFourierOperator)
¤
cryojax.ndimage.FourierSinc(cryojax.ndimage.AbstractFourierOperator)
¤
The separable sinc function is the Fourier transform of the box function and is commonly used for anti-aliasing applications. In 2D, this is
and in 3D this is
where \(\sinc(x) = \frac{\sin(\pi x)}{\pi x}\), \(\vec{q} = (q_x, q_y)\) or \(\vec{q} = (q_x, q_y, q_z)\) are spatial frequency coordinates for 2D and 3D respectively, and \(w\) is width of the real-space box function.
__init__
¤
__init__(box_width: cryojax.jax_util.FloatLike = 1.0)
Arguments:
box_width: If the inverse fourier transform of this class is the rectangular function, its interval is- box_width / 2to+ box_width / 2.
__call__
¤
__call__(frequency_grid: Float[Array, 'y_dim x_dim 2'] | Float[Array, 'z_dim y_dim x_dim 3']) -> Float[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']
cryojax.ndimage.FourierExp2D(cryojax.ndimage.AbstractFourierOperator)
¤
This operator, in real space, represents a function equal to an exponential decay, given by
.. math:: g(|r|) = \frac{\kappa}{2 \pi \xi^2} \exp(- |r| / \xi),
where :math:|r| = \sqrt{x^2 + y^2} is a radial coordinate.
Here, :math:\xi has dimensions of length and :math:g(r)
has dimensions of inverse area. The power spectrum from such
a correlation function (in two-dimensions) is given by its
Hankel transform pair
.. math:: P(|k|) = \frac{\kappa}{2 \pi \xi^3} \frac{1}{(\xi^{-2} + |k|^2)^{3/2}}.
Here :math:\kappa is a scale factor and :math:\xi is a length
scale.
__init__
¤
__init__(amplitude: cryojax.jax_util.FloatLike = 1.0, length_scale: cryojax.jax_util.FloatLike = 1.0)
Arguments:
amplitude: The amplitude of the operator, equal to \(\kappa\) in the above equation.length_scale: The length scale of the operator, equal to \(\xi\) in the above equation.
cryojax.ndimage.FourierDC(cryojax.ndimage.AbstractFourierOperator)
¤
This operator returns a constant in the DC component.
cryojax.ndimage.CustomFourierOperator(cryojax.ndimage.AbstractFourierOperator)
¤
An operator that calls a custom function.
__init__
¤
__init__(fn: Callable[..., Inexact[Array, 'y_dim x_dim'] | Inexact[Array, 'z_dim y_dim x_dim']], *args: Any, **kwargs: Any)
Arguments:
fn: TheCallablewrapped into aAbstractFourierOperator. Has signatureout = fn(frequency_grid, *args, **kwargs)args: Passed tofn.kwargs: Passed tofn.
__call__
¤
__call__(frequency_grid: Float[Array, 'y_dim x_dim 2'] | Float[Array, 'z_dim y_dim x_dim 3']) -> Inexact[Array, 'y_dim x_dim'] | Inexact[Array, 'z_dim y_dim x_dim']
Real-space¤
cryojax.ndimage.AbstractRealOperator
cryojax.ndimage.AbstractRealOperator
¤
The base class for all real operators.
By convention, operators should be defined to have units of inverse area (up to a scale factor).
To create a subclass,
1) Include the necessary parameters in
the class definition.
2) Overrwrite the `__call__` method.
__call__
¤
__call__(coordinate_grid: Float[Array, 'y_dim x_dim 2'] | Float[Array, 'z_dim y_dim x_dim 3']) -> Inexact[Array, 'y_dim x_dim'] | Float[Array, 'z_dim y_dim x_dim']
cryojax.ndimage.RealGaussian(cryojax.ndimage.AbstractRealOperator)
¤
This operator is a normalized gaussian in real space
where \(r^2 = x^2 + y^2\).
__init__
¤
__init__(amplitude: cryojax.jax_util.FloatLike = 1.0, variance: cryojax.jax_util.FloatLike = 1.0, offset: Float[NDArrayLike, '_'] | Sequence[float] | None = None)
Arguments:
amplitude: The amplitude of the operator, equal to \(\kappa\) in the above equation.variance: The variance of the gaussian, equal to \(\sigma\) in the above equation.offset: An offset to the origin, equal to \(r_0\) in the above equation.