Source code for wbia.algo.detect.lightnet

# -*- coding: utf-8 -*-
"""Interface to Lightnet object proposals."""
import logging
import utool as ut
import numpy as np
from os.path import abspath, dirname, expanduser, join, exists, splitext  # NOQA
from tqdm import tqdm
import cv2

(print, rrr, profile) = ut.inject2(__name__, '[lightnet]')
logger = logging.getLogger('wbia')


if not ut.get_argflag('--no-lightnet'):
    try:
        import torch
        from torchvision import transforms as tf
        import lightnet as ln
    except ImportError:
        logger.info(
            'WARNING Failed to import lightnet. '
            'PyDarknet YOLO detection is unavailable'
        )
        if ut.SUPER_STRICT:
            raise


VERBOSE_LN = ut.get_argflag('--verbln') or ut.VERBOSE


CONFIG_URL_DICT = {
    'hammerhead': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.shark_hammerhead.py',
    'lynx': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.lynx.py',
    'manta': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.manta_ray_giant.py',
    'seaturtle': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.sea_turtle.py',
    'rightwhale': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.rightwhale.v1.py',
    'rightwhale_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.rightwhale.v1.py',
    'rightwhale_v2': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.rightwhale.v2.py',
    'rightwhale_v3': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.rightwhale.v3.py',
    'rightwhale_v4': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.rightwhale.v4.py',
    'rightwhale_v5': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.rightwhale.v5.py',
    'jaguar_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.jaguar.v1.py',
    'jaguar_v2': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.jaguar.v2.py',
    'jaguar': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.jaguar.v2.py',
    'giraffe_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.giraffe.v1.py',
    'zebra_mountain_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.zebra_mountain.v0.py',
    'hendrik_elephant': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.hendrik.elephant.py',
    'hendrik_elephant_ears': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.hendrik.elephant.ears.py',
    'hendrik_elephant_ears_left': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.hendrik.elephant.ears.left.py',
    'hendrik_dorsal': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.hendrik.dorsal.py',
    'humpback_dorsal': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.whale_humpback.dorsal.v0.py',
    'orca_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.whale_orca.v0.py',
    'whale_sperm_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.whale_sperm.v0.py',
    'fins_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.fins.v0.py',
    'fins_v1_fluke': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.fins.v1.py',
    'fins_v1_dorsal': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.fins.v1.1.py',
    'fins_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.fins.v1.1.py',
    'nassau_grouper_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.grouper_nassau.v0.py',
    'nassau_grouper_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.grouper_nassau.v1.py',
    'nassau_grouper_v2': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.grouper_nassau.v2.py',
    'nassau_grouper_v3': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.grouper_nassau.v3.py',
    'salanader_fire_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.salanader_fire.v0.py',
    'spotted_dolphin_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.dolphin_spotted.v0.py',
    'spotted_skunk_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.skunk_spotted.v0.py',
    'spotted_skunk_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.skunk_spotted.v1.py',
    'spotted_dolphin_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.dolphin_spotted.v1.py',
    'seadragon_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.seadragon.v0.py',
    'seadragon_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.seadragon.v1.py',
    'iot_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.iot.v0.py',
    'wilddog_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.wild_dog.v0.py',
    'leopard_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.leopard.v0.py',
    'cheetah_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.cheetah.v1.py',
    'cheetah_v2': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.cheetah.v2.py',
    'hyaena_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.hyaena.v0.py',
    'wild_horse_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.wild_horse.v0.py',
    'kitsci_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.kitsci.v0.py',
    'monk_seal_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.mediterranean_monk_seal.v0.py',
    'candidacy': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.candidacy.py',
    'ggr2': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.ggr2.py',
    'snow_leopard_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.snow_leopard.v0.py',
    'megan_argentina_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.megan.argentina.v1.py',
    'megan_kenya_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.megan.kenya.v1.py',
    'megan_argentina_v2': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.megan.argentina.v2.py',
    'megan_kenya_v2': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.megan.kenya.v2.py',
    'grey_whale_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.whale_grey.v0.py',
    'beluga_whale_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.whale_beluga.v0.py',
    'beluga_whale_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.whale_beluga.v1.py',
    'seals_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.seals.v0.py',
    'sea_turtle_v4': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.sea_turtle.v4.py',
    'spotted_eagle_ray_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.spotted_eagle_ray.v0.py',
    'yellow_bellied_toad_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.yellow_bellied_toad.v0.py',
    None: 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.candidacy.py',
    'training_kit': 'https://wildbookiarepository.azureedge.net/data/lightnet-training-kit.zip',
}


def _download_training_kit():
    training_kit_url = CONFIG_URL_DICT['training_kit']
    training_kit_path = ut.grab_zipped_url(training_kit_url, appname='lightnet')
    return training_kit_path


def _parse_weights_from_cfg(url):
    return url.replace('.py', '.weights')


def _parse_class_list(config_filepath):
    # Load classes from file into the class list
    params = ln.engine.HyperParameters.from_file(config_filepath)
    class_list = params.class_label_map
    return class_list


[docs]def detect_gid_list(ibs, gid_list, verbose=VERBOSE_LN, **kwargs): """Detect gid_list with lightnet. Args: gid_list (list of int): the list of IBEIS image_rowids that need detection Kwargs (optional): refer to the Lightnet documentation for configuration settings Args: ibs (wbia.IBEISController): image analysis api gid_list (list of int): the list of IBEIS image_rowids that need detection Kwargs: detector, config_filepath, weight_filepath, verbose Yields: tuple: (gid, gpath, result_list) """ # Get new gpaths if downsampling gpath_list = ibs.get_image_paths(gid_list) # Run detection results_iter = detect(gpath_list, verbose=verbose, **kwargs) # Upscale the results _iter = zip(gid_list, results_iter) for gid, (gpath, result_list) in _iter: # Upscale the results back up to the original image size for result in result_list: bbox = ( result['xtl'], result['ytl'], result['width'], result['height'], ) bbox_list = [bbox] bbox = bbox_list[0] result['xtl'], result['ytl'], result['width'], result['height'] = bbox yield (gid, gpath, result_list)
def _create_network( config_filepath, weight_filepath, conf_thresh, nms_thresh, multi=False ): """Create the lightnet network.""" device = torch.device('cpu') if torch.cuda.is_available(): logger.info('[lightnet] CUDA enabled') device = torch.device('cuda') else: logger.info('[lightnet] CUDA not available') params = ln.engine.HyperParameters.from_file(config_filepath) params.load(weight_filepath) params.device = device # Update conf_thresh and nms_thresh in postpsocess params.network.postprocess[0].conf_thresh = conf_thresh params.network.postprocess[1].nms_thresh = nms_thresh if multi: import torch.nn as nn import lightnet.data as lnd # Add serialization to Brambox Detections for DataParallel postprocess_list = list(params.network.postprocess) postprocess_list.append(lnd.transform.SerializeBrambox()) params.network.postprocess = lnd.transform.Compose(postprocess_list) # Make mult-GPU params.network = nn.DataParallel(params.network) params.network.eval() params.network.to(params.device) return params def _detect(params, gpath_list, flip=False): """Perform a detection.""" # Load image imgs = [] img_sizes = [] for gpath in gpath_list: img = cv2.imread(gpath) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if flip: img = cv2.flip(img, 1) img_h, img_w = img.shape[:2] img_size = ( img_w, img_h, ) img_sizes.append(img_size) img = ln.data.transform.Letterbox.apply(img, dimension=params.input_dimension) img = tf.ToTensor()(img) imgs.append(img) imgs = torch.stack(imgs) if len(imgs.shape) != 4: imgs.unsqueeze_(0) if torch.cuda.is_available(): imgs = imgs.cuda() # ut.embed() # Run detector if torch.__version__.startswith('0.3'): imgs_tf = torch.autograd.Variable(imgs, volatile=True) out = params.network(imgs_tf) else: with torch.no_grad(): out = params.network(imgs) result_list = [] for result, img_size in zip(out, img_sizes): result = ln.data.transform.ReverseLetterbox.apply( [result], params.input_dimension, img_size ) result = result[0] result_list.append(result) return result_list, img_sizes
[docs]def detect( gpath_list, config_filepath=None, weight_filepath=None, classes_filepath=None, sensitivity=0.0, verbose=VERBOSE_LN, flip=False, batch_size=192, **kwargs, ): """Detect image filepaths with lightnet. Args: gpath_list (list of str): the list of image paths that need proposal candidates Kwargs (optional): refer to the Lightnet documentation for configuration settings Returns: iter """ # Get correct weight if specified with shorthand config_url = None if config_filepath in CONFIG_URL_DICT: config_url = CONFIG_URL_DICT[config_filepath] config_filepath = ut.grab_file_url( config_url, appname='lightnet', check_hash=True ) # Get correct weights if specified with shorthand if weight_filepath in CONFIG_URL_DICT: if weight_filepath is None and config_url is not None: config_url_ = config_url else: config_url_ = CONFIG_URL_DICT[weight_filepath] weight_url = _parse_weights_from_cfg(config_url_) weight_filepath = ut.grab_file_url( weight_url, appname='lightnet', check_hash=True ) assert exists(config_filepath) config_filepath = ut.truepath(config_filepath) assert exists(weight_filepath) weight_filepath = ut.truepath(weight_filepath) conf_thresh = sensitivity nms_thresh = 1.0 # Turn off NMS params = _create_network(config_filepath, weight_filepath, conf_thresh, nms_thresh) # Execute detector for each image results_list_ = [] for gpath_batch_list in tqdm(list(ut.ichunks(gpath_list, batch_size))): try: result_list, img_sizes = _detect(params, gpath_batch_list, flip=flip) except cv2.error: result_list, img_sizes = [], [] for result, img_size in zip(result_list, img_sizes): img_w, img_h = img_size result_list_ = [] for output in list(result): xtl = int(np.around(float(output.x_top_left))) ytl = int(np.around(float(output.y_top_left))) xbr = int(np.around(float(output.x_top_left + output.width))) ybr = int(np.around(float(output.y_top_left + output.height))) width = xbr - xtl height = ybr - ytl class_ = output.class_label conf = float(output.confidence) if flip: xtl = img_w - xbr result_dict = { 'xtl': xtl, 'ytl': ytl, 'width': width, 'height': height, 'class': class_, 'confidence': conf, } result_list_.append(result_dict) results_list_.append(result_list_) if len(results_list_) != len(gpath_list): raise ValueError('Lightnet did not return valid data') results_list = zip(gpath_list, results_list_) return results_list