Source code for wbia.algo.smk.vocab_indexer

# -*- coding: utf-8 -*-
import logging
from wbia import dtool
import utool as ut
import vtool as vt
from vtool._pyflann_backend import pyflann as pyflann
from wbia.algo.smk import pickle_flann
import numpy as np
import warnings
from wbia.control.controller_inject import register_preprocs

(print, rrr, profile) = ut.inject2(__name__)
logger = logging.getLogger('wbia')


derived_attribute = register_preprocs['annot']


[docs]class VocabConfig(dtool.Config): _param_info_list = [ ut.ParamInfo('algorithm', 'minibatch', 'alg'), ut.ParamInfo('random_seed', 42, 'seed'), ut.ParamInfo('num_words', 1000, 'n'), ut.ParamInfo('version', 2), ut.ParamInfo('n_init', 1), ]
[docs]@ut.reloadable_class class VisualVocab(ut.NiceRepr): """ Class that maintains a list of visual words (cluster centers) Also maintains a nearest neighbor index structure for finding words. This class is build using the depcache """ def __init__(vocab, words=None): vocab.wx_to_word = words vocab.wordflann = None vocab.flann_params = vt.get_flann_params(random_seed=42) vocab.flann_params['checks'] = 1024 vocab.flann_params['trees'] = 8 # TODO: grab the depcache rowid and maybe config? # make a dtool.Computable def __nice__(vocab): return 'nW=%r' % (ut.safelen(vocab.wx_to_word)) def __len__(vocab): return len(vocab.wx_to_word) @property def shape(vocab): return vocab.wx_to_word.shape def __getstate__(vocab): state = vocab.__dict__.copy() if 'wx2_word' in state: state['wx_to_word'] = state.pop('wx2_word') state['wordindex_bytes'] = vocab.wordflann.dumps() del state['wordflann'] return state def __setstate__(vocab, state): wordindex_bytes = state.pop('wordindex_bytes') vocab.__dict__.update(state) flannclass = pickle_flann.PickleFLANN vocab.wordflann = flannclass() try: vocab.wordflann.loads(wordindex_bytes, vocab.wx_to_word) except Exception: logger.info('Fixing vocab problem') vocab.build()
[docs] def build(vocab, verbose=True): num_vecs = len(vocab.wx_to_word) if vocab.wordflann is None: flannclass = pickle_flann.PickleFLANN vocab.wordflann = flannclass() if verbose: logger.info(' ...build kdtree with %d points (may take a sec).' % num_vecs) tt = ut.tic(msg='Building vocab index') if num_vecs == 0: logger.info('WARNING: CANNOT BUILD FLANN INDEX OVER 0 POINTS.') logger.info('THIS MAY BE A SIGN OF A DEEPER ISSUE') else: vocab.wordflann.build_index(vocab.wx_to_word, **vocab.flann_params) if verbose: ut.toc(tt)
[docs] def nn_index(vocab, idx_to_vec, nAssign, checks=None): """ >>> idx_to_vec = depc.d.get_feat_vecs(aid_list)[0] >>> vocab = vocab >>> nAssign = 1 """ # Assign each vector to the nearest visual words assert nAssign > 0, 'cannot assign to 0 neighbors' if checks is None: checks = vocab.flann_params['checks'] try: idx_to_vec = idx_to_vec.astype(vocab.wordflann._FLANN__curindex_data.dtype) _idx_to_wx, _idx_to_wdist = vocab.wordflann.nn_index( idx_to_vec, nAssign, checks=checks ) except pyflann.FLANNException as ex: ut.printex( ex, 'probably misread the cached flann_fpath=%r' % (getattr(vocab.wordflann, 'flann_fpath', None),), ) raise else: _idx_to_wx = vt.atleast_nd(_idx_to_wx, 2) _idx_to_wdist = vt.atleast_nd(_idx_to_wdist, 2) return _idx_to_wx, _idx_to_wdist
[docs] def render_vocab(vocab): """ Renders the average patch of each word. This is a quick visualization of the entire vocabulary. CommandLine: python -m wbia.algo.smk.vocab_indexer render_vocab --show Example: >>> # DISABLE_DOCTEST >>> from wbia.algo.smk.vocab_indexer import * # NOQA >>> vocab = testdata_vocab('PZ_MTEST', num_words=64) >>> all_words = vocab.render_vocab() >>> ut.quit_if_noshow() >>> import wbia.plottool as pt >>> pt.qt4ensure() >>> pt.imshow(all_words) >>> ut.show_if_requested() """ import wbia.plottool as pt wx_list = list(range(len(vocab))) # wx_list = ut.strided_sample(wx_list, 64) wx_list = ut.strided_sample(wx_list, 64) word_patch_list = [] for wx in ut.ProgIter(wx_list, bs=True, lbl='building patches'): word = vocab.wx_to_word[wx] word_patch = vt.inverted_sift_patch(word, 64) word_patch = pt.render_sift_on_patch(word_patch, word) word_patch_list.append(word_patch) all_words = vt.stack_square_images(word_patch_list) return all_words
[docs]@derived_attribute( tablename='vocab', parents=['feat*'], colnames=['words'], coltypes=[VisualVocab], configclass=VocabConfig, chunksize=1, fname='visual_vocab', taggable=True, vectorized=False, ) def compute_vocab(depc, fid_list, config): r""" Depcache method for computing a new visual vocab CommandLine: python -m wbia.core_annots --exec-compute_neighbor_index --show python -m wbia show_depc_annot_table_input --show --tablename=neighbor_index python -m wbia.algo.smk.vocab_indexer --exec-compute_vocab:0 python -m wbia.algo.smk.vocab_indexer --exec-compute_vocab:1 # FIXME make util_tests register python -m wbia.algo.smk.vocab_indexer compute_vocab:0 Ignore: >>> # Lev Oxford Debug Example >>> import wbia >>> ibs = wbia.opendb('Oxford') >>> depc = ibs.depc >>> table = depc['vocab'] >>> # Check what currently exists in vocab table >>> table.print_configs() >>> table.print_table() >>> table.print_internal_info() >>> # Grab aids used to compute vocab >>> from wbia.expt.experiment_helpers import get_annotcfg_list >>> expanded_aids_list = get_annotcfg_list(ibs, ['oxford'])[1] >>> qaids, daids = expanded_aids_list[0] >>> vocab_aids = daids >>> config = {'num_words': 64000} >>> exists = depc.check_rowids('vocab', [vocab_aids], config=config) >>> print('exists = %r' % (exists,)) >>> vocab_rowid = depc.get_rowids('vocab', [vocab_aids], config=config)[0] >>> print('vocab_rowid = %r' % (vocab_rowid,)) >>> vocab = table.get_row_data([vocab_rowid], 'words')[0] >>> print('vocab = %r' % (vocab,)) Example: >>> # DISABLE_DOCTEST >>> from wbia.algo.smk.vocab_indexer import * # NOQA >>> # Test depcache access >>> import wbia >>> ibs, aid_list = wbia.testdata_aids('testdb1') >>> depc = ibs.depc_annot >>> input_tuple = [aid_list] >>> rowid_kw = {} >>> tablename = 'vocab' >>> vocabid_list = depc.get_rowids(tablename, input_tuple, **rowid_kw) >>> vocab = depc.get(tablename, input_tuple, 'words')[0] >>> assert vocab.wordflann is not None >>> assert vocab.wordflann._FLANN__curindex_data is not None >>> assert vocab.wordflann._FLANN__curindex_data is vocab.wx_to_word Example: >>> # DISABLE_DOCTEST >>> from wbia.algo.smk.vocab_indexer import * # NOQA >>> import wbia >>> ibs, aid_list = wbia.testdata_aids('testdb1') >>> depc = ibs.depc_annot >>> fid_list = depc.get_rowids('feat', aid_list) >>> config = VocabConfig() >>> vocab, train_vecs = ut.exec_func_src(compute_vocab, keys=['vocab', 'train_vecs']) >>> idx_to_vec = depc.d.get_feat_vecs(aid_list)[0] >>> self = vocab >>> ut.quit_if_noshow() >>> data = train_vecs >>> centroids = vocab.wx_to_word >>> import wbia.plottool as pt >>> vt.plot_centroids(data, centroids, num_pca_dims=2) >>> ut.show_if_requested() >>> #config = ibs.depc_annot['vocab'].configclass() """ logger.info('[IBEIS] COMPUTE_VOCAB:') vecs_list = depc.get_native('feat', fid_list, 'vecs') train_vecs = np.vstack(vecs_list).astype(np.float32) num_words = config['num_words'] logger.info( '[smk_index] Train Vocab(nWords=%d) using %d annots and %d descriptors' % (num_words, len(fid_list), len(train_vecs)) ) if config['algorithm'] == 'kdtree': flann_params = vt.get_flann_params(random_seed=42) kwds = dict(max_iters=20, flann_params=flann_params) words = vt.akmeans(train_vecs, num_words, **kwds) elif config['algorithm'] == 'minibatch': logger.info('Using minibatch kmeans') import sklearn.cluster rng = np.random.RandomState(config['random_seed']) n_init = config['n_init'] with warnings.catch_warnings(): warnings.simplefilter('ignore') init_size = int(num_words * 4) batch_size = 1000 n_batches = ut.get_num_chunks(train_vecs.shape[0], batch_size) minibatch_params = dict( n_clusters=num_words, init='k-means++', init_size=init_size, n_init=n_init, max_iter=30000 // n_batches, batch_size=batch_size, tol=0.0, max_no_improvement=10, reassignment_ratio=0.01, ) logger.info('minibatch_params = %s' % (ut.repr4(minibatch_params),)) clusterer = sklearn.cluster.MiniBatchKMeans( compute_labels=False, random_state=rng, verbose=2, **minibatch_params ) try: clusterer.fit(train_vecs) except (Exception, KeyboardInterrupt) as ex: ut.printex(ex, tb=True) if ut.is_developer(): ut.embed() else: raise words = clusterer.cluster_centers_ logger.info('Finished clustering') # if False: # flann_params['checks'] = 64 # flann_params['trees'] = 4 # num_words = 128 # centroids = vt.initialize_centroids(num_words, train_vecs, 'akmeans++') # words, hist = vt.akmeans_iterations( # train_vecs, centroids, max_iters=1000, monitor=True, # flann_params=flann_params) logger.info('Constructing vocab') vocab = VisualVocab(words) logger.info('Building vocab index') vocab.build() logger.info('Returning vocab') return (vocab,)
[docs]def testdata_vocab(defaultdb='testdb1', **kwargs): """ >>> from wbia.algo.smk.vocab_indexer import * # NOQA >>> defaultdb='testdb1' >>> kwargs = {'num_words': 1000} """ import wbia ibs, aids = wbia.testdata_aids(defaultdb=defaultdb) config = kwargs # vocab = new_load_vocab(ibs, aid_list, kwargs) # Hack in depcache info to the loaded vocab class # (maybe this becomes part of the depcache) rowid = ibs.depc.get_rowids('vocab', [aids], config=config)[0] # rowid = 1 table = ibs.depc['vocab'] vocab = table.get_row_data([rowid], 'words')[0] vocab.rowid = rowid vocab.config_history = table.get_config_history([vocab.rowid])[0] vocab.config = table.get_row_configs([vocab.rowid])[0] return vocab