Source code for mangadap.proc.util

# Licensed under a 3-clause BSD style license - see LICENSE.rst
# -*- coding: utf-8 -*-
"""
Provides a set of processing utility functions for the MaNGA DAP.

----

.. include license and copyright
.. include:: ../include/copy.rst

----

.. include common links, assuming primary doc root is up one directory
.. include:: ../include/links.rst
"""
import os
import glob
import warnings

import numpy

from scipy import interpolate, spatial
import astropy.constants

[docs] def HDUList_mask_wavelengths(hdu, bitmask, bitmask_flag, wave_limits, wave_ext='WAVE', \ mask_ext='MASK', invert=False): """ Mask pixels in a specified wavelength range by turning on the bit value in the specified extention in a provided HDUList object. Args: hdu (`astropy.io.fits.HDUList`_): HDUList to alter bitmask (class:`BitMask`): Bit mask object used to turn on the named bit mask. bitmask_flag (str): Name of the bit to turn on. wave_limits (list or numpy.ndarray): Two-element array with the low and high wavelength limits. wave_ext (str): (Optional) Name of the wavelength extension in *hdu*. mask_ext (str): (Optional) Name of the mask extension in *hdu*. invert (bool): (Optional) Invert the sense of the masking. Instead of masking pixel in the wavelength interval, mask all pixels *outside* it. Returns: `astropy.io.fits.HDUList`_: The modified HDUList object. Raises: Exception: Raised if *wave_limits* does not have a length of two. """ if len(wave_limits) != 2: raise Exception('Wavelength limits must be a two-element vector.') indx = numpy.where( (hdu[wave_ext].data < wave_limits[0]) \ | (hdu[wave_ext].data > wave_limits[1])) if invert else \ numpy.where( (hdu[wave_ext].data >= wave_limits[0]) \ & (hdu[wave_ext].data <= wave_limits[1])) hdu[mask_ext].data[indx] = bitmask.turn_on(hdu[mask_ext].data[indx], bitmask_flag) return hdu
[docs] def select_proc_method(method_key, method_type, method_list=None, available_func=None): r""" Select a method from a list. One of method_list or available_func must be provided. Args: method_key (:obj:`str`): Keyword used to select the method. method_type (object): Object type to check ``method_list`` against. method_list (:obj:`list`, optional): List of methods from which to find the selection keyword. If None, ``available_func`` **must** be provided. available_func (callable, optional): Callable function that returns a list of default methods in place of ``method_list``. For example, see :func:`mangadap.proc.templatelibrary.available_template_libraries`. Returns: object: An object with base class :class:`mangadap.par.ParSet`, containing a set of parameters used to define a method or database. Raises: KeyError: Raised if the selected keyword is not among the provided list or if the provided list has more than one identical keyword. TypeError: Raised if the input *method_list* object is not a list or *method_type*, or if available function is not a callable function. """ # Get the default methods if no list provided if method_list is None: if not callable(available_func): raise TypeError('If not providing a list, must provide a callable function to ' \ 'produce the default list of methods/databases/libraries.') method_list = available_func() # Make sure the methods have the right type if not isinstance(method_list, list): method_list = [method_list] for l in method_list: if not isinstance(l, method_type): raise TypeError('Input method/database/library must have type {0}') # Find the selected method via its keyword selected_method = [ l['key'] == method_key for l in method_list ] if numpy.sum(selected_method) == 0: raise KeyError('{0} is not a valid method/database/library!'.format(method_key)) if numpy.sum(selected_method) > 1: for l in method_list: print(' ' + l['key']) print(method_key) raise KeyError('Keywords are not all unique!') # Return the method selected via the input keyword indx = numpy.where(selected_method)[0][0] return method_list[indx]
[docs] def get_database_key(f): """ Construct a key from the provided file or file path. The key is a capitalized version of the file after removing any extension. Args: f (:obj:`str`): The file name or path. Returns: :obj:`str`: The keyword. Examples: >>> get_database_key('junk') 'JUNK' >>> get_database_key('test.par') 'TEST' >>> get_database_key('/path/to/test.par') 'TEST' """ return os.path.split(f)[1].split('.')[0].upper()
[docs] def select_database(key, directory_path): r""" Select a database using a keyword and directory path. Args: key (:obj:`str`): Keyword used to select the method. directory_path (:obj:`str`): Full path with the valid database files. All files in the directory with a ``.par`` extension will be included. Returns: :obj:`str`: Returns the file with the selected database. Raises: NotADirectoryError: Raised if the provided directory path does not exist.j KeyError: Raised if the selected keyword cannot be associated with a file in the provided directory. """ if not os.path.isdir(directory_path): raise NotADirectoryError('{0} not found!'.format(directory_path)) files = glob.glob(os.path.join(directory_path, '*.par')) keys = [get_database_key(f) for f in files] if key not in keys: raise KeyError('No database found to associate with {0}.'.format(key)) return files[numpy.where(numpy.array(keys) == key)[0][0]]
#def _fill_vector(v, length, missing, value): # if v.size == length: # v[missing] = value # return v # _v = numpy.full(length, value, dtype=v.dtype) # _v[ list(set( numpy.arange(length) ) - set(missing)) ] = v # return _v
[docs] def flux_to_fnu(wave, flambda, unit_norm=1e-17): r""" Convert a spectrum with flux per unit wavelength to flux per unit frequency; i.e., calculate .. math:: F_{\nu} = F_{\lambda} \frac{d\lambda}{d\nu} = F_{\lambda} \frac{\lambda^2}{c}, where the first two arguments of the function are :math:`\lambda` and :math:`F_{\lambda}`. The input wavelength units are expected to be angstroms, and the input flux units are expected to be :math:`n\ {\rm erg\ s}^{-1}\ {\rm cm}^{-2}\ {\rm A}^{-1}`, where :math:`n` is the value of *unit_norm*. The output flux units are microjanskys, :math:`10^{-29} {\rm erg\ s}^{-1}\ {\rm cm}^{-2}\ {\rm Hz}^{-1}`. Args: wave (numpy.ndarray, list): The vector with the wavelengths in angstroms. flambda (numpy.ndarray, list): The vector with the flux per unit wavelength (angstroms). unit_norm (float): (**Optional**) The unit normalization of the flux. For example, this is :math:`10^{-17}` when the flux units are :math:`10^{-17} {\rm erg\ s}^{-1}\ {\rm cm}^{-2}\ {\rm A}^{-1}`. Returns: float,numpy.ndarray: The flux in units of microjanskys. Raises: ValueError: Raised if the arguments do not have the same shape. """ _wave = [wave] if isinstance(wave, float) else wave _wave = numpy.array(_wave) if isinstance(_wave, list) else _wave _flambda = [flambda] if isinstance(flambda, float) else flambda _flambda = numpy.array(_flambda) if isinstance(_flambda, list) else _flambda if _wave.shape != _flambda.shape: raise ValueError('Wavelength and flux arrays must have the same shape.') fnu = _flambda*numpy.square(_wave)*unit_norm*1e29/astropy.constants.c.to('angstrom/s').value return fnu[0] if isinstance(flambda, float) else fnu
# TODO: Requires a spectrum in each bin! # if any(numpy.diff(bin_indx[srt]) > 1): # rep = numpy.ones(bin_change.size, dtype=int) # i = 1 # while any(numpy.diff(bin_indx[srt]) > i): # rep[ numpy.where(numpy.diff(bin_indx[srt]) > i)[0]+1 == bin_change ] += 1 # i += 1 # bin_change = numpy.repeat(bin_change, rep) # OLD VERSION #def residual_growth(resid, growth_samples): # """ # Interpolate the growth curve at distinct fractions, bracketed by the # minimum and maximum. # """ # np = resid.size # grw = numpy.arange(np).astype(float)/np # resid_sort = numpy.sort(numpy.absolute(resid)) # interp = interpolate.interp1d(grw, resid_sort, fill_value='extrapolate') # return numpy.append(numpy.append(resid_sort[0], interp(growth_samples)), resid_sort[-1]) #def residual_growth(resid, growth_samples): # """ # Sample a set of residuals at specific the growth intervals. No # interpolation is performed. # """ # np = resid.size # resid_sort = numpy.sort(numpy.absolute(resid)) # i = numpy.zeros(len(growth_samples)+2, dtype=float) # i[1:-1] = np*numpy.asarray(growth_samples) # i[-1] = np-1 # i[i < 0] = 0 # i[i >= np] = np-1 # return resid_sort[i.astype(int)]
[docs] def sample_growth(a, samples, default=-9999., use_interpolate=True): _samples = numpy.asarray(samples) if numpy.any((_samples < 0) | (_samples > 1)): raise ValueError('Growth samples must be between 0 and 1.') _a = a.compressed() if isinstance(a, numpy.ma.MaskedArray) else numpy.atleast_1d(a).ravel() ns = _samples.size if len(_a) < 2: return [default]*ns if ns > 1 else default srt = numpy.argsort(_a) n = _a.size grw = (numpy.arange(n,dtype=float)+1)/n if use_interpolate: interpolator = interpolate.interp1d(grw, _a[srt], fill_value='extrapolate') g = interpolator(_samples) return tuple(g) if ns > 1 else g i = (n*_samples).astype(int) i[i > n-1] = n-1 return _a[srt][i]
[docs] def growth_lim(a, lim, fac=1.0, midpoint=None, default=[0., 1.]): """ Set the plots limits of an array based on two growth limits. Args: a (array-like): Array for which to determine limits. lim (float): Percentage of the array values to cover. fac (float): (**Optional**) Factor to increase the range based on the growth limits. Default is no increase. midpoint (float): (**Optional**) Force the midpoint of the range to be centered on this value. Default is to middle of growth range. default (list): (**Optional**) Default range to return if `a` has no data. Default is 0 to 1. Returns: list: Lower and upper limits for the range of a plot of the data in `a`. """ # Get the values to plot _a = a.compressed() if isinstance(a, numpy.ma.MaskedArray) else numpy.asarray(a).ravel() if len(_a) == 0: # No data so return the default range return default # Sort the values srt = numpy.ma.argsort(_a) # Set the starting and ending values based on a fraction of the # growth _lim = 1.0 if lim > 1.0 else lim start = int(len(_a)*(1.0-_lim)/2) end = int(len(_a)*(_lim + (1.0-_lim)/2)) if end == len(_a): end -= 1 # Set the full range and increase it by the provided factor Da = (_a[srt[end]] - _a[srt[start]])*fac # Set the midpoint if not provided mid = (_a[srt[start]] + _a[srt[end]])/2 if midpoint is None else midpoint # Return the range for the plotted data return [ mid - Da/2, mid + Da/2 ]
[docs] def optimal_scale(dat, mod, wgt=None): r""" Calculate the optimal scaling of an input model that minimizes the weighted root-mean-square difference between a set of data and a model. When defining the weighted RMS as: .. math:: {\rm RMS}^2 = \frac{1}{N} \sum_i w_i^2(d_i - f m_i)^2 The optimal renormalization factor that minimizes the RMS is: .. math:: f = \frac{\mathbf{d}^\prime \dot \mathbf{m}^\prime}{||\mathbf{m}^\prime||^2} where :math:`d^\prime_i = w_i d_i` and :math:`m^\prime_i = w_i m_i`. Args: dat (array-like): Array of data. mod (array-like): Model to renormalize. wgt (array-like): (**Optional**) Array of weights to apply to each residual. Returns: float : The optimal scaling that minimizes the weighted root-mean-square difference betwen the data and the model. Raises: ValueError: Raised if the array sizes do not match. """ _dat = numpy.atleast_1d(dat) _mod = numpy.atleast_1d(mod) _wgt = numpy.ones(_dat.shape, dtype=float) if wgt is None else numpy.atleast_1d(wgt) if _mod.shape != _dat.shape or (_wgt is not None and _wgt.shape != _dat.shape): raise ValueError('Shapes of all input arrays must match.') dp = _dat*_wgt norm_mp = numpy.sum(numpy.square(_wgt*_mod)) if norm_mp == 0: warnings.warn('Scale not determined because model norm is 0.') return 1.0 if norm_mp == 0 else numpy.sum(numpy.square(_wgt)*_dat*_mod)/norm_mp
[docs] def replace_with_data_from_nearest_coo(coo, data, replace): """ Replace data in array with the spatially nearest neighbor. Args: coo (numpy.ndarray): A 2D array with the x and y coordinates of all the data. Shape must be (NDATA, 2). data (numpy.ndarray): A 1D or 2D array with data to replace. The length of the first (or only) axis must be NDATA. replace (numpy.ndarray): Boolean array that is True for elements that should be replaced on output. Shape must be (NDATA,). Returns: numpy.ndarray: The data array with the selected rows replaced with the nearest data set. Raises: ValueError: Raised if the array sizes are inappropriate. """ # Check the input _coo = numpy.asarray(coo) if len(_coo.shape) != 2: raise ValueError('Input coordinate array must be two-dimensional.') if _coo.shape[1] != 2: raise ValueError('Currently only works with two coordinates.') ndata = _coo.shape[0] _data = numpy.asarray(data) if _data.shape[0] != ndata: raise ValueError('Coordinate and data arrays have a mismatched shape.') oned = len(_data.shape) == 1 _replace = numpy.asarray(replace, dtype=bool) if len(_replace.shape) != 1 or _replace.shape[0] != ndata: raise ValueError('Input replacement selection array has an incorrect shape.') # Nothing flagged to replace, so just return a copy of the input if numpy.sum(_replace) == 0: return data.copy() # Use the coordinates to replace to set the KDTree reference grid do_not_replace = numpy.invert(_replace) kd = spatial.KDTree(_coo[do_not_replace,:]) # Get the indices of the nearest data points dist, nearest_bin = kd.query(_coo[_replace,:]) # Replace the existing data with the nearest one and return it new_data = _data.copy() if oned: new_data[_replace] = _data[do_not_replace][nearest_bin] else: new_data[_replace,:] = _data[do_not_replace,:][nearest_bin,:] return new_data
[docs] def inverse(d): """ Return 1/d, where any division by 0 returns 0 instead of NaN or Inf. Args: d (scalar-like, array-like): Data values. Returns: float, `numpy.ndarray`_: Returns 1/d where values with d == 0. are replaced by 0. Return type matches input type: float for scalar, `numpy.ndarray`_ for array-like. """ _d = float(d) if isinstance(d, (float, int)) else numpy.atleast_1d(d).astype(float) m = _d != 0.0 return m/(_d + numpy.logical_not(m))