Simulate an image
This tutorial demonstrates how to get started simulating an image with cryojax.
# Jax imports
import jax
# Plotting imports and function definitions
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
def plot_image(image, fig, ax, cmap="gray", label=None, **kwargs):
im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im, cax=cax)
if label is not None:
ax.set(title=label)
return fig, ax
# Import the cryoJAX simulator
import cryojax.simulator as cxs
This tutorial starts by instantiating the representation of the volume. This is a fourier voxel representation of the protein electrostatic potential. See the tutorial on rendering volumes from atomic models for more information.
from cryojax.io import read_array_from_mrc
# Scattering potential stored in MRC format
filename = "./data/groel_5w0s_scattering_potential.mrc"
# ... read into a FourierVoxelGridPotential
real_voxel_grid, voxel_size = read_array_from_mrc(filename, loads_grid_spacing=True)
volume = cxs.FourierVoxelGridVolume.from_real_voxel_grid(real_voxel_grid, pad_scale=2)
Next, instantiate the image formation parameters. This includes parameters for the pose, CTF, and image configuration.
# Now, instantiate the pose. Angles are given in degrees
pose = cxs.EulerAnglePose(
offset_x_in_angstroms=5.0,
offset_y_in_angstroms=-3.0,
phi_angle=20.0,
theta_angle=80.0,
psi_angle=-5.0,
)
# Then the contrast transfer theory
ctf = cxs.AstigmaticCTF(
defocus_in_angstroms=10000.0,
astigmatism_in_angstroms=-100.0,
astigmatism_angle=10.0,
)
transfer_theory = cxs.ContrastTransferTheory(ctf, amplitude_contrast_ratio=0.1)
# Then the configuration. Add padding to avoid periodic artifacts due to CTF rings
pad_options = dict(shape=volume.shape[0:2])
image_config = cxs.BasicImageConfig(
shape=(80, 80),
pixel_size=voxel_size,
voltage_in_kilovolts=300.0,
pad_options=pad_options,
)
Finally, instantiate the image model using cryojax.simulator.make_image_model.
What is an image_model?
These are subclasses of the abstract base class AbstractImageModel. CryoJAX uses equinox.Modules for python classes, which implement PyTorch-like syntax while maintaining smooth integration with JAX functional programming. To learn more, see here.
The make_image_model function is just a convenience wrapper to construct an AbstractImageModel, the core cryoJAX class for image simulation. The important things to know about it are 1) It contains necessary parameters for
simulating an image, such as the pose, the volume, and the CTF and 2) Images are simulated with
image_model.simulate().
# Make the image model as well as a simulation function. By default, cryoJAX will simulate
# the contrast in physical units. Rather, normalize the image.
image_model = cxs.make_image_model(
volume,
image_config,
pose,
transfer_theory,
normalizes_signal=True,
)
print(image_model)
LinearImageModel(
volume_parametrization=FourierVoxelGridVolume(
fourier_voxel_grid=c64[160,160,160],
frequency_slice_in_pixels=f32[1,160,160,3]
),
pose=EulerAnglePose(
offset_in_angstroms=f32[2],
phi_angle=f32[],
theta_angle=f32[],
psi_angle=f32[]
),
volume_integrator=AutoVolumeProjection(options={}),
transfer_theory=ContrastTransferTheory(
ctf=AstigmaticCTF(
defocus_in_angstroms=f32[],
astigmatism_in_angstroms=f32[],
astigmatism_angle=f32[],
spherical_aberration_in_mm=f32[]
),
envelope=None,
amplitude_contrast_ratio=weak_f32[],
phase_shift=f32[]
),
image_config=BasicImageConfig(
shape=(80, 80),
pixel_size=f32[],
voltage_in_kilovolts=f32[],
grid_helper=None,
pad_options={'shape': (160, 160), 'grid_helper': None}
),
transform=None,
normalizes_signal=True,
signal_region=None,
signal_centering='mean',
translate_mode='fft'
)
Now, in JAX-style functional programming we need to define a function to simulate an image. We think about the image_model as a collection of arguments to pass a function for image simulation.
import equinox as eqx
@eqx.filter_jit
def simulate_fn(image_model):
return image_model.simulate()
What's with the eqx.filter_jit?
This is an example of a JAX transformation for JIT compilation (i.e. jax.jit). If you aren't familar with jax.jit, then start by reading the JAX documentation. In particular, eqx.filter_jit an equinox lightweight wrapper of jax.jit called a filtered transformation. To learn more, including about JAX transformations for automatic diffferntiation and vectorization, see the next cryoJAX tutorial.
# Simulate the image!
fig, ax = plt.subplots(figsize=(3, 3))
im1 = plot_image(
simulate_fn(image_model),
fig,
ax,
label="Image contrast",
)
plt.tight_layout()

Alternatively, the user can simulate an image with noise from a particular statistical distribution. In this example, we use the GaussianWhiteNoiseModel.
Here, we will directly control image SNR by transforming the final image output using a cryojax.ndimage.ImageScaling.
import cryojax.ndimage as im
import equinox as eqx
import numpy as np
from jaxtyping import PRNGKeyArray
@eqx.filter_jit
def simulate_fn_noisy(key: PRNGKeyArray, noise_model: cxs.AbstractNoiseModel):
"""Simulate an image with noise from a `image_model`."""
return noise_model.sample(key)
# Instantiate a `noise_model` and simulate
rng_key = jax.random.key(seed=0)
snr = 0.1
image_model = cxs.make_image_model(
volume,
image_config,
pose,
transfer_theory,
normalizes_signal=True,
transform=im.ScaleImage(scale=np.sqrt(snr)),
)
noise_model = cxs.GaussianWhiteNoiseModel(image_model, variance=1.0)
# ... then, either simulate an image from this `noise_model``
fig, axes = plt.subplots(ncols=2, figsize=(6, 3))
ax1, ax2 = axes
im1 = plot_image(
simulate_fn(image_model),
fig,
ax1,
label="Underlying image",
)
im2 = plot_image(
simulate_fn_noisy(rng_key, noise_model),
fig,
ax2,
label="Image with white noise",
)
plt.tight_layout()
