"""Classes and functions to perform serial non-rigid registration of a set of images
"""
import numpy as np
from skimage import io
from tqdm import tqdm
import os
from time import time
import pathlib
import pandas as pd
import pickle
import pyvips
import inspect
from . import warp_tools
from . import non_rigid_registrars
from . import valtils
from . import serial_rigid
from . import viz
from . import preprocessing
from . import slide_tools
IMG_LIST_KEY = "img_list"
IMG_F_LIST_KEY = "img_f_list"
IMG_NAME_KEY = "name_list"
MASK_LIST_KEY = "mask_list"
def get_matching_xy_from_rigid_registrar(rigid_registrar, ref_img_name=None):
"""Get matching keypoints to use in serial non-rigid registration
Parameters
----------
rigid_registrar : SerialRigidRegistrar
SerialRigidRegistrar that has aligned a series of images
ref_img_name : str, optional
Name of image that will be treated as the center of the stack.
If None, the middle image will be used as the center
Returns
-------
from_to_kp_dict : dict of list
Key = image name, value = list of matched and aligned keypoints between
each registered moving image and the registered fixed image.
Each element in the list contains 2 arrays:
#. Rigid registered xy in moving/current/from image
#. Rigid registered xy in fixed/next/to image
"""
img_f_list = [img_obj.full_img_f for img_obj in rigid_registrar.img_obj_list]
ref_img_idx = warp_tools.get_ref_img_idx(img_f_list, ref_img_name)
n_imgs = len(img_f_list)
from_to_indices = warp_tools.get_alignment_indices(n_imgs, ref_img_idx)
from_to_kp_dict = {}
for idx in from_to_indices:
moving_obj = rigid_registrar.img_obj_list[idx[0]]
fixed_obj = rigid_registrar.img_obj_list[idx[1]]
current_match_dict = moving_obj.match_dict[fixed_obj]
moving_kp = current_match_dict.matched_kp1_xy
fixed_kp = current_match_dict.matched_kp2_xy
assert moving_kp.shape[0] == fixed_kp.shape[0]
registered_moving = warp_tools.warp_xy(moving_kp, M=moving_obj.M)
registered_fixed = warp_tools.warp_xy(fixed_kp, M=fixed_obj.M)
from_to_kp_dict[moving_obj.name] = [registered_moving, registered_fixed]
return from_to_kp_dict
def get_imgs_from_dir(src_dir):
"""Get images from source directory.
Parameters
----------
src_dir : str
Location of images to be registered.
Returns
-------
img_list : list of ndarray
List of images to be registered
img_f_list : list of str
List of image file names
img_names : list of str
List of names for each image. Created by removing the extension
mask_list : list of ndarray
List of masks used for registration
"""
img_f_list = [f for f in os.listdir(src_dir) if
slide_tools.get_img_type(os.path.join(src_dir, f)) is not None]
valtils.sort_nicely(img_f_list)
img_list = [io.imread(os.path.join(src_dir, f)) for f in img_f_list]
img_names = [valtils.get_name(f) for f in img_f_list]
mask_list = [None] * len(img_f_list)
return img_list, img_f_list, img_names, mask_list
def get_imgs_rigid_reg(serial_rigid_reg):
"""Get images from SerialRigidRegistrar
Parameters
----------
serial_rigid_reg : SerialRigidRegistrar
SerialRigidRegistrar that has rigidly aligned images
Returns
-------
img_list : list of ndarray
List of images to be registered
img_f_list : list of str
List of image file names
img_names : list of str
List of names for each image. Created by removing the extension
mask_list : list of ndarray
List of masks used for registration
"""
img_list = [None] * serial_rigid_reg.size
img_names = [None] * serial_rigid_reg.size
img_f_list = [None] * serial_rigid_reg.size
mask_list = [None] * serial_rigid_reg.size
for i, img_obj in enumerate(serial_rigid_reg.img_obj_list):
img_list[i] = img_obj.registered_img
img_names[i] = img_obj.name
img_f_list[i] = img_obj.full_img_f
# Moving mask
temp_mask = np.full_like(img_obj.image, 255)
img_mask = warp_tools.warp_img(temp_mask, M=img_obj.M,
out_shape_rc=img_obj.registered_img.shape,
interp_method="nearest")
mask_list[i] = img_mask
return img_list, img_f_list, img_names, mask_list
def get_imgs_from_dict(img_dict):
"""Get images from source directory.
Parameters
----------
img_dict : dictionary
Dictionary containing the following key : value pairs
"img_list" : list of images to register
"img_f_list" : list of filenames of each image
"name_list" : list of image names. If not provided, will come from file names
"mask_list" list of masks for each image
All of the above are optional, except `img_list`.
Returns
-------
img_list : list of ndarray
List of images to be registered
img_f_list : list of str
List of image file names
img_names : list of str
List of names for each image. Created by removing the extension
mask_list : list of ndarray
List of masks used for registration
"""
img_list = img_dict[IMG_LIST_KEY]
names_provided = IMG_NAME_KEY in img_dict.keys()
files_provided = IMG_F_LIST_KEY in img_dict.keys()
masks_provided = MASK_LIST_KEY in img_dict.keys()
n_imgs = len(img_list)
if files_provided:
img_f_list = img_dict[IMG_F_LIST_KEY]
else:
img_f_list = [None] * n_imgs
if names_provided:
img_names = img_dict[IMG_NAME_KEY]
else:
if files_provided:
img_names = [valtils.get_name(f) for f in img_f_list]
else:
img_names = [None] * n_imgs
if masks_provided:
mask_list = img_dict[MASK_LIST_KEY]
else:
mask_list = [None] * n_imgs
return img_list, img_f_list, img_names, mask_list
[docs]
class NonRigidZImage(object):
""" Class that store info about an image, including both
rigid and non-rigid registration parameters
Attributes
----------
image : ndarray
Original, unwarped image with shape (P, Q)
name : str
Name of image.
stack_idx : int
Position of image in the stack
moving_xy : ndarray, optional
(V, 2) array containing points in the moving image that correspond
to those in the fixed image. If these are provided, non_rigid_reg_class
should be a subclass of non_rigid_registrars.NonRigidRegistrarXY
fixed_xy : ndarray, optional
(V, 2) array containing points in the fixed image that correspond
to those in the moving image
bk_dxdy : ndarray
(2, N, M) numpy array of pixel displacements in
the x and y directions from the reference image.
dx = bk_dxdy[0], and dy=bk_dxdy[1].
Used to warp images
fwd_dxdy : ndarray
Inversion of bk_dxdy. dx = fwd_dxdy[0], and dy=fwd_dxdy[1].
Used to warp points
warped_grid : ndarray
Image showing deformation applied to a regular grid.
"""
[docs]
def __init__(self, reg_obj, image, name, stack_idx, moving_xy=None, fixed_xy=None, mask=None):
"""
Parameters
----------
image : ndarray
Original, unwarped image with shape (P, Q)
name : str
Name of image.
stack_idx : int
Position of image in the stack
moving_xy : ndarray, optional
(V, 2) array containing points in the moving image that correspond
to those in the fixed image. If these are provided, non_rigid_reg_class
should be a subclass of non_rigid_registrars.NonRigidRegistrarXY
fixed_xy : ndarray, optional
(V, 2) array containing points in the fixed image that correspond
to those in the moving image
mask : ndarray, optional
Mask covering area to be registered.
"""
self.reg_obj = reg_obj
self.image = image
self.name = name
self.stack_idx = stack_idx
self.moving_xy = moving_xy
self.fixed_xy = fixed_xy
self.registered_img = None
self.warped_grid = None
self.bk_dxdy = None
self.fwd_dxdy = None
self.is_vips = isinstance(image, pyvips.Image)
self.shape = self.get_shape(image)
mask_shape = self.get_shape(mask)
if self.is_vips and not self.check_if_vips(mask):
mask = warp_tools.numpy2vips(mask)
if np.all(mask_shape == self.shape):
mask = warp_tools.resize_img(mask, self.shape)
self.mask = mask
def get_shape(self, img):
if isinstance(img, pyvips.Image):
shape = np.array([img.height, img.width])
else:
shape = img.shape[0:2]
return shape
def check_if_vips(self, img):
return isinstance(img, pyvips.Image)
def mask_img(self, img, mask):
if isinstance(img, pyvips.Image):
if isinstance(mask, np.ndarray):
vips_mask = warp_tools.numpy2vips(mask)
else:
vips_mask = mask
masked_img = (vips_mask == 0).ifthenelse(0, img)
else:
masked_img = img.copy()
masked_img[mask == 0] = 0
return masked_img
def mask_dxdy(self, dxdy, mask):
if isinstance(dxdy, pyvips.Image):
masked_dxdy = self.mask_img(dxdy, mask)
else:
masked_dxdy = [self.mask_img(dxdy[0], mask), self.mask_img(dxdy[1], mask)]
return masked_dxdy
def split_params(self, params, non_rigid_reg_class):
if params is not None:
init_arg_list = inspect.getfullargspec(non_rigid_reg_class.__init__).args
reg_arg_list = inspect.getfullargspec(non_rigid_reg_class.register).args
init_kwargs = {k:v for k, v in params.items() if k in init_arg_list}
reg_kwargs = {k:v for k, v in params.items() if k in reg_arg_list}
else:
init_kwargs = {}
reg_kwargs = {}
return init_kwargs, reg_kwargs
def calc_deformation(self, registered_fixed_image, non_rigid_reg_class,
bk_dxdy=None, params=None, mask=None):
"""
Finds the non-rigid deformation fields that align this ("moving") image
to the "fixed" image
Parameters
----------
registered_fixed_image : ndarray
Adjacent, aligned image in the stack that this image is being
aligned to. Has shape (P, Q)
non_rigid_reg_class : NonRigidRegistrar
Uninstantiated NonRigidRegistrar class that will be used to
calculate the deformation fields between images
bk_dxdy : ndarray, optional
(2, P, Q) numpy array of pixel displacements in
the x and y directions. dx = dxdy[0], and dy=dxdy[1].
Used to warp the registered_img before finding deformation fields.
params : dictionary, optional
Keyword: value dictionary of parameters to be used in reigstration.
Passed to the non_rigid_reg_class' init() method.
In the case where simple ITK will be used, params should be
a SimpleITK.ParameterMap. Note that numeric values needd to be
converted to strings.
mask : ndarray, optional
2D array with shape (P,Q) where non-zero pixel values are foreground,
and 0 is background, which is ignnored during registration. If None,
then all non-zero pixels in images will be used to create the mask.
"""
if self.reg_obj.from_rigid_reg:
rigid_img_obj = self.reg_obj.src.img_obj_dict[self.name]
M = rigid_img_obj.M
unwarped_shape = rigid_img_obj.image.shape[0:2]
og_reg_shape_rc = rigid_img_obj.registered_shape_rc
if mask is not None:
if isinstance(mask, pyvips.Image):
reg_mask = warp_tools.vips2numpy(mask)
else:
reg_mask = mask.copy()
else:
reg_mask = None
if bk_dxdy is not None:
if isinstance(bk_dxdy, list):
bk_dxdy = np.array(bk_dxdy)
if reg_mask is not None:
for_reg_dxdy = self.mask_dxdy(bk_dxdy, reg_mask)
else:
for_reg_dxdy = bk_dxdy
if self.reg_obj.from_rigid_reg:
for_reg_dxdy = warp_tools.remove_invasive_displacements(for_reg_dxdy,
M=M,
src_shape_rc=unwarped_shape,
out_shape_rc=og_reg_shape_rc
)
moving_img = warp_tools.warp_img(self.image, bk_dxdy=for_reg_dxdy)
if reg_mask is not None:
# Update mask too
reg_mask = warp_tools.warp_img(reg_mask, bk_dxdy=for_reg_dxdy)
else:
moving_img = self.image.copy()
for_reg_dxdy = None
if self.is_vips:
bk_dxdy = pyvips.Image.black(self.shape[1], self.shape[0], bands=2)
else:
bk_dxdy = np.array([np.zeros(self.shape[0:2]), np.zeros(self.shape[0:2])])
init_kwargs, reg_kwargs = self.split_params(params, non_rigid_reg_class)
non_rigid_reg = non_rigid_reg_class(params=init_kwargs)
if self.moving_xy is not None and self.fixed_xy is not None and \
issubclass(non_rigid_reg_class, non_rigid_registrars.NonRigidRegistrarXY):
if for_reg_dxdy is not None:
# Update positions #
fwd_dxdy = warp_tools.get_inverse_field(for_reg_dxdy)
fixed_xy = warp_tools.warp_xy(self.fixed_xy, M=None, fwd_dxdy=fwd_dxdy)
moving_xy = warp_tools.warp_xy(self.moving_xy, M=None, fwd_dxdy=fwd_dxdy)
else:
fixed_xy = self.fixed_xy
moving_xy = self.moving_xy
else:
fixed_xy = None
moving_xy = None
xy_args = {"moving_xy": moving_xy, "fixed_xy": fixed_xy}
reg_kwargs.update(xy_args)
warped_moving, moving_grid_img, moving_bk_dxdy = \
non_rigid_reg.register(moving_img=moving_img,
fixed_img=registered_fixed_image,
mask=reg_mask,
**reg_kwargs)
if self.reg_obj.from_rigid_reg:
moving_bk_dxdy = warp_tools.remove_invasive_displacements(moving_bk_dxdy,
M=M,
src_shape_rc=unwarped_shape,
out_shape_rc=og_reg_shape_rc
)
if not self.check_if_vips(moving_bk_dxdy):
if reg_mask is not None:
# Only add new transformations
moving_bk_dxdy = self.mask_dxdy(moving_bk_dxdy, reg_mask)
bk_dxdy_from_ref = np.array([bk_dxdy[0] + moving_bk_dxdy[0],
bk_dxdy[1] + moving_bk_dxdy[1]])
else:
if reg_mask is not None:
moving_bk_dxdy = self.mask_dxdy(moving_bk_dxdy, reg_mask)
bk_dxdy_from_ref = bk_dxdy + moving_bk_dxdy
img_bk_dxdy = bk_dxdy_from_ref.copy()
if reg_mask is not None:
img_bk_dxdy = self.mask_dxdy(img_bk_dxdy, reg_mask)
if self.reg_obj.from_rigid_reg:
img_bk_dxdy = warp_tools.remove_invasive_displacements(img_bk_dxdy,
M=M,
src_shape_rc=unwarped_shape,
out_shape_rc=og_reg_shape_rc
)
self.bk_dxdy = img_bk_dxdy
if hasattr(non_rigid_reg, "fwd_dxdy"):
# Already calculated
self.fwd_dxdy = non_rigid_reg.fwd_dxdy
else:
self.fwd_dxdy = warp_tools.get_inverse_field(self.bk_dxdy)
if not self.is_vips:
# If dxdy is a pyvips.Image, it's likely the displacement is too large to draw
self.warped_grid = viz.color_displacement_grid(*self.bk_dxdy)
self.registered_img = warp_tools.warp_img(self.image,
bk_dxdy=self.bk_dxdy,
out_shape_rc=self.shape)
return bk_dxdy_from_ref
[docs]
class SerialNonRigidRegistrar(object):
"""Class that performs serial non-rigid registration, based on results SerialRigidRegistrar
A SerialNonRigidRegistrar finds the deformation fields that will non-rigidly align
a series of images, using the rigid registration parameters found by a
SerialRigidRegistrar object. There are two types of non-rigid registration
methods:
#. Images are aligned towards a reference image, which may or may not
be at the center of the stack. In this case, the image directly "above" the
reference image is aligned to the reference image, after which the image 2 steps
above the reference image is aligned to the 1st (now aligned) image above
the reference image, and so on. The process is similar when aligning images
"below" the reference image.
#. All images are aligned simultaneously, and so a reference image is not
# required. An example is the SimpleElastix groupwise registration.
Similar to SerialRigidRegistrar, SerialNonRigidRegistrar creates a list
and dictionary of NonRigidZImage objects each of which contains information
related to the non-rigid registration, including the original rigid
transformation matrices, and the calculated deformation fields.
Attributes
----------
name : str, optional
Optional name of this SerialNonRigidRegistrar
from_rigid_reg : bool
Whether or not the images are from a SerialRigidRegistrar
ref_image_name : str
Name of mage that is being treated as the "center" of the stack.
For example, this may be associated with an H+E image that is
the 2nd image in a stack of 7 images.
size : int
Number of images to align
shape : tuple of int
Shape of each image to register. Must be the same for all images
non_rigid_obj_dict : dict
Dictionary, where each key is the name of a NonRigidZImage, and
the value is the assocatiated NonRigidZImage
non_rigid_reg_params: dictionary
Dictionary containing parameters {name: value} to be used to initialize
the NonRigidRegistrar.
In the case where simple ITK is used by the, params should be
a SimpleITK.ParameterMap. Note that numeric values nedd to be
converted to strings.
mask : ndarray
Mask used in non-rigid alignments, with shape (P, Q).
mask_bbox_xywh : ndarray
Bounding box of `mask` (top left x, top left y, width, height)
summary : Dataframe
Pandas dataframe containing the median distance between matched
features before and after registration.
"""
[docs]
def __init__(self, src, reference_img_f=None, moving_to_fixed_xy=None,
mask=None, name=None, align_to_reference=False, compose_transforms=True):
"""
Parameters
----------
src : SerialRigidRegistrar, str, dict
A SerialRigidRegistrar object that was used to optimally
align a series of images.
If a string, it should indicating where the images
to be aligned are located. If src is a string, the images should be
named such that they are read in the correct order, i.e. each
starting with a number.
If a dictionary, it should contain the following key, value pairs:
"img_list" : list of images to register
"img_f_list" : list of filenames of each image
"name_list" : list of image names. If not provided, will come from file names
"mask_list" list of masks for each image
reference_img_f : str, optional
Filename of image that will be treated as the center of the stack.
If None, the index of the middle image will be returned.
moving_to_fixed_xy : dict of list, or bool
If `moving_to_fixed_xy` is a dict of list, then
Key = image name, value = list of matched keypoints between
each moving image and the fixed image.
Each element in the list contains 2 arrays:
#. Rigid registered xy in moving/current/from image
#. Rigid registered xy in fixed/next/to image
To deterime which pairs of images will be aligned, use
`get_alignment_indices`. Can use `get_imgs_from_dir`
to see the order inwhich the images will be read, which will correspond
to the indices retuned by `get_alignment_indices`.
If `src` is a SerialRigidRegistrar and `moving_to_fixed_xy` is
True, then the matching features in the SerialRigidRegistrar will
be used. If False, then matching features will not be used.
mask : ndarray, bool, optional
Mask used for all non-rigid alignments.
If an ndarray, it must have the same size as the other images.
If True, then the `overlap_mask` in the SerialRigidRegistrar
will be used.
If False or None, no mask will be used.
name : optional
Optional name for this SerialNonRigidRegistrar
align_to_reference : bool, optional
Whether or not images should be aligned to a reference image
specified by `reference_img_f`.
img_params : dict, optional
Dictionary of parameters to be used for each particular image.
Useful if images to be registered haven't been processed.
Will be passed to `non_rigid_reg_class` init and register functions.
key = file name, value= dictionary of keyword arguments and values
"""
self.src = src
if isinstance(src, serial_rigid.SerialRigidRegistrar):
self.from_rigid_reg = True
elif isinstance(src, str):
self.from_rigid_reg = False
elif isinstance(src, dict):
self.from_rigid_reg = False
else:
valtils.print_warning(f"src must be either a SerialRigidRegistrar, string, or dictionary")
return None
self.name = name
self.size = 0
self.shape = None
self.non_rigid_obj_dict = {}
self.non_rigid_obj_list = None
self.non_rigid_reg_params = None
self.summary = None
self.mask = mask
self.reference_img_f = None
self.ref_img_name = None
self.ref_img_idx = None
self.compose_transforms = compose_transforms
self.align_to_reference = align_to_reference
self.generate_non_rigid_obj_list(reference_img_f, moving_to_fixed_xy)
if self.align_to_reference is False and reference_img_f is not None:
og_ref_name = valtils.get_name(reference_img_f)
msg = (f"The reference was specified as {og_ref_name} ",
f"but `align_to_reference` is `False`, and so images will be aligned serially. ",
f"If you would like all images to be directly aligned to {og_ref_name}, "
f"then set `align_to_reference` to `True`")
valtils.print_warning(msg)
def get_shape(self, img):
if isinstance(img, pyvips.Image):
shape = np.array([img.height, img.width])
else:
shape = img.shape[0:2]
return shape
def create_mask(self):
temp_mask = np.zeros(self.shape, dtype=np.uint8)
for nr_img_obj in self.non_rigid_obj_list:
temp_mask[nr_img_obj.image > 0] = 255
mask = warp_tools.bbox2mask(*warp_tools.xy2bbox(
warp_tools.mask2xy(temp_mask)),
temp_mask.shape)
return mask
def set_mask(self, mask):
"""Set mask and get its bounding box
"""
if mask is not None:
if isinstance(mask, bool) and self.from_rigid_reg:
mask = self.src.overlap_mask
mask = np.clip(mask.astype(int)*255, 0, 255).astype(np.uint8)
else:
mask = self.create_mask()
mask_bbox_xywh = warp_tools.xy2bbox(warp_tools.mask2xy(mask))
self.mask = mask
self.mask_bbox_xywh = mask_bbox_xywh
def generate_non_rigid_obj_list(self, reference_img_f=None, moving_to_fixed_xy=None):
"""Create non_rigid_obj_list
"""
if self.from_rigid_reg:
img_list, img_f_list, img_names, mask_list = \
get_imgs_rigid_reg(self.src)
else:
if isinstance(self.src, str):
img_list, img_f_list, img_names, mask_list = \
get_imgs_from_dir(self.src)
# overwrite `src` because all info now in NonRigidZImages
self.src = "dictionary"
elif isinstance(self.src, dict):
img_list, img_f_list, img_names, mask_list = \
get_imgs_from_dict(self.src)
self.size = len(img_list)
self.shape = self.get_shape(img_list[0])
if reference_img_f is not None:
reference_name = valtils.get_name(reference_img_f)
else:
reference_name = None
ref_img_idx = warp_tools.get_ref_img_idx(img_f_list, reference_name)
if reference_img_f is None:
reference_img_f = img_f_list[ref_img_idx]
self.reference_img_f = reference_img_f
self.ref_img_idx = ref_img_idx
self.ref_img_name = reference_name
if self.from_rigid_reg and isinstance(moving_to_fixed_xy, bool):
if moving_to_fixed_xy:
moving_to_fixed_xy = \
get_matching_xy_from_rigid_registrar(self.src, reference_name)
else:
moving_to_fixed_xy = None
self.non_rigid_obj_list = [None] * self.size
for i, img in enumerate(img_list):
img_shape = self.get_shape(img)
assert np.all(img_shape == self.shape), \
valtils.print_warning("Images must all have the shape")
img_name = img_names[i]
mask = mask_list[i]
moving_xy = None
fixed_xy = None
if moving_to_fixed_xy is not None and img_name != reference_img_f:
if isinstance(moving_to_fixed_xy, dict):
xy_coords = moving_to_fixed_xy[img_name]
moving_xy = xy_coords[0]
fixed_xy = xy_coords[1]
else:
msg = "moving_to_fixed_xy is not a dictionary. Will be ignored"
valtils.print_warning(msg)
nr_obj = NonRigidZImage(self, img, img_name, stack_idx=i,
moving_xy=moving_xy,
fixed_xy=fixed_xy,
mask=mask)
if i == ref_img_idx:
# Set reference image attributes #
zero_displacement = np.zeros(self.shape)
if not nr_obj.is_vips:
nr_obj.bk_dxdy = [zero_displacement, zero_displacement]
nr_obj.fwd_dxdy = [zero_displacement, zero_displacement]
nr_obj.warped_grid = viz.color_displacement_grid(*nr_obj.bk_dxdy)
else:
nr_obj.bk_dxdy = pyvips.Image.black(nr_obj.shape[1], nr_obj.shape[0], bands=2)
nr_obj.fwd_dxdy = pyvips.Image.black(nr_obj.shape[1], nr_obj.shape[0], bands=2)
nr_obj.registered_img = img.copy()
self.non_rigid_obj_list[i] = nr_obj
def update_img_params(self, non_rigid_reg_params=None, img_params=None, moving_name=None, fixed_name=None, is_tiler=False):
"""
Update img params for non-rigid-registration
"""
if img_params is not None and moving_name is not None:
if len(img_params) == 0:
indv_img_params = None
else:
indv_img_params = img_params[moving_name]
else:
indv_img_params = img_params
if is_tiler:
#Tiler needs processor arguments for moving and fixed images
assert moving_name in img_params and fixed_name in img_params, "Tiled registration requires image processors for each image"
moving_dict = img_params[moving_name]
indv_img_params[non_rigid_registrars.NR_TILE_MOVING_P_KEY] = moving_dict[non_rigid_registrars.NR_PROCESSING_CLASS_KEY]
indv_img_params[non_rigid_registrars.NR_TILE_MOVING_P_INIT_KW_KEY] = moving_dict[non_rigid_registrars.NR_PROCESSING_INIT_KW_KEY]
indv_img_params[non_rigid_registrars.NR_TILE_MOVING_P_KW_KEY] = moving_dict[non_rigid_registrars.NR_PROCESSING_KW_KEY]
fixed_dict = img_params[fixed_name]
indv_img_params[non_rigid_registrars.NR_TILE_FIXED_P_KEY] = fixed_dict[non_rigid_registrars.NR_PROCESSING_CLASS_KEY]
indv_img_params[non_rigid_registrars.NR_TILE_FIXED_P_INIT_KW_KEY] = fixed_dict[non_rigid_registrars.NR_PROCESSING_INIT_KW_KEY]
indv_img_params[non_rigid_registrars.NR_TILE_FIXED_P_KW_KEY] = fixed_dict[non_rigid_registrars.NR_PROCESSING_KW_KEY]
if non_rigid_reg_params is not None and indv_img_params is not None:
updated_params = indv_img_params.copy()
updated_params[non_rigid_registrars.NR_PARAMS_KEY] = non_rigid_reg_params
elif non_rigid_reg_params is not None and indv_img_params is None:
updated_params = non_rigid_reg_params
elif non_rigid_reg_params is None and indv_img_params is not None:
updated_params = indv_img_params
else:
updated_params = None
return updated_params
def register_serial(self, non_rigid_reg_class, non_rigid_reg_params=None, img_params=None):
"""Non-rigidly align images in serial
Parameters
----------
non_rigid_reg_class : NonRigidRegistrar
Uninstantiated NonRigidRegistrar class that will be used to
calculate the deformation fields between images
non_rigid_reg_params: dictionary, optional
Dictionary containing parameters {name: value} to be used to initialize
`non_rigid_reg_class`.
In the case where simple ITK is used by the, params should be
a SimpleITK.ParameterMap. Note that numeric values nedd to be
converted to strings.
"""
current_dxdy = None
self.non_rigid_reg_params = non_rigid_reg_params
iter_order = warp_tools.get_alignment_indices(self.size, self.ref_img_idx)
is_tiler = non_rigid_reg_class.__name__ == non_rigid_registrars.NonRigidTileRegistrar.__name__
for moving_idx, fixed_idx in tqdm(iter_order, desc="Finding non-rigid transforms", unit="image"):
moving_obj = self.non_rigid_obj_list[moving_idx]
fixed_obj = self.non_rigid_obj_list[fixed_idx]
if self.compose_transforms:
if fixed_obj.stack_idx == self.ref_img_idx:
current_dxdy = None
else:
current_dxdy = updated_dxdy
if moving_obj.mask is not None:
if self.mask is not None:
reg_mask = preprocessing.combine_masks(self.mask, moving_obj.mask, op="and")
else:
reg_mask = moving_obj.mask
elif self.mask is not None:
reg_mask = self.mask
else:
reg_mask is None
nr_reg_params = self.update_img_params(non_rigid_reg_params, img_params, moving_name=moving_obj.name, fixed_name=fixed_obj.name, is_tiler=is_tiler)
updated_dxdy = moving_obj.calc_deformation(registered_fixed_image=fixed_obj.registered_img,
non_rigid_reg_class=non_rigid_reg_class,
bk_dxdy=current_dxdy,
params=nr_reg_params,
mask=reg_mask
)
def register_to_ref(self, non_rigid_reg_class, non_rigid_reg_params=None, img_params=None):
"""Non-rigidly align images to a reference image
Parameters
----------
non_rigid_reg_class : NonRigidRegistrar
Uninstantiated NonRigidRegistrar class that will be used to
calculate the deformation fields between images
non_rigid_reg_params: dictionary, optional
Dictionary containing parameters {name: value} to be used to initialize
the NonRigidRegistrar.
In the case where simple ITK is used by the, params should be
a SimpleITK.ParameterMap. Note that numeric values nedd to be
converted to strings.
"""
self.non_rigid_reg_params = non_rigid_reg_params
ref_nr_obj = self.non_rigid_obj_list[self.ref_img_idx]
ref_img = ref_nr_obj.image
is_tiler = non_rigid_reg_class.__name__ == non_rigid_registrars.NonRigidTileRegistrar.__name__
for moving_idx in tqdm(range(self.size), desc="Finding non-rigid transforms", unit="image"):
moving_obj = self.non_rigid_obj_list[moving_idx]
if moving_obj.stack_idx == self.ref_img_idx:
continue
overlap_mask = None
nr_reg_params = self.update_img_params(non_rigid_reg_params, img_params, moving_name=moving_obj.name, fixed_name=ref_nr_obj.name, is_tiler=is_tiler)
moving_obj.calc_deformation(ref_img,
non_rigid_reg_class,
params=nr_reg_params,
mask=overlap_mask)
def register_groupwise(self, non_rigid_reg_class, non_rigid_reg_params=None):
"""Non-rigidly align images as a group
Parameters
----------
non_rigid_reg_class : NonRigidRegistrarGroupwise
Uninstantiated NonRigidRegistrar class that will be used to
calculate the deformation fields between images
non_rigid_reg_params: dictionary, optional
Dictionary containing parameters {name: value} to be used to initialize
the NonRigidRegistrar.
In the case where simple ITK is used by the, params should be
a SimpleITK.ParameterMap. Note that numeric values nedd to be
converted to strings.
"""
img_list = [nr_img_obj.image for nr_img_obj in self.non_rigid_obj_list]
non_rigid_reg = non_rigid_reg_class(params=non_rigid_reg_params)
print("\n======== Registering images (non-rigid)\n")
warped_imgs, warped_grids, backward_deformations = non_rigid_reg.register(img_list, self.mask)
for i, nr_img_obj in tqdm(enumerate(self.non_rigid_obj_list), desc="Aligning images", unit="annotation"):
nr_img_obj.registered_img = warped_imgs[i]
nr_img_obj.bk_dxdy = backward_deformations[i]
nr_img_obj.warped_grid = viz.color_displacement_grid(*nr_img_obj.bk_dxdy)
nr_img_obj.fwd_dxdy = warp_tools.get_inverse_field(nr_img_obj.bk_dxdy)
[docs]
def register(self, non_rigid_reg_class, non_rigid_reg_params, img_params=None):
"""Non-rigidly align images, either as a group or serially
Images will be registered serially if `non_rigid_reg_class` is a
subclass of NonRigidRegistrarGroupwise, then groupwise registration
will be conductedd. If `non_rigid_reg_class` is a subclass of
NonRigidRegistrar then images will be aligned serially.
Parameters
----------
non_rigid_reg_class : NonRigidRegistrar, NonRigidRegistrarGroupwise
Uninstantiated NonRigidRegistrar or NonRigidRegistrarGroupwise class
that will be used to calculate the deformation fields between images
non_rigid_reg_params: dictionary, optional
Dictionary containing parameters {name: value} to be used to initialize
the NonRigidRegistrar.
In the case where simple ITK is used by the, params should be
a SimpleITK.ParameterMap. Note that numeric values nedd to be
converted to strings.
img_params : dict, optional
Dictionary of parameters to be used for each particular image.
Useful if images to be registered haven't been processed.
Will be passed to `non_rigid_reg_class` init and register functions.
key = file name, value= dictionary of keyword arguments and values
"""
if img_params is not None:
named_img_params = {valtils.get_name(k):v for k, v in img_params.items()}
else:
named_img_params = None
if issubclass(non_rigid_reg_class, non_rigid_registrars.NonRigidRegistrarGroupwise):
self.register_groupwise(non_rigid_reg_class, non_rigid_reg_params)
elif self.align_to_reference:
self.register_to_ref(non_rigid_reg_class, non_rigid_reg_params, img_params=named_img_params)
else:
self.register_serial(non_rigid_reg_class, non_rigid_reg_params, img_params=named_img_params)
self.non_rigid_obj_dict = {img_obj.name: img_obj for img_obj
in self.non_rigid_obj_list}
[docs]
def summarize(self):
"""Summarize alignment error
Returns
-------
summary_df: Dataframe
Pandas dataframe containin the registration error of the
alignment between each image and the previous one in the stack.
"""
src_img_names = [None] * self.size
dst_img_names = [None] * self.size
shape_list = [None] * self.size
og_med_d_list = [None] * self.size
og_tre_list = [None] * self.size
med_d_list = [None] * self.size
tre_list = [None] * self.size
src_img_names[self.ref_img_idx] = self.ref_img_name
shape_list[self.ref_img_idx] = self.non_rigid_obj_list[self.ref_img_idx].image.shape
iter_order = warp_tools.get_alignment_indices(self.size, self.ref_img_idx)
print("\n======== Summarizing registration\n")
for moving_idx, fixed_idx in tqdm(iter_order):
moving_obj = self.non_rigid_obj_list[moving_idx]
fixed_obj = self.non_rigid_obj_list[fixed_idx]
src_img_names[moving_idx] = moving_obj.name
dst_img_names[moving_idx] = fixed_obj.name
shape_list[moving_idx] = moving_obj.image.shape
og_tre_list[moving_idx], og_med_d_list[moving_idx] = \
warp_tools.measure_error(moving_obj.moving_xy,
moving_obj.fixed_xy,
moving_obj.image.shape)
warped_moving_xy = warp_tools.warp_xy(moving_obj.moving_xy,
M=None,
fwd_dxdy=moving_obj.fwd_dxdy)
warped_fixed_xy = warp_tools.warp_xy(moving_obj.fixed_xy,
M=None,
fwd_dxdy=moving_obj.fwd_dxdy)
tre_list[moving_idx], med_d_list[moving_idx] = \
warp_tools.measure_error(warped_moving_xy,
warped_fixed_xy,
moving_obj.image.shape)
summary_df = pd.DataFrame({
"from": src_img_names,
"to": dst_img_names,
"original_D": og_med_d_list,
"D": med_d_list,
"original_TRE": og_tre_list,
"TRE": tre_list,
"shape": shape_list,
})
to_summarize_idx = [i for i in range(self.size) if i != self.ref_img_idx]
summary_df["series_d"] = warp_tools.calc_total_error(np.array(med_d_list)[to_summarize_idx])
summary_df["series_tre"] = warp_tools.calc_total_error(np.array(tre_list)[to_summarize_idx])
summary_df["name"] = self.name
self.summary_df = summary_df
return summary_df
[docs]
def register_images(src, non_rigid_reg_class=non_rigid_registrars.OpticalFlowWarper,
non_rigid_reg_params=None, dst_dir=None,
reference_img_f=None, moving_to_fixed_xy=None,
mask=None, name=None, align_to_reference=False,
img_params=None, compose_transforms=True, qt_emitter=None):
"""
Parameters
----------
src : SerialRigidRegistrar, str
Either a SerialRigidRegistrar object that was used to optimally
align a series of images, or a string indicating where the images
to be aligned are located. If src is a string, the images should be
named such that they are read in the correct order, i.e. each
starting with a number.
non_rigid_reg_class : NonRigidRegistrar
Uninstantiated NonRigidRegistrar class that will be used to
calculate the deformation fields between images.
By default this is an OpticalFlowWarper that uses the OpenCV
implementation of DeepFlow.
non_rigid_reg_params: dictionary, optional
Dictionary containing parameters {name: value} to be used to initialize
the NonRigidRegistrar.
In the case where simple ITK is used by the, params should be
a SimpleITK.ParameterMap. Note that numeric values nedd to be
converted to strings.
dst_dir : str, optional
Top directory where aliged images should be save. SerialNonRigidRegistrar will
be in this folder, and aligned images in the "registered_images"
sub-directory. If None, the images will not be written to file
reference_img_f : str, optional
Filename of image that will be treated as the center of the stack.
If None, the index of the middle image will be returned.
moving_to_fixed_xy : dict of list, or bool
If `moving_to_fixed_xy` is a dict of list, then
Key = image name, value = list of matched keypoints between
each moving image and the fixed image.
Each element in the list contains 2 arrays:
#. Rigid registered xy in moving/current/from image
#. Rigid registered xy in fixed/next/to image
To deterime which pairs of images will be aligned, use
`warp_tools.get_alignment_indices`. Can use `get_imgs_from_dir`
to see the order inwhich the images will be read, which will correspond
to the indices retuned by `warp_tools.get_alignment_indices`.
If `src` is a SerialRigidRegistrar and `moving_to_fixed_xy` is
True, then the matching features in the SerialRigidRegistrar will
be used. If False, then matching features will not be used.
mask : ndarray, bool, optional
Mask used in non-rigid alignments.
If an ndarray, it must have the same size as the other images.
If True, then the `overlap_mask` in the SerialRigidRegistrar
will be used.
If False or None, no mask will be used.
name : optional
Optional name for this SerialNonRigidRegistrar
align_to_reference : bool, optional
Whether or not images should be aligne to a reference image
specified by `reference_img_f`. Will be set to True if
`reference_img_f` is provided.
img_params : dict, optional
Dictionary of parameters to be used for each particular image.
Useful if images to be registered haven't been processed.
Will be passed to `non_rigid_reg_class` init and register functions.
key = file name, value= dictionary of keyword arguments and values
qt_emitter : PySide2.QtCore.Signal, optional
Used to emit signals that update the GUI's progress bars
Returns
-------
nr_reg : SerialNonRigidRegistrar
SerialNonRigidRegistrar that has registeredt the images in `src`
"""
tic = time()
nr_reg = SerialNonRigidRegistrar(src=src, reference_img_f=reference_img_f,
moving_to_fixed_xy=moving_to_fixed_xy,
mask=mask, name=name,
align_to_reference=align_to_reference,
compose_transforms=compose_transforms)
nr_reg.register(non_rigid_reg_class, non_rigid_reg_params, img_params=img_params)
if dst_dir is not None:
registered_img_dir = os.path.join(dst_dir, "non_rigid_registered_images")
registered_data_dir = os.path.join(dst_dir, "data")
registered_grids_dir = os.path.join(dst_dir, "deformation_grids")
for d in [registered_img_dir, registered_data_dir, registered_grids_dir]:
pathlib.Path(d).mkdir(exist_ok=True, parents=True)
print("\n======== Saving results\n")
if moving_to_fixed_xy is not None:
summary_df = nr_reg.summarize()
summary_file = os.path.join(registered_data_dir, name + "_results.csv")
summary_df.to_csv(summary_file, index=False)
pickle_file = os.path.join(registered_data_dir, name + "_non_rigid_registrar.pickle")
pickle.dump(nr_reg, open(pickle_file, 'wb'))
for img_obj in nr_reg.non_rigid_obj_list:
f_out = f"{img_obj.name}.png"
io.imsave(os.path.join(registered_img_dir, f_out),
img_obj.registered_img.astype(np.uint8))
colord_tri_grid = viz.color_displacement_tri_grid(img_obj.bk_dxdy[0],
img_obj.bk_dxdy[1])
io.imsave(os.path.join(registered_grids_dir, f_out), colord_tri_grid)
toc = time()
elapsed = toc - tic
time_string, time_units = valtils.get_elapsed_time_string(elapsed)
print(f"\n======== Non-rigid registration complete in {time_string} {time_units}\n")
return nr_reg