Skip to content

voxelmorph.nn.modules

Neural network building blocks for VoxelMorph.

IntegrateVelocityField

IntegrateVelocityField(shape: tuple, steps: int = 1, interpolation_mode: str = 'bilinear', align_corners: bool = False, device: str = 'cpu')

Bases: Module

Integrates a velocity field over multiple steps using the scaling and squaring method.

This module ensures that transformations caused by a velocity field is diffeomorphic by compounding small, intermediate transformations (by recursive scaling and squaring). This ensures the resultant is both smooth and invertable.

ATTRIBUTE DESCRIPTION
steps

The number of squaring steps used for integration.

TYPE: int

scale

Scaling factor for the initial velocity field, determined as 1 / (2^steps).

TYPE: float

transformer

A spatial transformer module used to iteratively warp the vector field.

TYPE: Module

Examples:

Integrate a 2D velocity field over multiple steps:
>>> shape = (128, 128)  # 2D spatial grid
>>> integrator = IntegrateVelocityField(shape, steps=256)
>>> velocity_field = torch.randn(1, 2, 128, 128)  # (B, C, H, W)
>>> disp = integrator(velocity_field)
>>> disp.shape
torch.Size([1, 2, 128, 128])
Perform integration on a 3D velocity field with a single scaling step:
>>> shape = (64, 64, 64)  # 3D spatial grid
>>> integrator = IntegrateVelocityField(shape, steps=1)
>>> velocity_field = torch.randn(1, 3, 64, 64, 64)  # (B, C, D, H, W)
>>> disp = integrator(velocity_field)
>>> disp.shape
torch.Size([1, 3, 64, 64, 64])

Initialize IntegrateVelocityField

PARAMETER DESCRIPTION
shape

Shape of the input velocity field (excluding batch and channel dimensions).

TYPE: tuple

steps

Number of integration steps. A higher value leads to a more smooth and accurate integration at the cost of higher/longer computation. Default is 1.

TYPE: int DEFAULT: 1

interpolation_mode

Algorithm used for interpolating the warped image. Default is 'bilinear'. Options are: 'bilinear' | 'nearest' | 'bicubic'.

TYPE: str DEFAULT: 'bilinear'

align_corners

Map the corner points of the moving image to the corner points of the warped image.

TYPE: bool DEFAULT: False

device

Device to construct and hold the identity grid.

TYPE: str DEFAULT: 'cpu'

Source code in voxelmorph/nn/modules.py
def __init__(
    self, shape: tuple,
    steps: int = 1,
    interpolation_mode: str = "bilinear",
    align_corners: bool = False,
    device: str = "cpu"
):
    """
    Initialize `IntegrateVelocityField`

    Parameters
    ----------
    shape : tuple
        Shape of the input velocity field (excluding batch and channel dimensions).
    steps : int, optional
        Number of integration steps. A higher value leads to a more smooth and accurate
        integration at the cost of higher/longer computation. Default is 1.
    interpolation_mode : str
        Algorithm used for interpolating the warped image. Default is  'bilinear'. Options are:
        'bilinear' | 'nearest' | 'bicubic'.
    align_corners : bool
        Map the corner points of the moving image to the corner points of the warped image.
    device : str
        Device to construct and hold the identity grid.
    """

    super().__init__()

    if steps < 0:
        raise ValueError(f"steps should be >= 0, found: {steps}")

    self.steps = steps
    self.scale = 1.0 / (2 ** self.steps)  # Initial downscaling factor

    # Make the transformer which will perform the warping operation
    self.transformer = SpatialTransformer(shape, interpolation_mode, align_corners, device)

ResizeDisplacementField

ResizeDisplacementField(scale_factor: Optional[Union[float, int, Sampler]] = 1.0, interpolation_mode: str = 'bilinear', align_corners: bool = True)

Bases: Module

Resize and rescale a displacement field.

Resizd a displacement field both spatially (via interpolation) and in magnitude (via scaling).

Examples:

Resize a 2D displacement field
>>> resize_field = ResizeDisplacementField(scale_factor=2.0, interpolation_mode="bilinear")
>>> disp = torch.rand(1, 2, 16, 16)  # Example displacement field in 2d
>>> resized_disp = resize_field(disp)
>>> print(resized_disp.shape)  # Should be larger if scale_factor > 1
torch.Size([1, 2, 32, 32])

Instantiate the ResizeDisplacementField module.

PARAMETER DESCRIPTION
scale_factor

Factor by which to stretch or shrink the spatial dimensions of the displacement field. Values of scale_factor > 1 stretch/expand the field, and values < 1 shrink it. By default None.

TYPE: Optional[Union[float, int, Sampler]] DEFAULT: 1.0

interpolation_mode

Algorithm used for interpolating the warped image. Default is 'bilinear'. Options are: 'bilinear' | 'nearest' | 'bicubic', 'trilinear'.

TYPE: str DEFAULT: 'bilinear'

align_corners

Map the corner points of the moving image to the corner points of the warped image.

TYPE: bool DEFAULT: True

Source code in voxelmorph/nn/modules.py
def __init__(
    self,
    scale_factor: Optional[Union[float, int, ne.samplers.Sampler]] = 1.0,
    interpolation_mode: str = "bilinear",
    align_corners: bool = True,
):
    """
    Instantiate the `ResizeDisplacementField` module.

    Parameters
    ----------
    scale_factor : Optional[Union[float, int, Sampler]], optional
        Factor by which to stretch or shrink the spatial dimensions of the displacement field.
        Values of `scale_factor` > 1 stretch/expand the field, and values < 1 shrink it. By
        default None.
    interpolation_mode : str
        Algorithm used for interpolating the warped image. Default is  'bilinear'. Options are:
        'bilinear' | 'nearest' | 'bicubic', 'trilinear'.
    align_corners : bool
        Map the corner points of the moving image to the corner points of the warped image.
    """
    super().__init__()
    self.interpolation_mode = interpolation_mode
    self.align_corners = align_corners
    self.scale_factor = ne.samplers.Fixed.make(scale_factor)

SpatialTransformer

SpatialTransformer(size: Tuple[int], interpolation_mode: str = 'bilinear', align_corners: bool = False, device: Union[str, device] = 'cpu')

Bases: Module

N-D Spatial transformation according to a deformation field.

Uses a deformation field to transform the moving image.

References

If you find this helpful, please cite the following paper:

VoxelMorph: A Learning Framework for Deformable Medical Image Registration G. Balakrishnan, A. Zhao, M. R. Sabuncu, J. Guttag, A.V. Dalca. IEEE TMI: Transactions on Medical Imaging. 38(8). pp 1788-1800. 2019.

Initialize SpatialTransformer.

PARAMETER DESCRIPTION
size

Expected size of moving_image (input image to be warped) for the forward pass.

TYPE: tuple[int]

interpolation_mode

Algorithm used for interpolating the warped image. Default is 'bilinear'. Options are: 'bilinear' | 'nearest' | 'bicubic'.

TYPE: str DEFAULT: 'bilinear'

align_corners

Map the corner points of the moving image to the corner points of the warped image.

TYPE: bool DEFAULT: False

device

Device to construct and hold the identity grid.

TYPE: str DEFAULT: 'cpu'

Source code in voxelmorph/nn/modules.py
def __init__(
    self,
    size: Tuple[int],
    interpolation_mode: str = "bilinear",
    align_corners: bool = False,
    device: Union[str, torch.device] = "cpu",
):
    """
    Initialize `SpatialTransformer`.

    Parameters
    ----------
    size : tuple[int]
        Expected size of `moving_image` (input image to be warped) for the forward pass.
    interpolation_mode : str
        Algorithm used for interpolating the warped image. Default is  'bilinear'. Options are:
        'bilinear' | 'nearest' | 'bicubic'.
    align_corners : bool
        Map the corner points of the moving image to the corner points of the warped image.
    device : str
        Device to construct and hold the identity grid.
    """
    super().__init__()

    self.size = size
    self.device = device
    self.interpolation_mode = interpolation_mode
    self.align_corners = align_corners

    # Make identity grid (the grid to later warp with deformation field) and register as a
    # buffer (without saving to `state_dict`: persistent=False)
    self.register_buffer(
        name='identity_grid',
        tensor=ne.utils.utils.grid(size=size, device=device),
        persistent=False  # Don't save to this module's state dict!
    )

_normalize_warped_grid

_normalize_warped_grid(warped_grid: Tensor) -> torch.Tensor

Normalize a warped grid to make PyTorch grid_sample() happy!

PyTorch's grid_sample() requires coordinates in the range [-1, 1]. This function scales and shifts the warped grid accordingly.

PARAMETER DESCRIPTION
warped_grid

The resultant of the identity grid and the deformation field.

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

The warped grid rescaled to the range [-1, 1] for each spatial axis

Source code in voxelmorph/nn/modules.py
def _normalize_warped_grid(
    self,
    warped_grid: torch.Tensor
) -> torch.Tensor:
    """
    Normalize a warped grid to make PyTorch `grid_sample()` happy!

    PyTorch's `grid_sample()` requires coordinates in the range [-1, 1].
    This function scales and shifts the warped grid accordingly.

    Parameters
    ----------
    warped_grid : torch.Tensor
        The resultant of the identity grid and the deformation field.

    Returns
    -------
    torch.Tensor
        The warped grid rescaled to the range [-1, 1] for each spatial axis
    """

    for i, dim in enumerate(self.size):

        # Rescale each dimension individually
        warped_grid[..., i] = 2 * (warped_grid[..., i] / (dim - 1) - 0.5)

    return warped_grid