Source code for wbia.algo.smk.match_chips5

# -*- coding: utf-8 -*-
"""
TODO: semantic_uuids should be replaced with PCC-like hashes pertaining to
annotation clusters if any form of name scoring is used.
"""
import logging
from os.path import exists, join
from wbia.algo.hots import chip_match
import utool as ut
import numpy as np

(print, rrr, profile) = ut.inject2(__name__, '[mc5]')
logger = logging.getLogger('wbia')


[docs]class EstimatorRequest(ut.NiceRepr): def __init__(qreq_): qreq_.use_single_cache = True qreq_.use_bulk_cache = True qreq_.min_bulk_size = 64 qreq_.chunksize = 256 qreq_.prog_hook = None def __len__(qreq_): return len(qreq_.qaids)
[docs] def execute(qreq_, qaids=None, prog_hook=None, use_cache=True): # assert qaids is None if qaids is not None: qreq_ = qreq_.shallowcopy(qaids) if use_cache: cm_list = execute_bulk(qreq_) else: cm_list = qreq_.execute_pipeline() # cm_list = qreq_.execute_pipeline() return cm_list
[docs] def shallowcopy(qreq_, qaids=None): """ Creates a copy of qreq with the same qparams object and a subset of the qx and dx objects. used to generate chunks of vsone and vsmany queries Example: >>> # ENABLE_DOCTEST >>> from wbia.algo.smk.match_chips5 import * # NOQA >>> from wbia.algo.smk.smk_pipeline import testdata_smk >>> import wbia >>> wbia, smk, qreq_ = testdata_smk() >>> qreq2_ = qreq_.shallowcopy(qaids=1) >>> assert qreq_.daids is qreq2_.daids, 'should be the same' >>> assert len(qreq_.qaids) != len(qreq2_.qaids), 'should be diff' >>> #assert qreq_.metadata is not qreq2_.metadata """ # qreq2_ = qreq_.__class__() cls = qreq_.__class__ qreq2_ = cls.__new__(cls) qreq2_.__dict__.update(qreq_.__dict__) qaids = ut.ensure_iterable(qaids) assert ut.issubset(qaids, qreq_.qaids), 'not a subset' qreq2_.qaids = qaids return qreq2_
[docs] def get_pipe_hashid(qreq_): return ut.hashstr27(str(qreq_.stack_config))
[docs] def get_pipe_cfgstr(qreq_): pipe_cfgstr = qreq_.stack_config.get_cfgstr() return pipe_cfgstr
[docs] def get_data_hashid(qreq_): data_hashid = qreq_.ibs.get_annot_hashid_semantic_uuid(qreq_.daids, prefix='D') return data_hashid
[docs] def get_query_hashid(qreq_): # TODO: SYSTEM : semantic should only be used if name scoring is on query_hashid = qreq_.ibs.get_annot_hashid_semantic_uuid(qreq_.qaids, prefix='Q') return query_hashid
[docs] def get_cfgstr( qreq_, with_input=False, with_data=True, with_pipe=True, hash_pipe=False ): cfgstr_list = [] if with_input: cfgstr_list.append(qreq_.get_query_hashid()) if with_data: cfgstr_list.append(qreq_.get_data_hashid()) if with_pipe: pipe_cfgstr = qreq_.get_pipe_cfgstr() if hash_pipe: pipe_cfgstr = ut.hashstr27(pipe_cfgstr) cfgstr_list.append(pipe_cfgstr) cfgstr = ''.join(cfgstr_list) return cfgstr
def __nice__(qreq_): return ' '.join(qreq_.get_nice_parts())
[docs] def get_chipmatch_fpaths(qreq_, qaid_list): r""" Efficient function to get a list of chipmatch paths """ cfgstr = qreq_.get_cfgstr(with_input=False, with_data=True, with_pipe=True) qauuid_list = qreq_.ibs.get_annot_semantic_uuids(qaid_list) fname_list = [ chip_match.get_chipmatch_fname(qaid, qreq_, qauuid=qauuid, cfgstr=cfgstr) for qaid, qauuid in zip(qaid_list, qauuid_list) ] dpath = ut.ensuredir((qreq_.cachedir, 'mc5_cms')) fpath_list = [join(dpath, fname) for fname in fname_list] return fpath_list
[docs] def get_nice_parts(qreq_): parts = [] parts.append(qreq_.ibs.get_dbname()) parts.append('nQ=%d' % len(qreq_.qaids)) parts.append('nD=%d' % len(qreq_.daids)) parts.append(qreq_.get_pipe_hashid()) return parts
# Hacked in functions
[docs] def ensure_nids(qreq_): # Hacked over from hotspotter, seriously hacky ibs = qreq_.ibs qreq_.unique_aids = np.union1d(qreq_.qaids, qreq_.daids) qreq_.unique_nids = ibs.get_annot_nids(qreq_.unique_aids) qreq_.aid_to_idx = ut.make_index_lookup(qreq_.unique_aids)
[docs] @ut.accepts_numpy def get_qreq_annot_nids(qreq_, aids): # uses own internal state to grab name rowids instead of using wbia. if not hasattr(qreq_, 'aid_to_idx'): qreq_.ensure_nids() idxs = ut.take(qreq_.aid_to_idx, aids) nids = ut.take(qreq_.unique_nids, idxs) return nids
[docs] @ut.accepts_numpy def get_qreq_annot_gids(qreq_, aids): # Hack uses own internal state to grab name rowids # instead of using wbia. return qreq_.ibs.get_annot_gids(aids)
@property def dnids(qreq_): """TODO: save dnids in qreq_ state""" # return qreq_.dannots.nids return qreq_.get_qreq_annot_nids(qreq_.daids) @property def qnids(qreq_): """TODO: save qnids in qreq_ state""" # return qreq_.qannots.nids return qreq_.get_qreq_annot_nids(qreq_.qaids) @property def extern_query_config2(qreq_): return qreq_.qparams @property def extern_data_config2(qreq_): return qreq_.qparams
[docs]def execute_bulk(qreq_): # Do not use bulk single queries bulk_on = qreq_.use_bulk_cache and len(qreq_.qaids) > qreq_.min_bulk_size if bulk_on: # Try and load directly from a big cache bc_dpath = ut.ensuredir((qreq_.cachedir, 'bulk_mc5')) bc_fname = 'bulk_mc5_' + '_'.join(qreq_.get_nice_parts()) bc_cfgstr = qreq_.get_cfgstr(with_input=True) try: cm_list = ut.load_cache(bc_dpath, bc_fname, bc_cfgstr) logger.info('... bulk cache hit %r/%r' % (len(qreq_), len(qreq_))) except (IOError, AttributeError): # Fallback to smallcache cm_list = execute_singles(qreq_) ut.save_cache(bc_dpath, bc_fname, bc_cfgstr, cm_list) else: # Fallback to smallcache cm_list = execute_singles(qreq_) return cm_list
def _load_singles(qreq_): # Find existing cached chip matches # Try loading as many as possible fpath_list = qreq_.get_chipmatch_fpaths(qreq_.qaids) exists_flags = [exists(fpath) for fpath in fpath_list] qaids_hit = ut.compress(qreq_.qaids, exists_flags) fpaths_hit = ut.compress(fpath_list, exists_flags) # First, try a fast reload assuming no errors fpath_iter = ut.ProgIter( fpaths_hit, length=len(fpaths_hit), enabled=len(fpaths_hit) > 1, label='loading cache hits', adjust=True, freq=1, ) try: qaid_to_hit = { qaid: chip_match.ChipMatch.load_from_fpath(fpath, verbose=False) for qaid, fpath in zip(qaids_hit, fpath_iter) } except chip_match.NeedRecomputeError as ex: # Fallback to a slow reload ut.printex(ex, 'Some cached results need to recompute', iswarning=True) qaid_to_hit = _load_singles_fallback(fpaths_hit) return qaid_to_hit def _load_singles_fallback(fpaths_hit): fpath_iter = ut.ProgIter( fpaths_hit, enabled=len(fpaths_hit) > 1, label='checking chipmatch cache', adjust=True, freq=1, ) # Recompute those that fail loading qaid_to_hit = {} for fpath in fpath_iter: try: cm = chip_match.ChipMatch.load_from_fpath(fpath, verbose=False) except chip_match.NeedRecomputeError: pass else: qaid_to_hit[cm.qaid] = cm logger.info( '%d / %d cached matches need to be recomputed' % (len(fpaths_hit) - len(qaid_to_hit), len(fpaths_hit)) ) return qaid_to_hit
[docs]def execute_singles(qreq_): if qreq_.use_single_cache: qaid_to_hit = _load_singles(qreq_) else: qaid_to_hit = {} hit_all = len(qaid_to_hit) == len(qreq_.qaids) hit_any = len(qaid_to_hit) > 0 if hit_all: qaid_to_cm = qaid_to_hit else: if hit_any: logger.info('... partial cm cache hit %d/%d' % (len(qaid_to_hit), len(qreq_))) hit_aids = list(qaid_to_hit.keys()) miss_aids = ut.setdiff(qreq_.qaids, hit_aids) qreq_miss = qreq_.shallowcopy(miss_aids) else: qreq_miss = qreq_ # Compute misses qreq_miss.ensure_data() qaid_to_cm = execute_and_save(qreq_miss) # Merge misses with hits if hit_any: qaid_to_cm.update(qaid_to_hit) cm_list = ut.take(qaid_to_cm, qreq_.qaids) return cm_list
[docs]def execute_and_save(qreq_miss): # Iterate over vsone queries in chunks. total_chunks = ut.get_num_chunks(len(qreq_miss.qaids), qreq_miss.chunksize) qaid_chunk_iter = ut.ichunks(qreq_miss.qaids, qreq_miss.chunksize) _prog = ut.ProgPartial( length=total_chunks, freq=1, label='[mc5] query chunk: ', prog_hook=qreq_miss.prog_hook, bs=False, ) qaid_chunk_iter = iter(_prog(qaid_chunk_iter)) qaid_to_cm = {} for qaids in qaid_chunk_iter: sub_qreq = qreq_miss.shallowcopy(qaids=qaids) cm_batch = sub_qreq.execute_pipeline() assert len(cm_batch) == len(qaids), 'bad alignment' assert all([qaid == cm.qaid for qaid, cm in zip(qaids, cm_batch)]) # TODO: we already computed the fpaths # should be able to pass them in fpath_list = sub_qreq.get_chipmatch_fpaths(qaids) _prog = ut.ProgPartial( length=len(cm_batch), adjust=True, freq=1, label='saving chip matches', bs=True, ) for cm, fpath in _prog(zip(cm_batch, fpath_list)): cm.save_to_fpath(fpath, verbose=False) qaid_to_cm.update({cm.qaid: cm for cm in cm_batch}) return qaid_to_cm