"""Functions and classes to detect and describe image features
Bundles OpenCV feature detectors and descriptors into the FeatureDD class
Also makes it easier to mix and match feature detectors and descriptors
from different pacakges (e.g. skimage and OpenCV). See CensureVggFD for
an example
"""
import cv2
from skimage import feature, exposure
import numpy as np
import torch
import traceback
from . import valtils
from .superglue_models import superpoint
DEFAULT_FEATURE_DETECTOR = cv2.BRISK_create()
"""The default OpenCV feature detector"""
MAX_FEATURES = 20000
"""Maximum number of image features that will be recorded. If the number
of features exceeds this value, the MAX_FEATURES features with the
highest response will be returned."""
[docs]
def filter_features(kp, desc, n_keep=MAX_FEATURES):
"""Get keypoints with highest response
Parameters
----------
kp : list
List of cv2.KeyPoint detected by an OpenCV feature detector.
desc : ndarray
2D numpy array of keypoint descriptors, where each row is a keypoint
and each column a feature.
n_keep : int
Maximum number of features that are retained.
Returns
-------
Keypoints and and corresponding descriptors that the the n_keep highest
responses.
"""
response = np.array([x.response for x in kp])
keep_idx = np.argsort(response)[::-1][0:n_keep]
return [kp[i] for i in keep_idx], desc[keep_idx, :]
[docs]
class FeatureDD(object):
"""Abstract class for feature detection and description.
User can create other feature detectors as subclasses, but each must
return keypoint positions in xy coordinates along with the descriptors
for each keypoint.
Note that in some cases, such as KAZE, kp_detector can also detect
features. However, in other cases, there may need to be a separate feature
detector (like BRISK or ORB) and feature descriptor (like VGG).
Attributes
----------
kp_detector : object
Keypoint detetor, by default from OpenCV
kp_descriptor : object
Keypoint descriptor, by default from OpenCV
kp_detector_name : str
Name of keypoint detector
kp_descriptor : str
Name of keypoint descriptor
Methods
-------
detectAndCompute(image, mask=None)
Detects and describes keypoints in image
"""
[docs]
def __init__(self, kp_detector=None, kp_descriptor=None):
"""
Parameters
----------
kp_detector : object
Keypoint detetor, by default from OpenCV
kp_descriptor : object
Keypoint descriptor, by default from OpenCV
"""
self.kp_detector = kp_detector
self.kp_descriptor = kp_descriptor
if kp_descriptor is not None and kp_detector is not None:
# User provides both a detector and descriptor #
self.kp_descriptor_name = kp_descriptor.__class__.__name__
self.kp_detector_name = kp_detector.__class__.__name__
if kp_descriptor is None and kp_detector is not None:
# Will be using kp_descriptor for detectAndCompute #
kp_descriptor = kp_detector
kp_detector = None
if kp_descriptor is not None and kp_detector is None:
# User provides a descriptor, which must also be able to detect #
self.kp_descriptor_name = kp_descriptor.__class__.__name__
self.kp_detector_name = self.kp_descriptor_name
try:
_img = np.zeros((10, 10), dtype=np.uint8)
kp_descriptor.detectAndCompute(_img, mask=None)
except:
traceback_msg = traceback.format_exc()
msg = f"{self.kp_descriptor_name} unable to both detect and compute features. Setting to {DEFAULT_FEATURE_DETECTOR.__class__.__name__}"
valtils.print_warning(msg, traceback_msg=traceback_msg)
self.kp_detector = DEFAULT_FEATURE_DETECTOR
[docs]
def detect_and_compute(self, image, mask=None):
"""Detect the features in the image
Detect the features in the image using the defined kp_detector, then
describe the features using the kp_descriptor. The user can override
this method so they don't have to use OpenCV's Keypoint class.
Parameters
----------
image : ndarray
Image in which the features will be detected. Should be a 2D uint8
image if using OpenCV
mask : ndarray, optional
Binary image with same shape as image, where foreground > 0,
and background = 0. If provided, feature detection will only be
performed on the foreground.
Returns
-------
kp : ndarry
(N, 2) array positions of keypoints in xy corrdinates for N
keypoints
desc : ndarry
(N, M) array containing M features for each of the N keypoints
"""
image = exposure.rescale_intensity(image, out_range=(0, 255)).astype(np.uint8)
if self.kp_detector is not None:
detected_kp = self.kp_detector.detect(image)
kp, desc = self.kp_descriptor.compute(image, detected_kp)
else:
kp, desc = self.kp_descriptor.detectAndCompute(image, mask=mask)
if desc.shape[0] > MAX_FEATURES:
kp, desc = filter_features(kp, desc)
kp_pos_xy = np.array([k.pt for k in kp])
return kp_pos_xy, desc
# Thin wrappers around OpenCV detectors and descriptors #
class OrbFD(FeatureDD):
"""Uses ORB for feature detection and description"""
def __init__(self, kp_descriptor=cv2.ORB_create(MAX_FEATURES)):
super().__init__(kp_descriptor=kp_descriptor)
[docs]
class BriskFD(FeatureDD):
"""Uses BRISK for feature detection and description"""
def __init__(self, kp_descriptor=cv2.BRISK_create()):
super().__init__(kp_descriptor=kp_descriptor)
[docs]
class KazeFD(FeatureDD):
"""Uses KAZE for feature detection and description"""
def __init__(self, kp_descriptor=cv2.KAZE_create(extended=False)):
super().__init__(kp_descriptor=kp_descriptor)
[docs]
class AkazeFD(FeatureDD):
"""Uses AKAZE for feature detection and description"""
def __init__(self, kp_descriptor=cv2.AKAZE_create()):
super().__init__(kp_descriptor=kp_descriptor)
[docs]
class DaisyFD(FeatureDD):
"""Uses BRISK for feature detection and DAISY for feature description"""
def __init__(self, kp_detector=DEFAULT_FEATURE_DETECTOR,
kp_descriptor=cv2.xfeatures2d.DAISY_create()):
super().__init__(kp_detector=kp_detector, kp_descriptor=kp_descriptor)
[docs]
class LatchFD(FeatureDD):
"""Uses BRISK for feature detection and LATCH for feature description"""
def __init__(self, kp_detector=DEFAULT_FEATURE_DETECTOR,
kp_descriptor=cv2.xfeatures2d.LATCH_create(rotationInvariance=True)):
super().__init__(kp_detector=kp_detector, kp_descriptor=kp_descriptor)
[docs]
class BoostFD(FeatureDD):
"""Uses BRISK for feature detection and Boost for feature description"""
def __init__(self, kp_detector=DEFAULT_FEATURE_DETECTOR,
kp_descriptor=cv2.xfeatures2d.BoostDesc_create()):
super().__init__(kp_detector=kp_detector, kp_descriptor=kp_descriptor)
[docs]
class VggFD(FeatureDD):
"""Uses BRISK for feature detection and VGG for feature description"""
def __init__(self, kp_detector=DEFAULT_FEATURE_DETECTOR,
kp_descriptor=cv2.xfeatures2d.VGG_create(scale_factor=5.0)):
super().__init__(kp_detector=kp_detector, kp_descriptor=kp_descriptor)
[docs]
class OrbVggFD(FeatureDD):
"""Uses ORB for feature detection and VGG for feature description"""
def __init__(self, kp_detector=cv2.ORB_create(nfeatures=MAX_FEATURES, fastThreshold=0), kp_descriptor=cv2.xfeatures2d.VGG_create(scale_factor=0.75)):
super().__init__(kp_detector=kp_detector, kp_descriptor=kp_descriptor)
# Example of a custom detector that uses the Censure feature detector
# from scikit-image along with the KAZE descriptor (OpenCV)
class FeatureDetector(object):
"""Abstract class that detects features in an image
Features should be returned in a list of OpenCV cv2.KeyPoint objects.
Useful if wanting to use a non-OpenCV feature detector
Attributes
----------
detector : object
Object that can detect image features.
Methods
-------
detect(image)
Interface
---------
Required methods are: detect
"""
def __init__(self):
self.detector = None
def detect(self, image):
"""
Use detector to detect features, and return keypoints as XY
Returns
---------
kp : KeyPoints
List of OpenCV KeyPoint objects
"""
pass
# Example of how to create a feature detector using OpenCV + skimage #
class SkCensureDetector(FeatureDetector):
"""A CENSURE feature detector from scikit image
This scikit-image feature detecotr can be used with an
OpenCV feature descriptor
"""
def __init__(self, **kwargs):
super().__init__()
self.detector = feature.CENSURE(**kwargs)
def detect(self, image):
"""
Detect keypoints in image using CENSURE.
See https://scikit-image.org/docs/dev/api/skimage.feature.html#skimage.feature.CENSURE
Uses keypoint info to create KeyPoint objects for OpenCV
Paramters
---------
image : ndarray
image from keypoints will be detected
Returns
---------
kp : KeyPoints
List of OpenCV KeyPoint objects
"""
self.detector.detect(image)
# Skimage returns keypoints as row, col, but need to be returned as xy
kp_xy = self.detector.keypoints[:, ::-1].astype(float)
# Now create a list of OpenCV KeyPoint objects with these coordinates
kp = cv2.KeyPoint_convert(kp_xy.tolist())
return kp
class CensureVggFD(FeatureDD):
def __init__(self, kp_detector=SkCensureDetector(mode="Octagon",
max_scale=8, non_max_threshold=0.02),
kp_descriptor=cv2.xfeatures2d.VGG_create(scale_factor=6.25)):
super().__init__(kp_detector=kp_detector, kp_descriptor=kp_descriptor)
self.kp_descriptor_name = self.__class__.__name__
self.kp_detector_name = self.__class__.__name__
# Example of a custom detector and descriptor using scikit-image #
class SkDaisy(FeatureDD):
def __init__(self, dasiy_arg_dict=None):
"""
Create FeatureDD that uses scikit-image's dense DASIY
https://scikit-image.org/docs/dev/auto_examples/features_detection/plot_daisy.html#sphx-glr-auto-examples-features-detection-plot-daisy-py
"""
self.dasiy_arg_dict = {"step": 4,
"radius": 15,
"rings": 3,
"histograms": 8,
"orientations": 8,
"normalization": "l1",
"sigmas": None,
"ring_radii": None,
"visualize": False
}
if dasiy_arg_dict is not None:
self.dasiy_arg_dict.update(dasiy_arg_dict)
self.kp_descriptor_name = self.__class__.__name__
self.kp_detector_name = self.__class__.__name__
def detect_and_compute(self, image, mask=None):
descs = feature.daisy(image, **self.dasiy_arg_dict)
# Keypoints in a regular grid, and each point has a feature array #
# Below determines grid and then gets features
rows = np.arange(0, descs.shape[0])
cols = np.arange(0, descs.shape[1])
all_rows, all_cols = np.meshgrid(rows, cols)
all_rows = all_rows.reshape(-1)
all_cols = all_cols.reshape(-1)
n_samples = len(all_rows)
flat_desc = [descs[all_rows[i]][all_cols[i]] for i in range(n_samples)]
desc2d = np.vstack(flat_desc)
step = self.dasiy_arg_dict["step"]
radius = self.dasiy_arg_dict["radius"]
feature_x = all_cols * step + radius
feature_y = all_rows * step + radius
kp_xy = np.dstack([feature_x, feature_y])[0]
return kp_xy, desc2d
[docs]
class SuperPointFD(FeatureDD):
"""SuperPoint `FeatureDD`
Use SuperPoint to detect and describe features (`detect_and_compute`)
Adapted from https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/match_pairs.py
References
-----------
Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew
Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural
Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763
"""
def __init__(self, keypoint_threshold=0.005, nms_radius=4, force_cpu=False, kp_descriptor=None, kp_detector=None):
"""
Parameters
----------
keypoint_threshold : float
SuperPoint keypoint detector confidence threshold
nms_radius : int
SuperPoint Non Maximum Suppression (NMS) radius (must be positive)
force_cpu : bool
Force pytorch to run in CPU mode
kp_descriptor : optional, OpenCV feature desrciptor
"""
super().__init__(kp_detector=kp_detector, kp_descriptor=kp_descriptor)
self.keypoint_threshold = keypoint_threshold
self.nms_radius = nms_radius
self.device = 'cuda' if torch.cuda.is_available() and not force_cpu else "cpu"
if kp_detector is None:
self.kp_detector_name = "SuperPoint"
self.kp_detector = None
else:
self.kp_detector_name = kp_detector.__class__.__name__
if kp_descriptor is None:
self.kp_descriptor_name = "SuperPoint"
self.kp_descriptor = None
else:
self.kp_descriptor_name = kp_descriptor.__class__.__name__
self.config = {
'superpoint': {
'nms_radius': self.nms_radius,
'keypoint_threshold': self.keypoint_threshold,
'max_keypoints': MAX_FEATURES
}}
def frame2tensor(self, img):
tensor = torch.from_numpy(img/255.).float()[None, None].to(self.device)
return tensor
def detect(self, img):
if self.kp_detector is None:
kp_pos_xy, _ = self.detect_and_compute_sg(img)
else:
kp = self.kp_detector.detect(img)
kp_pos_xy = np.array([k.pt for k in kp])
return kp_pos_xy
def compute(self, img, kp_pos_xy):
if self.kp_descriptor is None:
sp = superpoint.SuperPoint(self.config["superpoint"])
x = sp.relu(sp.conv1a(self.frame2tensor(img)))
x = sp.relu(sp.conv1b(x))
x = sp.pool(x)
x = sp.relu(sp.conv2a(x))
x = sp.relu(sp.conv2b(x))
x = sp.pool(x)
x = sp.relu(sp.conv3a(x))
x = sp.relu(sp.conv3b(x))
x = sp.pool(x)
x = sp.relu(sp.conv4a(x))
x = sp.relu(sp.conv4b(x))
cDa = sp.relu(sp.convDa(x))
descriptors = sp.convDb(cDa)
descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
descriptors = [superpoint.sample_descriptors(k[None], d[None], 8)[0]
for k, d in zip([torch.from_numpy(kp_pos_xy.astype(np.float32))], descriptors)]
descriptors = descriptors[0].detach().numpy().T
else:
kp = cv2.KeyPoint_convert(kp_pos_xy.tolist())
kp, descriptors = self.kp_descriptor.compute(img, kp)
if descriptors.shape[0] > MAX_FEATURES:
kp, descriptors = filter_features(kp, descriptors)
kp_pos_xy = np.array([k.pt for k in kp])
return descriptors
def detect_and_compute_sg(self, img):
inp = self.frame2tensor(img)
superpoint_obj = superpoint.SuperPoint(self.config.get('superpoint', {}))
pred = superpoint_obj({'image': inp})
pred = {**pred, **{k+'0': v for k, v in pred.items()}}
kp_pos_xy = pred['keypoints'][0].detach().numpy()
desc = pred['descriptors'][0].detach().numpy().T
return kp_pos_xy, desc
def detect_and_compute(self, img):
if self.kp_detector is None and self.kp_descriptor is None:
kp_pos_xy, desc = self.detect_and_compute_sg(img)
else:
kp_pos_xy = self.detect(img)
desc = self.compute(img, kp_pos_xy)
return kp_pos_xy, desc