Structural refinement using gradient descent

Automatic differentiation using JAX is a powerful tool for conducting high-dimensional data analysis. This tutorial demonstrates how automatic differentiation may be leveraged for structure refinement.

The numerical experiment performed here can be described in 3 steps:

  1. Simulate a synthetic dataset of a protein in a "ground truth" conformation
  2. Define a image simulation model using a different conformation
  3. Recover the ground truth by optimizing the atom positions using gradient descent.

We will use thyroglobulin for this example, using structures modified from the Flatiron Institute cryo-EM conformational heterogeneity challenge.

References:

  • Astore, Miro A., et al. "The Inaugural Flatiron Institute Cryo-EM Conformational Heterogeneity Challenge." bioRxiv (2025): 2025-07.

We will start by simulating a synthetic dataset.

# Imports for dataset generation
import math

import cryojax.simulator as cxs
import equinox as eqx
import jax
import jax.numpy as jnp
from cryojax.io import read_atoms_from_pdb
from cryojax.ndimage import transforms as tf
from cryojax.rotations import SO3
from jaxtyping import Array, PRNGKeyArray
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable


def plot_image_stack(images, cmap="gray", **kwargs):
    """Plot an image stack"""
    n_images_per_side = int(math.sqrt(images.shape[0]))
    fig, axes = plt.subplots(nrows=n_images_per_side, ncols=n_images_per_side)
    vmin, vmax = images.min(), images.max()
    for idx, ax in enumerate(axes.ravel()):
        im = ax.imshow(
            images[idx], cmap=cmap, vmin=vmin, vmax=vmax, origin="lower", **kwargs
        )
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        fig.colorbar(im, cax=cax)
    fig.tight_layout()


# Dataset generation parameters
N_PARTICLES = 200
SIMULATE_RNG_SEED = 1234
SNR = 0.1
DEFOCUS_RANGE = (5000, 15000)
SPHERICAL_ABERRATION_IN_MM = 2.7
AMPLITUDE_CONTRAST_RATIO = 0.1
PIXEL_SIZE = 2.0
VOLTAGE_IN_KILOVOLTS = 300.0
IMAGE_DIM = 200
PAD_SCALE = 1.5
GMM_WIDTH_IN_ANGSTROMS = 3.0


def make_parameters(
    rng_key: PRNGKeyArray, image_config: cxs.BasicImageConfig
) -> tuple[cxs.EulerAnglePose, cxs.ContrastTransferTheory]:
    """Generate random pose and CTF parameters"""
    key, subkey = jax.random.split(rng_key)
    rotation = SO3.sample_uniform(subkey)
    ny, nx = image_config.shape
    key, subkey = jax.random.split(key)
    offset_in_angstroms = (
        jax.random.uniform(subkey, (2,), minval=-0.1, maxval=0.1)
        * jnp.asarray((nx, ny))
        / 2
        * image_config.pixel_size
    )
    pose = cxs.EulerAnglePose.from_rotation_and_translation(rotation, offset_in_angstroms)
    key, subkey = jax.random.split(key)
    defocus_in_angstroms = jax.random.uniform(
        subkey, (), minval=DEFOCUS_RANGE[0], maxval=DEFOCUS_RANGE[1], dtype=float
    )
    transfer_theory = cxs.ContrastTransferTheory(
        ctf=cxs.AstigmaticCTF(
            defocus_in_angstroms=defocus_in_angstroms,
            spherical_aberration_in_mm=SPHERICAL_ABERRATION_IN_MM,
        ),
        amplitude_contrast_ratio=AMPLITUDE_CONTRAST_RATIO,
    )

    return pose, transfer_theory


@eqx.filter_jit
@eqx.filter_vmap(in_axes=(0, None, None))
def simulate_dataset(
    rng_key: PRNGKeyArray,
    volume: cxs.GaussianMixtureVolume,
    image_config: cxs.BasicImageConfig,
) -> tuple[Array, dict, tf.CircularCosineMask]:
    """Simulate a synthetic dataset"""
    particle_rng_key, noise_rng_key = jax.random.split(rng_key, num=2)
    pose, transfer_theory = make_parameters(particle_rng_key, image_config)
    circular_mask = tf.CircularCosineMask(
        image_config.coordinate_grid_in_angstroms,
        radius=150.0,
        rolloff_width=10.0,
        xy_offset=pose.offset_in_angstroms,
    )
    image_model = cxs.make_image_model(
        volume,
        image_config,
        pose,
        transfer_theory,
        normalizes_signal=True,
        signal_region=circular_mask.get() == 1.0,
    )
    noise_model = cxs.GaussianWhiteNoiseModel(
        image_model, variance=1.0, signal_scale_factor=jnp.sqrt(SNR)
    )
    parameters = dict(
        pose=pose, transfer_theory=transfer_theory, image_config=image_config
    )

    return (
        noise_model.sample(noise_rng_key, mask=circular_mask),
        parameters,
        circular_mask,
    )


# Load gaussian mixture. For simplicity, simulate using only C-alphas
ground_truth_positions, _ = read_atoms_from_pdb(
    "./data/thyroglobulin_target.pdb",
    selection_string="name CA",
)
n_atoms = ground_truth_positions.shape[0]
volume = cxs.GaussianMixtureVolume(
    ground_truth_positions,
    amplitudes=1.0,
    variances=GMM_WIDTH_IN_ANGSTROMS**2,
)
# Instantiate `image_config`
padded_dim = int(PAD_SCALE * IMAGE_DIM)
pad_options = dict(shape=(padded_dim, padded_dim))
image_config = cxs.BasicImageConfig(
    shape=(IMAGE_DIM, IMAGE_DIM),
    pixel_size=PIXEL_SIZE,
    voltage_in_kilovolts=VOLTAGE_IN_KILOVOLTS,
    pad_options=pad_options,
)
# Create RNG keys and simulate
rng_keys = jax.random.split(jax.random.key(SIMULATE_RNG_SEED), num=N_PARTICLES)
dataset_images, dataset_parameters, masks = simulate_dataset(
    rng_keys, volume, image_config
)
plot_image_stack(dataset_images[0:9])

img

Next, define the loss function. We will naively do a gradient descent on the positions of the C-alpha atoms to see if we can capture the large-scale difference of the conformational change.

@eqx.filter_value_and_grad
def compute_loss(atom_positions, args):
    """Cross-corrlelation loss over a batch of images"""

    @eqx.filter_vmap(in_axes=(None, eqx.if_array(0)))
    def compute_loss_per_particle(atom_positions, args):
        (image, particle_parameters, mask) = args
        volume = cxs.GaussianMixtureVolume(
            atom_positions,
            amplitudes=1.0,
            variances=GMM_WIDTH_IN_ANGSTROMS**2,
        )
        image_model = cxs.make_image_model(
            volume,
            image_config=particle_parameters["image_config"],
            pose=particle_parameters["pose"],
            transfer_theory=particle_parameters["transfer_theory"],
            normalizes_signal=True,
            signal_region=mask.get() == 1.0,
        )
        return -jnp.sum(image_model.simulate(mask=mask) * image)

    return jnp.sum(compute_loss_per_particle(atom_positions, args))

Finally, perform the gradient descent. To do so, we will implement a stochastic gradient descent, choosing random subsets of the dataset for each gradient step. This will require defining what is called a dataloader, which allows us to loop through these random subsets.

We will do this in a naive way, but to learn about dataloaders in general see pytorch's implementation. At the time of writing this the JAX ecosystem has a dataloader library in development called grain.

import numpy as np
import optax
from cryojax.dataset import AbstractDataset


class PyTreeDataset(AbstractDataset):
    """Dataset as an arbitrary in-memory PyTree, where all arrays
    have a batch dimension.
    """

    def __init__(self, pytree):
        self.pytree_dynamic, self.pytree_static = eqx.partition(pytree, eqx.is_array)
        assert all(
            arr.shape[0] == len(self) for arr in jax.tree.leaves(self.pytree_dynamic)
        )

    def __getitem__(self, index):
        return eqx.combine(
            jax.tree.map(lambda x: x[index], self.pytree_dynamic), self.pytree_static
        )

    def __len__(self) -> int:
        return jax.tree.leaves(self.pytree_dynamic)[0].shape[0]


def dataloader(dataset, batch_size):
    """A simple dataloader that splits the dataset into
    random subsets of size `batch_size`.
    """
    dataset_size = len(dataset)
    indices = np.arange(dataset_size)
    while True:
        perm = np.random.permutation(indices)
        start = 0
        end = batch_size
        while end <= dataset_size:
            batch_perm = perm[start:end]
            yield dataset[batch_perm]
            start = end
            end = start + batch_size


# Optimisation parameters
BATCH_SIZE = 20
NUM_STEPS = 100
PRINT_EVERY = 10
OPTIM = optax.adabelief(learning_rate=0.5, nesterov=True)


# Wrap everything into a single JIT region
@eqx.filter_jit(donate="all")
def make_step(atom_positions, opt_state, args):
    loss, grads = compute_loss(atom_positions, args)
    updates, opt_state = OPTIM.update(grads, opt_state)
    atom_positions = eqx.apply_updates(atom_positions, updates)
    return loss, atom_positions, opt_state


# Initialize optimization using different conformation
initial_positions, _ = read_atoms_from_pdb(
    "./data/thyroglobulin_initial.pdb", selection_string="name CA"
)

# Run optimization
pytree = (dataset_images, dataset_parameters, masks)
dataset = PyTreeDataset(pytree)
opt_positions, opt_state, loss_values = (
    initial_positions,
    OPTIM.init(initial_positions),
    [],
)
for step, data_batch in zip(range(NUM_STEPS), dataloader(dataset, BATCH_SIZE)):
    loss, opt_positions, opt_state = make_step(opt_positions, opt_state, data_batch)
    loss_values.append(loss.item())
    if step % PRINT_EVERY == 0 or step in [0, NUM_STEPS - 1]:
        print(f"step={step}, loss={loss.item()}")

fig, ax = plt.figure(figsize=(5, 3)), plt.gca()
ax.plot(range(NUM_STEPS), loss_values)
ax.set(xlabel="step number", ylabel="cross-correlation loss")
step=0, loss=-82785.3125
step=10, loss=-106068.6015625
step=20, loss=-109953.0234375
step=30, loss=-112941.578125
step=40, loss=-112299.015625
step=50, loss=-112578.9609375
step=60, loss=-112034.578125
step=70, loss=-113446.4375
step=80, loss=-111717.9375
step=90, loss=-113332.6484375
step=99, loss=-112113.4765625
[Text(0.5, 0, 'step number'), Text(0, 0.5, 'cross-correlation loss')]

img

Now let's check the results! Let's first look at FSC curves between the ground truth and structure before and after refinement.

from cryojax.coordinates import make_radial_frequency_grid
from cryojax.ndimage import compute_fourier_shell_correlation, rfftn


def plot_fsc(pos1, pos2, ax, **kwargs):
    shape, voxel_size = (IMAGE_DIM, IMAGE_DIM, IMAGE_DIM), PIXEL_SIZE
    gaussian_render_fn = cxs.GaussianMixtureRenderFn(shape, voxel_size)
    compute_voxels = lambda pos: gaussian_render_fn(
        cxs.GaussianMixtureVolume(
            pos,
            amplitudes=1.0,
            variances=GMM_WIDTH_IN_ANGSTROMS**2,
        )
    )
    voxels1 = rfftn(compute_voxels(pos1))
    voxels2 = rfftn(compute_voxels(pos2))
    frequencies = make_radial_frequency_grid(shape)
    fsc, frequency_bins, _ = compute_fourier_shell_correlation(
        voxels1,
        voxels2,
        frequencies,
        maximum_frequency=0.5,
    )
    ax.plot(frequency_bins, fsc, **kwargs)


fig, ax = plt.figure(figsize=(5, 3)), plt.gca()
plot_fsc(
    ground_truth_positions,
    initial_positions,
    ax,
    label="Before optimization",
    color="orange",
)
plot_fsc(
    ground_truth_positions, opt_positions, ax, label="After optimization", color="red"
)
ax.set(xlabel=r"frequency magnitude [$\AA^{-1}$]", ylabel="fourier shell correlation")
ax.legend()
plt.tight_layout()

img

The above FSC curves quantifies the increase in correlation after gradient descent refinement. Indeed, visually the refinement is able to (roughly) recover the ground truth conformation!

def plot_cross_sections(
    pos1, pos2, axes, label1=None, label2=None, color1="b", color2="orange"
):
    kwargs = dict(marker=".", alpha=0.1)
    xlim, ylim = (-175, 175), (-175, 175)
    for pos, label, color in zip((pos1, pos2), (label1, label2), (color1, color2)):
        axes[0].scatter(pos[:, 1], pos[:, 0], c=color, label=label, **kwargs)
        axes[0].set(xlabel="y", ylabel="x", xlim=xlim, ylim=ylim)
        axes[1].scatter(pos[:, 0], pos[:, 2], c=color, **kwargs)
        axes[1].set(xlabel="x", ylabel="z", xlim=xlim, ylim=ylim)
        axes[2].scatter(pos[:, 1], pos[:, 2], c=color, **kwargs)
        axes[2].set(xlabel="y", ylabel="z", xlim=xlim, ylim=ylim)
    axes[0].legend(loc="upper left")

    return fig, axes


fig, axes = plt.subplots(figsize=(10, 6), ncols=3, nrows=2)
plot_cross_sections(
    ground_truth_positions,
    initial_positions,
    axes=axes[0, :],
    label1="Ground truth",
    label2="Initial positions",
    color2="orange",
)
plot_cross_sections(
    ground_truth_positions,
    opt_positions,
    axes=axes[1, :],
    label1="Ground truth",
    label2="Optimized positions",
    color2="red",
)
fig.suptitle("Central cross sections of GMM positions")
plt.tight_layout()

img

Interactive plots of the 3D point clouds can be generated locally using plotly with the following code block:

import plotly.graph_objects as go
from plotly.subplots import make_subplots


def plot_3d_point_clouds_interactive(
    pos1, pos2, label1="Ground truth", label2="Comparison"
):
    """Create interactive 3D scatter plots with plotly"""

    # Create subplots with 3D scene type
    fig = make_subplots(
        rows=1,
        cols=2,
        subplot_titles=[f"{label1} vs Initial", f"{label1} vs Optimized"],
        specs=[[{"type": "scatter3d"}, {"type": "scatter3d"}]],
    )

    # First subplot: Ground truth vs Initial
    fig.add_trace(
        go.Scatter3d(
            x=pos1[:, 0],
            y=pos1[:, 1],
            z=pos1[:, 2],
            mode="markers",
            marker=dict(size=3, color="blue", opacity=0.7),
            name=label1,
            showlegend=True,
        ),
        row=1,
        col=1,
    )

    fig.add_trace(
        go.Scatter3d(
            x=initial_positions[:, 0],
            y=initial_positions[:, 1],
            z=initial_positions[:, 2],
            mode="markers",
            marker=dict(size=3, color="orange", opacity=0.7),
            name="Initial positions",
            showlegend=True,
        ),
        row=1,
        col=1,
    )

    # Second subplot: Ground truth vs Optimized
    fig.add_trace(
        go.Scatter3d(
            x=pos1[:, 0],
            y=pos1[:, 1],
            z=pos1[:, 2],
            mode="markers",
            marker=dict(size=3, color="blue", opacity=0.7),
            name=label1,
            showlegend=False,  # Don't duplicate legend
        ),
        row=1,
        col=2,
    )

    fig.add_trace(
        go.Scatter3d(
            x=opt_positions[:, 0],
            y=opt_positions[:, 1],
            z=opt_positions[:, 2],
            mode="markers",
            marker=dict(size=3, color="red", opacity=0.7),
            name="Optimized positions",
            showlegend=True,
        ),
        row=1,
        col=2,
    )

    # Update layout
    fig.update_layout(
        title="Interactive 3D Point Cloud Comparison",
        width=1200,
        height=600,
        scene=dict(
            xaxis_title="X (Å)",
            yaxis_title="Y (Å)",
            zaxis_title="Z (Å)",
            aspectmode="cube",
        ),
        scene2=dict(
            xaxis_title="X (Å)",
            yaxis_title="Y (Å)",
            zaxis_title="Z (Å)",
            aspectmode="cube",
        ),
    )

    return fig


# Generate interactive plot:
fig = plot_3d_point_clouds_interactive(
    ground_truth_positions, initial_positions, label1="Ground truth"
)
fig.show()