Source code for wbia.dtool.base

# -*- coding: utf-8 -*-
import logging
import re
import functools
import operator as op
import utool as ut
import numpy as np
import copy

(print, rrr, profile) = ut.inject2(__name__, '[depbase]')
logger = logging.getLogger('wbia.dtool')


[docs]class StackedConfig(ut.DictLike, ut.HashComparable): """ Manages a list of configurations """ def __init__(self, config_list): self._orig_config_list = config_list # Cast all inputs to config classes self._new_config_list = [ cfg if hasattr(cfg, 'get_cfgstr') else make_configclass(cfg, '') for cfg in self._orig_config_list ] # Parse out items self._items = ut.flatten( [ list(cfg.parse_items()) if hasattr(cfg, 'parse_items') else list(cfg.items()) for cfg in self._orig_config_list ] ) for key, val in self._items: setattr(self, key, val) # self.keys = ut.flatten(list(cfg.keys()) for cfg in self.config_list)
[docs] def get_cfgstr(self): cfgstr_list = [cfg.get_cfgstr() for cfg in self._new_config_list] cfgstr = '_'.join(cfgstr_list) return cfgstr
[docs] def keys(self): return ut.take_column(self._items, 0)
def __hash__(cfg): """Needed for comparison operators""" return hash(cfg.get_cfgstr())
[docs] def getitem(self, key): try: return getattr(self, key) except AttributeError as ex: raise KeyError(ex)
[docs]@functools.total_ordering class Config(ut.NiceRepr, ut.DictLike): r""" Base class for heirarchical config need to overwrite get_param_info_list CommandLine: python -m dtool.base Config Example: >>> # ENABLE_DOCTEST >>> from wbia.dtool.base import * # NOQA >>> cfg1 = Config.from_dict({'a': 1, 'b': 2}) >>> cfg2 = Config.from_dict({'a': 2, 'b': 2}) >>> # Must be hashable and orderable >>> hash(cfg1) >>> cfg1 > cfg2 """ def __init__(cfg, **kwargs): cfg._parent = None cfg.initialize_params(**kwargs)
[docs] def deepcopy(cfg): cfg2 = copy.deepcopy(cfg) cfg2._subconfig_attrs = copy.deepcopy(cfg._subconfig_attrs) cfg2._subconfig_names = copy.deepcopy(cfg._subconfig_names) try: cfg2._param_info_list = copy.deepcopy(cfg._param_info_list) except AttributeError: pass return cfg2
def __nice__(cfg): return cfg.get_cfgstr(with_name=False) def __lt__(self, other): """hash comparable broke in python3""" return ut.compare_instance(op.lt, self, other) def __eq__(self, other): """hash comparable broke in python3""" return ut.compare_instance(op.eq, self, other) def __hash__(cfg): """Needed for comparison operators""" return hash(cfg.get_cfgstr())
[docs] def get_config_name(cfg, **kwargs): """the user might want to overwrite this function""" # VERY HACKY config_name = cfg.__class__.__name__.replace('Config', '') config_name = re.sub('_$', '', config_name) return config_name
[docs] def get_varnames(cfg): return [pi.varname for pi in cfg.get_param_info_list()] + cfg._subconfig_attrs
[docs] def update(cfg, **kwargs): """ Overwrites default DictLike update for only keys that exist. Non-existing key are ignored. Note: prefixed keys in the form <classname>_<key> will be just be interpreted as <key> CommandLine: python -m dtool.base update --show Example: >>> # ENABLE_DOCTEST >>> from wbia.dtool.base import * # NOQA >>> from wbia.dtool.example_depcache import DummyVsManyConfig >>> cfg = DummyVsManyConfig() >>> cfg.update(DummyAlgo_version=4) >>> print(cfg) """ # FIXME: currently can't update subconfigs based on namespaces # and non-namespaced vars are in the context of the root level. # self_keys = set(cfg.__dict__.keys()) # self_keys.append(cfg.get_varnames()) _aliases = cfg._make_key_alias_checker() self_keys = set(cfg.keys()) for key, val in kwargs.items(): # update only existing keys or namespace prefixed keys for k in _aliases(key): if k in self_keys: cfg.setitem(k, val) break
[docs] def pop_update(cfg, other): """ Updates based on other, while popping off used arguments. (useful for testing if a parameter was unused or misspelled) Doctest: >>> from wbia.dtool.base import * # NOQA >>> from wbia import dtool as dt >>> cfg = dt.Config.from_dict({'a': 1, 'b': 2, 'c': 3}) >>> other = {'a': 5, 'e': 2} >>> cfg.pop_update(other) >>> assert cfg['a'] == 5 >>> assert len(other) == 1 and 'a' not in other """ _aliases = cfg._make_key_alias_checker() self_keys = set(cfg.keys()) for key in list(other.keys()): # update only existing keys or namespace prefixed keys for k in _aliases(key): if k in self_keys: val = other.pop(key) cfg.setitem(k, val) break
def _make_key_alias_checker(cfg): prefixes = (cfg.get_config_name(), cfg.__class__.__name__) def _aliases(key): yield key for part in prefixes: prefix = part + '_' if key.startswith(prefix): key_alias = key[len(prefix) :] yield key_alias return _aliases
[docs] def update2(cfg, *args, **kwargs): """ Overwrites default DictLike update for only keys that exist. Non-existing key are ignored. Also updates nested configs. Note: prefixed keys in the form <classname>_<key> will be just be interpreted as <key> CommandLine: python -m dtool.base update --show Example: >>> # ENABLE_DOCTEST >>> from wbia.dtool.base import * # NOQA >>> from wbia import dtool as dt >>> cfg = dt.Config.from_dict({ >>> 'a': 1, >>> 'b': 2, >>> 'c': 3, >>> 'sub1': dt.Config.from_dict({ >>> 'x': 'x', >>> 'y': {'z', 'x'}, >>> 'c': 33, >>> }), >>> 'sub2': dt.Config.from_dict({ >>> 's': [1, 2, 3], >>> 't': (1, 2, 3), >>> 'c': 42, >>> 'sub3': dt.Config.from_dict({ >>> 'b': 99, >>> 'c': 88, >>> }), >>> }), >>> }) >>> kwargs = {'c': 10} >>> cfg.update2(c=10, y={1,2}) >>> assert cfg.c == 10 >>> assert cfg.sub1.c == 10 >>> assert cfg.sub2.c == 10 >>> assert cfg.sub2.sub3.c == 10 >>> assert cfg.sub1.y == {1, 2} """ if len(args) > 1: raise ValueError('only specify one arg') elif len(args) == 1: kwargs.update(args[0]) return list(cfg._update2(kwargs))
def _update2(cfg, kwargs): # yields a list of keys updated as they happen _aliases = cfg._make_key_alias_checker() for key, val in cfg.native_items(): for k in _aliases(key): if k in kwargs: cfg.setitem(k, kwargs[k]) yield k break for key, val in cfg.nested_items(): val = cfg[key] if isinstance(val, Config): for k in val._update2(kwargs): yield k
[docs] def nested_items(cfg): for key in cfg.keys(): val = cfg[key] if isinstance(val, Config): yield key, val
[docs] def native_items(cfg): for key in cfg.keys(): val = cfg[key] if not isinstance(val, Config): yield key, val
[docs] def initialize_params(cfg, **kwargs): """Initializes config class attributes based on params info list""" # logger.info("INIT PARAMS") for pi in cfg.get_param_info_list(): setattr(cfg, pi.varname, pi.default) # SO HACKY # Hacks in implicit edges from nodes to the algorithm # using their subconfigurations cfg._subconfig_attrs = [] cfg._subconfig_names = [] _sub_config_list = cfg.get_sub_config_list() if _sub_config_list: for subclass in _sub_config_list: # subclass.static_config_name() subcfg = subclass() subcfg_name = subcfg.get_config_name() subcfg_attr = ut.to_underscore_case(subcfg_name) + '_cfg' setattr(cfg, subcfg_attr, subcfg) cfg._subconfig_names.append(subcfg_name) cfg._subconfig_attrs.append(subcfg_attr) subcfg.update(**kwargs) cfg.update(**kwargs)
[docs] def get_sub_config_list(cfg): if hasattr(cfg, '_sub_config_list'): return cfg._sub_config_list else: return []
[docs] def parse_namespace_config_items(cfg): """ Recursively extracts key, val pairs from Config objects into a flat list. (there must not be name conflicts) """ param_list = [] seen = set([]) for item in cfg.items(): key, val = item if hasattr(val, 'parse_namespace_config_items'): child_cfg = val child_params = child_cfg.parse_namespace_config_items() param_list.extend(child_params) elif hasattr(val, 'parse_items'): # hack for ut.Pref configs name = val.get_config_name() for key, val in val.parse_items(): if key in seen: logger.info( '[Config] WARNING: key=%r appears more than once' % (key,) ) seen.add(key) # Incorporate namespace param_list.append((name, key, val)) elif key.startswith('_'): pass else: if key in seen: logger.info( '[Config] WARNING: key=%r appears more than once' % (key,) ) seen.add(key) # Incorporate namespace name = cfg.get_config_name() param_list.append((name, key, val)) return param_list
[docs] def parse_items(cfg): r""" Returns: list: param_list CommandLine: python -m dtool.base --exec-parse_items Example: >>> # ENABLE_DOCTEST >>> from wbia.dtool.base import * # NOQA >>> from wbia.dtool.example_depcache import DummyVsManyConfig >>> cfg = DummyVsManyConfig() >>> param_list = cfg.parse_items() >>> result = ('param_list = %s' % (ut.repr2(param_list, nl=1),)) >>> print(result) """ namespace_param_list = cfg.parse_namespace_config_items() param_names = ut.get_list_column(namespace_param_list, 1) needs_namespace_keys = ut.find_duplicate_items(param_names) param_list = ut.get_list_column(namespace_param_list, [1, 2]) # prepend namespaces to variables that need it for idx in ut.flatten(needs_namespace_keys.values()): name = namespace_param_list[idx][0] param_list[idx][0] = name + '_' + param_list[idx][0] duplicate_keys = ut.find_duplicate_items(ut.get_list_column(param_list, 0)) # hack to let version through # import utool # with utool.embed_on_exception_context: assert len(duplicate_keys) == 0, ( 'Configs have duplicate names: %r' % duplicate_keys ) return param_list
[docs] def get_cfgstr_list(cfg, ignore_keys=None, with_name=True, **kwargs): """default get_cfgstr_list, can be overrided by a config object""" if ignore_keys is not None: itemstr_list = [ pi.get_itemstr(cfg) for pi in cfg.get_param_info_list() if pi.varname not in ignore_keys ] else: itemstr_list = [pi.get_itemstr(cfg) for pi in cfg.get_param_info_list()] filtered_itemstr_list = list(filter(len, itemstr_list)) if with_name: config_name = cfg.get_config_name() else: config_name = '' body = ','.join(filtered_itemstr_list) cfgstr = ''.join([config_name, '(', body, ')']) return cfgstr
[docs] def get_cfgstr(cfg, **kwargs): str_ = ''.join(cfg.get_cfgstr_list(**kwargs)) return '_'.join( [str_] + [cfg[subcfg_attr].get_cfgstr() for subcfg_attr in cfg._subconfig_attrs] )
[docs] def get_param_info_dict(cfg): param_info_list = cfg.get_param_info_list() param_info_dict = {pi.varname: pi for pi in param_info_list} return param_info_dict
[docs] def assert_self_types(cfg, verbose=True): if verbose: logger.info('Assert self types of cfg=%r' % (cfg,)) pi_dict = cfg.get_param_info_dict() for key in cfg.keys(): pi = pi_dict[key] value = cfg[key] pi.error_if_invalid_value(value) if verbose: logger.info('... checks passed')
[docs] def getinfo(cfg, key): pass
[docs] def get_hashid(cfg): return ut.hashstr27(cfg.get_cfgstr())
[docs] def keys(cfg): """Required for DictLike interface""" return cfg.get_varnames()
[docs] def getitem(cfg, key): """Required for DictLike interface""" try: return getattr(cfg, key) except AttributeError as ex: raise KeyError(ex)
[docs] def get(qparams, key, *d): """get a paramater value by string""" ERROR_ON_DEFAULT = False if ERROR_ON_DEFAULT: return getattr(qparams, key) else: return getattr(qparams, key, *d)
[docs] def setitem(cfg, key, value): """Required for DictLike interface""" # TODO; check for valid config setting pi_dict = cfg.get_param_info_dict() pi = pi_dict[key] pi.error_if_invalid_value(value) return setattr(cfg, key, value)
[docs] def get_param_info_list(cfg): try: return cfg._param_info_list except AttributeError: raise NotImplementedError( 'Need to define _param_info_list or get_param_info_list' )
[docs] @classmethod def from_argv_dict(cls, **kwargs): """ handy command line tool ut.parse_argv_cfg """ cfg = cls(**kwargs) new_vals = ut.parse_dict_from_argv(cfg) cfg.update(**new_vals) return cfg
[docs] @classmethod def from_argv_cfgs(cls): """ handy command line tool """ cfg = cls() name = cfg.get_config_name() # name = cls.static_config_name() argname = '--' + name if hasattr(cfg, '_alias'): argname = (argname, '--' + cfg._alias) # if hasattr(cls, '_alias'): # argname = (argname, '--' + cls._alias) new_vals_list = ut.parse_argv_cfg(argname) self_list = [cls(**new_vals) for new_vals in new_vals_list] return self_list
[docs] @classmethod def from_dict(cls, dict_, tablename=None): r""" Args: dict_ (dict_): a dictionary tablename (None): (default = None) Returns: list: param_info_list CommandLine: python -m dtool.base Config.from_dict --show Example: >>> # DISABLE_DOCTEST >>> from wbia.dtool.base import * # NOQA >>> cls = Config >>> dict_ = {'K': 1, 'Knorm': 5, 'min_pername': 1, 'max_pername': 1,} >>> tablename = None >>> config = cls.from_dict(dict_, tablename) >>> print(config) >>> # xdoctest: +REQUIRES(--show) >>> ut.quit_if_noshow() >>> dlg = config.make_qt_dialog( >>> title='Confirm Merge Query', >>> msg='Confirm') >>> dlg.resize(700, 500) >>> dlg.show() >>> import wbia.plottool as pt >>> self = dlg.widget >>> guitool.qtapp_loop(qwin=dlg) >>> updated_config = self.config # NOQA >>> print('updated_config = %r' % (updated_config,)) """ UnnamedConfig = cls.class_from_dict(dict_, tablename) config = UnnamedConfig() return config
[docs] @classmethod def class_from_dict(cls, dict_, tablename=None): if tablename is None: tablename = 'Unnamed' UnnamedConfig = make_configclass(dict_, tablename) return UnnamedConfig
[docs] def make_qt_dialog(cfg, parent=None, title='Edit Config', msg='Confim'): import wbia.guitool as gt gt.ensure_qapp() # must be ensured before any embeding dlg = gt.ConfigConfirmWidget.as_dialog(title=title, msg=msg, config=cfg) dlg.resize(700, 500) dlg.show() return dlg
[docs] def getstate_todict_recursive(cfg): from wbia import dtool _dict = cfg.asdict() _dict2 = {} for key, val in _dict.items(): if isinstance(val, dtool.Config): # val = val.asdict() try: val = val.getstate_todict_recursive() except Exception: val = getstate_todict_recursive(val) # NOQA _dict2[key] = val else: _dict2[key] = val return _dict2
def __getstate__(cfg): """ FIXME Example: >>> # ENABLE_DOCTEST >>> from wbia.dtool.base import * # NOQA >>> from wbia.dtool.example_depcache import DummyKptsConfig >>> import pickle >>> cfg = DummyKptsConfig() >>> ser = pickle.dumps(cfg) >>> cfg2 = pickle.loads(ser) >>> assert cfg == cfg2 >>> assert cfg is not cfg2 Example: >>> # ENABLE_DOCTEST >>> from wbia.dtool.base import * # NOQA >>> from wbia.dtool.example_depcache import DummyVsManyConfig >>> import pickle >>> cfg = DummyVsManyConfig() >>> state = cfg.__getstate__() >>> cfg2 = DummyVsManyConfig() >>> serialized = pickle.dumps(cfg) >>> unserialized = pickle.loads(serialized) >>> assert cfg == unserialized >>> assert cfg is not unserialized """ # from wbia import dtool # _dict = cfg.asdict() # _dict2 = {} # for key, val in _dict.items(): # if isinstance(val, dtool.Config): # val = val.asdict() # _dict2[key] = val # return {'dtool.Config': _dict2} return cfg.__dict__ def __setstate__(cfg, state): cfg.__dict__.update(**state)
# cfg.initialize_params() # cfg.update(**state) # @classmethod # def static_config_name(cls): # class_str = str(cls) # full_class_str = class_str.replace('<class \'', '').replace('\'>', '') # config_name = splitext(full_class_str)[1][1:].replace('Config', '') # return config_name
[docs]def make_configclass(dict_, tablename): """Creates a custom config class from a dict""" def rectify_item(key, val): if val is None: return ut.ParamInfo(key, val) elif isinstance(val, ut.ParamInfo): if val.varname is None: # Copy and assign a new varname pi = copy.deepcopy(val) pi.varname = key else: pi = val assert pi.varname == key, 'Given varname=%r does not match key=%r' % ( pi.varname, key, ) return pi else: if isinstance(val, Config): # Set table name from key when doing nested from dicts if val.__class__.__name__ == 'UnnamedConfig': val.__class__.__name__ = str(key + 'Config') return ut.ParamInfo(key, val, type_=type(val)) param_info_list = [rectify_item(key, val) for key, val in dict_.items()] return from_param_info_list(param_info_list, tablename)
[docs]def from_param_info_list(param_info_list, tablename='Unnamed'): from wbia import dtool class UnnamedConfig(dtool.Config): _param_info_list = param_info_list UnnamedConfig.__name__ = str(tablename + 'Config') return UnnamedConfig
[docs]class IBEISRequestHacks(object): _isnewreq = True @property def ibs(request): """HACK specific to wbia""" if request.depc is None: return None return request.depc.controller @property def qannots(self): return self.ibs.annots(self.qaids, self.params) @property def dannots(self): return self.ibs.annots(self.daids, self.params)
[docs] def get_qreq_annot_nids(self, aids): # VERY HACKY. To be just hacky it should store # the nids as a state, but whatever... # devleopment time constraints and whatnot return self.ibs.get_annot_nids(aids)
# return self.ibs.annots(self.daids, self.params) @property def extern_query_config2(request): return request.params @property def extern_data_config2(request): return request.params
# # def get_external_data_config2(request): # # HACK # #return None # #logger.info('[d] request.params = %r' % (request.params,)) # return request.params # def get_external_query_config2(request): # # HACK # #return None # #logger.info('[q] request.params = %r' % (request.params,)) # return request.params
[docs]def config_graph_subattrs(cfg, depc): # TODO: if this hack is fully completed need a way of getting the # full config belonging to both chip + feat # cfg = request.config.feat_cfg import networkx as netx tablename = ut.invert_dict(depc.configclass_dict)[cfg.__class__] # tablename = cfg.get_config_name() ancestors = netx.dag.ancestors(depc.graph, tablename) subconfigs_ = ut.dict_take(depc.configclass_dict, ancestors, None) subconfigs = ut.filter_Nones(subconfigs_) # NOQA
[docs]@ut.reloadable_class class BaseRequest(IBEISRequestHacks, ut.NiceRepr): r""" Class that maintains both an algorithm, inputs, and a config. """
[docs] @staticmethod def static_new(cls, depc, parent_rowids, cfgdict=None, tablename=None): """hack for autoreload""" request = cls() if tablename is None: try: if hasattr(cls, '_tablename'): tablename = cls._tablename else: tablename = ut.invert_dict(depc.requestclass_dict)[cls] except Exception as ex: ut.printex(ex, 'tablename must be given') raise request.tablename = tablename request.parent_rowids = parent_rowids request.depc = depc if cfgdict is None: cfgdict = {} configclass = depc.configclass_dict[tablename] config = configclass(**cfgdict) request.config = config # HACK FOR IBEIS request.params = dict(config.parse_items()) # HACK-ier FOR BACKWARDS COMPATABILITY if True: # params.featweight_cfgstr = query_cfg._featweight_cfg.get_cfgstr() # TODO: if this hack is fully completed need a way of getting the # full config belonging to both chip + feat try: request.params['chip_cfgstr'] = config.chip_cfg.get_cfgstr() request.params['chip_cfg_dict'] = config.chip_cfg.asdict() request.params['feat_cfgstr'] = config.feat_cfg.get_cfgstr() request.params['hesaff_params'] = config.feat_cfg.get_hesaff_params() request.params['featweight_cfgstr'] = config.feat_weight_cfg.get_cfgstr() except AttributeError: pass request.qparams = ut.DynStruct() for key, val in request.params.items(): setattr(request.qparams, key, val) return request
[docs] @classmethod def new(cls, depc, parent_rowids, cfgdict=None, tablename=None): return cls.static_new(cls, depc, parent_rowids, cfgdict, tablename)
def _get_rootset_hashid(request, root_rowids, prefix): uuid_type = 'V' label = ''.join((prefix, uuid_type, 'UUIDS')) # Hack: allow general specification of uuid types uuid_list = request.depc.get_root_uuid(root_rowids) # uuid_hashid = ut.hashstr_arr27(uuid_list, label, pathsafe=True) uuid_hashid = ut.hashstr_arr27(uuid_list, label, pathsafe=False) # TODO: uuid_hashid = ut.hashid_arr(uuid_list, label=label) return uuid_hashid
[docs] def get_cfgstr(request, with_input=False, with_pipe=True, **kwargs): r""" main cfgstring used to identify the 'querytype' """ cfgstr_list = [] if with_input: cfgstr_list.append(request.get_input_hashid()) if with_pipe: cfgstr_list.append(request.get_pipe_cfgstr()) cfgstr = '_'.join(cfgstr_list) return cfgstr
[docs] def get_input_hashid(request): raise NotImplementedError('abstract class method')
[docs] def get_pipe_cfgstr(request): return request.config.get_cfgstr()
[docs] def get_pipe_hashid(request): return ut.hashstr27(request.get_pipe_cfgstr())
[docs] def ensure_dependencies(request): r""" CommandLine: python -m dtool.base --exec-BaseRequest.ensure_dependencies Example: >>> # ENABLE_DOCTEST >>> from wbia.dtool.base import * # NOQA >>> from wbia.dtool.example_depcache import testdata_depc >>> depc = testdata_depc() >>> request = depc.new_request('vsmany', [1, 2], [2, 3, 4]) >>> request.ensure_dependencies() """ import networkx as nx depc = request.depc if False: dependencies = nx.ancestors(depc.graph, request.tablename) subgraph = depc.graph.subgraph(set.union(dependencies, {request.tablename})) dependency_order = nx.topological_sort(subgraph) root = dependency_order[0] [ nx.algorithms.dijkstra_path(subgraph, root, start)[:-1] + nx.algorithms.dijkstra_path(subgraph, start, request.tablename) for start in dependency_order ] graph = depc.graph root = list(nx.topological_sort(graph))[0] edges = graph.edges() # parent_to_children = ut.edges_to_adjacency_list(edges) child_to_parents = ut.edges_to_adjacency_list([t[::-1] for t in edges]) to_root = { request.tablename: ut.paths_to_root(request.tablename, root, child_to_parents) } from_root = ut.reverse_path(to_root, root, child_to_parents) dependency_levels_ = ut.get_levels(from_root) dependency_levels = ut.longest_levels(dependency_levels_) true_order = ut.flatten(dependency_levels)[1:-1] # logger.info('[req] Ensuring %s request dependencies: %r' % (request, true_order,)) ut.colorprint( '[req] Ensuring request %s dependencies: %r' % (request, true_order), 'yellow', ) for tablename in true_order: table = depc[tablename] if table.ismulti: pass else: # HACK FOR IBEIS all_aids = ut.flat_unique(request.qaids, request.daids) depc.get_rowids(tablename, all_aids) pass pass # zip(depc.get_implicit_edges()) # zip(depc.get_implicit_edges()) # raise NotImplementedError('todo') # depc = request.depc # parent_rowids = request.parent_rowids # config = request.config # rowid_dict = depc.get_all_descendant_rowids( # request.tablename, root_rowids, config=config) pass
[docs] def execute(request, parent_rowids=None, use_cache=None, postprocess=True): ut.colorprint('[req] Executing request %s' % (request,), 'yellow') table = request.depc[request.tablename] if use_cache is None: use_cache = not ut.get_argflag('--nocache') if parent_rowids is None: parent_rowids = request.parent_rowids # Compute and cache any uncomputed results rowids = table.get_rowid(parent_rowids, config=request, recompute=not use_cache) # Load all results result_list = table.get_row_data(rowids) if postprocess and hasattr(request, 'postprocess_execute'): logger.info('Converting results') result_list = request.postprocess_execute(table, parent_rowids, result_list) pass return result_list
def __getstate__(request): state_dict = request.__dict__.copy() # SUPER HACK state_dict['dbdir'] = request.depc.controller.get_dbdir() del state_dict['depc'] del state_dict['config'] return state_dict def __setstate__(request, state_dict): import wbia dbdir = state_dict['dbdir'] del state_dict['dbdir'] params = state_dict['params'] depc = wbia.opendb(dbdir=dbdir, web=False).depc configclass = depc.configclass_dict[state_dict['tablename']] config = configclass(**params) state_dict['depc'] = depc state_dict['config'] = config request.__dict__.update(state_dict)
[docs]class AnnotSimiliarity(object):
[docs] def get_query_hashid(request): return request._get_rootset_hashid(request.qaids, 'Q')
[docs] def get_data_hashid(request): return request._get_rootset_hashid(request.daids, 'D')
[docs]@ut.reloadable_class class VsOneSimilarityRequest(BaseRequest, AnnotSimiliarity): r""" Similarity request for pairwise scores References: https://thingspython.wordpress.com/2010/09/27/ another-super-wrinkle-raising-typeerror/ CommandLine: python -m dtool.base --exec-VsOneSimilarityRequest Example: >>> # ENABLE_DOCTEST >>> from wbia.dtool.base import * # NOQA >>> from wbia.dtool.example_depcache import testdata_depc >>> qaid_list = [1, 2, 3, 5] >>> daid_list = [2, 3, 4] >>> depc = testdata_depc() >>> request = depc.new_request('vsone', qaid_list, daid_list) >>> results = request.execute() >>> # Test that adding a query / data id only recomputes necessary items >>> request2 = depc.new_request('vsone', qaid_list + [4], daid_list + [5]) >>> results2 = request2.execute() >>> print('results = %r' % (results,)) >>> print('results2 = %r' % (results2,)) >>> ut.assert_eq(len(results), 10, 'incorrect num output') >>> ut.assert_eq(len(results2), 16, 'incorrect num output') """ _symmetric = False
[docs] @classmethod def new(cls, depc, qaid_list, daid_list, cfgdict=None, tablename=None): parent_rowids = cls.make_parent_rowids(qaid_list, daid_list) parent_rowids = list(ut.product_nonsame(qaid_list, daid_list)) request = cls.static_new(cls, depc, parent_rowids, cfgdict, tablename) request.qaids = safeop(np.array, qaid_list) request.daids = safeop(np.array, daid_list) return request
[docs] @staticmethod def make_parent_rowids(qaid_list, daid_list): return list(ut.product_nonsame(qaid_list, daid_list))
@property def parent_rowids_T(request): return ut.list_transpose(request.parent_rowids)
[docs] def execute(request, parent_rowids=None, use_cache=None, postprocess=True, **kwargs): """HACKY REIMPLEMENTATION""" ut.colorprint('[req] Executing request %s' % (request,), 'yellow') table = request.depc[request.tablename] if use_cache is None: use_cache = not ut.get_argflag('--nocache') if parent_rowids is None: parent_rowids = request.parent_rowids else: # previously defined in execute subset # subparent_rowids = request.make_parent_rowids( # qaids, request.daids) logger.info('given %d specific parent_rowids' % (len(parent_rowids),)) # vsone hack (i,j) same as (j,i) if request._symmetric: import vtool as vt directed_edges = np.array(parent_rowids) undirected_edges = vt.to_undirected_edges(directed_edges) edge_ids = vt.compute_unique_data_ids(undirected_edges) unique_rows, unique_rowx, inverse_idx = np.unique( edge_ids, return_index=True, return_inverse=True ) parent_rowids_ = ut.take(parent_rowids, unique_rowx) else: parent_rowids_ = parent_rowids # Compute and cache any uncomputed results rowids = table.get_rowid(parent_rowids_, config=request, recompute=not use_cache) # Load all results result_list = table.get_row_data(rowids) if request._symmetric: result_list = ut.take(result_list, inverse_idx) if postprocess and hasattr(request, 'postprocess_execute'): logger.info('Converting results') result_list = request.postprocess_execute( table, parent_rowids, rowids, result_list ) return result_list
[docs] def get_input_hashid(request): return '_'.join([request.get_query_hashid(), request.get_data_hashid()])
def __nice__(request): dbname = ( None if request.depc is None or request.depc.controller is None else request.depc.controller.get_dbname() ) infostr_ = 'nQ=%s, nD=%s, nP=%d %s' % ( len(request.qaids), len(request.daids), len(request.parent_rowids), request.get_pipe_hashid(), ) return '(%s) %s' % (dbname, infostr_)
[docs]@ut.reloadable_class class VsManySimilarityRequest(BaseRequest, AnnotSimiliarity): r""" Request for one-vs-many simlarity CommandLine: python -m dtool.base --exec-VsManySimilarityRequest Example: >>> # ENABLE_DOCTEST >>> from wbia.dtool.base import * # NOQA >>> from wbia.dtool.example_depcache import testdata_depc >>> qaid_list = [1, 2] >>> daid_list = [2, 3, 4] >>> depc = testdata_depc() >>> request = depc.new_request('vsmany', qaid_list, daid_list) >>> request.ensure_dependencies() >>> results = request.execute() >>> # Test dependence on data >>> request2 = depc.new_request('vsmany', qaid_list + [3], daid_list + [5]) >>> results2 = request2.execute() >>> print('results = %r' % (results,)) >>> print('results2 = %r' % (results2,)) >>> assert len(results) == 2, 'incorrect num output' >>> assert len(results2) == 3, 'incorrect num output' """
[docs] @classmethod def new(cls, depc, qaid_list, daid_list, cfgdict=None, tablename=None): parent_rowids = list(zip(qaid_list)) request = cls.static_new(cls, depc, parent_rowids, cfgdict, tablename) request.qaids = safeop(np.array, qaid_list) request.daids = safeop(np.array, daid_list) # HACK request.config.daids = request.daids return request
[docs] def get_input_hashid(request): # return '_'.join([request.get_query_hashid(), request.get_data_hashid()]) return '_'.join([request.get_query_hashid()])
[docs] def get_cfgstr( request, with_input=False, with_data=True, with_pipe=True, hash_pipe=False ): r""" Override default get_cfgstr to show reliance on data """ cfgstr_list = [] if with_input: cfgstr_list.append(request.get_query_hashid()) if with_data: cfgstr_list.append(request.get_data_hashid()) if with_pipe: if hash_pipe: cfgstr_list.append(request.get_pipe_hashid()) else: cfgstr_list.append(request.get_pipe_cfgstr()) cfgstr = '_'.join(cfgstr_list) return cfgstr
def __nice__(request): dbname = ( None if request.depc is None or request.depc.controller is None else request.depc.controller.get_dbname() ) infostr_ = 'nQ=%s, nD=%s %s' % ( len(request.qaids), len(request.daids), request.get_pipe_hashid(), ) return '(%s) %s' % (dbname, infostr_)
[docs]class ClassVsClassSimilarityRequest(BaseRequest): pass
[docs]class AlgoResult(object): """Base class for algo result objects"""
[docs] @classmethod def load_from_fpath(cls, fpath, verbose=ut.VERBOSE): state_dict = ut.load_cPkl(fpath, verbose=verbose) self = cls() self.__setstate__(state_dict) return self
[docs] def save_to_fpath(cm, fpath, verbose=ut.VERBOSE): ut.save_cPkl(fpath, cm.__getstate__(), verbose=verbose, n=2)
def __getstate__(self): state_dict = self.__dict__ return state_dict def __setstate__(self, state_dict): self.__dict__.update(state_dict)
[docs] def copy(self): cls = self.__class__ out = cls() state_dict = copy.deepcopy(self.__getstate__()) out.__setstate__(state_dict) return out
[docs]def safeop(op_, xs, *args, **kwargs): return None if xs is None else op_(xs, *args, **kwargs)
[docs]class MatchResult(AlgoResult, ut.NiceRepr): def __init__( self, qaid=None, daids=None, qnid=None, dnid_list=None, annot_score_list=None, unique_nids=None, name_score_list=None, ): self.qaid = qaid self.daid_list = safeop(np.array, daids) self.dnid_list = safeop(np.array, dnid_list) self.annot_score_list = safeop(np.array, annot_score_list) self.name_score_list = safeop(np.array, name_score_list) @property def num_daids(cm): return None if cm.daid_list is None else len(cm.daid_list) @property def daids(cm): return cm.daid_list @property def qaids(cm): return cm.qaid def __nice__(cm): return ' qaid=%s nD=%s' % (cm.qaid, cm.num_daids)