Source code for jwst.clean_flicker_noise.tso_median_image
import logging
import warnings
import numpy as np
from stdatamodels.jwst import datamodels
from jwst.background.background_step import BackgroundStep
from jwst.clean_flicker_noise.background_level import background_level
from jwst.datamodels.utils.tso_multispec import make_tso_specmodel
from jwst.extract_1d.extract_1d_step import Extract1dStep
from jwst.extract_1d.soss_extract import pastasoss, soss_extract
from jwst.extract_2d.extract_2d_step import Extract2dStep
from jwst.lib.basic_utils import disable_logging
from jwst.lib.pipe_utils import is_tso
from jwst.tso_photometry.tso_photometry_step import TSOPhotometryStep
from jwst.white_light.white_light_step import WhiteLightStep
__all__ = ["make_median_image"]
log = logging.getLogger(__name__)
def _soss_box_extract(rateints, soss_refmodel=None):
"""
Extract spectra with a simple box around the SOSS trace.
The spectra are intended for approximate scaling, so they do
not need to be very accurate. Only order 0 is extracted from
each integration, with a width of 15 pixels around the trace from
the pastasoss model.
Parameters
----------
rateints : `~stdatamodels.jwst.datamodels.CubeModel`
Background subtracted rateints datamodel.
soss_refmodel : `~stdatamodels.jwst.datamodels.PastasossModel`, optional
Used to identify SOSS traces for box extraction.
If not provided, a default model will be retrieved.
Returns
-------
multi_spec : ~stdatamodels.jwst.datamodels.TSOMultiSpecModel`
Extracted spectra, with only FLUX and WAVELENGTH arrays
populated.
"""
nints = rateints.data.shape[0]
img_shape = rateints.data.shape[-2:]
pwcpos = rateints.meta.instrument.pupil_position
subarray = rateints.meta.subarray.name
if soss_refmodel is None:
soss_refmodel = pastasoss.retrieve_default_pastasoss_model()
# Set a reasonable extraction width
width = 15
# Extract only order 1 for scaling purposes
order = 1
# Get trace x,y positions
_, xtrace, ytrace, _ = pastasoss.get_soss_traces(pwcpos, order=order, refmodel=soss_refmodel)
box_weights = soss_extract.get_box_weights(ytrace, width, img_shape, cols=xtrace.astype(int))
# Get wavelengths
wavemaps = pastasoss.get_soss_wavemaps(
pwcpos, subarray=subarray, refmodel=soss_refmodel, orders_requested=[order]
)
wave_grid = wavemaps[0]
# Extract a spectrum from each integration
spec_list = []
for i in range(nints):
sci_data = rateints.data[i]
sci_mask = (rateints.dq[i] & datamodels.dqflags.pixel["DO_NOT_USE"]) > 0
weights = box_weights.copy()
weights[sci_mask] = 0
npix = np.sum(box_weights, axis=0)
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
# Sum the flux
flux = np.nansum(sci_data * box_weights, axis=0)
# Average the wavelengths
wavelength = np.nansum(wave_grid * box_weights, axis=0) / npix
# Store fluxes with valid wavelengths
valid = np.isfinite(wavelength)
spec = datamodels.SpecModel()
spec.spec_table = np.zeros((valid.sum(),), dtype=spec.get_dtype("spec_table"))
spec.spec_table["FLUX"] = flux[valid]
spec.spec_table["WAVELENGTH"] = wavelength[valid]
spec.spectral_order = order
spec_list.append(spec)
# Make a multispec model to hold the spectra
tso_spec = make_tso_specmodel(spec_list)
# Populate the midtime array with unique placeholder values to
# make sure every integration appears in the whitelight table, later
tso_spec.spec_table["MJD-AVG"] = np.arange(nints)
multi_spec = datamodels.TSOMultiSpecModel()
multi_spec.update(rateints)
multi_spec.spec.append(tso_spec)
return multi_spec
def _make_background_ramp(input_ramp, background_rateints_data):
"""
Extrapolate a background rate to an up-the-ramp sampling.
Given a ramp and a model of the background rate per integration,
produce an up-the-ramp version of the countrate image to match the ramp.
Total time for one group is (nframes + groupgap) * frame_time. The group
gap occurs after the last frame, so the read time for group i
(zero-indexed) is:
group_time = (nframes + (nframes + groupgap) * i) * frame_time
Parameters
----------
input_ramp : `~stdatamodels.jwst.datamodels.RampModel`
The input ramp datamodel to match.
background_rateints_data : ndarray
A 3D background array, with one rate image per integration.
Returns
-------
background_ramp : ndarray
4D up-the-ramp background data, matching the array shape of
the input ramp.
"""
nints, ngroups, ny, nx = input_ramp.data.shape
nframes = input_ramp.meta.exposure.nframes
groupgap = input_ramp.meta.exposure.groupgap
frame_time = input_ramp.meta.exposure.frame_time
one_group = nframes + groupgap
background_ramp = np.zeros_like(input_ramp.data)
for i in range(ngroups):
group_time = (nframes + one_group * i) * frame_time
background_ramp[:, i, ...] = background_rateints_data * group_time
return background_ramp
[docs]
def make_median_image(input_model, rateints_model, soss_refmodel=None):
"""
Make a scaled median image across integrations to subtract before cleaning.
The procedure is:
1. Subtract a reference background from the input ramp and rate data
(NIRISS SOSS only).
2. Compute a representative flux for scaling for each integration
from the rate data.
a. For TSO spectral modes, extract a spectrum from a simple box
and compute the whitelight flux for each integration.
b. For TSO imaging modes, sum the flux over an aperture at the
expected source location for each integration.
c. For any other mode, take the median of the flux in each
integration.
3. Median combine all data across integrations to make a median
ramp or image.
4. Scale the median image by the representative flux for each
integration.
5. Add the subtracted background level to the median image
(NIRISS SOSS only).
Parameters
----------
input_model : `~stdatamodels.jwst.datamodels.RampModel` or \
`~stdatamodels.jwst.datamodels.CubeModel`
Ramp or rateints model to be cleaned.
rateints_model : `~stdatamodels.jwst.datamodels.CubeModel`
Draft rateints model corresponding to the input model.
May be the same as ``input_model``.
soss_refmodel : `~stdatamodels.jwst.datamodels.PastasossModel`, optional
Used to identify NIRISS SOSS traces for box extraction.
If not provided, a default model will be retrieved.
Returns
-------
scaled_median : ndarray
The scaled median image to subtract, matching the dimensions of
the input model.
Raises
------
ValueError
If the input does not have multiple integrations or extracted
fluxes are all invalid.
"""
ndim = input_model.data.ndim
if ndim < 3:
raise ValueError("Cannot make a median image for 2D data")
nints = input_model.data.shape[0]
if nints <= 1:
raise ValueError("Cannot make a median image for <2 integrations")
# Run background subtraction for rateints files
exp_type = input_model.meta.exposure.type
if exp_type == "NIS_SOSS":
log.info("Calling the bkg_subtract step on the rate file to subtract SOSS background")
with disable_logging(level=logging.WARNING):
step = BackgroundStep()
bgsub_rateints = step.run(rateints_model)
# Subtract rateints models to get the background by integration
background_rate = rateints_model.data - bgsub_rateints.data
# Replace any NaN values in the background rate with smoothed local values
for i, bg_data in enumerate(background_rate):
invalid = ~np.isfinite(bg_data)
if np.all(invalid):
raise ValueError("No valid values in background rate")
if np.any(invalid):
log.debug(f"Replacing {np.sum(invalid)} values in background integration {i}")
smoothed_bg = background_level(bg_data, ~invalid, background_method="model")
if np.isscalar(smoothed_bg):
# 2D model failed, median value returned instead
bg_data[invalid] = smoothed_bg
else:
bg_data[invalid] = smoothed_bg[invalid]
else:
bgsub_rateints = rateints_model
background_rate = 0.0
# Box extraction for flux scaling
if exp_type == "NIS_SOSS":
# Simple direct box extraction for SOSS
multi_spec = _soss_box_extract(bgsub_rateints, soss_refmodel=soss_refmodel)
elif exp_type in ["NRS_BRIGHTOBJ", "NRC_TSGRISM"]:
log.info("Calling the extract_2d and extract_1d steps to extract a representative spectrum")
with disable_logging(level=logging.WARNING):
# Run extract2d to assign a slit-appropriate WCS
# (required for extract_1d)
step = Extract2dStep()
single_slit = step.run(bgsub_rateints)
# Set the source type to POINT for TSO
single_slit.source_type = "POINT"
# Call extract_1d with latest CRDS parameters (via call)
# since the recommended defaults may vary by mode
multi_spec = Extract1dStep.call(single_slit, save_results=False)
single_slit.close()
if not isinstance(multi_spec, datamodels.TSOMultiSpecModel) or len(multi_spec.spec) < 1:
raise ValueError("No valid spectra extracted for flux scaling")
else:
# Not a TSO spectral mode
multi_spec = None
# Sum the flux for each integration and normalize by the median across all integrations
if multi_spec is not None:
# For spectra, use the whitelight step to sum the flux, using a wavelengthrange
# file as appropriate.
log.info(
"Calling the white_light step to compute an approximate whitelight curve for scaling"
)
with disable_logging(level=logging.WARNING):
step = WhiteLightStep()
whitelight_table = step.run(multi_spec)
multi_spec.close()
if exp_type == "NRS_BRIGHTOBJ":
detector = input_model.meta.instrument.detector
wlc_flux = whitelight_table[f"whitelight_flux_{detector}"]
else:
wlc_flux = whitelight_table["whitelight_flux"]
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
norm_flux = wlc_flux / np.nanmedian(wlc_flux)
elif exp_type == "NRC_TSIMAGE" or (exp_type == "MIR_IMAGE" and is_tso(input_model)):
# For imaging, call tso_photometry with latest CRDS parameters (via call),
# to sum the flux. The recommended defaults for this step may vary by mode.
log.info(
"Calling the tso_photometry step to compute an approximate aperture flux for scaling"
)
with disable_logging(level=logging.WARNING):
phot_table = TSOPhotometryStep.call(bgsub_rateints)
# Use the aperture sum as the representative flux for scaling
phot_flux = phot_table["aperture_sum"].value
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
norm_flux = phot_flux / np.nanmedian(phot_flux)
else:
# Not an expected TSO spectral or imaging mode.
# Take the median flux of the image as the representative value.
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
median_flux = np.nanmedian(bgsub_rateints.data, axis=(1, 2))
norm_flux = median_flux / np.nanmedian(median_flux)
# Check for bad values in the normalized flux
invalid = ~np.isfinite(norm_flux)
if np.all(invalid):
# Raise an error if they are all bad
raise ValueError("No valid flux for scaling")
elif np.any(invalid):
# Otherwise replace with a median value to avoid losing a whole integration
log.warning(
f"{np.sum(invalid)} integration(s) out of {nints} had non-finite "
"extracted flux and will be scaled by the median flux instead."
)
norm_flux[invalid] = np.median(norm_flux[~invalid])
# Make a background corrected ramp if needed
if ndim > 3:
# Check for a real background image (SOSS only)
if not np.isscalar(background_rate):
# Extrapolate a background ramp from the rate
background_ramp = _make_background_ramp(input_model, background_rate)
# Subtract it from the data for the median computation below
bgsub_ramp = input_model.data - background_ramp
else:
# Otherwise, the background is zero everywhere and we will take
# the median over the input data
background_ramp = 0.0
bgsub_ramp = input_model.data
else:
background_ramp = background_rate # may be 0.0
bgsub_ramp = bgsub_rateints.data
# Make a median background-subtracted ramp
log.info("Making a scaled median image")
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
# np.nanmedian allocates lots of memory; this for loop gets around that
median_ramp = np.empty(bgsub_ramp.shape[1:], dtype=bgsub_ramp.dtype)
for i in range(median_ramp.shape[0]):
np.nanmedian(
bgsub_ramp[:, i, ...], axis=0, overwrite_input=True, out=median_ramp[i, ...]
)
# Scale the median ramp by the normalized flux
if ndim == 3:
scaled_median = norm_flux[:, None, None] * median_ramp[None, ...]
else:
scaled_median = norm_flux[:, None, None, None] * median_ramp[None, ...]
# Final data to subtract is background ramp + scaled median
scaled_median += background_ramp
return scaled_median