"""Top-level module for WFSS contamination correction."""
import logging
import multiprocessing
import numpy as np
from stcal.multiprocessing import compute_num_cores
from stdatamodels.jwst import datamodels
from stdatamodels.jwst.transforms.models import (
NIRCAMBackwardGrismDispersion,
NIRISSBackwardGrismDispersion,
)
from jwst.lib.catalog_utils import read_source_catalog
from jwst.wfss_contam.observations import Observation
from jwst.wfss_contam.sens1d import get_photom_data
log = logging.getLogger(__name__)
__all__ = ["contam_corr"]
class UnmatchedSlitIDError(Exception):
"""Exception raised when a slit ID is not found in the list of simulated slits."""
pass
def _find_matching_simul_slit(slit, simul_slit_sids, simul_slit_orders):
"""
Find the index of the matching simulated slit in the list of simulated slits.
Parameters
----------
slit : `~stdatamodels.jwst.datamodels.SlitModel`
Source slit model
simul_slit_sids : list
List of source IDs for simulated slits
simul_slit_orders : list
List of spectral orders for simulated slits
Returns
-------
good_idx : int
Index of the matching simulated slit in the list of simulated slits
"""
sid = slit.source_id
order = slit.meta.wcsinfo.spectral_order
good = (np.array(simul_slit_sids) == sid) * (np.array(simul_slit_orders) == order)
if not any(good):
raise UnmatchedSlitIDError(
f"Source ID {sid} order {order} requested by input slit model "
"but not found in simulated slits. "
"Setting contamination correction to zero for that slit."
)
return np.where(good)[0][0]
def _cut_frame_to_match_slit(contam, slit):
"""
Cut out the contamination image to match the extent of the source slit.
Parameters
----------
contam : 2D array
Contamination image for the full grism exposure
slit : `~stdatamodels.jwst.datamodels.SlitModel`
Source slit model
Returns
-------
cutout : 2D array
Contamination image cutout that matches the extent of the source slit
"""
x1 = slit.xstart
y1 = slit.ystart
xf = x1 + slit.xsize
yf = y1 + slit.ysize
# zero-pad the contamination image if the slit extends beyond the contamination image
# fixes an off-by-one bug when sources extend to the edge of the contamination image
if xf > contam.shape[1]:
contam = np.pad(contam, ((0, 0), (0, xf - contam.shape[1])), mode="constant")
if yf > contam.shape[0]:
contam = np.pad(contam, ((0, yf - contam.shape[0]), (0, 0)), mode="constant")
return contam[y1 : y1 + slit.ysize, x1 : x1 + slit.xsize]
class SlitOverlapError(Exception):
"""Exception raised when there is no overlap between data and model for a slit."""
pass
def match_backplane_prefer_first(slit0, slit1):
"""
Reshape slit1 to the backplane of slit0.
Parameters
----------
slit0 : `~stdatamodels.jwst.datamodels.SlitModel`
Slit model for the first slit, which is used as reference.
slit1 : `~stdatamodels.jwst.datamodels.SlitModel`
Slit model for the second slit, which is reshaped to match slit0.
Returns
-------
slit0, slit1 : `~stdatamodels.jwst.datamodels.SlitModel`
Reshaped slit models slit0, slit1.
"""
data0 = slit0.data
data1 = slit1.data
x1 = slit1.xstart - slit0.xstart
y1 = slit1.ystart - slit0.ystart
backplane1 = np.zeros_like(data0)
i0 = max([y1, 0])
i1 = min([y1 + data1.shape[0], data0.shape[0], data1.shape[0]])
j0 = max([x1, 0])
j1 = min([x1 + data1.shape[1], data0.shape[1], data1.shape[1]])
if i0 >= i1 or j0 >= j1:
raise SlitOverlapError(
f"No overlap region between data and model for slit {slit0.source_id}, "
f"order {slit0.meta.wcsinfo.spectral_order}. "
"setting contamination correction to zero for that slit."
)
backplane1[i0:i1, j0:j1] = data1[i0:i1, j0:j1]
slit1.data = backplane1
slit1.xstart = slit0.xstart
slit1.ystart = slit0.ystart
slit1.xsize = slit0.xsize
slit1.ysize = slit0.ysize
return slit0, slit1
def match_backplane_encompass_both(slit0, slit1):
"""
Put data from the two slits into a common backplane, encompassing both.
Slits are zero-padded where their new extent does not overlap with the original data.
Parameters
----------
slit0, slit1 : `~stdatamodels.jwst.datamodels.SlitModel`
Slit model for the first and second slit.
Returns
-------
slit0, slit1 : `~stdatamodels.jwst.datamodels.SlitModel`
Reshaped slit models slit0, slit1.
"""
data0 = slit0.data
data1 = slit1.data
shape = (max(data0.shape[0], data1.shape[0]), max(data0.shape[1], data1.shape[1]))
xmin = min(slit0.xstart, slit1.xstart)
ymin = min(slit0.ystart, slit1.ystart)
shape = (
max(
slit0.xsize + slit0.xstart - xmin,
slit1.xsize + slit1.xstart - xmin,
),
max(
slit0.ysize + slit0.ystart - ymin,
slit1.ysize + slit1.ystart - ymin,
),
)
x0 = slit0.xstart - xmin
y0 = slit0.ystart - ymin
x1 = slit1.xstart - xmin
y1 = slit1.ystart - ymin
backplane0 = np.zeros(shape).T
backplane0[y0 : y0 + data0.shape[0], x0 : x0 + data0.shape[1]] = data0
backplane1 = np.zeros(shape).T
backplane1[y1 : y1 + data1.shape[0], x1 : x1 + data1.shape[1]] = data1
slit0.data = backplane0
slit1.data = backplane1
for slit in [slit0, slit1]:
slit.xstart = xmin
slit.ystart = ymin
slit.xsize = shape[0]
slit.ysize = shape[1]
return slit0, slit1
def _validate_orders_against_reference(orders, spec_orders):
"""
Compare user-requested spectral orders with the orders defined in the reference file.
Parameters
----------
orders : list[int]
List of user-requested spectral orders.
spec_orders : list[int]
List of spectral orders defined in the reference file.
Returns
-------
np.ndarray[int]
List of spectral orders constrained to the user-specified ones
that are also defined in the reference file.
"""
spec_orders = np.array(spec_orders, dtype=int)
if orders is None:
return spec_orders
orders = np.array(orders, dtype=int)
good_orders = np.isin(orders, spec_orders, assume_unique=True)
if (len(good_orders) == 0) or (not np.any(good_orders)):
log.error(
f"None of the requested spectral orders {orders} are defined "
"in the wavelength range reference file. "
f"Expected orders are: {spec_orders}. "
)
return []
if not np.all(good_orders):
log.warning(
f"Not all requested spectral orders {orders} are defined in the "
f"wavelength range reference file. Defined orders are: {spec_orders}. "
"Skipping undefined orders."
)
return orders[good_orders]
def _validate_orders_against_transform(wcs, spec_orders):
"""
Ensure the requested spectral orders are defined in the WCS transforms.
Parameters
----------
wcs : gwcs.wcs.WCS
The input MultiSlitModel's WCS object.
spec_orders : list[int]
The list of requested spectral orders.
Returns
-------
list
List of spectral orders that are defined in the WCS transform.
"""
sky_to_grism = wcs.backward_transform
good_orders = spec_orders.copy()
for model in sky_to_grism:
if isinstance(model, (NIRCAMBackwardGrismDispersion, NIRISSBackwardGrismDispersion)):
# Get the orders defined in the transform
orders = np.sort(model.orders)
is_good_order = [order in orders for order in spec_orders]
if not any(is_good_order):
log.error(
f"None of the requested spectral orders {spec_orders} are defined "
"in the WCS transform. "
f"Defined orders are: {orders}. "
)
return []
if not all(is_good_order):
log.warning(
f"Not all requested spectral orders {spec_orders} are "
f"defined in the WCS transform. Defined orders are: {orders}. "
"Skipping undefined orders."
)
good_orders = [order for order in spec_orders if order in orders]
# There will be only one transform of this type in the wcs
break
return np.sort(good_orders)
def _find_min_relresp(sens_waves, sens_response):
"""
Find the minimum relative response in the sensitivity response.
Helper function is necessary instead of just nanmin because reference file
sometimes has zero-valued wavelength/response pairs.
Parameters
----------
sens_waves : np.ndarray
Wavelengths corresponding to the sensitivity response.
sens_response : np.ndarray
Sensitivity response values.
Returns
-------
float
Minimum relative response value.
"""
good = (sens_waves > 0) & (sens_response > 0) & np.isfinite(sens_waves)
return np.nanmin(sens_response[good])
def _apply_magnitude_limit(
order, source_catalog, sens_wave, sens_response, magnitude_limit, min_relresp_order1
):
"""
Rescale the magnitude limit based on the sensitivity response for a given spectral order.
Parameters
----------
order : int
Spectral order for which the magnitude limit is applied.
source_catalog : astropy.table.Table
The source catalog containing source IDs and isophotal AB magnitudes.
sens_wave : np.ndarray
The wavelengths corresponding to the sensitivity response.
sens_response : np.ndarray
The sensitivity response for the order.
magnitude_limit : float
The isophotal AB magnitude limit for sources to be included in the contamination correction.
min_relresp_order1 : float
Minimum relative response for order 1, used to scale the magnitude limit.
Returns
-------
list
List of source IDs that meet the magnitude limit criteria.
"""
if order in [0, 1]:
# Magnitude limit is set according to order 1 sensitivity response
# and order 0 is a special case because it's not dispersed
order_mag_limit = magnitude_limit
else:
# Scale the magnitude limit according to the order sensitivity response
order_sens_factor = min_relresp_order1 / _find_min_relresp(sens_wave, sens_response)
order_mag_diff = -2.5 * np.log10(order_sens_factor)
order_mag_limit = magnitude_limit - order_mag_diff
# Select sources that are brighter than the magnitude limit
good_sources = source_catalog[source_catalog["isophotal_abmag"] < order_mag_limit]
if len(good_sources) == 0:
return None
log.info(
f"Applying magnitude limit of {order_mag_limit:.1f} to order {order}. "
f"Sources selected: {len(good_sources)}"
)
return good_sources["label"].tolist()
[docs]
def contam_corr(
input_model,
waverange,
photom,
max_cores,
orders=None,
magnitude_limit=None,
max_pixels_per_chunk=5e4,
oversample_factor=2,
):
"""
Correct contamination in WFSS spectral cutouts.
Parameters
----------
input_model : `~stdatamodels.jwst.datamodels.MultiSlitModel`
Input data model containing 2D spectral cutouts. May be modified by processing:
make a copy before calling this function, if needed.
waverange : `~stdatamodels.jwst.datamodels.WavelengthrangeModel`
Wavelength range reference file model
photom : `~stdatamodels.jwst.datamodels.NrcWfssPhotomModel` or \
`~stdatamodels.jwst.datamodels.NisWfssPhotomModel`
Photom (flux cal) reference file model
max_cores : str or int
Number of cores to use for multiprocessing. If set to 'none'
(the default), then no multiprocessing will be done. The other
allowable string values are 'quarter', 'half', and 'all', which indicate
the fraction of cores to use for multi-proc. The total number of
cores includes the SMT cores (Hyper Threading for Intel).
If an integer is provided, it will be the exact number of cores used.
orders : list, optional
List of spectral orders to process.
If None, all orders defined in the wavelengthrange file will be processed.
magnitude_limit : float, optional
Isophotal AB magnitude limit for sources to be included in the contamination correction.
The magnitude limit is applied per spectral order, where the orders are scaled relative
to order 0 based on their photometric response as read from the photom reference file.
This means that generally fewer sources will be dispersed in higher orders.
If None, no magnitude limit is applied and all sources are included.
max_pixels_per_chunk : int, optional
Maximum number of pixels to disperse simultaneously.
oversample_factor : int, optional
Wavelength oversampling factor.
Returns
-------
output_model : `~stdatamodels.jwst.datamodels.MultiSlitModel`
A copy of the input_model that has been corrected
simul_model : `~stdatamodels.jwst.datamodels.ImageModel`
Full-frame simulated image of the grism exposure
contam_model : `~stdatamodels.jwst.datamodels.MultiSlitModel`
Contamination estimate images for each source slit
"""
max_available_cores = multiprocessing.cpu_count()
# don't worry about case where nchunks < ncpus; just set nchunks large for now
ncpus = compute_num_cores(max_cores, 1e10, max_available_cores)
# Get the segmentation map and direct image for this grism exposure
seg_model = datamodels.open(input_model.meta.segmentation_map)
direct_file = input_model.meta.direct_image
log.debug(f"Direct image ={direct_file}")
with datamodels.open(direct_file) as direct_model:
direct_image = direct_model.data
direct_image_wcs = direct_model.meta.wcs
# Get the grism WCS object from the first cutout in the input model.
# This WCS is used to transform from direct image to grism frame for all sources
# in the segmentation map.
# The "detector" to "grism_detector" and "world" to "detector" transforms are identical
# for all slits, so just use the first one. The "grism_detector" to "grism_slit"
# transform is not used by the step.
grism_wcs = input_model.slits[0].meta.wcs
# Find out how many spectral orders are defined based on the
# array of order values in the Wavelengthrange ref file,
# then constrain the orders to the user-specified ones
spec_orders = np.asarray(waverange.order)
spec_orders = _validate_orders_against_reference(orders, spec_orders)
spec_orders = _validate_orders_against_transform(grism_wcs, spec_orders)
if len(spec_orders) == 0:
log.error("No valid spectral orders found. Step will be SKIPPED.")
return input_model, None, None, None
log.info(f"Spectral orders requested = {[int(x) for x in spec_orders]}")
# Get the FILTER and PUPIL wheel positions, for use later
filter_kwd = input_model.meta.instrument.filter
pupil_kwd = input_model.meta.instrument.pupil
# NOTE: The NIRCam WFSS mode uses filters that are in the FILTER wheel
# with gratings in the PUPIL wheel. NIRISS WFSS mode, however, is just
# the opposite. It has gratings in the FILTER wheel and filters in the
# PUPIL wheel. So when processing NIRISS grism exposures the name of
# filter needs to come from the PUPIL keyword value.
if input_model.meta.instrument.name == "NIRISS":
filter_name = pupil_kwd
else:
filter_name = filter_kwd
# Read the source catalog to perform magnitude-based source selection later
# mag limit will be scaled according to order 1 sensitivity
if magnitude_limit is not None:
source_catalog = read_source_catalog(input_model.meta.source_catalog)
order1_wave_response, order1_sens_response = get_photom_data(
photom, filter_kwd, pupil_kwd, order=1
)
min_relresp_order1 = _find_min_relresp(order1_wave_response, order1_sens_response)
# set up observation object to disperse
obs = Observation(
direct_image,
seg_model.data,
grism_wcs,
direct_image_wcs,
boundaries=[0, 2047, 0, 2047],
max_cpu=ncpus,
max_pixels_per_chunk=max_pixels_per_chunk,
oversample_factor=oversample_factor,
)
no_sources = True
for order in spec_orders:
# Load lists of wavelength ranges and flux cal info
wavelength_range = waverange.get_wfss_wavelength_range(filter_name, [order])
wmin = wavelength_range[order][0]
wmax = wavelength_range[order][1]
log.debug(f"wmin={wmin}, wmax={wmax} for order {order}")
sens_waves, sens_response = get_photom_data(photom, filter_kwd, pupil_kwd, order)
# Constrain the source IDs to those that are below the magnitude limit
selected_ids = None
if magnitude_limit is not None:
good_ids = _apply_magnitude_limit(
order,
source_catalog,
sens_waves,
sens_response,
magnitude_limit,
min_relresp_order1,
)
if good_ids is None:
log.info(
f"No sources meet the magnitude limit of {magnitude_limit} for order {order}. "
"Skipping contamination correction for this order."
)
continue
selected_ids = good_ids
no_sources = False
# Compute the dispersion for all sources in this order
log.info(f"Creating full simulated grism image for order {order}")
obs.disperse_order(order, wmin, wmax, sens_waves, sens_response, selected_ids)
if no_sources:
log.error(
f"No sources found that met the magnitude limit {magnitude_limit}. Step will be SKIPPED"
)
return input_model, None, None, None
# Initialize the full-frame simulated grism image
simul_model = datamodels.ImageModel(data=obs.simulated_image)
simul_model.update(input_model, only="PRIMARY")
simul_slit_sids = [slit.source_id for slit in obs.simulated_slits.slits]
simul_slit_orders = [slit.meta.wcsinfo.spectral_order for slit in obs.simulated_slits.slits]
# Initialize output multislitmodel
output_model = datamodels.MultiSlitModel()
# Copy over matching slits.
# Note that this makes a reference to input slits, not a deep copy,
# so the input data may be modified by this function. The input data is
# copied in the calling step, as needed.
good_slits = [slit for slit in input_model.slits if slit.source_id in obs.source_ids]
output_model.slits.extend(good_slits)
# Loop over all slits/sources to subtract contaminating spectra
log.info("Creating contamination image for each individual source")
contam_model = datamodels.MultiSlitModel()
contam_model.update(input_model, only="PRIMARY")
simul_slits = datamodels.MultiSlitModel()
simul_slits.update(input_model, only="PRIMARY")
for slit in output_model.slits:
try:
good_idx = _find_matching_simul_slit(slit, simul_slit_sids, simul_slit_orders)
this_simul = obs.simulated_slits.slits[good_idx]
slit, this_simul = match_backplane_prefer_first(slit, this_simul)
simul_all_cut = _cut_frame_to_match_slit(obs.simulated_image, slit)
contam_cut = simul_all_cut - this_simul.data
simul_slits.slits.append(this_simul)
except (UnmatchedSlitIDError, SlitOverlapError) as e:
log.warning(e)
contam_cut = np.zeros_like(slit.data)
contam_slit = datamodels.SlitModel()
contam_slit.data = contam_cut
contam_model.slits.append(contam_slit)
# Subtract the contamination from the source slit
slit.data -= contam_cut
output_model.update(input_model, only="PRIMARY")
output_model.meta.cal_step.wfss_contam = "COMPLETE"
seg_model.close()
return output_model, simul_model, contam_model, simul_slits