Source code for wbia.algo.hots.requery_knn

# -*- coding: utf-8 -*-
import logging
import numpy as np
import utool as ut
import vtool as vt
import itertools as it

(print, rrr, profile) = ut.inject2(__name__)
logger = logging.getLogger('wbia')


# DEBUG_REQUERY = True
DEBUG_REQUERY = False


[docs]class FinalResults(ut.NiceRepr): def __init__(final, shape): final.qfx2_idx = np.full(shape, -1, dtype=np.int32) final.qfx2_dist = np.full(shape, np.nan, dtype=np.float64) final.qfx2_truek = np.full(shape, -1, dtype=np.int32)
[docs] def assign(final, index, idxs, dists, trueks): final.qfx2_idx[index, :] = idxs final.qfx2_dist[index, :] = dists final.qfx2_truek[index, :] = trueks
def __nice__(final): return str(final.qfx2_idx)
[docs]class TempQuery(ut.NiceRepr): """queries that are incomplete""" def __init__(query, vecs, invalid_axs, get_neighbors, get_axs): # Static attributes query.invalid_axs = invalid_axs query.get_neighbors = get_neighbors query.get_axs = get_axs # Dynamic attributes query.index = np.arange(len(vecs)) query.vecs = vecs def __nice__(query): return str(query.index)
[docs] def neighbors(query, temp_K): _idxs, _dists = query.get_neighbors(query.vecs, temp_K) idxs = vt.atleast_nd(_idxs, 2) dists = vt.atleast_nd(_dists, 2) # Flag any neighbors that are invalid validflags = ~in1d_shape(query.get_axs(idxs), query.invalid_axs) # Store results in an object cand = TempResults(query.index, idxs, dists, validflags) return cand
[docs] def compress_inplace(query, flags): query.index = query.index.compress(flags, axis=0) query.vecs = query.vecs.compress(flags, axis=0)
[docs]class TempResults(ut.NiceRepr): def __init__(cand, index, idxs, dists, validflags): cand.index = index cand.idxs = idxs cand.dists = dists cand.validflags = validflags def __nice__(cand): return str(cand.index)
[docs] def compress(cand, flags): qfx = cand.index.compress(flags, axis=0) idx_ = cand.idxs.compress(flags, axis=0) dist_ = cand.dists.compress(flags, axis=0) valid_ = cand.validflags.compress(flags, axis=0) return TempResults(qfx, idx_, dist_, valid_)
[docs] def done_flags(cand, num_neighbs): return cand.validflags.sum(axis=1) >= num_neighbs
[docs] def done_part(cand, num_neighbs): # Find the first `num_neighbs` complete columns in each row rowxs, colxs = np.where(cand.validflags) unique_rows, groupxs = vt.group_indices(rowxs, assume_sorted=True) first_k_groupxs = [groupx[0:num_neighbs] for groupx in groupxs] if DEBUG_REQUERY: assert all(ut.issorted(groupx) for groupx in groupxs) assert all([len(group) == num_neighbs for group in first_k_groupxs]) chosen_xs = np.array(ut.flatten(first_k_groupxs), dtype=np.int) # chosen_xs = np.hstack(first_k_groupxs) # then convert these to multi-indices done_rows = rowxs.take(chosen_xs) done_cols = colxs.take(chosen_xs) multi_index = (done_rows, done_cols) # done_shape = (cand.validflags.shape[0], num_neighbs) # flat_xs = np.ravel_multi_index(multi_index, done_shape) flat_xs = np.ravel_multi_index(multi_index, cand.idxs.shape) _shape = (-1, num_neighbs) idxs = cand.idxs.take(flat_xs).reshape(_shape) dists = cand.dists.take(flat_xs).reshape(_shape) trueks = colxs.take(chosen_xs).reshape(_shape) if DEBUG_REQUERY: # dists2 = dists.copy() for count, (row, cols) in enumerate(zip(unique_rows, groupxs)): pass assert np.all(np.diff(dists, axis=1) >= 0) valid = cand.validflags.take(flat_xs).reshape(_shape) assert np.all(valid) return idxs, dists, trueks
[docs]def in1d_shape(arr1, arr2): return np.in1d(arr1, arr2).reshape(arr1.shape)
[docs]def requery_knn( get_neighbors, get_axs, qfx2_vec, num_neighbs, invalid_axs=[], pad=2, limit=4, recover=True, ): """ Searches for `num_neighbs`, while ignoring certain matches. K is increassed until enough valid neighbors are found or a limit is reached. Example: >>> # ENABLE_DOCTEST >>> from wbia.algo.hots.neighbor_index import * # NOQA >>> import wbia >>> qreq_ = wbia.testdata_qreq_(defaultdb='testdb1', a='default') >>> qreq_.load_indexer() >>> indexer = qreq_.indexer >>> qannot = qreq_.internal_qannots[1] >>> qfx2_vec = qannot.vecs >>> ibs = qreq_.ibs >>> qaid = qannot.aid >>> impossible_aids = ibs.get_annot_groundtruth(qaid, noself=False) >>> invalid_axs = np.array(ut.take(indexer.aid2_ax, impossible_aids)) >>> pad = 0 >>> limit = 1 >>> num_neighbs = 3 >>> def get_neighbors(vecs, temp_K): >>> return indexer.flann.nn_index(vecs, temp_K, checks=indexer.checks, >>> cores=indexer.cores) >>> get_axs = indexer.get_nn_axs >>> res = requery_knn( >>> get_neighbors, get_axs, qfx2_vec, num_neighbs, invalid_axs, pad, >>> limit, recover=True) >>> qfx2_idx, qfx2_dist = res >>> assert np.all(np.diff(qfx2_dist, axis=1) >= 0) Ignore: >>> from wbia.algo.hots.neighbor_index import * # NOQA >>> from wbia.algo.hots.requery_knn import * # NOQA >>> max_k = 9 >>> n_pts = 5 >>> num_neighbs = 3 >>> temp_K = num_neighbs * 2 >>> # >>> # Create dummy data >>> rng = np.random.RandomState(0) >>> tx2_idx_full = rng.randint(0, 10, size=(n_pts, max_k)) >>> tx2_idx_full[:, 0] = 0 >>> tx2_dist_full = np.meshgrid(np.arange(max_k), np.arange(n_pts))[0] / 10 >>> tx2_dist_full += (rng.rand(n_pts, max_k) * 10).astype(np.int) / 100 >>> qfx2_vec = np.arange(n_pts)[:, None] >>> vecs = qfx2_vec >>> # >>> pad = 0 >>> limit = 1 >>> recover = True >>> # >>> invalid_axs = np.array([0, 1, 2, 5, 7, 9]) >>> get_axs = ut.identity >>> # >>> def get_neighbors(vecs, temp_K): >>> # simulates finding k nearest neighbors >>> idxs = tx2_idx_full[vecs.ravel(), 0:temp_K] >>> dists = tx2_dist_full[vecs.ravel(), 0:temp_K] >>> return idxs, dists >>> # >>> res = requery_knn( >>> get_neighbors, get_axs, qfx2_vec, num_neighbs, invalid_axs, pad, >>> limit, recover=True) >>> qfx2_idx, qfx2_dist = res """ # Alloc space for final results shape = (len(qfx2_vec), num_neighbs) final = FinalResults(shape) # NOQA query = TempQuery(qfx2_vec, invalid_axs, get_neighbors, get_axs) temp_K = num_neighbs + pad assert limit > 0, 'must have at least one iteration' at_limit = False for count in it.count(): # logger.info('count = %r' % (count,)) cand = query.neighbors(temp_K) # Find which query features have found enough neighbors done_flags = cand.done_flags(num_neighbs) if DEBUG_REQUERY: logger.info('count = %r' % (count,)) assert np.all(np.diff(cand.dists, axis=1) >= 0) logger.info('done_flags = %r' % (done_flags,)) # Move any done queries into results and compress the query if np.any(done_flags): # Get the valid part of the results done = cand.compress(done_flags) idxs, dists, trueks = done.done_part(num_neighbs) final.assign(done.index, idxs, dists, trueks) if DEBUG_REQUERY: assert np.all(np.diff(dists, axis=1) >= 0) blocks = final.qfx2_dist nanelem_flags = np.isnan(blocks) nanrow_flags = np.any(nanelem_flags, axis=1) assert np.all(nanelem_flags.sum(axis=1)[nanrow_flags] == num_neighbs) assert np.all(np.diff(blocks[~nanrow_flags], axis=1) >= 0) logger.info('final.qfx2_dist') logger.info(final.qfx2_dist) if np.all(done_flags): # If everything was found then we are done break else: # Continue query with remaining invalid results query.compress_inplace(~done_flags) # double the search space temp_K *= 2 at_limit = limit is not None and count >= limit if at_limit: if len(done_flags) == 0: import utool utool.embed() logger.info( '[knn] Hit limit=%r and found %d/%d' % (limit, sum(done_flags), len(done_flags)) ) break if at_limit and recover: # If over the limit, then we need to do the best with what we have # otherwise we would just return nan best = cand.compress(~done_flags) logger.info('[knn] Recover for %d features' % (len(best.index))) # Simply override the last indices to be valid and use those best.validflags[:, -num_neighbs:] = True # Now we can find a valid part idxs, dists, trueks = best.done_part(num_neighbs) final.assign(best.index, idxs, dists, trueks) if DEBUG_REQUERY: logger.info('final.qfx2_dist') logger.info(final.qfx2_dist) return final.qfx2_idx, final.qfx2_dist