Source code for jwst.wfss_contam.observations

import logging
import multiprocessing as mp
import time
import warnings

import numpy as np
from astropy.stats import SigmaClip
from astropy.utils.exceptions import AstropyUserWarning
from photutils.background import Background2D, MedianBackground
from stdatamodels.jwst import datamodels

from jwst.wfss_contam.disperse import disperse

log = logging.getLogger(__name__)

__all__ = ["background_subtract", "Observation"]


[docs] def background_subtract( data, box_size=None, filter_size=(3, 3), sigma=3.0, exclude_percentile=30.0, ): """ Apply a simple astropy background subtraction. Parameters ---------- data : ndarray 2D array of pixel values box_size : tuple Size of box in pixels to use for background estimation. If not set, defaults to 1/5 of the image size. filter_size : tuple Size of filter to use for background estimation sigma : float Sigma threshold for background clipping exclude_percentile : float Percentage of masked pixels above which box is excluded from background estimation Returns ------- data : ndarray 2D array of pixel values with background subtracted Notes ----- Improper background subtraction in input _i2d image leads to extra flux in the simulated dispersed image, and was one cause of flux scaling issues in a previous version. """ if box_size is None: box_size = (int(data.shape[0] / 5), int(data.shape[1] / 5)) sigma_clip = SigmaClip(sigma=sigma) bkg_estimator = MedianBackground() with warnings.catch_warnings(): # there can be multiple different AstropyUserWarning messages here about NaN and Inf values warnings.filterwarnings("ignore", category=AstropyUserWarning) bkg = Background2D( data, box_size, filter_size=filter_size, sigma_clip=sigma_clip, bkg_estimator=bkg_estimator, exclude_percentile=exclude_percentile, ) return data - bkg.background
def _select_ids(source_id, all_ids): """ Select the source IDs to be processed based on the input ID parameter. Parameters ---------- source_id : int or list-like ID(s) of source to process. If None, all sources processed. all_ids : ndarray Array of all source IDs in the segmentation map Returns ------- selected_IDs : list List of selected source IDs """ if source_id is None: log.info(f"Loading all {len(all_ids)} sources from segmentation map") return all_ids elif isinstance(source_id, int): log.info(f"Loading single source {source_id} from segmentation map") return [source_id] elif isinstance(source_id, list) or isinstance(source_id, np.ndarray): log.info( f"Loading {len(source_id)} of {len(all_ids)} selected sources from segmentation map" ) return list(source_id) else: raise ValueError("ID must be an integer or a list of integers")
[docs] class Observation: """ Define an observation leading to a single grism image. The Observation class is responsible for calling the various WCS transforms that convert a direct image and a segmentation image into a simulation of the grism image, making assumptions about the spectral properties of the direct image sources. When the `disperse_order` method is called one or more times, two products are created: the simulated dispersed image (``simulated_image`` attribute) and the simulated `~stdatamodels.jwst.datamodels.MultiSlitModel` (``simulated_slits`` attribute). """ def __init__( self, direct_image, segmentation_map, grism_wcs, direct_image_wcs, boundaries=None, max_cpu=1, max_pixels_per_chunk=5e4, oversample_factor=2, ): """ Initialize all data and metadata for a given observation. Parameters ---------- direct_image : np.ndarray Direct imaging data. segmentation_map : np.ndarray Segmentation map data. grism_wcs : `~gwcs.wcs.WCS` WCS object from grism image direct_image_wcs : `~gwcs.wcs.WCS` WCS object from direct image boundaries : list, optional Start/Stop coordinates of the FOV within the larger seed image. max_cpu : int, optional Max number of cpu's to use when multiprocessing max_pixels_per_chunk : int, optional Maximum number of pixels per chunk when dispersing sources oversample_factor : int, optional Factor by which to oversample the wavelength grid """ if boundaries is None: boundaries = [] # Load all the info for this grism mode self.direct_image_wcs = direct_image_wcs self.grism_wcs = grism_wcs self.seg = segmentation_map all_ids = list(set(np.ravel(self.seg))) all_ids.remove(0) # Remove the background ID self.source_ids = all_ids self.max_cpu = max_cpu self.max_pixels_per_chunk = max_pixels_per_chunk self.oversample_factor = oversample_factor # ensure the direct image has background subtracted self.dimage = background_subtract(direct_image) # Set the limits of the dispersed image to be simulated if len(boundaries) == 0: log.debug("No boundaries passed.") self.xstart = 0 self.xend = self.xstart + self.seg.shape[0] - 1 self.ystart = 0 self.yend = self.ystart + self.seg.shape[1] - 1 else: self.xstart, self.xend, self.ystart, self.yend = boundaries self.dims = (self.yend - self.ystart + 1, self.xend - self.xstart + 1) log.debug(f"Using simulated image size of ({self.dims[1]}, {self.dims[0]}).") self.naxis = self.dims[::-1] # Create lists of pixels labeled in segmentation map self._create_pixel_list() # Initialize the output MultiSlitModel self.simulated_slits = datamodels.MultiSlitModel() # Initialize the simulated dispersed image self.simulated_image = np.zeros(self.dims, float) def _create_pixel_list(self): """Create flat lists of pixels to be dispersed.""" self.ys, self.xs = np.nonzero(self.seg) self.source_ids_per_pixel = self.seg[self.ys, self.xs] self.fluxes = self.dimage[self.ys, self.xs]
[docs] def chunk_sources( self, order, wmin, wmax, sens_waves, sens_response, selected_ids=None, max_pixels=1e5 ): """ Chunk the sources into groups of max_pixels. Parameters ---------- order : int Spectral order to process wmin : float Minimum wavelength for dispersed spectra wmax : float Maximum wavelength for dispersed spectra sens_waves : ndarray Wavelength array from photom reference file sens_response : ndarray Response (flux calibration) array from photom reference file selected_ids : list, optional List of source IDs to process. If None, all sources are processed. max_pixels : int, optional Maximum number of pixels per chunk. Returns ------- disperse_args : list[list] Outer list has length number of groups, and each inner list contains the arguments to disperse() for that group in the format that multiprocessing starmap expects. """ source_ids = _select_ids(selected_ids, self.source_ids) max_pixels = int(max_pixels) # Create a mask for selected sources selected_mask = np.isin(self.source_ids_per_pixel, source_ids) # Get pixels for selected sources selected_xs = self.xs[selected_mask] selected_ys = self.ys[selected_mask] selected_fluxes = self.fluxes[selected_mask] selected_source_ids = self.source_ids_per_pixel[selected_mask] # Sort by source ID to keep sources mostly together # This reduces the number of times we have to call build_dispersed_image_of_source # within disperse() sort_indices = np.argsort(selected_source_ids) sorted_xs = selected_xs[sort_indices] sorted_ys = selected_ys[sort_indices] sorted_fluxes = selected_fluxes[sort_indices] sorted_source_ids = selected_source_ids[sort_indices] # Split into chunks of max_pixels total_pixels = len(sorted_xs) n_chunks = int(np.ceil(total_pixels / max_pixels)) log.info( f"Splitting {total_pixels} pixels from {len(source_ids)} sources into {n_chunks} chunks" ) disperse_args = [] for i in range(n_chunks): start_idx = i * max_pixels end_idx = min((i + 1) * max_pixels, total_pixels) chunk_xs = sorted_xs[start_idx:end_idx] chunk_ys = sorted_ys[start_idx:end_idx] chunk_fluxes = sorted_fluxes[start_idx:end_idx] chunk_source_ids = sorted_source_ids[start_idx:end_idx] disperse_args.append( [ chunk_xs, chunk_ys, chunk_fluxes, chunk_source_ids, order, wmin, wmax, sens_waves, sens_response, self.direct_image_wcs, self.grism_wcs, self.naxis, self.oversample_factor, ] ) return disperse_args
[docs] def disperse_order(self, order, wmin, wmax, sens_waves, sens_response, selected_ids=None): """ Disperse the sources for a given spectral order, with multiprocessing. The ``simulated_slits`` and ``simulated_image`` attributes are updated in place. Parameters ---------- order : int Spectral order to process wmin : float Minimum wavelength for dispersed spectra wmax : float Maximum wavelength for dispersed spectra sens_waves : ndarray Wavelength array from photom reference file sens_response : ndarray Response (flux calibration) array from photom reference file selected_ids : list, optional List of source IDs to process. If None, all sources are processed. """ # generate lists of input parameters for the disperse function # for each chunk of sources disperse_args = self.chunk_sources( order, wmin, wmax, sens_waves, sens_response, selected_ids=selected_ids, max_pixels=self.max_pixels_per_chunk, ) t0 = time.time() if self.max_cpu > 1: # Use multiprocessing to disperse the sources log.info( f"Using {self.max_cpu} CPU cores for multiprocessing " f"{len(self.source_ids)} sources in {len(disperse_args)} chunks." ) ctx = mp.get_context("spawn") pool = ctx.Pool(self.max_cpu) try: all_res = pool.starmap(disperse, disperse_args) except Exception as e: log.error(f"Error during parallel processing: {e}") raise finally: pool.close() pool.join() else: all_res = [disperse(*args) for args in disperse_args] t1 = time.time() log.info(f"Wall clock time for disperse_chunk order {order}: {(t1 - t0):.1f} sec") # Combine results from all chunks, aggregating by source ID source_results = {} for results in all_res: if results is None: # None of the sources in this chunk for this order had pixels on the detector continue for sid in results: _aggregate_by_source(results, sid, source_results) # Now add the combined results to the simulation for sid in source_results: bounds = source_results[sid]["bounds"] img = source_results[sid]["image"] slit = _construct_slitmodel(img, bounds, sid, order) self.simulated_image[bounds[2] : bounds[3] + 1, bounds[0] : bounds[1] + 1] += img self.simulated_slits.slits.append(slit)
def _aggregate_by_source(results, sid, source_results): """ Combine results from different chunks into a single image and bounds for each source ID. Parameters ---------- results : dict Dictionary containing the results for each source ID in the current chunk, in the format: {source_id: {"bounds": [xmin, xmax, ymin, ymax], "image": 2D array}} sid : int Source ID source_results : dict Dictionary to store simulated image and bounds for each source ID, in the format: {source_id: {"bounds": [xmin, xmax, ymin, ymax], "image": 2D array}} Updated in place. """ if sid not in source_results: source_results[sid] = { "bounds": results[sid]["bounds"], "image": results[sid]["image"], } return # Combine bounds old_bounds = source_results[sid]["bounds"] new_bounds = results[sid]["bounds"] combined_bounds = [ min(old_bounds[0], new_bounds[0]), max(old_bounds[1], new_bounds[1]), min(old_bounds[2], new_bounds[2]), max(old_bounds[3], new_bounds[3]), ] # Create combined image with the union of bounds combined_shape = ( combined_bounds[3] - combined_bounds[2] + 1, combined_bounds[1] - combined_bounds[0] + 1, ) combined_image = np.zeros(combined_shape, dtype=float) # Add existing image to combined image old_y_start = old_bounds[2] - combined_bounds[2] old_y_end = old_y_start + source_results[sid]["image"].shape[0] old_x_start = old_bounds[0] - combined_bounds[0] old_x_end = old_x_start + source_results[sid]["image"].shape[1] combined_image[old_y_start:old_y_end, old_x_start:old_x_end] += source_results[sid]["image"] # Add new image to combined image new_y_start = new_bounds[2] - combined_bounds[2] new_y_end = new_y_start + results[sid]["image"].shape[0] new_x_start = new_bounds[0] - combined_bounds[0] new_x_end = new_x_start + results[sid]["image"].shape[1] combined_image[new_y_start:new_y_end, new_x_start:new_x_end] += results[sid]["image"] # Update source results source_results[sid] = {"bounds": combined_bounds, "image": combined_image} def _construct_slitmodel( img, bounds, sid, order, ): """ Turn an output image from a single source/order into a SlitModel. Parameters ---------- img : ndarray Dispersed model image of segmentation map source bounds : list The bounds of the object in relation to the full-frame image. sid : int The source ID order : int The spectral order Returns ------- slit : `jwst.datamodels.SlitModel` Slit model containing the dispersed pixel values """ [thisobj_minx, thisobj_maxx, thisobj_miny, thisobj_maxy] = bounds slit = datamodels.SlitModel() slit.source_id = sid slit.name = f"source_{sid}" slit.xstart = thisobj_minx slit.xsize = thisobj_maxx - thisobj_minx + 1 slit.ystart = thisobj_miny slit.ysize = thisobj_maxy - thisobj_miny + 1 slit.meta.wcsinfo.spectral_order = order slit.data = img return slit