Source code for jwst.cube_build.blot_cube_build

import logging

import numpy as np
from gwcs import wcstools

from jwst.assign_wcs import nirspec
from jwst.assign_wcs.util import in_ifu_slice
from jwst.cube_build import instrument_defaults
from jwst.cube_build.blot_median import blot_wrapper  # c extension
from jwst.datamodels import ModelContainer

log = logging.getLogger(__name__)

__all__ = ["CubeBlot"]


[docs] class CubeBlot: """ Main module for blotting a sky cube back to detector space. Information is pulled out of the median sky cube created by a previous run of ``cube_build`` in single mode and stored in the class. These variables include the WCS of median sky cube, the weighting parameters used to create this median sky image, and basic information of the input data (instrument, channel, band, grating, or filter). Parameters ---------- median_model : `~stdatamodels.jwst.datamodels.IFUCubeModel` The median input sky cube is created from a median stack of all the individual input models mapped to the full IFU cube imprint on the sky. input_models : `~jwst.datamodels.container.ModelContainer` The input models used to create the median sky cube. """ def __init__(self, median_model, input_models): # Pull out the needed information from the Median IFUCube self.median_skycube = median_model self.instrument = median_model.meta.instrument.name # basic information about the type of data self.grating = None self.filter = None self.subchannel = None self.channel = None self.par_median_select1 = None self.par_median_select2 = None if self.instrument == "MIRI": self.channel = median_model.meta.instrument.channel self.subchannel = median_model.meta.instrument.band.lower() self.par_median_select1 = self.channel self.par_median_select2 = self.subchannel elif self.instrument == "NIRSPEC": self.grating = median_model.meta.instrument.grating self.filter = median_model.meta.instrument.filter self.par_median_select1 = self.grating # ________________________________________________________________ # set up x,y,z of Median Cube # Median cube should have linear wavelength xcube, ycube, zcube = wcstools.grid_from_bounding_box( self.median_skycube.meta.wcs.bounding_box, step=(1, 1, 1) ) # using wcs of ifu cube determine ra, dec, lambda self.cube_ra, self.cube_dec, self.cube_wave = self.median_skycube.meta.wcs( xcube + 1, ycube + 1, zcube + 1 ) # pull out flux from the median sky cube that matches with # cube_ra, dec, wave self.cube_flux = self.median_skycube.data # remove all the nan values - just in case valid1 = ~np.isnan(self.cube_ra) valid2 = ~np.isnan(self.cube_dec) good_data = np.where(valid1 & valid2) self.cube_ra = self.cube_ra[good_data] self.cube_dec = self.cube_dec[good_data] self.cube_wave = self.cube_wave[good_data] self.cube_flux = self.cube_flux[good_data] # initialize blotted images to be original input images valid for the Median Image # read channel (MIRI) or grating (NIRSpec) value that the Median Image covers # only use input models in this range self.input_models = [] self.input_list_number = [] for icount, model in enumerate(input_models): if self.instrument == "MIRI": par1 = model.meta.instrument.channel par2 = model.meta.instrument.band.lower() found2 = par2.find(self.par_median_select2) if self.instrument == "NIRSPEC": par1 = model.meta.instrument.grating par2 = model.meta.instrument.grating found2 = 1 found1 = par1.find(self.par_median_select1) if found1 > -1 and found2 > -1: self.input_models.append(model) self.input_list_number.append(icount) # **********************************************************************
[docs] def blot_info(self): """Print the basic parameters of the blot image and median sky cube.""" log.info("Information on Blotting") log.info("Working with instrument %s", self.instrument) log.info( "Shape of sky cube %f %f %f", self.median_skycube.data.shape[2], self.median_skycube.data.shape[1], self.median_skycube.data.shape[0], ) if self.instrument == "MIRI": log.info("Channel %s", self.channel) log.info("Sub-channel %s", self.subchannel) elif self.instrument == "NIRSPEC": log.info("Grating %s", self.grating) log.info("Filter %s", self.filter)
# ***********************************************************************
[docs] def blot_images(self): """ Call the instrument specific blotting code. Returns ------- blotmodels : `~jwst.datamodels.container.ModelContainer` Blotted IFU image models input_list_number : list of int List containing index of blot model in input models """ if self.instrument == "MIRI": blotmodels = self.blot_images_miri() elif self.instrument == "NIRSPEC": blotmodels = self.blot_images_nirspec() return blotmodels, self.input_list_number
# ************************************************************************
[docs] def blot_images_miri(self): """ Core blotting routine for MIRI. This is the main routine for blotting the MIRI median sky cube back to the detector space and creating a blotting image for each input model: 1. Loop over every data model to be blotted and find RA, Dec, and wavelength for every pixel in a valid slice on the detector. 2. Loop over every input model and using the inverse (backwards) transform convert the median sky cube values RA, Dec, lambda to the blotted x, y detector value ``(x_cube, y_cube)``. 3. For each input model loop over the blotted x, y values and find the x, y detector values that fall within the ROI. The blotted flux is the weighted flux, where the weight is based on distance between the center of the blotted pixel and the detector pixel. Returns ------- blot_models : `~jwst.datamodels.container.ModelContainer` Container of blotted `~stdatamodels.jwst.datamodels.IFUImageModel` """ blot_models = ModelContainer() instrument_info = instrument_defaults.InstrumentInfo() for model in self.input_models: blot = model.copy() blot.err = None blot.dq = None xstart = 0 # ___________________________________________________________________ # For MIRI we only work on one channel at a time this_par1 = self.channel # get the detector values for this model xstart, xend = instrument_info.GetMIRISliceEndPts(this_par1) ysize, xsize = model.data.shape ydet, xdet = np.mgrid[:ysize, :xsize] ydet = ydet.flatten() xdet = xdet.flatten() self.ycenter_grid, self.xcenter_grid = np.mgrid[0:ysize, 0:xsize] xsize2 = xend - xstart + 1 xcenter = np.arange(xsize2) + xstart ycenter = np.arange(ysize) valid_channel = np.logical_and(xdet >= xstart, xdet <= xend) xdet = xdet[valid_channel] ydet = ydet[valid_channel] # cube spaxel ra,dec values --> x, y on detector x_cube, y_cube = model.meta.wcs.backward_transform( self.cube_ra, self.cube_dec, self.cube_wave ) x_cube = np.ndarray.flatten(x_cube) y_cube = np.ndarray.flatten(y_cube) flux_cube = np.ndarray.flatten(self.cube_flux) valid = ~np.isnan(y_cube) valid_channel = np.logical_and(x_cube >= xstart, x_cube <= xend) valid_flux = flux_cube != 0 fuse = np.where(valid & valid_channel & valid_flux) x_cube = x_cube[fuse] y_cube = y_cube[fuse] flux_cube = flux_cube[fuse] log.info("Blotting back to %s", model.meta.filename) # ______________________________________________________________________________ # blot_wrapper is a c extension that finds: # the overlapping median cube spaxels with the detector pixels # A median spaxel that falls within the ROI of the center of the # detector pixel is flagged as an overlapping pixel. xsize2 = xcenter.shape[0] roi_det = 1.0 # Just large enough that we don't get holes # set up c wrapper for blotting result = blot_wrapper( roi_det, xsize, ysize, xstart, xsize2, xcenter, ycenter, x_cube, y_cube, flux_cube ) blot_flux, blot_weight = result igood = np.where(blot_weight > 0) blot_flux[igood] = blot_flux[igood] / blot_weight[igood] blot_flux = blot_flux.reshape((ysize, xsize)) blot.data = blot_flux blot_models.append(blot) return blot_models
# ************************************************************************
[docs] def blot_images_nirspec(self): """ Core blotting routine for NIRSPEC. This is the main routine for blotting the NIRSPEC median sky cube back to the detector space and creating a blotting image for each input model. This routine was split from the MIRI routine because the blotting for NIRSpec needs to be done slice by slice and an error in the inverse mapping (sky to detector) mapped too many values back to the detector. This routine adds a check and first pulls out the min and max RA and Dec values in the slice and only inverts the slice values back to the detector. For each data model loop over the 30 slices and find: a. the x, y bounding box of slice b. the RA, Dec, lambda values for the x, y pixels in the slice c. from step b, determine the min and max RA and Dec for slice values d. pull out the valid RA, Dec and lambda values from the median sky cube that fall within the min and max RA and Dec determined in step c e. invert the valid RA, Dec, and lambda values for the slice determined in step d to the detector f. blot the inverted x, y values to the detector plane. This step determines the overlap of the blotted x, y values with a regular grid setup in the detector plane which is the blotted image. Returns ------- blot_models : `~jwst.datamodels.container.ModelContainer` Container of blotted `~stdatamodels.jwst.datamodels.IFUImageModel`. """ blot_models = ModelContainer() for model in self.input_models: blot_ysize, blot_xsize = model.shape ntotal = blot_ysize * blot_xsize blot_flux = np.zeros(ntotal, dtype=np.float32) blot_weight = np.zeros(ntotal, dtype=np.float32) blot = model.copy() blot.err = None blot.dq = None # ___________________________________________________________________ ycenter = np.arange(blot_ysize) xcenter = np.arange(blot_xsize) # for NIRSPEC wcs information accessed separately for each slice nslices = 30 log.info("Blotting 30 slices on NIRSPEC detector") roi_det = 1.0 # Just large enough that we don't get holes for ii in range(nslices): # for each slice pull out the blotted values that actually fall on the slice region # use the bounding box of each slice to determine the slice limits slice_wcs = nirspec.nrs_wcs_set_input(model, ii) slicer2world = slice_wcs.get_transform("slicer", "world") detector2slicer = slice_wcs.get_transform("detector", "slicer") # find some rough limits on ra,dec, lambda using the x,y -> ra,dec,lambda x, y = wcstools.grid_from_bounding_box(slice_wcs.bounding_box) ra, dec, lam = slice_wcs(x, y) # Add a padding to make slice a little bigger on sky. # The slice is very small and the median cube is coarse grid on the sky in ra,dec # So we need to expand the slice min and max or we will find not values # falling in min and max limits. ramin = np.nanmin(ra) - self.median_skycube.meta.wcsinfo.cdelt1 * 4 ramax = np.nanmax(ra) + self.median_skycube.meta.wcsinfo.cdelt1 * 4 decmin = np.nanmin(dec) - self.median_skycube.meta.wcsinfo.cdelt2 * 4 decmax = np.nanmax(dec) + self.median_skycube.meta.wcsinfo.cdelt2 * 4 lam_min = np.nanmin(lam) lam_max = np.nanmax(lam) if ramin < 0: ramin = 0 if ramax > 360: ramax = 360 use1 = np.logical_and(self.cube_ra >= ramin, self.cube_ra <= ramax) use2 = np.logical_and(self.cube_dec >= decmin, self.cube_dec <= decmax) use3 = np.logical_and(self.cube_wave >= lam_min, self.cube_wave <= lam_max) use = np.logical_and(np.logical_and(use1, use2), use3) ra_use = self.cube_ra[use] dec_use = self.cube_dec[use] wave_use = self.cube_wave[use] flux_use = self.cube_flux[use] # get the indices of elements on the slice onslice_ind = in_ifu_slice(slice_wcs, ra_use, dec_use, wave_use) slx, sly, sllam = slicer2world.inverse(ra_use, dec_use, wave_use) xslice, yslice = detector2slicer.inverse( slx[onslice_ind], sly[onslice_ind], sllam[onslice_ind] ) # pull out region for slice fluxslice = flux_use[onslice_ind] # one more limit on the x,y bounding box # only use values what fall in bounding box of the slice xlimit, ylimit = slice_wcs.bounding_box xuse = np.logical_and(xslice >= xlimit[0], xslice <= xlimit[1]) yuse = np.logical_and(yslice >= ylimit[0], yslice <= ylimit[1]) use = np.logical_and(xuse, yuse) xuse = xslice[use] yuse = yslice[use] flux_use = fluxslice[use] if ii == 0: x_total = xuse y_total = yuse flux_total = flux_use else: x_total = np.concatenate((x_total, xuse)) y_total = np.concatenate((y_total, yuse)) flux_total = np.concatenate((flux_total, flux_use)) # end looping over the 30 slices # set up c wrapper for blotting xstart = 0 xsize2 = blot_xsize result = blot_wrapper( roi_det, blot_xsize, blot_ysize, xstart, xsize2, xcenter, ycenter, x_total, y_total, flux_total, ) blot_flux_slice, blot_weight_slice = result blot_flux = blot_flux + blot_flux_slice blot_weight = blot_weight + blot_weight_slice result = None blot_weight_slice = None blot_flux_slice = None # done mapping median cube to this input model igood = np.where(blot_weight > 0) blot_flux[igood] = blot_flux[igood] / blot_weight[igood] blot_flux = blot_flux.reshape((blot_ysize, blot_xsize)) blot.data = blot_flux blot_models.append(blot) return blot_models