Source code for valis.serial_rigid

"""Classes and functions to perform serial rigid registration of a set of images

"""
import numpy as np
import os
import pickle
from fastcluster import linkage
from scipy.spatial.distance import squareform
from scipy.cluster.hierarchy import optimal_leaf_ordering, leaves_list
from skimage import transform, io
from skimage.transform import EuclideanTransform
import pandas as pd
import warnings
from tqdm import tqdm
import pathlib
import multiprocessing
from time import time
from pqdm.threads import pqdm

from . import valtils
from . import warp_tools
from . import slide_tools
from .feature_detectors import VggFD
from .feature_matcher import Matcher, convert_distance_to_similarity, GMS_NAME

DENOISE_MSG = "Denoising images"
FEATURE_MSG = "Detecting features"
MATCHING_MSG = "Matching images"
TRANSFORM_MSG = "Finding transforms"
OPTIMIZING_MSG = "Optimizing transforms"
FINALIZING_MSG = "Finalizing"

msg_list = [DENOISE_MSG, FEATURE_MSG, MATCHING_MSG, TRANSFORM_MSG, FINALIZING_MSG, OPTIMIZING_MSG]
DENOISE_MSG, FEATURE_MSG, MATCHING_MSG, TRANSFORM_MSG, FINALIZING_MSG, OPTIMIZING_MSG = valtils.pad_strings(msg_list)

def get_image_files(img_dir, imgs_ordered=False):
    """Get images filenames in img_dir

    If imgs_ordered is True, then this ensures the returned list is sorted
    properly. Otherwise, the list is sorted lexicographicly.

    Parameters
    ----------
    img_dir : str
        Path to directory containing the images.

    imgs_ordered: bool, optinal
        Whether or not the order of images already known. If True, the file
        names should start with ascending numbers, with the first image file
        having the smallest number, and the last image file having the largest
        number. If False (the default), the order of images will be determined
        by ordering a distance matrix.

    Returns
    -------
        If `imgs_ordered` is True, then this ensures the returned list is sorted
        properly. Otherwise, the list is sorted lexicographicly.

    """

    img_list = [f for f in os.listdir(img_dir) if
                slide_tools.get_img_type(os.path.join(img_dir, f)) is not None]

    if imgs_ordered:
        valtils.sort_nicely(img_list)
    else:
        img_list.sort()

    return img_list


def get_max_image_dimensions(img_list):
    """Find the maximum width and height of all images

    Parameters
    ----------
    img_list : list
        List of images

    Returns
    -------
    max_wh : tuple
        Maximum width and height of all images

    """

    shapes = [img.shape[0:2] for img in img_list]
    all_w, all_h = list(zip(*shapes))
    max_wh = (max(all_w), max(all_h))

    return max_wh


def order_Dmat(D):
    """ Cluster distance matrix and sort

    Leaf sorting is accomplished using optimal leaf ordering (Bar-Joseph 2001)

    Parmaters
    ---------
    D: ndarray
        (N, N) Symmetric distance matrix for N samples

    Returns
    -------
    sorted_D :ndarray
        (N, N) array Distance matrix sorted using optimal leaf ordering

    ordered_leaves : ndarray
        (1, N) array containing the leaves of dendrogram found during
        hierarchical clustering

    optimal_Z : ndarray
        ordered linkage matrix

    """

    D = D.copy()
    sq_D = squareform(D)
    Z = linkage(sq_D, 'single', preserve_input=True)

    optimal_Z = optimal_leaf_ordering(Z, sq_D)
    ordered_leaves = leaves_list(optimal_Z)

    sorted_D = D[ordered_leaves, :]
    sorted_D = sorted_D[:, ordered_leaves]

    return sorted_D, ordered_leaves, optimal_Z


[docs] class ZImage(object): """Class store info about an image, including the rigid registration parameters Attributes ---------- image : ndarray Greyscale image that will be used for feature detection. This images should be greyscale and may need to have undergone preprocessing to make them look as similar as possible. full_img_f : str full path to the image img_id : int ID of the image, based on its ordering `processed_src_dir` name : str Name of the image. Usually `img_f` but with the extension removed. desc : ndarray (N, M) array of N desciptors for each keypoint, each of which has M features kp_pos_xy : ndarray (N, 2) array of position for each keypoint match_dict : dict Dictionary of image matches. Key= img_obj this ZImage is being compared to, value= MatchInfo containing information about the comparison, such as the position of matches, features for each match, number of matches, etc... The MatchInfo objects in this dictionary contain only the info for matches that were considered "good". unfiltered_match_dict : dict Dictionary of image matches. Key= img_obj this ZImage is being compared to, value= MatchInfo containing inoformation about the comparison, such as the position of matches, features for each match, number of matches, etc... The MatchInfo objects in this dictionary contain info for all matches that were cross-checked. stack_idx : int Position of image in sorted Z-stack fixed_obj : ZImage ZImage to which this ZImage was aligned, i.e. this is the "moving" image, and `fixed_obj` is the "fixed" image. This is set during the `align_to_prev` method of the SerialRigidRegistrar. The `fixed_obj` will either be immediately above or immediately below this ZImage in the image stack. reflection_M : ndarray Transformation to reflect the image in the x and/or y axis, before padding. Will be the first transformation performed T : ndarray Transformation matrix that translates the image such that it is in a padded image that has the same shape as all other images to_prev_A : ndarray Transformation matrix that warps image to align with the previous image optimal_M : ndarray Transformation matrix found by minimizing a cost function. Used as final optional step to refine alignment crop_T : ndarray Transformation matrix used to crop image after registration M : ndarray Final transformation matrix that aligns image in the Z-stack. M_inv : ndarray Inverse of final transformation matrix that aligns image in the Z-stack. registered_img : ndarray image after being warped padded_shape_rc : tuple Shape of padded image. All other images will have this shape registered_shape_rc = tuple: Shape of aligned image. All other aligned images will have this shape """
[docs] def __init__(self, image, img_f, img_id, name): """Class that stores information about an image Parameters ---------- image : ndarray Greyscale image that will be used for feature detection. This images should be single channel uint8 images, and may need to have undergone preprocessing and/or normalization to make them look as similar as possible. img_f : str full path to `image` img_id : int ID of the image, based on its ordering in the image source directory name : str Name of the image. Usually img_f but with the extension removed. """ self.image = image self.full_img_f = img_f self.id = img_id self.name = name self.desc = None self.kp_pos_xy = None self.match_dict = {} self.unfiltered_match_dict = {} self.stack_idx = None self.fixed_obj = None self.padded_shape_rc = None self.reflection_M = np.identity(3) self.T = np.identity(3) self.to_prev_A = np.identity(3) self.optimal_M = np.identity(3) self.crop_T = np.identity(3) self.M = np.identity(3) self.M_inv = np.identity(3) self.registered_img = None self.padded_shape_rc = None self.registered_shape_rc = None
def reduce(self, prev_img_obj, next_img_obj): """Reduce amount of info stored, which can take up a lot of space. No longer need all descriptors. Only keep match info for neighgbors Parameters ---------- prev_img_obj : Zimage Zimage below this Zimage next_img_obj : Zimage Zimage above this Zimage """ self.desc = None for img_obj in self.match_dict.keys(): if prev_img_obj is not None and next_img_obj is not None: if prev_img_obj != img_obj and img_obj != next_img_obj: # In middle of stack self.match_dict[img_obj] = None elif prev_img_obj is None and img_obj != next_img_obj: # First image doesn't have a previous neighbor self.match_dict[img_obj] = None elif prev_img_obj != img_obj and next_img_obj is None: # Last image doesn't have a next neighbor self.match_dict[img_obj] = None
[docs] class SerialRigidRegistrar(object): """Class that performs serial rigid registration Registration is conducted by first detecting features in all images. Features are then matched between images, which are then used to construct a distance matrix, D. D is then sorted such that the most similar images are adjcent to one another. The rigid transformation matrics are then found to align each image with the previous image. Optionally, optimization can be performed to improve the alignments, although the "optimized" matrix will be discarded if it increases the distances between matched features. SerialRigidRegistrar creates a list and dictionary of ZImage objects, each of which contains information related to feature matching and the rigid registration matrices. Attributes ---------- img_dir : str Path to directory containing the images that will be registered. The images in this folder should be single channel uint8 images. For the best registration results, they have undergone some sort of pre-processing and normalization. The preprocessing module contains methods for this, but the user may want/need to use other methods. aleady_sorted: bool, optional Whether or not the order of images already known. If True, the file names should start with ascending numbers, with the first image file having the smallest number, and the last image file having the largest number. If False (the default), the order of images will be determined by ordering a distance matrix. name : str Descriptive name of registrar, such as the sample's name img_file_list : list List of full paths to single channel uint8 images size : int Number of images to align distance_metric_name : str Name of distance metric used to determine the dis/similarity between each pair of images distance_metric_type : str Name of the type of metric used to determine the dis/similarity between each pair of images. Despite the name, it could be "similarity" if the Matcher object compares image feautres using a similarity metric. In that case, similarities are converted to distances. img_obj_list : list List of ZImage objects. Initially unordered, but eventually be sorted img_obj_dict : dict Dictionary of ZImage objects. Created to conveniently access ZIimages. Key = ZImage.name, value= ZImage optimal_Z :ndarray Ordered linkage matrix for `distance_mat` unsorted_distance_mat : ndarray Distance matrix with shape (N, N), where each element is the disimilariy betweewn each pair of the N images. The order of rows and columns reflects the order in which the images were read. This matrix is used to order the images the Z-stack. distance_mat : ndarray `unsorted_distance_mat` reorderd such that the most similar images are adjacent to one another unsorted_similarity_mat : ndarray Similar to `unsorted_distance_mat`, except the elements are image similarity similarity_mat : ndarray Similar to `distance_mat`, except the elements are image similarity features : str Name of feature detector and descriptor used transform_type : str Name of scikit-image transformer class that was used reference_img_f : str Filename of image that will be treated as the center of the stack. reference_img_idx : int Index of ZImage that corresponds to `reference_img_f`, after the `img_obj_list` has been sorted. align_to_reference : bool, optional Whether or not images should be aligne to a reference image specified by `reference_img_f`. Will be set to True if `reference_img_f` is provided. iter_order : list of tuples Each element of `iter_order` contains a tuple of stack indices. The first value is the index of the moving/current/from image, while the second value is the index of the moving/next/to image. summary_df : Dataframe Pandas dataframe containin the registration error of the alignment between each image and the previous one in the stack. """
[docs] def __init__(self, img_dir, imgs_ordered=False, reference_img_f=None, name=None, align_to_reference=False): """Class that performs serial rigid registration Parameters ---------- img_dir : str Path to directory containing the images that will be registered. The images in this folder should be single channel uint8 images. For the best registration results, they have undergone some sort of pre-processing and normalization. The preprocessing module contains methods for this, but the user may want/need to use other methods. imgs_ordered : bool Whether or not the order of images already known. If True, the file names should start with ascending numbers, with the first image file having the smallest number, and the last image file having the largest number. If False (the default), the order of images will be determined by sorting a distance matrix. reference_img_f : str, optional Filename of image that will be treated as the center of the stack. If None, the index of the middle image will be the reference. name : str, optional Descriptive name of registrar, such as the sample's name align_to_reference : bool, optional Whether or not images should be aligne to a reference image specified by `reference_img_f`. Will be set to True if `reference_img_f` is provided. """ self.img_dir = img_dir self.aleady_sorted = imgs_ordered self.name = name self.img_file_list = get_image_files(img_dir, imgs_ordered=imgs_ordered) self.size = len(self.img_file_list) self.distance_metric_name = None self.distance_metric_type = None self.img_obj_list = None self.img_obj_dict = {} self.optimal_z = None self.unsorted_distance_mat = None self.distance_mat = None self.unsorted_similarity_mat = None self.similarity_mat = None self.features = None self.transform_type = None self.reference_img_f = reference_img_f self.reference_img_idx = 0 self.align_to_reference = align_to_reference self.iter_order = None self.summary = None if self.align_to_reference is False and reference_img_f is not None: og_ref_name = valtils.get_name(reference_img_f) msg = (f"The reference was specified as {og_ref_name} ", f"but `align_to_reference` is `False`, and so images will be aligned serially *towards* the reference image. ", f"If you would like all images to be *directly* aligned to {og_ref_name}, " f"then set `align_to_reference` to `True`. Note that in both cases, {og_ref_name} will remain unwarped.") valtils.print_warning(msg)
def generate_img_obj_list(self, feature_detector, qt_emitter=None): """Create a list of ZImage objects Create a list of ZImage objects, each of which represents an image. This function also determines the maximum size of the images so that there is no cropping during warping. Finally, the features of each image are detected using the feature_detector Parameters ---------- feature_detector : FeatureDD FeatureDD object that detects and computes image features. qt_emitter : PySide2.QtCore.Signal, optional Used to emit signals that update the GUI's progress bars """ # NOTE tried parallelizing with joblib, but it's actually slower # sorted_img_list = [io.imread(os.path.join(self.img_dir, f), True) for f in self.img_file_list] out_w, out_h = get_max_image_dimensions(sorted_img_list) # Get dimensions if images were rotated 45 degrees or 90 degrees max_new_w = out_w*np.cos(45) + out_h*np.sin(45) max_new_h = out_w*np.sin(45) + out_h*np.cos(45) max_dist = np.ceil(np.max([out_w, out_h, max_new_h, max_new_w])).astype(int) out_shape = (max_dist, max_dist) img_obj_list = [None] * self.size for i in tqdm(range(self.size), desc=FEATURE_MSG, unit="image", leave=None): img_f = self.img_file_list[i] img = sorted_img_list[i] img_name = valtils.get_name(img_f) img_obj = ZImage(img, os.path.join(self.img_dir, img_f), i, name=img_name) img_obj.padded_shape_rc = out_shape img_obj.T = warp_tools.get_padding_matrix(img.shape, img_obj.padded_shape_rc) img_obj.kp_pos_xy, img_obj.desc = feature_detector.detect_and_compute(img) img_obj_list[i] = img_obj self.img_obj_dict[img_name] = img_obj if qt_emitter is not None: qt_emitter.emit(1) self.img_obj_list = img_obj_list self.features = feature_detector.__class__.__name__ def match_sorted_imgs(self, matcher_obj, keep_unfiltered=False, qt_emitter=None): """Conduct feature matching between images that have already been sorted. Results will be stored in each ZImage's match_dict Parameters ---------- matcher_obj : Matcher Object to match features between images. keep_unfiltered : bool Whether or not matcher_obj should store unfiltered matches qt_emitter : PySide2.QtCore.Signal, optional Used to emit signals that update the GUI's progress bars """ def match_adj_img_obj(i): if i == 0: return None img_obj_1 = self.img_obj_list[i] img_obj_2 = self.img_obj_list[i-1] if matcher_obj.match_filter_method == GMS_NAME: filter_kwargs = {"img1_shape":img_obj_1.image.shape[0:2], "img2_shape": img_obj_2.image.shape[0:2]} else: filter_kwargs = None unfiltered_match_info12, filtered_match_info12, unfiltered_match_info21, filtered_match_info21 = \ matcher_obj.match_images(img1=img_obj_1.image, desc1=img_obj_1.desc, kp1_xy=img_obj_1.kp_pos_xy, img2=img_obj_2.image, desc2=img_obj_2.desc, kp2_xy=img_obj_2.kp_pos_xy, additional_filtering_kwargs=filter_kwargs) if len(filtered_match_info12.matched_kp1_xy) == 0: warnings.warn(f"{len(filtered_match_info12.matched_kp1_xy)} between {img_obj_1.name} and {img_obj_2.name}") # Update match dictionaries if keep_unfiltered: unfiltered_match_info12.set_names(img_obj_1.name, img_obj_2.name) img_obj_1.unfiltered_match_dict[img_obj_2] = unfiltered_match_info12 unfiltered_match_info21.set_names(img_obj_2.name, img_obj_1.name) img_obj_2.unfiltered_match_dict[img_obj_1] = unfiltered_match_info21 filtered_match_info12.set_names(img_obj_1.name, img_obj_2.name) img_obj_1.match_dict[img_obj_2] = filtered_match_info12 filtered_match_info21.set_names(img_obj_2.name, img_obj_1.name) img_obj_2.match_dict[img_obj_1] = filtered_match_info21 if qt_emitter is not None: qt_emitter.emit(1) n_cpu = multiprocessing.cpu_count() - 1 res = pqdm(range(self.size), match_adj_img_obj, n_jobs=n_cpu, desc=MATCHING_MSG, unit="image", leave=None) # with parallel_backend("threading", n_jobs=n_cpu): # Parallel()(delayed(match_adj_img_obj)(i) for i in range(self.size)) def match_imgs(self, matcher_obj, keep_unfiltered=False, qt_emitter=None): """Conduct feature matching between all pairs of images. Results will be stored in each ZImage's match_dict Parameters ---------- matcher_obj : Matcher Object to match features between images. keep_unfiltered : bool Whether or not matcher_obj should store unfiltered matches qt_emitter : PySide2.QtCore.Signal, optional Used to emit signals that update the GUI's progress bars """ # n_comparisions = int((self.size*(self.size-1))/2) # pbar = tqdm(total=n_comparisions, desc=MATCHING_MSG, unit="image", leave=None) def match_img_obj(i): img_obj_1 = self.img_obj_list[i] for j in np.arange(i+1, self.size): img_obj_2 = self.img_obj_list[j] if matcher_obj.match_filter_method == GMS_NAME: filter_kwargs = {"img1_shape":img_obj_1.image.shape[0:2], "img2_shape": img_obj_2.image.shape[0:2]} else: filter_kwargs = None unfiltered_match_info12, filtered_match_info12, unfiltered_match_info21, filtered_match_info21 = \ matcher_obj.match_images(img1=img_obj_1.image, desc1=img_obj_1.desc, kp1_xy=img_obj_1.kp_pos_xy, img2=img_obj_2.image, desc2=img_obj_2.desc, kp2_xy=img_obj_2.kp_pos_xy, additional_filtering_kwargs=filter_kwargs) if len(filtered_match_info12.matched_kp1_xy) == 0: warnings.warn(f"{len(filtered_match_info12.matched_kp1_xy)} between {img_obj_1.name} and {img_obj_2.name}") # Update match dictionaries # if keep_unfiltered: unfiltered_match_info12.set_names(img_obj_1.name, img_obj_2.name) img_obj_1.unfiltered_match_dict[img_obj_2] = unfiltered_match_info12 unfiltered_match_info21.set_names(img_obj_2.name, img_obj_1.name) img_obj_2.unfiltered_match_dict[img_obj_1] = unfiltered_match_info21 filtered_match_info12.set_names(img_obj_1.name, img_obj_2.name) img_obj_1.match_dict[img_obj_2] = filtered_match_info12 filtered_match_info21.set_names(img_obj_2.name, img_obj_1.name) img_obj_2.match_dict[img_obj_1] = filtered_match_info21 # pbar.update(1) if qt_emitter is not None: qt_emitter.emit(1) n_cpu = multiprocessing.cpu_count() - 1 res = pqdm(range(self.size), match_img_obj, n_jobs=n_cpu, desc=MATCHING_MSG, unit="image", leave=None) # with parallel_backend("threading", n_jobs=n_cpu): # Parallel()(delayed(match_img_obj)(i) for i in range(self.size)) def get_neighbor_matches_idx(self, img_obj, prev_img_obj, next_img_obj): """Get indices of features found in both neighbors Returns ------- nf_prev_idx nf_next_idx """ xy_to_prev = img_obj.match_dict[prev_img_obj].matched_kp1_xy xy_to_next = img_obj.match_dict[next_img_obj].matched_kp1_xy xy_to_prev_idx = warp_tools.index2d_to_1d(xy_to_prev[:, 1], xy_to_prev[:, 0], img_obj.image.shape[1]) xy_to_next_idx = warp_tools.index2d_to_1d(xy_to_next[:, 1], xy_to_next[:, 0], img_obj.image.shape[1]) shared_pts, nf_prev_idx, nf_next_idx = np.intersect1d(xy_to_prev_idx, xy_to_next_idx, return_indices=True) # trying to remove diff features if they are different... (possible due to some very rare rounding errors?) diff = np.where(xy_to_prev[nf_prev_idx, :] != xy_to_next[nf_next_idx, :]) if diff[0].any(): diff = list(np.unique(diff[0])) nf_prev_idx = np.delete(nf_prev_idx, diff) nf_next_idx = np.delete(nf_next_idx, diff) return nf_prev_idx, nf_next_idx def get_common_desc(self, current_img_obj, neighbor_obj, nf_kp_idx): """Get descriptors that correspond to filtered neighbor points Parameters ---------- nf_kp_idx : ndarray Indicies of already matched keypoints that were found after neighbonr filtering """ neighbor_match_info12 = current_img_obj.match_dict[neighbor_obj] nf_kp = neighbor_match_info12.matched_kp1_xy[nf_kp_idx] nf_desc = neighbor_match_info12.matched_desc1[nf_kp_idx] return nf_desc, nf_kp def neighbor_match_filtering(self, img_obj, prev_img_obj, next_img_obj, tform, matcher_obj): """Remove poor matches by keeping only the matches found in neighbors Parameters ---------- img_obj : ZImage current ZImage prev_img_obj : ZImage ZImage to below `img_obj` next_img_obj : ZImage ZImage to above `img_obj` tform : skimage.transform object The scikit-image transform object that estimates the parameter matrix matcher_obj : Matcher Object to match features between images. Returns ------- improved: bool Whether or not neighbor filtering improved the alignment updated_prev_match_info12 : MatchInfo If improved is True, then `updated_prev_match_info12` includes only features, descriptors that were found in both neighbors. Otherwise, all of the original features will be maintained updated_next_match_info12 : MatchInfo If improved is True, then `updated_next_match_info12` includes only features, descriptors that were found in both neighbors. Otherwise, all of the original features will be maintained """ def measure_d(src_xy, dst_xy, tform, M=None): """Measure distance between warped corresponding points """ if M is None: tform.estimate(src=dst_xy, dst=src_xy) M = tform.params warped_xy = warp_tools.warp_xy(src_xy, M) d = np.median(warp_tools.calc_d(warped_xy, dst_xy)) return d, M nf_prev_idx, nf_next_idx = self.get_neighbor_matches_idx(img_obj, prev_img_obj, next_img_obj) to_prev_match_info12 = img_obj.match_dict[prev_img_obj] to_next_match_info12 = img_obj.match_dict[next_img_obj] improved = False if len(nf_prev_idx) >= 3: # Need at least 3 points for an affine transform common_kp = to_prev_match_info12.matched_kp1_xy[nf_prev_idx] _common_kp = to_next_match_info12.matched_kp1_xy[nf_next_idx] assert np.all(common_kp == _common_kp) common_prev_kp = to_prev_match_info12.matched_kp2_xy[nf_prev_idx] common_next_kp = to_next_match_info12.matched_kp2_xy[nf_next_idx] common_matches_d, common_matches_M = measure_d(common_kp, common_prev_kp, tform) original_d, _ = measure_d(to_prev_match_info12.matched_kp1_xy, to_prev_match_info12.matched_kp2_xy, tform) original_with_neighbor_filter_d, _ = measure_d(to_prev_match_info12.matched_kp1_xy, to_prev_match_info12.matched_kp2_xy, tform, M=common_matches_M) if common_matches_d < original_d and original_with_neighbor_filter_d <= original_d: # neighbor filtering improved alignment improved = True filtered_desc, filtered_kp = self.get_common_desc(img_obj, prev_img_obj, nf_prev_idx) _filtered_desc, _filtered_kp = self.get_common_desc(img_obj, next_img_obj, nf_next_idx) filtered_prev_desc, filtered_prev_kp = self.get_common_desc(prev_img_obj, img_obj, nf_prev_idx) assert np.all(common_prev_kp == filtered_prev_kp) filtered_next_desc, filtered_next_kp = self.get_common_desc(next_img_obj, img_obj, nf_next_idx) assert np.all(common_next_kp == filtered_next_kp) updated_prev_match_info12, _, updated_prev_match_info21, _ = \ matcher_obj.match_images(desc1=filtered_desc, kp1_xy=filtered_kp, desc2=filtered_prev_desc, kp2_xy=filtered_prev_kp) updated_next_match_info12, _, updated_next_match_info21, _ = \ matcher_obj.match_images(desc1=_filtered_desc, kp1_xy=_filtered_kp, desc2=filtered_next_desc, kp2_xy=filtered_next_kp) if improved: return improved, updated_prev_match_info12, updated_next_match_info12 else: return improved, to_prev_match_info12, to_next_match_info12 def update_match_dicts_with_neighbor_filter(self, tform, matcher_obj): """Remove poor matches by keeping only the matches found in neighbors Parameters ---------- tform : skimage.transform object The scikit-image transform object that estimates the parameter matrix matcher_obj : Matcher Object to match features between images. """ new_matches = {} for i, img_obj in enumerate(self.img_obj_list): if i == 0 or i == self.size - 1: continue prev_idx = i - 1 prev_img_obj = self.img_obj_list[prev_idx] next_idx = i + 1 next_img_obj = self.img_obj_list[next_idx] improved, updated_prev_match_info12, updated_next_match_info12 = \ self.neighbor_match_filtering(img_obj, prev_img_obj, next_img_obj, tform, matcher_obj) if improved: new_matches[img_obj.name] = [updated_prev_match_info12, updated_next_match_info12] # Update matches for i, img_obj in enumerate(self.img_obj_list): if not img_obj.name in new_matches: continue prev_idx = i - 1 prev_img_obj = self.img_obj_list[prev_idx] next_idx = i + 1 next_img_obj = self.img_obj_list[next_idx] img_obj_new_matches = new_matches[img_obj.name] img_obj.match_dict[prev_img_obj] = img_obj_new_matches[0] img_obj.match_dict[next_img_obj] = img_obj_new_matches[1] def build_metric_matrix(self, metric="n_matches"): """Create metric matrix based image similarity/distance Parameters ---------- metric: str Name of metrric to use. If 'distance' that the distances and similiarities calculated during feature matching will be used. If 'n_matches', then the number of matches will be used for similariy, and 1/n_matches for distance. """ distance_mat = np.zeros((self.size, self.size)) similarity_mat = np.zeros_like(distance_mat) for i, obj1 in enumerate(self.img_obj_list): for j in np.arange(i, self.size): obj2 = self.img_obj_list[j] if i == j: continue if metric == "n_matches": s = obj1.match_dict[obj2].n_matches else: s = obj1.match_dict[obj2].similarity d = obj1.match_dict[obj2].distance distance_mat[i, j] = d distance_mat[j, i] = d similarity_mat[i, j] = s similarity_mat[j, i] = s min_s = similarity_mat.min() max_s = similarity_mat.max() min_d = distance_mat.min() max_d = distance_mat.max() # Make sure that image has highest similarity with itself similarity_mat[np.diag_indices_from(similarity_mat)] += max_s*0.01 # Scale metrics between 0 and 1 similarity_mat = (similarity_mat - min_s) / (max_s - min_s) similarity_mat[np.diag_indices_from(similarity_mat)] = 1 if metric == "n_matches": distance_mat = 1 - similarity_mat else: distance_mat = (distance_mat - min_d) / (max_d - min_d) distance_mat[np.diag_indices_from(distance_mat)] = 0 self.unsorted_similarity_mat = similarity_mat self.unsorted_distance_mat = distance_mat def sort(self): """Order images such that most similar images are adjacent Order the images in the stack by optimally ordering the leaves of dendrogram created by clustering a matrix of image feature distances. """ sorted_D, sorted_idx, optimal_Z = order_Dmat(self.unsorted_distance_mat) self.optimal_z = optimal_Z self.distance_mat = sorted_D self.similarity_mat = self.unsorted_similarity_mat[sorted_idx, :] self.similarity_mat = self.similarity_mat[:, sorted_idx] self.img_file_list = [self.img_file_list[i] for i in sorted_idx] self.img_file_list = [self.img_file_list[i] for i in sorted_idx] self.img_obj_list = [self.img_obj_list[i] for i in sorted_idx] for z, img_obj in enumerate(self.img_obj_list): img_obj.stack_idx = z def get_iter_order(self): """Get order in which to align images Will treat the reference image as the center of the stack """ if self.reference_img_f is not None: ref_img_name = valtils.get_name(self.reference_img_f) else: ref_img_name = None obj_names = [img_obj.name for img_obj in self.img_obj_list] ref_img_idx = warp_tools.get_ref_img_idx(obj_names, ref_img_name) self.reference_img_idx = ref_img_idx self.reference_img_f = self.img_obj_list[ref_img_idx].full_img_f self.iter_order = warp_tools.get_alignment_indices(self.size, ref_img_idx) for moving_idx, fixed_idx in self.iter_order: img_obj = self.img_obj_list[moving_idx] prev_img_obj = self.img_obj_list[fixed_idx] img_obj.fixed_obj = prev_img_obj def align_to_prev_check_reflections(self, transformer, feature_detector, matcher_obj, keep_unfiltered=False, qt_emitter=None): """Use key points to align current image to previous image in the stack, but checking if reflection improves alignment Parameters --------- transformer : skimage.transform object The scikit-image transform object that estimates the parameter matrix feature_detector : FeatureDD FeatureDD object that detects and computes image features. matcher_obj : Matcher Object to match features between images. keep_unfiltered : bool Whether or not matcher_obj should store unfiltered matches qt_emitter : PySide2.QtCore.Signal, optional Used to emit signals that update the GUI's progress bars """ ref_img_obj = self.img_obj_list[self.reference_img_idx] for moving_idx, fixed_idx in tqdm(self.iter_order, desc=TRANSFORM_MSG, unit="image", leave=None): img_obj = self.img_obj_list[moving_idx] prev_img_obj = self.img_obj_list[fixed_idx] if fixed_idx == self.reference_img_idx: prev_M = ref_img_obj.T.copy() if matcher_obj.match_filter_method == GMS_NAME: filter_kwargs = {"img1_shape":img_obj.image.shape[0:2], "img2_shape": prev_img_obj.image.shape[0:2]} else: filter_kwargs = None # Estimate current error without reflections. Don't need to re-detect and match features to_prev_match_info = img_obj.match_dict[prev_img_obj] transformer.estimate(to_prev_match_info.matched_kp2_xy, to_prev_match_info.matched_kp1_xy) unreflected_warped_src_xy = warp_tools.warp_xy(to_prev_match_info.matched_kp1_xy, transformer.params) _, unreflected_d = warp_tools.measure_error(to_prev_match_info.matched_kp2_xy, unreflected_warped_src_xy, prev_img_obj.image.shape) reflected_d_vals = [unreflected_d] reflection_M = [np.eye(3)] transforms = [transformer.params] reflected_matches12 = [to_prev_match_info] reflected_matches21 = [prev_img_obj.match_dict[img_obj]] if keep_unfiltered and prev_img_obj in img_obj.unfiltered_match_dict: unfiltered_reflected_matches12 = [img_obj.unfiltered_match_dict[prev_img_obj]] unfiltered_reflected_matches21 = [prev_img_obj.unfiltered_match_dict[img_obj]] # Estimate error with reflections dst_xy = warp_tools.warp_xy(prev_img_obj.kp_pos_xy, prev_M) for rx in [False, True]: for ry in [False, True]: if not rx and not ry: continue rM = warp_tools.get_reflection_M(rx, ry, img_obj.image.shape) reflected_img = warp_tools.warp_img(img_obj.image, rM @ img_obj.T, out_shape_rc=img_obj.padded_shape_rc) reflected_src_xy, reflected_desc = feature_detector.detect_and_compute(reflected_img) unfiltered_match_info12, filtered_match_info12, unfiltered_match_info21, filtered_match_info21 = \ matcher_obj.match_images(img1=reflected_img, desc1=reflected_desc, kp1_xy=reflected_src_xy, img2=prev_img_obj.image, desc2=prev_img_obj.desc, kp2_xy=dst_xy, additional_filtering_kwargs=filter_kwargs) # Record info # _ = transformer.estimate(filtered_match_info12.matched_kp2_xy, filtered_match_info12.matched_kp1_xy) reflected_warped_src_xy = warp_tools.warp_xy(filtered_match_info12.matched_kp1_xy, transformer.params) _, reflected_d = warp_tools.measure_error(filtered_match_info12.matched_kp2_xy, reflected_warped_src_xy, prev_img_obj.padded_shape_rc) reflected_d_vals.append(reflected_d) reflection_M.append(rM) transforms.append(transformer.params) # Move matched features to position in original images img_inv_M = np.linalg.inv(rM @ img_obj.T) prev_img_inv_M = np.linalg.inv(prev_M) filtered_match_info12.matched_kp1_xy = warp_tools.warp_xy(filtered_match_info12.matched_kp1_xy, img_inv_M) filtered_match_info12.matched_kp2_xy = warp_tools.warp_xy(filtered_match_info12.matched_kp2_xy, prev_img_inv_M) filtered_match_info21.matched_kp1_xy = warp_tools.warp_xy(filtered_match_info21.matched_kp1_xy, prev_img_inv_M) filtered_match_info21.matched_kp2_xy = warp_tools.warp_xy(filtered_match_info21.matched_kp2_xy, img_inv_M) reflected_matches12.append(filtered_match_info12) reflected_matches21.append(filtered_match_info21) if keep_unfiltered: unfiltered_match_info12.matched_kp1_xy = warp_tools.warp_xy(unfiltered_match_info12.matched_kp1_xy, img_inv_M) unfiltered_match_info12.matched_kp2_xy = warp_tools.warp_xy(unfiltered_match_info12.matched_kp2_xy, prev_img_inv_M) unfiltered_match_info21.matched_kp1_xy = warp_tools.warp_xy(unfiltered_match_info21.matched_kp1_xy, prev_img_inv_M) unfiltered_match_info21.matched_kp2_xy = warp_tools.warp_xy(unfiltered_match_info21.matched_kp2_xy, img_inv_M) unfiltered_reflected_matches12.append(unfiltered_match_info12) unfiltered_reflected_matches21.append(unfiltered_match_info21) best_idx = np.argmin(reflected_d_vals) best_reflect_M = reflection_M[best_idx] best_M = transforms[best_idx] img_obj.to_prev_A = best_M img_obj.reflection_M = best_reflect_M prev_M = img_obj.reflection_M @ img_obj.T @ img_obj.to_prev_A ref_x, ref_y = best_reflect_M[[0, 1], [0, 1]] < 0 if ref_x or ref_y: msg = f'detected relfections between {img_obj.name} and {prev_img_obj.name} along the' if ref_x and ref_y: msg = f'{msg} x and y axes' elif ref_x: msg = f'{msg} x axis' elif ref_y: msg = f'{msg} y axis' valtils.print_warning(msg) # Update matches img_obj.match_dict[prev_img_obj] = reflected_matches12[best_idx] prev_img_obj.match_dict[img_obj] = reflected_matches21[best_idx] if keep_unfiltered: img_obj.unfiltered_match_dict[prev_img_obj] = unfiltered_reflected_matches12[best_idx] prev_img_obj.unfiltered_match_dict[img_obj] = unfiltered_reflected_matches21[best_idx] if qt_emitter is not None: qt_emitter.emit(1) def align_to_prev(self, transformer, qt_emitter=None): """Use key points to align current image to previous image in the stack Parameters --------- transformer : skimage.transform object The scikit-image transform object that estimates the parameter matrix qt_emitter : PySide2.QtCore.Signal, optional Used to emit signals that update the GUI's progress bars """ ref_img_obj = self.img_obj_list[self.reference_img_idx] if qt_emitter is not None: qt_emitter.emit(1) for moving_idx, fixed_idx in tqdm(self.iter_order, desc=TRANSFORM_MSG, unit="image", leave=None): img_obj = self.img_obj_list[moving_idx] prev_img_obj = self.img_obj_list[fixed_idx] img_obj.fixed_obj = prev_img_obj if fixed_idx == self.reference_img_idx: prev_M = ref_img_obj.T.copy() to_prev_match_info = img_obj.match_dict[prev_img_obj] src_xy = warp_tools.warp_xy(to_prev_match_info.matched_kp1_xy, img_obj.T) dst_xy = warp_tools.warp_xy(to_prev_match_info.matched_kp2_xy, prev_M) transformer.estimate(dst_xy, src_xy) img_obj.to_prev_A = transformer.params prev_M = img_obj.T @ img_obj.to_prev_A if qt_emitter is not None: qt_emitter.emit(1) def optimize(self, affine_optimizer, qt_emitter=None): """Refine alignment by minimizing a metric Transformation will only be allowed if it both decreases the cost and median distance between keypoints. Parameters ----------- affine_optimizer : AffineOptimzer Object that will minimize a cost function to find the optimal affine transformations qt_emitter : PySide2.QtCore.Signal, optional Used to emit signals that update the GUI's progress bars """ ref_img_obj = self.img_obj_list[self.reference_img_idx] ref_warped = warp_tools.warp_img(ref_img_obj.image, M=ref_img_obj.T, out_shape_rc=ref_img_obj.padded_shape_rc) if qt_emitter is not None: qt_emitter.emit(1) for moving_idx, fixed_idx in tqdm(self.iter_order, desc=OPTIMIZING_MSG, unit="image", leave=None): img_obj = self.img_obj_list[moving_idx] prev_img_obj = self.img_obj_list[fixed_idx] if prev_img_obj == ref_img_obj: prev_img = ref_warped prev_M = ref_img_obj.T M = img_obj.reflection_M @ img_obj.T @ img_obj.to_prev_A warped_img = warp_tools.warp_img(img_obj.image, M=M, out_shape_rc=img_obj.padded_shape_rc) to_prev_match_info = img_obj.match_dict[prev_img_obj] before_src_xy = warp_tools.warp_xy(to_prev_match_info.matched_kp1_xy, M) before_dst_xy = warp_tools.warp_xy(to_prev_match_info.matched_kp2_xy, prev_M) before_tre, before_med_d = warp_tools.measure_error(before_src_xy, before_dst_xy, warped_img.shape) # Get mask img_mask = np.ones(img_obj.image.shape[0:2], dtype=np.uint8) warped_img_mask = warp_tools.warp_img(img_mask, M=M, out_shape_rc=img_obj.padded_shape_rc) prev_img_mask = np.ones(prev_img_obj.image.shape[0:2], dtype=np.uint8) warped_prev_img_mask = warp_tools.warp_img(prev_img_mask, M=prev_M, out_shape_rc=prev_img_obj.padded_shape_rc) mask = np.zeros(warped_img_mask.shape, dtype=np.uint8) mask[(warped_img_mask != 0) & (warped_prev_img_mask != 0)] = 255 # Optimize area inside mask if affine_optimizer.accepts_xy: moving_xy = before_src_xy fixed_xy = before_dst_xy else: moving_xy = None fixed_xy = None with valtils.HiddenPrints(): _, optimal_M, _ = affine_optimizer.align(moving=warped_img, fixed=prev_img, mask=mask, initial_M=None, moving_xy=moving_xy, fixed_xy=fixed_xy) # Keep optimal M if it actually improved alignment initial_cst = affine_optimizer.cost_fxn(warped_img, prev_img, mask) after_src_xy = warp_tools.warp_xy(to_prev_match_info.matched_kp1_xy, M @ optimal_M) after_dst_xy = warp_tools.warp_xy(to_prev_match_info.matched_kp2_xy, prev_M) optimal_reg_img = warp_tools.warp_img(warped_img, M=optimal_M, out_shape_rc=img_obj.padded_shape_rc) after_cst = affine_optimizer.cost_fxn(optimal_reg_img, prev_img, mask) after_tre, after_med_d = warp_tools.measure_error(after_src_xy, after_dst_xy, warped_img.shape) if after_cst is not None and initial_cst is not None: lower_cost = after_cst <= initial_cst else: lower_cost = True lower_d = after_med_d <= before_med_d if lower_cost and lower_d: prev_img = optimal_reg_img img_obj.optimal_M = optimal_M else: msg = (f"Somehow optimization made things worse. " f"Cost was {initial_cst} but is now {after_cst}" f"KP medD was {before_med_d}, but is now {after_med_d}.") valtils.print_warning(msg) prev_img = warped_img prev_M = M @ img_obj.optimal_M if qt_emitter is not None: qt_emitter.emit(1) def calc_warped_img_size(self): """Determine the shape of the registered images """ min_x = np.inf max_x = 0 min_y = np.inf max_y = 0 for i in range(self.size): img_obj = self.img_obj_list[i] M = img_obj.reflection_M @ img_obj.T @ img_obj.to_prev_A @ img_obj.optimal_M img_corners_rc = warp_tools.get_corners_of_image(img_obj.image.shape) warped_corners_xy = warp_tools.warp_xy(img_corners_rc[:, ::-1], M) min_x = np.min([np.min(warped_corners_xy[:, 0]), min_x]) max_x = np.max([np.max(warped_corners_xy[:, 0]), max_x]) min_y = np.min([np.min(warped_corners_xy[:, 1]), min_y]) max_y = np.max([np.max(warped_corners_xy[:, 1]), max_y]) w = int(np.ceil(max_x - min_x)) h = int(np.ceil(max_y - min_y)) return np.array([h, w]) def finalize(self): """Combine transformation matrices and get final shape of registered images """ min_x = np.inf max_x = 0 min_y = np.inf max_y = 0 M_list = [None] * self.size for i in tqdm(range(self.size), desc=FINALIZING_MSG, unit="image", leave=None): img_obj = self.img_obj_list[i] M = img_obj.reflection_M @ img_obj.T @ img_obj.to_prev_A @ img_obj.optimal_M M_list[i] = M img_corners_rc = warp_tools.get_corners_of_image(img_obj.image.shape) warped_corners_xy = warp_tools.warp_xy(img_corners_rc[:, ::-1], M) min_x = np.min([np.min(warped_corners_xy[:, 0]), min_x]) max_x = np.max([np.max(warped_corners_xy[:, 0]), max_x]) min_y = np.min([np.min(warped_corners_xy[:, 1]), min_y]) max_y = np.max([np.max(warped_corners_xy[:, 1]), max_y]) w = int(np.ceil(max_x - min_x)) h = int(np.ceil(max_y - min_y)) crop_T = np.identity(3) crop_T[0, 2] = min_x crop_T[1, 2] = min_y for i, img_obj in enumerate(self.img_obj_list): img_obj.crop_T = crop_T img_obj.M = M_list[i] @ crop_T img_obj.M_inv = np.linalg.inv(img_obj.M) img_obj.registered_img = warp_tools.warp_img(img=img_obj.image, M=img_obj.M, out_shape_rc=(h, w)) img_obj.registered_shape_rc = img_obj.registered_img.shape[0:2] def wiggle_to_ref(self, transformer): """Compose rigid transforms to wiggle image to reference #. For each slide, get M that aligns it's rigidly warp points to it's fixed image's rigidly warped points. These will be `rolling_M` #. Then, for each slide, compose their `M` with each neighbor's `rolling M` until it gets to the reference slide """ ref_obj = self.img_obj_list[self.reference_img_idx] # Find inverse transforms that will align rigid image to rigid neighbor rolling_M_list = [None] * self.size for img_obj in self.img_obj_list: if img_obj == ref_obj: continue matches = img_obj.match_dict[img_obj.fixed_obj] rigid_reg_moving_xy = warp_tools.warp_xy(matches.matched_kp1_xy, M=img_obj.M) rigid_reg_fixed_xy = warp_tools.warp_xy(matches.matched_kp2_xy, M=img_obj.fixed_obj.M) transformer.estimate(src=rigid_reg_fixed_xy, dst=rigid_reg_moving_xy) rolling_M = transformer.params rolling_M_list[img_obj.stack_idx] = rolling_M # Compose rolling transforms wiggle_M_list = [None] * self.size for img_obj in self.img_obj_list: if img_obj == ref_obj: continue neighbor_slide = img_obj.fixed_obj wiggle_M = np.eye(3) while neighbor_slide != ref_obj: neighbor_rolling_M = rolling_M_list[neighbor_slide.stack_idx] wiggle_M = wiggle_M @ neighbor_rolling_M neighbor_slide = neighbor_slide.fixed_obj wiggle_M_list[img_obj.stack_idx] = wiggle_M # Update M for img_obj in self.img_obj_list: if img_obj == ref_obj: continue updated_M = img_obj.M @ wiggle_M_list[img_obj.stack_idx] img_obj.M = updated_M def clear_unused_matches(self): """Clear up space by removing unused matches between Zimages Will only keep matches between each ZImage and the previous Zimage in the stack """ for i, img_obj in enumerate(self.img_obj_list): if i == 0: prev_img_obj = None else: prev_img_obj = self.img_obj_list[i-1] if i == self.size - 1: next_img_obj = None else: next_img_obj = self.img_obj_list[i+1] img_obj.reduce(prev_img_obj, next_img_obj)
[docs] def summarize(self): """Summarize alignment error Returns ------- summary_df: Dataframe Pandas dataframe containin the registration error of the alignment between each image and the previous one in the stack. """ src_img_names = [None] * self.size dst_img_names = [None] * self.size og_med_d_list = [None] * self.size og_tre_list = [None] * self.size med_d_list = [None] * self.size weighted_med_d_list = [None] * self.size tre_list = [None] * self.size shape_list = [None] * self.size for i in range(0, self.size): img_obj = self.img_obj_list[i] src_img_names[i] = img_obj.name shape_list[i] = img_obj.registered_img.shape if i == self.reference_img_idx: continue prev_img_obj = img_obj.fixed_obj dst_img_names[i] = prev_img_obj.name current_to_prev_matches = img_obj.match_dict[prev_img_obj] temp_current_pts = current_to_prev_matches.matched_kp1_xy temp_prev_pts = current_to_prev_matches.matched_kp2_xy og_tre_list[i], og_med_d_list[i] = \ warp_tools.measure_error(temp_current_pts, temp_prev_pts, img_obj.image.shape) current_pts = warp_tools.warp_xy(temp_current_pts, img_obj.M) prev_pts = warp_tools.warp_xy(temp_prev_pts, prev_img_obj.M) tre_list[i], med_d_list[i] = \ warp_tools.measure_error(current_pts, prev_pts, img_obj.image.shape) similarities = \ convert_distance_to_similarity(current_to_prev_matches.match_distances, current_to_prev_matches.matched_desc1.shape[0]) _, weighted_med_d_list[i] = \ warp_tools.measure_error(current_pts, prev_pts, img_obj.image.shape, similarities) summary_df = pd.DataFrame({ "from": src_img_names, "to": dst_img_names, "original_D": og_med_d_list, "D": med_d_list, "D_weighted": weighted_med_d_list, "original_TRE": og_tre_list, "TRE": tre_list, "shape": shape_list, }) non_ref_idx = list(range(self.size)) non_ref_idx.remove(self.reference_img_idx) summary_df["series_d"] = warp_tools.calc_total_error(summary_df.D.values[non_ref_idx]) summary_df["series_tre"] = warp_tools.calc_total_error(summary_df.TRE.values[non_ref_idx]) summary_df["series_weighted_d"] = warp_tools.calc_total_error(summary_df.D_weighted.values[non_ref_idx]) summary_df["name"] = self.name return summary_df
[docs] def register_images(img_dir, dst_dir=None, name="registrar", feature_detector=VggFD(), matcher=Matcher(), transformer=EuclideanTransform(), affine_optimizer=None, imgs_ordered=False, reference_img_f=None, similarity_metric="n_matches", check_for_reflections=False, max_scaling=3.0, align_to_reference=False, qt_emitter=None, valis_obj=None): """ Rigidly align collection of images Parameters ---------- img_dir : str Path to directory containing the images that the user would like to be registered. These images need to be single channel, uint8 images dst_dir : str, optional Top directory where aliged images should be save. SerialRigidRegistrar will be in this folder, and aligned images in the "registered_images" sub-directory. If None, the images will not be written to file name : str, optional Descriptive name of registrar, such as the sample's name feature_detector : FeatureDD FeatureDD object that detects and computes image features. matcher : Matcher Matcher object that will be used to match image features transformer : scikit-image Transform object Transformer used to find transformation matrix that will warp each image to the target image. affine_optimizer : AffineOptimzer object Object that will minimize a cost function to find the optimal affine transoformations imgs_ordered : bool Boolean defining whether or not the order of images in img_dir are already in the correct order. If True, then each filename should begin with the number that indicates its position in the z-stack. If False, then the images will be sorted by ordering a feature distance matix. reference_img_f : str, optional Filename of image that will be treated as the center of the stack. If None, the index of the middle image will be the reference. check_for_reflections : bool, optional Determine if alignments are improved by relfecting/mirroring/flipping images. Optional because it requires re-detecting features in each version of the images and then re-matching features, and so can be time consuming and not always necessary. similarity_metric : str Metric used to calculate similarity between images, which is in turn used to build the distance matrix used to sort the images. summary : Dataframe Pandas dataframe containing the median distance between matched features before and after registration. align_to_reference : bool, optional Whether or not images should be aligned to a reference image specified by `reference_img_f`. qt_emitter : PySide2.QtCore.Signal, optional Used to emit signals that update the GUI's progress bars Returns ------- registrar : SerialRigidRegistrar SerialRigidRegistrar object contains general information about the alginments, but also a list of Z-images. Each ZImage contains the warp information for an image in the stack, including the transformation matrices calculated at each step, keypoint poisions, image descriptors, and matches with other images. See attributes from Zimage for more information. """ tic = time() if affine_optimizer is not None: if transformer.__class__.__name__ != affine_optimizer.transformation: print(Warning("Transformer is of type ", transformer.__class__.__name__, "but affine_optimizer optimizes the", affine_optimizer.transformation, ". Setting", transformer.__class__.__name__, "as the transform to be optimized")) affine_optimizer.transformation = transformer.__class__.__name__ if transformer.__class__.__name__ == "EuclideanTransform": matcher.scaling = False else: matcher.scaling = True registrar = SerialRigidRegistrar(img_dir, imgs_ordered=imgs_ordered, reference_img_f=reference_img_f, name=name, align_to_reference=align_to_reference) # print("\n======== Detecting features\n") registrar.generate_img_obj_list(feature_detector, qt_emitter=qt_emitter) if valis_obj is not None: if valis_obj.create_masks: # Remove feature points outside of mask for img_obj in registrar.img_obj_dict.values(): slide_obj = valis_obj.get_slide(img_obj.name) features_in_mask_idx = warp_tools.get_xy_inside_mask(xy=img_obj.kp_pos_xy, mask=slide_obj.rigid_reg_mask) if len(features_in_mask_idx) > 0: img_obj.kp_pos_xy = img_obj.kp_pos_xy[features_in_mask_idx, :] img_obj.desc = img_obj.desc[features_in_mask_idx, :] # print("\n======== Matching images\n") if registrar.aleady_sorted: registrar.match_sorted_imgs(matcher, keep_unfiltered=False, qt_emitter=qt_emitter) for i, img_obj in enumerate(registrar.img_obj_list): img_obj.stack_idx = i else: registrar.match_imgs(matcher, keep_unfiltered=False, qt_emitter=qt_emitter) # print("\n======== Sorting images\n") registrar.build_metric_matrix(metric=similarity_metric) registrar.sort() registrar.distance_metric_name = matcher.metric_name registrar.distance_metric_type = matcher.metric_type # print("\n======== Calculating transformations\n") registrar.get_iter_order() if registrar.size > 2: registrar.update_match_dicts_with_neighbor_filter(transformer, matcher) if check_for_reflections: registrar.align_to_prev_check_reflections(transformer=transformer, feature_detector=feature_detector, matcher_obj=matcher, keep_unfiltered=False, qt_emitter=qt_emitter) else: registrar.align_to_prev(transformer=transformer, qt_emitter=qt_emitter) # Check current output shape. If too large, then registration failed for img_obj in registrar.img_obj_list: s = transform.SimilarityTransform(img_obj.M).scale if s >= max_scaling or s <= 1/max_scaling: print(Warning(f"Max allowed scaling is {max_scaling},\ but was calculated as being {s}.\ Registration failed. Maybe try using the Euclidean transform.")) return False if affine_optimizer is not None: # print("\n======== Optimizing alignments\n") registrar.optimize(affine_optimizer, qt_emitter=qt_emitter) registrar.finalize() if align_to_reference: registrar.wiggle_to_ref(transformer) if dst_dir is not None: registered_img_dir = os.path.join(dst_dir, "registered_images") registered_data_dir = os.path.join(dst_dir, "data") for d in [registered_img_dir, registered_data_dir]: pathlib.Path(d).mkdir(exist_ok=True, parents=True) # print("\n======== Summarizing alignments\n") summary_df = registrar.summarize() summary_file = os.path.join(registered_data_dir, name + "_results.csv") summary_df.to_csv(summary_file, index=False) registrar.summary = summary_df # print("\n======== Saving results\n") pickle_file = os.path.join(registered_data_dir, name + "_registrar.pickle") pickle.dump(registrar, open(pickle_file, 'wb')) n_digits = len(str(registrar.size)) for img_obj in registrar.img_obj_list: f_out = "".join([str.zfill(str(img_obj.stack_idx), n_digits), "_", img_obj.name, ".png"]) io.imsave(os.path.join(registered_img_dir, f_out), img_obj.registered_img.astype(np.uint8)) registrar.clear_unused_matches() toc = time() elapsed = toc - tic time_string, time_units = valtils.get_elapsed_time_string(elapsed) print(f"\n======== Rigid registration complete in {time_string} {time_units}\n") return registrar