Skip to content

voxelmorph.nn.functional

Functions containing the core operations and logic of for image registration for voxelmorph written in PyTorch.

affine_to_disp

affine_to_disp(affine: Tensor, meshgrid: Tensor, rotate_around_center: Optional[bool] = True) -> Tensor

Convert an affine transformation matrix to a displacement field.

PARAMETER DESCRIPTION
affine

Affine transformation matrix. It is expected to be a vox2vox target to source transformation.

TYPE: Tensor

meshgrid

The meshgrid tensor of shape (W, H[, D], N), where N is the spatial dimensionality.

TYPE: Tensor

rotate_around_center

If True, the rotation will be around the center of the image, otherwise around the origin.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
Tensor

The generated displacement field of shape meshgrid.shape[:-1].

Source code in voxelmorph/nn/functional.py
def affine_to_disp(
    affine: Tensor,
    meshgrid: Tensor,
    rotate_around_center: Optional[bool] = True
) -> Tensor:
    """
    Convert an affine transformation matrix to a displacement field.

    Parameters
    ----------
    affine : Tensor
        Affine transformation matrix. It is expected to be a vox2vox target to source
        transformation.
    meshgrid : Tensor
        The meshgrid tensor of shape `(W, H[, D], N)`, where N is the spatial dimensionality.
    rotate_around_center : bool, optional
        If True, the rotation will be around the center of the image, otherwise around the origin.

    Returns
    -------
    Tensor
        The generated displacement field of shape `meshgrid.shape[:-1]`.
    """
    ndim = meshgrid.shape[-1]
    shape = meshgrid.shape[:-1]

    # if rotate_around_center is enabled, adjust the meshgrid so that the rotation
    # is around the center of the image instead of the origin
    grid = meshgrid.clone() if rotate_around_center else meshgrid
    if rotate_around_center:
        for d in range(ndim):
            grid[..., d] -= (shape[d] - 1) / 2

    # convert the meshgrid to homogeneous coordinates by appending a column of ones
    coords = grid.view(-1, ndim)
    ones = torch.ones((coords.shape[-2], 1), device=meshgrid.device)
    coords = torch.cat([coords, ones], dim=-1)

    # Apply the affine transformation to the coordinates to get the displacement field
    # affine needs to be vox2vox transformation matrix, and mapping from target to source
    # the computed displacement field is the absolute crs in source space
    disp = (affine @ coords.T)[:ndim].T

    # Reshape the displacement field to match the shape of the meshgrid and subtract
    # the original meshgrid to get the displacement field
    disp = disp.view(*shape, ndim) - grid

    return disp

angles_to_rotation_matrix

angles_to_rotation_matrix(rotation: Tensor, degrees: bool = True) -> Tensor

Compute a rotation matrix from the given rotation angles.

PARAMETER DESCRIPTION
rotation

A tensor containing the rotation angles. If degrees is True, the angles are in degrees, otherwise they are in radians.

TYPE: Tensor

degrees

Whether to interpret the rotation angles as degrees.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
Tensor

The computed (ndim + 1, ndim + 1) rotation matrix.

Source code in voxelmorph/nn/functional.py
def angles_to_rotation_matrix(
    rotation: Tensor,
    degrees: bool = True
) -> Tensor:
    """
    Compute a rotation matrix from the given rotation angles.

    Parameters
    ----------
    rotation : Tensor
        A tensor containing the rotation angles. If `degrees` is True, the angles
        are in degrees, otherwise they are in radians.
    degrees : bool, optional
        Whether to interpret the rotation angles as degrees.

    Returns
    -------
    Tensor
        The computed `(ndim + 1, ndim + 1)` rotation matrix.
    """
    if degrees:
        rotation = torch.deg2rad(rotation)

    # scalar value allowed for 2D transforms
    rotation = torch.as_tensor(rotation)
    if rotation.ndim == 0:
        rotation = rotation.view(1)
    num_angles = len(rotation)

    # build the matrix
    if num_angles == 1:
        c, s = torch.cos(rotation[0]), torch.sin(rotation[0])
        matrix = torch.tensor([[c, -s], [s, c]], dtype=torch.float64)
    elif num_angles == 3:
        c, s = torch.cos(rotation[0]), torch.sin(rotation[0])
        rx = torch.tensor([[1, 0, 0], [0, c, s], [0, -s, c]], dtype=torch.float64)
        c, s = torch.cos(rotation[1]), torch.sin(rotation[1])
        ry = torch.tensor([[c, 0, s], [0, 1, 0], [-s, 0, c]], dtype=torch.float64)
        c, s = torch.cos(rotation[2]), torch.sin(rotation[2])
        rz = torch.tensor([[c, s, 0], [-s, c, 0], [0, 0, 1]], dtype=torch.float64)
        matrix = rx @ ry @ rz
    else:
        raise ValueError(f'expected 1 (2D) or 3 (3D) rotation angles, got {num_angles}')

    return matrix.to(rotation.device)

chance

chance(prob: float) -> bool

Returns True with given probability.

PARAMETER DESCRIPTION
prob

Probability of returning True. Must be in the range [0, 1].

TYPE: float

RETURNS DESCRIPTION
bool

True with probability prob.

Source code in voxelmorph/nn/functional.py
def chance(prob: float) -> bool:
    """
    Returns True with given probability.

    Parameters
    ----------
    prob : float
        Probability of returning True. Must be in the range [0, 1].

    Returns
    -------
    bool
        True with probability `prob`.
    """
    if prob < 0.0 or prob > 1.0:
        raise ValueError(f'chance() expected a value in the range [0, 1], but got {prob}')
    return frandom.rand() < prob

compose_affine

compose_affine(ndim: int, translation: Tensor = None, rotation: Tensor = None, scale: Tensor = None, shear: Tensor = None, degrees: bool = True, device: device = None) -> Tensor

Composes an affine matrix from a set of translation, rotation, scale, and shear transform components.

PARAMETER DESCRIPTION
ndim

The number of dimensions of the affine matrix. Must be 2 or 3.

TYPE: int

translation

The translation vector. Must be a vector of size ndim.

TYPE: Tensor DEFAULT: None

rotation

The rotation angles. Must be a scalar value for 2D affine matrices, and a tensor of size 3 for 3D affine matrices.

TYPE: Tensor DEFAULT: None

scale

The scaling factor. Can be scalar or vector of size ndim.

TYPE: Tensor DEFAULT: None

shear

The shearing factor. Must be a scalar value for 2D affine matrices, and a tensor of size 3 for 3D affine matrices.

TYPE: Tensor DEFAULT: None

degrees

Whether to interpret the rotation angles as degrees.

TYPE: bool DEFAULT: True

device

The device of the returned matrix.

TYPE: device DEFAULT: None

RETURNS DESCRIPTION
Tensor

The composed affine matrix, as a tensor of shape (ndim + 1, ndim + 1).

Source code in voxelmorph/nn/functional.py
def compose_affine(
    ndim: int,
    translation: Tensor = None,
    rotation: Tensor = None,
    scale: Tensor = None,
    shear: Tensor = None,
    degrees: bool = True,
    device: torch.device = None
) -> Tensor:
    """
    Composes an affine matrix from a set of translation, rotation, scale,
    and shear transform components.

    Parameters
    ----------
    ndim (int):
        The number of dimensions of the affine matrix. Must be 2 or 3.
    translation : Tensor, optional
        The translation vector. Must be a vector of size `ndim`. 
    rotation : Tensor, optional
        The rotation angles. Must be a scalar value for 2D affine matrices,
        and a tensor of size 3 for 3D affine matrices.
    scale : Tensor, optional
        The scaling factor. Can be scalar or vector of size `ndim`.
    shear : Tensor, optional
        The shearing factor. Must be a scalar value for 2D affine matrices,
        and a tensor of size 3 for 3D affine matrices.
    degrees : bool, optional
        Whether to interpret the rotation angles as degrees.
    device : torch.device, optional
        The device of the returned matrix.

    Returns
    -------
    Tensor
        The composed affine matrix, as a tensor of shape `(ndim + 1, ndim + 1)`.
    """
    if ndim not in (2, 3):
        raise ValueError(f'affine transform must be 2D or 3D, got ndim {ndim}')

    # check translation
    translation = torch.zeros(ndim) if translation is None else torch.as_tensor(translation)
    if len(translation) != ndim:
        raise ValueError(f'translation must be of shape ({ndim},)')

    # check rotation angles
    expected = 3 if ndim == 3 else 1
    rotation = torch.zeros(expected) if rotation is None else torch.as_tensor(rotation)
    if rotation.ndim == 0 and ndim == 3 or rotation.ndim != 0 and rotation.shape[0] != expected:
        raise ValueError(f'rotation must be of shape ({expected},)')

    # check scaling factor
    scale = torch.ones(ndim) if scale is None else torch.as_tensor(scale)
    if scale.ndim == 0:
        scale = scale.repeat(ndim)
    if scale.shape[0] != ndim:
        raise ValueError(f'scale must be of size {ndim}')

    # check shearing
    expected = 3 if ndim == 3 else 1
    shear = torch.zeros(expected) if shear is None else torch.as_tensor(shear)
    if shear.ndim == 0:
        shear = shear.view(1)
    if shear.shape[0] != expected:
        raise ValueError(f'shear must be of shape ({expected},)')

    # start from translation
    T = torch.eye(ndim + 1, dtype=torch.float64)
    T[:ndim, -1] = translation

    # rotation matrix
    R = torch.eye(ndim + 1, dtype=torch.float64)
    R[:ndim, :ndim] = angles_to_rotation_matrix(rotation, degrees=degrees)

    # scaling
    Z = torch.diag(torch.cat([scale, torch.ones(1, dtype=torch.float64)]))

    # shear matrix
    S = torch.eye(ndim + 1, dtype=torch.float64)
    S[0][1] = shear[0]
    if ndim == 3:
        S[0][2] = shear[1]
        S[1][2] = shear[2]

    # compose component matrices
    matrix = T @ R @ Z @ S

    return torch.as_tensor(matrix, dtype=torch.float32, device=device)

coords_to_disp

coords_to_disp(coords, meshgrid=None) -> Tensor

TODOC

Source code in voxelmorph/nn/functional.py
def coords_to_disp(coords, meshgrid=None) -> Tensor:
    """
    TODOC
    """
    if meshgrid is None:
        meshgrid = grid_coordinates(coords.shape[:-1], device=coords.device)

    raise NotImplementedError(
        'coords_to_disp is not yet implemented. '
        'contact andrew if you get this... or implement it :)'
    )

disp_to_coords

disp_to_coords(disp, meshgrid=None) -> Tensor

Convert the displacement crs to absolute crs scaled to range [-1, 1].

Parameters:

disp: torch.Tensor Displacement crs field meshgrid: torch.Tensor, optional crs grid for the image shape

Returns:

torch.Tensor: The absolute crs field scaled to range [-1, 1].

Source code in voxelmorph/nn/functional.py
def disp_to_coords(disp, meshgrid=None) -> Tensor:
    """
    Convert the displacement crs to absolute crs scaled to range [-1, 1].

    Parameters:
    -----------
    disp: torch.Tensor
        Displacement crs field
    meshgrid: torch.Tensor, optional
       crs grid for the image shape

    Returns:
    --------
    torch.Tensor:
        The absolute crs field scaled to range [-1, 1].
    """
    if meshgrid is None:
        meshgrid = grid_coordinates(disp.shape[:-1], device=disp.device)

    shape = disp.shape[:-1]
    ndim = disp.shape[-1]

    # compute the absolute crs field
    # scale the field to range [-1, 1], which is expected by torch.nn.functional.grid_sample()
    coords = (meshgrid + disp)
    for d in range(ndim):
        if shape[d] == 1:
            coords[..., d] *= 0
        else:
            coords[..., d] *= 2 / (shape[d] - 1)
            coords[..., d] -= 1

    coords = coords.flip(-1)

    return coords

gaussian_blur

gaussian_blur(image: Tensor, sigma: List[float], batched: bool = False, truncate: int = 3) -> Tensor

Apply Gaussian blurring to an image.

PARAMETER DESCRIPTION
image

An input tensor of shape (C, W, H[, D]) to blur. A batch dimension can be included by setting batched to True.

TYPE: Tensor

sigma

Standard deviation(s) of the Gaussian filter along each dimension.

TYPE: float or List[float]

batched

Whether the input tensor includes a batch dimension.

TYPE: bool DEFAULT: False

truncate

The number of standard deviations to extend the kernel before truncating.

TYPE: int DEFAULT: 3

RETURNS DESCRIPTION
Tensor

The blurred tensor with the same shape as the input tensor.

Notes

The Gaussian filter is applied using convolution. The size of the filter kernel is determined by the standard deviation and the truncation factor.

Source code in voxelmorph/nn/functional.py
def gaussian_blur(
    image: Tensor,
    sigma: List[float],
    batched: bool = False,
    truncate: int = 3,
) -> Tensor:
    """
    Apply Gaussian blurring to an image.

    Parameters
    ----------
    image : Tensor
        An input tensor of shape `(C, W, H[, D])` to blur. A batch dimension
        can be included by setting `batched` to `True`.
    sigma : float or List[float]
        Standard deviation(s) of the Gaussian filter along each dimension.
    batched : bool, optional
        Whether the input tensor includes a batch dimension.
    truncate : int, optional
        The number of standard deviations to extend the kernel before truncating.

    Returns
    -------
    Tensor
        The blurred tensor with the same shape as the input tensor.

    Notes
    -----
    The Gaussian filter is applied using convolution. The size of the filter kernel is
    determined by the standard deviation and the truncation factor.
    """
    ndim = image.ndim - (2 if batched else 1)

    # sanity check for common mistake
    if ndim == 4 and not batched:
        raise ValueError(
            f'gaussian blur input has {image.ndim} dims, but batched option is False'
        )

    # normalize sigmas
    if torch.as_tensor(sigma).ndim == 0:
        sigma = [sigma] * ndim
    if len(sigma) != ndim:
        raise ValueError(f'sigma must be {ndim}D, but got length {len(sigma)}')

    blurred = image if batched else image.unsqueeze(0)

    if all(s == sigma[0] for s in sigma):
        # Isotropic, can use the same vector in all directions cases. Since
        # creating the kernel is actually one of the most time intensive steps
        # this is an efficiency gain worth exploiting
        kernel_vec = gaussian_kernel_1d(
            sigma[0],
            truncate,
            device=blurred.device,
            dtype=blurred.dtype,
        )
        kernel_vecs = [kernel_vec] * ndim
    else:
        # Three different kernels, one for each direction
        kernel_vecs = [
            gaussian_kernel_1d(
                s,
                truncate,
                device=blurred.device,
                dtype=blurred.dtype,
            )
            for s in sigma
        ]

    for dim, kernel in enumerate(kernel_vecs):

        # apply the convolution
        slices = [None] * (ndim + 2)
        slices[dim + 2] = slice(None)
        kernel_dim = kernel[slices]
        conv = getattr(torch.nn.functional, f'conv{ndim}d')
        blurred = conv(blurred, kernel_dim, groups=image.shape[0], padding="same")

    if not batched:
        blurred = blurred.squeeze(0)

    return blurred

gaussian_kernel_1d

gaussian_kernel_1d(sigma, truncate: int = 3, device=None, dtype=None)

Generate a 1D Gaussian kernel with the specified standard deviations.

PARAMETER DESCRIPTION
sigma

A list of standard deviations for each dimension.

TYPE: float

truncate

The number of standard deviations to extend the kernel before truncating.

TYPE: int DEFAULT: 3

device

The device on which to create the kernel.

TYPE: device DEFAULT: None

dtype

Data type of the returned kernel.

TYPE: dtype | None DEFAULT: None

RETURNS DESCRIPTION
Tensor

A kernel of shape 2 * truncate * sigma + 1.

Notes

The kernel is truncated when its values drop below 1e-5 of the maximum value.

Source code in voxelmorph/nn/functional.py
def gaussian_kernel_1d(sigma, truncate: int = 3, device=None, dtype=None):
    """
    Generate a 1D Gaussian kernel with the specified standard deviations.

    Parameters
    ----------
    sigma : float
        A list of standard deviations for each dimension.
    truncate : int, optional
        The number of standard deviations to extend the kernel before truncating.
    device : torch.device, optional
        The device on which to create the kernel.
    dtype : torch.dtype | None, optional
        Data type of the returned kernel.

    Returns
    -------
    Tensor
        A kernel of shape `2 * truncate * sigma + 1`.

    Notes
    -----
    The kernel is truncated when its values drop below `1e-5` of the maximum value.
    """
    r = int(truncate * sigma + 0.5)
    x = torch.arange(-r, r + 1, device=device, dtype=dtype)
    sigma2 = 1 / torch.clip(torch.as_tensor(sigma), min=1e-5).pow(2)
    pdf = torch.exp(-0.5 * (x.pow(2) * sigma2))
    return pdf / pdf.sum()

grid_coordinates

grid_coordinates(shape: Sequence[int], indexing: Optional[Literal['ij', 'xy']] = 'ij', dtype: Optional[dtype] = torch.float32, device: Optional[device] = None) -> Tensor

Generates a grid of coordinates with the specified spatial shape.

PARAMETER DESCRIPTION
shape

The spatial shape of the grid to generate.

TYPE: tuple of int

indexing

The indexing convention to use. 'ij' for matrix indexing, 'xy' for Cartesian indexing. Default is 'ij'.

TYPE: (ij, xy) DEFAULT: 'ij'

dtype

The desired data type of the output tensor. Default is torch.float32.

TYPE: dtype DEFAULT: float32

device

The device on which to create the tensor. Default is None, which uses the current device.

TYPE: device DEFAULT: None

RETURNS DESCRIPTION
Tensor

A tensor of shape (*shape, len(shape)) containing the grid coordinates.

Examples:

>>> grid_coordinates((2, 3))
tensor([[[0., 0.],
         [0., 1.],
         [0., 2.]],
        [[1., 0.],
         [1., 1.],
         [1., 2.]]])
>>> grid_coordinates((1, 2, 2), device=torch.device('cuda:0'))
tensor([[[[0., 0., 0.],
           [0., 0., 1.]],
          [[0., 1., 0.],
           [0., 1., 1.]]]], device='cuda:0')
Source code in voxelmorph/nn/functional.py
def grid_coordinates(
    shape: Sequence[int],
    indexing: Optional[Literal['ij', 'xy']] = 'ij',
    dtype: Optional[torch.dtype] = torch.float32,
    device: Optional[torch.device] = None
) -> Tensor:
    """
    Generates a grid of coordinates with the specified spatial shape.

    Parameters
    ----------
    shape : tuple of int
        The spatial shape of the grid to generate.
    indexing : {'ij', 'xy'}, optional
        The indexing convention to use. 'ij' for matrix indexing, 'xy' for Cartesian
        indexing. Default is 'ij'.
    dtype : torch.dtype, optional
        The desired data type of the output tensor. Default is torch.float32.
    device : torch.device, optional
        The device on which to create the tensor. Default is None, which uses the
        current device.

    Returns
    -------
    torch.Tensor
        A tensor of shape (*shape, len(shape)) containing the grid coordinates.

    Examples
    --------
    >>> grid_coordinates((2, 3))
    tensor([[[0., 0.],
             [0., 1.],
             [0., 2.]],
            [[1., 0.],
             [1., 1.],
             [1., 2.]]])
    >>> grid_coordinates((1, 2, 2), device=torch.device('cuda:0'))
    tensor([[[[0., 0., 0.],
               [0., 0., 1.]],
              [[0., 1., 0.],
               [0., 1., 1.]]]], device='cuda:0')
    """
    ranges = [torch.arange(s, dtype=dtype, device=device) for s in shape]
    meshgrid = torch.stack(torch.meshgrid(*ranges, indexing=indexing), dim=-1)
    return meshgrid

integrate_disp

integrate_disp(disp: Tensor, steps: int, meshgrid: Tensor = None) -> Tensor

TODOC

Source code in voxelmorph/nn/functional.py
def integrate_disp(
    disp: Tensor,
    steps: int,
    meshgrid: Tensor = None
) -> Tensor:
    """
    TODOC
    """
    if meshgrid is None:
        # generate a crs grid
        meshgrid = grid_coordinates(disp.shape[:-1], device=disp.device)

    if steps == 0:
        return disp

    disp = disp / (2 ** steps)
    for _ in range(steps):
        disp += spatial_transform(disp.movedim(-1, 0), disp, meshgrid=meshgrid).movedim(0, -1)

    return disp

perlin

perlin(shape, smoothing: Union[float, List[float]] = None, magnitude: Union[float, List[float]] = 1.0, weights=None, device=None, method='blur')

Generates a perlin noise image.

PARAMETER DESCRIPTION
shape

The desired shape of the output tensor. Can be 2D or 3D.

TYPE: List[int]

smoothing

The spatial smoothing sigma(s) in voxel coordinates.

TYPE: float or List[float] DEFAULT: None

magnitude

The standard deviation of the noise.

TYPE: float DEFAULT: 1.0

weights

The weights of the smoothing components (scales). If None, defaults to monotonically increasing weights.

TYPE: float or List[float] DEFAULT: None

device

The device on which the output tensor is allocated. If None, defaults to CPU.

TYPE: device or None DEFAULT: None

method

Method for noise generation. Upsampling is much faster and more memory efficient for larger sigma values, but at the cost of quality.

TYPE: blur or upsample DEFAULT: 'blur'

RETURNS DESCRIPTION
Tensor

A Perlin noise image of shape shape.

Source code in voxelmorph/nn/functional.py
def perlin(
    shape,
    smoothing: Union[float, List[float]] = None,
    magnitude: Union[float, List[float]] = 1.0,
    weights=None,
    device=None,
    method='blur'
):
    """
    Generates a perlin noise image.

    Parameters
    ----------
    shape : List[int]
        The desired shape of the output tensor. Can be 2D or 3D.
    smoothing : float or List[float]
        The spatial smoothing sigma(s) in voxel coordinates.
    magnitude : float
        The standard deviation of the noise.
    weights : float or List[float]
        The weights of the smoothing components (scales). If None, defaults
        to monotonically increasing weights.
    device : torch.device or None, optional
        The device on which the output tensor is allocated. If None, defaults to CPU.
    method : 'blur' or 'upsample'
        Method for noise generation. Upsampling is much faster and more memory efficient
        for larger sigma values, but at the cost of quality.

    Returns
    -------
    Tensor
        A Perlin noise image of shape `shape`.
    """
    if smoothing is None:
        smoothing = 2 ** np.arange(np.log2(max(shape)))[1:]

    elif np.isscalar(smoothing):
        return smooth_gaussian(
            shape, smoothing, magnitude, device=device, method=method
        )

    if len(smoothing) == 1:
        weights = [None]

    elif weights is None:
        weights = np.arange(len(smoothing)) + 1

    noise = None
    for s, w in zip(smoothing, weights):

        # generate smooth field
        sample = smooth_gaussian(shape, s, device=device, method=method)
        if w is not None:
            sample *= w

        # merge the noise at this scale with the rest
        if noise is None:
            noise = sample

        else:
            noise += sample

    # in-place normalize
    noise -= noise.mean()
    noise *= magnitude / noise.std()
    return noise

random_affine

random_affine(ndim: int, max_translation: float = 0, max_rotation: float = 0, max_scaling: float = 1, device: device = None, sampling: bool = True) -> Tensor
PARAMETER DESCRIPTION
ndim

Dimensionality of target transform.

TYPE: int

max_translation

Range to sample translation parameters from. Scalar values define the max deviation from 0.0 (-max_translation, max_translation).

TYPE: float DEFAULT: 0

max_rotation

Range to sample rotation parameters from. Scalar values define the max deviation from 0.0 (-max_rotation, max_rotation).

TYPE: float DEFAULT: 0

max_scaling

Max to sample scale parameters from. It is converted into a 2-element array defines the (min, max) deviation from 1.0.

TYPE: float DEFAULT: 1

RETURNS DESCRIPTION
Tensor

vox2vox affine matrix rotating around the image center

Source code in voxelmorph/nn/functional.py
def random_affine(
    ndim: int,
    max_translation: float = 0,
    max_rotation: float = 0,
    max_scaling: float = 1,
    device: torch.device = None,
    sampling: bool = True
) -> Tensor:
    """
    Parameters
    ----------
    ndim : int
        Dimensionality of target transform.
    max_translation : float
        Range to sample translation parameters from. Scalar values define the max
        deviation from 0.0 (-max_translation, max_translation).
    max_rotation : float
        Range to sample rotation parameters from. Scalar values define the max
        deviation from 0.0 (-max_rotation, max_rotation).
    max_scaling : float
        Max to sample scale parameters from.
        It is converted into a 2-element array defines the (min, max) deviation from 1.0.

    Returns
    -------
    Tensor
        vox2vox affine matrix rotating around the image center
    """

    #
    if (sampling):
        translation_range = sorted([-max_translation, max_translation])
        translation = np.random.uniform(*translation_range, size=ndim)
    else:
        translation = np.array([max_translation] * ndim)

    #
    if (sampling):
        rotation_range = sorted([-max_rotation, max_rotation])
        rotation = np.random.uniform(*rotation_range, size=(1 if ndim == 2 else 3))
    else:
        rotation = np.array([max_rotation] * (1 if ndim == 2 else 3))

    #
    if (sampling):
        if max_scaling < 1:
            raise ValueError('max scaling to random affine cannot be less than 1, '
                             'see function doc for more info')
        inv = np.random.choice([-1, 1], size=ndim)
        scale = np.random.uniform(1, max_scaling, size=ndim) ** inv
    else:
        scale = np.array(max_scaling * ndim)

    # compose from random paramters
    aff = compose_affine(
        ndim=ndim,
        translation=translation,
        rotation=rotation,
        scale=scale,
        device=device)
    return aff

random_disp

random_disp(shape: List[int], smoothing: Union[float, List[float]] = 10, magnitude: Union[float, List[float]] = 10, integrations: int = 0, voxsize: float = 1, meshgrid: Tensor = None, device: device = None, perlin_method: str = 'upsample') -> Tensor

TODOC

Source code in voxelmorph/nn/functional.py
def random_disp(
    shape: List[int],
    smoothing: Union[float, List[float]] = 10,
    magnitude: Union[float, List[float]] = 10,
    integrations: int = 0,
    voxsize: float = 1,
    meshgrid: Tensor = None,
    device: torch.device = None,
    perlin_method: str = 'upsample'
) -> Tensor:
    """
    TODOC
    """

    # Perlin can take a list so
    smoothing = smoothing / voxsize
    magnitude = magnitude / voxsize

    # randomly sample a displacement crs field of the input shape
    ndim = len(shape)
    disp = [
        perlin(
            shape, smoothing, magnitude, method=perlin_method, device=device
        ) for i in range(ndim)
    ]
    disp = torch.stack(disp, dim=-1)

    if integrations > 0:
        disp = integrate_disp(disp, integrations, meshgrid)

    return disp

random_transform

random_transform(shape: List[int], affine_probability: float = 1.0, max_translation: float = 5.0, max_rotation: float = 5.0, max_scaling: float = 1.1, warp_probability: float = 1.0, warp_integrations: int = 5, warp_smoothing_range: List[int] = [10, 20], warp_magnitude_range: List[int] = [1, 2], voxsize: int = 1, device: device = None, isdisp: bool = True, perlin_method: str = 'upsample', sampling: bool = True) -> Tensor

generate a randomly sampled transform

Parameters:

disp: torch.Tensor Displacement crs field meshgrid: torch.Tensor, optional crs grid for the image shape

Returns:

torch.Tensor: displacement crs field, or absolute crs field scaled to range [-1, 1] if isdisp is False

Source code in voxelmorph/nn/functional.py
def random_transform(
    shape: List[int],
    affine_probability: float = 1.0,
    max_translation: float = 5.0,
    max_rotation: float = 5.0,
    max_scaling: float = 1.1,
    warp_probability: float = 1.0,
    warp_integrations: int = 5,
    warp_smoothing_range: List[int] = [10, 20],
    warp_magnitude_range: List[int] = [1, 2],
    voxsize: int = 1,
    device: torch.device = None,
    isdisp: bool = True,
    perlin_method: str = 'upsample',
    sampling: bool = True,
) -> Tensor:
    """
    generate a randomly sampled transform

    Parameters:
    -----------
    disp: torch.Tensor
        Displacement crs field
    meshgrid: torch.Tensor, optional
       crs grid for the image shape

    Returns:
    --------
    torch.Tensor:
       displacement crs field, or
       absolute crs field scaled to range [-1, 1] if isdisp is False
    """
    ndim = len(shape)
    trf = None

    # generate a random affine
    if chance(affine_probability):

        # compute meshgrid, it is the target crs
        meshgrid = grid_coordinates(shape, device=device)

        # convert max_translation from mm to voxel
        # the matrix returned from random_affine() is vox2vox rotating around the image center.
        # it is used as target to source transformation in affine_to_disp() to covert
        # the vox2vox matrix to dispacement field.
        max_translation = max_translation / voxsize
        matrix = random_affine(
            ndim=ndim,
            max_translation=max_translation,
            max_rotation=max_rotation,
            max_scaling=max_scaling,
            device=device,
            sampling=sampling)
        trf = affine_to_disp(matrix, meshgrid)

    # generate a nonlinear transform
    if chance(warp_probability):
        disp = random_disp(
            shape=shape,
            smoothing=np.random.uniform(*warp_smoothing_range),
            magnitude=np.random.uniform(*warp_magnitude_range),
            integrations=warp_integrations,
            voxsize=voxsize,
            device=device,
            perlin_method=perlin_method)

        # merge with the affine transform if necessary
        if trf is None:
            trf = disp
        else:
            trf += spatial_transform(disp.movedim(-1, 0), trf, meshgrid=meshgrid).movedim(0, -1)

    # convert to coordinates if specified
    if trf is not None and not isdisp:
        # compute the absolute crs field scaled to range [-1, 1]
        trf = disp_to_coords(trf)

    return trf

resize

resize(image: Tensor, scale_factor: List[float] = None, shape: List[int] = None, nearest: bool = False) -> Tensor

Resize an image with the option of scaling and/or setting to a new shape.

Parameters:

image: torch.Tensor An input tensor with shape (C, H, W[, D]) to resize. scale_factor: float or List[float], optional Multiplicative factor(s) for scaling the input tensor. If a float, then the same scale factor is applied to all spatial dimensions. If a tuple, then the scaling factor for each dimension should be provided. shape: List[int], optional Target shape of the output tensor. nearest: bool, optional If True, use nearest neighbor interpolation. Otherwise, use linear interpolation.

Returns:

torch.Tensor: The resized tensor with the shape specified by shape or scaled by scale_factor.

Source code in voxelmorph/nn/functional.py
def resize(
    image: Tensor,
    scale_factor: List[float] = None,
    shape: List[int] = None,
    nearest: bool = False
) -> Tensor:
    """
    Resize an image with the option of scaling and/or setting to a new shape.

    Parameters:
    -----------
    image: torch.Tensor
        An input tensor with shape (C, H, W[, D]) to resize.
    scale_factor: float or List[float], optional
        Multiplicative factor(s) for scaling the input tensor. If a float, then the same
        scale factor is applied to all spatial dimensions. If a tuple, then the scaling
        factor for each dimension should be provided.
    shape: List[int], optional
        Target shape of the output tensor.
    nearest: bool, optional
        If True, use nearest neighbor interpolation. Otherwise, use linear interpolation.

    Returns:
    --------
    torch.Tensor:
        The resized tensor with the shape specified by `shape` or scaled by `scale_factor`.
    """
    ndim = image.ndim - 1

    # scale the image if the scale factor is provided
    if scale_factor is not None and scale_factor != 1:

        # compute target shape based on the scale factor
        target_shape = [int(s * scale_factor + 0.5) for s in image.shape[1:]]

        # convert image to float32 if it's not already to enable interpolation
        # if using nearest interpolation, save the original dtype to convert back later
        reset_type = None
        if not torch.is_floating_point(image):
            if nearest:
                reset_type = image.dtype
            image = image.type(torch.float32)

        # determine interpolation mode based on ndim and interpolation type
        linear = 'trilinear' if image.ndim - 1 == 3 else 'bilinear'
        mode = 'nearest' if nearest else linear

        # apply interpolation to the image
        if nearest:
            image = torch.nn.functional.interpolate(image.unsqueeze(0), target_shape, mode=mode)
        else:
            image = torch.nn.functional.interpolate(image.unsqueeze(0), target_shape, mode=mode)
        image = image.squeeze(0)

        # convert image back to its original dtype if necessary
        if reset_type is not None:
            image = image.type(reset_type)

    if shape is not None:

        # compute padding for each spatial dimension
        padding = []
        baseshape = image.shape[1:]
        for d in range(ndim):
            diff = shape[d] - baseshape[d]
            if diff > 0:
                half = diff / 2
                a, b = int(np.floor(half)), int(np.ceil(half))
                padding.extend([a, b])
            else:
                padding.extend([0, 0])

        # apply padding to the image
        padding.reverse()
        image = torch.nn.functional.pad(image, padding)

        # compute slice to remove excess dimensions
        slicing = [slice(0, image.shape[0])]
        baseshape = image.shape[1:]
        for d in range(ndim):
            diff = baseshape[d] - shape[d]
            if diff > 0:
                half = diff / 2
                a, b = int(np.floor(half)), int(np.ceil(half))
                slicing.append(slice(a, baseshape[d] - b))
            else:
                slicing.append(slice(0, baseshape[d]))

        # apply slice to remove excess dimensions
        image = image[tuple(slicing)]

    return image

smooth_gaussian

smooth_gaussian(shape, sigma, magnitude=1.0, device=None, method='blur')

Generates a smooth Gaussian noise image.

PARAMETER DESCRIPTION
shape

The desired shape of the output tensor. Can be 2D or 3D.

TYPE: List[int]

sigma

The spatial smoothing sigma in voxel coordinates.

TYPE: float

magnitude

The standard deviation of the noise.

TYPE: float DEFAULT: 1.0

device

The device on which the output tensor is allocated. If None, defaults to CPU.

TYPE: device or None DEFAULT: None

method

Method for noise generation. Upsampling is much faster and more memory efficient for larger sigma values, but at the cost of quality.

TYPE: blur or upsample DEFAULT: 'blur'

RETURNS DESCRIPTION
Tensor

A smooth Gaussian noise image of shape shape.

Source code in voxelmorph/nn/functional.py
def smooth_gaussian(shape, sigma, magnitude=1.0, device=None, method='blur'):
    """
    Generates a smooth Gaussian noise image.

    Parameters
    ----------
    shape : List[int]
        The desired shape of the output tensor. Can be 2D or 3D.
    sigma : float
        The spatial smoothing sigma in voxel coordinates.
    magnitude : float
        The standard deviation of the noise.
    device : torch.device or None, optional
        The device on which the output tensor is allocated. If None, defaults to CPU.
    method : 'blur' or 'upsample'
        Method for noise generation. Upsampling is much faster and more memory efficient
        for larger sigma values, but at the cost of quality.

    Returns
    -------
    Tensor
        A smooth Gaussian noise image of shape `shape`.
    """
    if method == 'blur':
        noise = torch.normal(0, 1, size=shape, device=device)
        noise = gaussian_blur(noise.unsqueeze(0), sigma).squeeze(0)
    elif method == 'upsample':
        downshape = tuple([max(int(s // sigma), 2) for s in shape])
        noise = torch.normal(0, 1, size=(1, 1, *downshape), device=device)
        mode = 'trilinear' if len(shape) == 3 else 'bilinear'
        noise = torch.nn.functional.interpolate(noise, shape, mode=mode).view(shape)
    else:
        raise ValueError(f'unknown smooth gaussian method `{method}`')

    # in-place normalize
    noise -= noise.mean()
    noise *= magnitude / noise.std()
    return noise

spatial_transform

spatial_transform(image: Tensor, trf: Tensor, method: str = 'linear', isdisp: bool = True, meshgrid: Tensor = None, rotate_around_center: bool = True) -> Tensor

TODOC

Source code in voxelmorph/nn/functional.py
def spatial_transform(
    image: Tensor,
    trf: Tensor,
    method: str = 'linear',
    isdisp: bool = True,
    meshgrid: Tensor = None,
    rotate_around_center: bool = True
) -> Tensor:
    """
    TODOC
    """
    if trf is None:
        return image

    if trf.ndim == 2:
        if meshgrid is None:
            meshgrid = grid_coordinates(image.shape[1:], device=image.device)

        trf = torch.linalg.inv(trf)
        trf = affine_to_disp(
            trf,
            meshgrid,
            rotate_around_center=rotate_around_center
        )
        isdisp = True

    if isdisp:
        # convert the displacement crs to absolute crs scaled to range [-1, 1]
        trf = disp_to_coords(trf, meshgrid=meshgrid)

    method = 'bilinear' if method == 'linear' else method

    reset_type = None
    if not torch.is_floating_point(image):
        if method == 'nearest':
            reset_type = image.dtype
        image = image.type(torch.float32)

    image = image.unsqueeze(0)
    trf = trf.unsqueeze(0)

    # trf is an absolute crs field in the range of [-1, 1]
    interped = torch.nn.functional.grid_sample(image, trf, align_corners=True, mode=method)
    interped = interped.squeeze(0)

    if reset_type is not None:
        interped = interped.type(reset_type)

    return interped