import logging
import multiprocessing as mp
import warnings
import numpy as np
from astropy.modeling.mappings import Mapping
from scipy import sparse
from jwst.lib.winclip import get_clipped_pixels
from jwst.wfss_contam.sens1d import create_1d_sens
log = logging.getLogger(__name__)
__all__ = ["disperse"]
def _determine_native_wl_spacing(
x0_sky,
y0_sky,
sky_to_imgxy,
imgxy_to_grismxy,
order,
wmin,
wmax,
oversample_factor=2,
):
"""
Determine the wavelength spacing necessary to adequately sample the dispersed frame.
Parameters
----------
x0_sky : float or ndarray
RA of the input pixel position in direct image and segmentation map
y0_sky : float or ndarray
Dec of the input pixel position in direct image and segmentation map
sky_to_imgxy : astropy model
Transform from sky to image coordinates
imgxy_to_grismxy : astropy model
Transform from image to grism coordinates
order : int
Spectral order number
wmin : float
Minimum wavelength for dispersed spectra
wmax : float
Maximum wavelength for dispersed spectra
oversample_factor : int, optional
Factor by which to oversample the wavelength grid
Returns
-------
lambdas : ndarray
Wavelengths at which to compute dispersed pixel values
Notes
-----
It was found that the native wavelength spacing varies by a few percent or less
across the detector for both NIRCam and NIRISS. This function has the capability to
take in many x0, y0 at once and take the median to get the wavelengths,
but typically it's okay to just put in any x0, y0 pair.
"""
# Get x/y positions in the grism image corresponding to wmin and wmax:
# Convert to x/y in the direct image frame
x0_xy, y0_xy, _, _ = sky_to_imgxy(x0_sky, y0_sky, 1, order)
# then convert to x/y in the grism image frame.
xwmin, ywmin = imgxy_to_grismxy(x0_xy, y0_xy, wmin, order)
xwmax, ywmax = imgxy_to_grismxy(x0_xy, y0_xy, wmax, order)
dxw = xwmax - xwmin
dyw = ywmax - ywmin
# Create list of wavelengths on which to compute dispersed pixels
dw = np.abs((wmax - wmin) / (dyw - dxw))
dlam = np.median(dw / oversample_factor)
lambdas = np.arange(wmin, wmax + dlam, dlam)
return lambdas
def _disperse_onto_grism(x0_sky, y0_sky, sky_to_imgxy, imgxy_to_grismxy, lambdas, order):
"""
Compute x/y positions in the grism image for the set of desired wavelengths.
Parameters
----------
x0_sky : ndarray
RA of the input pixel position in direct image and segmentation map
y0_sky : ndarray
Dec of the input pixel position in direct image and segmentation map
sky_to_imgxy : astropy model
Transform from sky to image coordinates
imgxy_to_grismxy : astropy model
Transform from image to grism coordinates
lambdas : ndarray
Wavelengths at which to compute dispersed pixel values
order : int
Spectral order number
Returns
-------
x0s : ndarray
X coordinates of dispersed pixels in the grism image
y0s : ndarray
Y coordinates of dispersed pixels in the grism image
lambdas : ndarray
Wavelengths corresponding to each dispersed pixel
"""
# x/y in image frame of grism image is the same for all wavelengths
x0_sky = np.repeat(x0_sky[np.newaxis, :], len(lambdas), axis=0)
y0_sky = np.repeat(y0_sky[np.newaxis, :], len(lambdas), axis=0)
x0_xy, y0_xy, _, _ = sky_to_imgxy(x0_sky, y0_sky, lambdas, order)
del x0_sky, y0_sky
# Convert to x/y in grism frame.
lambdas = np.repeat(lambdas[:, np.newaxis], x0_xy.shape[1], axis=1)
x0s, y0s = imgxy_to_grismxy(x0_xy, y0_xy, lambdas, order)
# x0s, y0s now have shape (n_lam, n_pixels)
return x0s, y0s, lambdas
def _collect_outputs_by_source(xs, ys, counts, source_ids_per_pixel):
"""
Collect the dispersed pixel values into separate images for each source.
Parameters
----------
xs : ndarray
X coordinates of dispersed pixels
ys : ndarray
Y coordinates of dispersed pixels
counts : ndarray
Count rates of dispersed pixels
source_ids_per_pixel : int array
Source IDs of the dispersed pixels
Returns
-------
outputs_by_source : dict
Dictionary containing dispersed images and bounds for each source ID
"""
# First sort by source ID. xs, ys input here cannot be assumed sorted after get_clipped_pixels
sort_idx = np.argsort(source_ids_per_pixel)
sorted_ids = source_ids_per_pixel[sort_idx]
sorted_xs = xs[sort_idx]
sorted_ys = ys[sort_idx]
sorted_counts = counts[sort_idx]
# Compute per-source bounds in a vectorized way
unique_ids, split_points = np.unique(sorted_ids, return_index=True)
minxs = np.minimum.reduceat(sorted_xs, split_points)
maxxs = np.maximum.reduceat(sorted_xs, split_points)
minys = np.minimum.reduceat(sorted_ys, split_points)
maxys = np.maximum.reduceat(sorted_ys, split_points)
# Now loop through sources, build the output images, and store bounds
# to reconstruct the full dispersed image later
outputs_by_source = {}
for i, this_sid in enumerate(unique_ids):
start = split_points[i]
end = split_points[i + 1] if i + 1 < len(split_points) else len(sorted_xs)
this_xs = sorted_xs[start:end]
this_ys = sorted_ys[start:end]
this_flxs = sorted_counts[start:end]
bounds = [int(minxs[i]), int(maxxs[i]), int(minys[i]), int(maxys[i])]
img = _build_dispersed_image_of_source(this_xs, this_ys, this_flxs, bounds)
outputs_by_source[this_sid] = {
"bounds": bounds,
"image": img,
}
return outputs_by_source
def _build_dispersed_image_of_source(x, y, flux, bounds):
"""
Convert a flattened list of pixels to a 2-D grism image of that source.
Parameters
----------
x : ndarray
X coordinates of pixels in the grism image
y : ndarray
Y coordinates of pixels in the grism image
flux : ndarray
Fluxes of pixels in the grism image
bounds : list
Pre-computed [minx, maxx, miny, maxy] bounds for the source.
Returns
-------
a : ndarray
2-D dispersed image of the source
"""
minx, maxx, miny, maxy = bounds
return sparse.coo_matrix(
(flux, (y - miny, x - minx)), shape=(maxy - miny + 1, maxx - minx + 1)
).toarray()
[docs]
def disperse(
xs,
ys,
fluxes,
source_ids_per_pixel,
order,
wmin,
wmax,
sens_waves,
sens_resp,
direct_image_wcs,
grism_wcs,
naxis,
oversample_factor=2,
):
"""
Compute the dispersed image pixel values from the direct image.
Parameters
----------
xs : ndarray
Flat array of X coordinates of pixels in the direct image
ys : ndarray
Flat array of Y coordinates of pixels in the direct image
fluxes : ndarray
Fluxes of the pixels in the direct image corresponding to xs, ys.
These should have units of MJy/sr.
source_ids_per_pixel : int array
Source IDs of the input pixels in the segmentation map
order : int
Spectral order number
wmin : float
Minimum wavelength for dispersed spectra
wmax : float
Maximum wavelength for dispersed spectra
sens_waves : float array
Wavelength array from photom reference file. Expected unit is micron.
sens_resp : float array
Response (flux calibration) array from photom reference file.
Expected units are (micron) * (MJy / sr) / (ADU/s).
direct_image_wcs : WCS object
WCS object for the direct image and segmentation map
grism_wcs : WCS object
WCS object for the grism image
naxis : tuple
Dimensions of the grism image (naxis[0], naxis[1])
oversample_factor : int, optional
Factor by which to oversample the wavelength grid
Returns
-------
outputs_by_source : dict
Dictionary containing dispersed images and bounds for each source ID
in the specified spectral order.
"""
n_input_sources = np.unique(source_ids_per_pixel).size
log.debug(
f"{mp.current_process()} dispersing {n_input_sources} "
f"sources in order {order} with total number of pixels: {len(xs)}"
)
width = 1.0
height = 1.0
x0 = xs + 0.5 * width
y0 = ys + 0.5 * height
del xs, ys
# Set up the transforms we need from the input WCS objects
sky_to_imgxy = grism_wcs.get_transform("world", "detector")
imgxy_to_grismxy = grism_wcs.get_transform("detector", "grism_detector")
# We only need the x,y outputs of imgxy_to_grismxy
# Making the number of outputs dynamic handles legacy WCS objects that did not pass
# the x0, y0, and order through the transform unmodified like the current version does.
n_outputs = len(imgxy_to_grismxy.outputs)
imgxy_to_grismxy = imgxy_to_grismxy | Mapping((0, 1), n_inputs=n_outputs)
# Find RA/Dec of the input pixel position in direct image
x0_sky, y0_sky = direct_image_wcs(x0, y0, with_bounding_box=False)
del x0, y0
# native spacing does not change much over the detector, so just put in one x0, y0
lambdas = _determine_native_wl_spacing(
x0_sky[0],
y0_sky[0],
sky_to_imgxy,
imgxy_to_grismxy,
order,
wmin,
wmax,
oversample_factor=oversample_factor,
)
nlam = len(lambdas)
dlam = lambdas[1] - lambdas[0]
x0s, y0s, lambdas = _disperse_onto_grism(
x0_sky,
y0_sky,
sky_to_imgxy,
imgxy_to_grismxy,
lambdas,
order,
)
del x0_sky, y0_sky
# If none of the dispersed pixel indexes are within the image frame,
# return a null result without wasting time doing other computations
if x0s.min() >= naxis[0] or x0s.max() < 0 or y0s.min() >= naxis[1] or y0s.max() < 0:
return
source_ids_per_pixel = np.repeat(source_ids_per_pixel[np.newaxis, :], nlam, axis=0)
fluxes = np.repeat(fluxes[np.newaxis, :], nlam, axis=0)
# Discretize x and y coordinates to integer pixel values, keeping track of the fractional area
# that each pixel contributes to the final grism image.
# The resulting x, y coordinate pairs are non-unique: there are multiple wavelengths
# that contribute to each pixel.
padding = 1
xs, ys, areas, index = get_clipped_pixels(x0s, y0s, padding, naxis[0], naxis[1], width, height)
del x0s, y0s
lambdas = np.take(lambdas, index)
fluxes = np.take(fluxes, index)
source_ids_per_pixel = np.take(source_ids_per_pixel, index)
del index
# compute 1D sensitivity array corresponding to list of wavelengths
sens, no_cal = create_1d_sens(lambdas, sens_waves, sens_resp)
del lambdas
# Compute countrates for dispersed pixels.
# The input direct image data is already photometrically calibrated,
# so we need to basically apply a reverse flux calibration here.
# Divide out the response values to convert from Mjy/sr to DN/s.
# Note that the photom reference files are constructed with per-wavelength units,
# so oversampling is accounted for by the spacing of dlam.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning, message="divide by zero")
counts = fluxes * areas * dlam / sens
counts[no_cal] = 0.0 # set to zero where no flux cal info available
del fluxes, areas, sens, dlam, no_cal
outputs_by_source = _collect_outputs_by_source(xs, ys, counts, source_ids_per_pixel)
del xs, ys, counts, source_ids_per_pixel
n_out = len(outputs_by_source)
log.debug(
f"{mp.current_process()} finished order {order} with {n_out} "
f"sources that overlap with the output frame "
f"(out of {n_input_sources} input sources)"
)
return outputs_by_source