Source code for emd.support

#!/usr/bin/python

# vim: set expandtab ts=4 sw=4:

"""
Helper functions for interacting with an EMD install and ensuring array sizes.

Main Routines:
  get_install_dir
  get_installed_version
  run_tests

Ensurance Routines:
  ensure_equal_dims
  ensure_vector
  ensure_1d_with_singleton
  ensure_2d

Errors:
  EMDSiftCovergeError

"""

import logging
import os
import pathlib

import numpy as np

# Housekeeping for logging
logger = logging.getLogger(__name__)


[docs] def get_install_dir(): """Get directory path of currently installed & imported emd.""" #return os.path.dirname(sift.__file__) return str(pathlib.Path(__file__).parent.resolve())
[docs] def get_installed_version(): """Read version of currently installed & imported emd. Version is determined according to local setup.py. If a user has made local changes this version may not be exactly the same as the online package. """ # The directory containing this file HERE = pathlib.Path(__file__).parent.parent version = (HERE / 'emd' / '_version.py').read_text().split(' = ')[1].rstrip('\n').strip("'") return version
def run_tests(): """Run tests in directly from python. Useful for people without a dev-install to run tests perhaps. https://docs.pytest.org/en/latest/usage.html#calling-pytest-from-python-code """ import pytest inst_dir = get_install_dir() if os.path.exists(os.path.join(inst_dir, 'tests')) is False: logger.info('Test directory not found in: {0}'.format(inst_dir)) logger.info('(this is normal for PyPI/pip EMD installs)') else: logger.info('Running EMD package tests from: {0}'.format(inst_dir)) out = pytest.main(['-x', inst_dir]) if out.value != 0: logger.warning('EMD package tests FAILED - EMD may not behave as expected') else: logger.info('EMD package tests passed') # Parallel processing def run_parallel(pfunc, args, nprocesses=1): """Run set of processes in serial or parallel.""" from joblib import Parallel, delayed if nprocesses > 1: with Parallel(n_jobs=nprocesses) as parallel: res = parallel(delayed(pfunc)(*aa) for aa in args) else: res = [pfunc(*aa) for aa in args] return res # Ensurance Department def ensure_equal_dims(to_check, names, func_name, dim=None): """Check that a set of arrays all have the same dimension. Raises an error with details if not. Parameters ---------- to_check : list of arrays List of arrays to check for equal dimensions names : list List of variable names for arrays in to_check func_name : str Name of function calling ensure_equal_dims dim : int Integer index of specific axes to ensure shape of, default is to compare all dims Raises ------ ValueError If any of the inputs in to_check have differing shapes """ if dim is None: dim = np.arange(to_check[0].ndim) else: dim = [dim] all_dims = [tuple(np.array(x.shape)[dim]) for x in to_check] check = [True] + [all_dims[0] == all_dims[ii + 1] for ii in range(len(all_dims[1:]))] if np.alltrue(check) == False: # noqa: E712 msg = 'Checking {0} inputs - Input dim mismatch'.format(func_name) logger.error(msg) msg = "Mismatch between inputs: " for ii in range(len(to_check)): msg += "'{0}': {1}, ".format(names[ii], to_check[ii].shape) logger.error(msg) raise ValueError(msg) def ensure_vector(to_check, names, func_name): """Check that a set of arrays are all vectors with only 1-dimension. Arrays with singleton second dimensions will be trimmed and an error will be raised for non-singleton 2d or greater than 2d inputs. Parameters ---------- to_check : list of arrays List of arrays to check for equal dimensions names : list List of variable names for arrays in to_check func_name : str Name of function calling ensure_equal_dims Returns ------- out Copy of arrays in to_check with 1d shape. Raises ------ ValueError If any input is a 2d or greater array """ out_args = list(to_check) for idx, xx in enumerate(to_check): if (xx.ndim > 1) and (xx.shape[1] == 1): msg = "Checking {0} inputs - trimming singleton from input '{1}'" msg = msg.format(func_name, names[idx]) out_args[idx] = out_args[idx][:, 0] logger.warning(msg) elif (xx.ndim > 1) and (xx.shape[1] != 1): msg = "Checking {0} inputs - Input '{1}' {2} must be a vector or 2d with singleton second dim" msg = msg.format(func_name, names[idx], xx.shape) logger.error(msg) raise ValueError(msg) elif xx.ndim > 2: msg = "Checking {0} inputs - Shape of input '{1}' {2} must be a vector." msg = msg.format(func_name, names[idx], xx.shape) logger.error(msg) raise ValueError(msg) if len(out_args) == 1: return out_args[0] else: return out_args def ensure_1d_with_singleton(to_check, names, func_name): """Check that a set of arrays are all vectors with singleton second dimensions. 1d arrays will have a singleton second dimension added and an error will be raised for non-singleton 2d or greater than 2d inputs. Parameters ---------- to_check : list of arrays List of arrays to check for equal dimensions names : list List of variable names for arrays in to_check func_name : str Name of function calling ensure_equal_dims Returns ------- out Copy of arrays in to_check with '1d with singleton' shape. Raises ------ ValueError If any input is a 2d or greater array """ out_args = list(to_check) for idx, xx in enumerate(to_check): if (xx.ndim >= 2) and np.all(xx.shape[1:] == np.ones_like(xx.shape[1:])): # nd input where all trailing are ones msg = "Checking {0} inputs - Trimming trailing singletons from input '{1}' (input size {2})" logger.debug(msg.format(func_name, names[idx], xx.shape)) out_args[idx] = np.squeeze(xx)[:, np.newaxis] elif (xx.ndim >= 2) and np.all(xx.shape[1:] == np.ones_like(xx.shape[1:])) == False: # noqa: E712 # nd input where some trailing are not one msg = "Checking {0} inputs - trailing dims of input '{1}' {2} must be singletons (length=1)" logger.error(msg.format(func_name, names[idx], xx.shape)) raise ValueError(msg) elif xx.ndim == 1: # Vector input - add a dummy dimension msg = "Checking {0} inputs - Adding dummy dimension to input '{1}'" logger.debug(msg.format(func_name, names[idx])) out_args[idx] = out_args[idx][:, np.newaxis] if len(out_args) == 1: return out_args[0] else: return out_args def ensure_2d(to_check, names, func_name): """Check that a set of arrays are all arrays with 2 dimensions. 1d arrays will have a singleton second dimension added. Parameters ---------- to_check : list of arrays List of arrays to check for equal dimensions names : list List of variable names for arrays in to_check func_name : str Name of function calling ensure_equal_dims Returns ------- out Copy of arrays in to_check with 2d shape. """ out_args = list(to_check) for idx in range(len(to_check)): if to_check[idx].ndim == 1: msg = "Checking {0} inputs - Adding dummy dimension to input '{1}'" logger.debug(msg.format(func_name, names[idx])) out_args[idx] = out_args[idx][:, np.newaxis] if len(out_args) == 1: return out_args[0] else: return out_args # Exceptions & Errors class EMDSiftCovergeError(Exception): """Exception raised for errors in the input. Attributes ---------- expression -- input expression in which the error occurred message -- explanation of the error """ def __init__(self, message): """Raise error indicating that sift has failed to converge.""" self.message = message logger.exception(self.message)