"""
Utility function for assign_wcs.
"""
import warnings
import logging
import functools
import numpy as np
from astropy.utils.misc import isiterable
from astropy.io import fits
from astropy.modeling import models as astmodels
from astropy.table import QTable
from astropy.constants import c
from gwcs import WCS
from gwcs.wcstools import wcs_from_fiducial, grid_from_bounding_box
from gwcs import utils as gwutils
from . import pointing
from ..lib.catalog_utils import SkyObject
from ..transforms.models import GrismObject
from ..datamodels import WavelengthrangeModel, DataModel, CubeModel, IFUCubeModel
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
__all__ = ["reproject", "wcs_from_footprints", "velocity_correction"]
class MissingMSAFileError(Exception):
def __init__(self, message):
super(MissingMSAFileError, self).__init__(message)
class NoDataOnDetectorError(Exception):
"""WCS solution indicates no data on detector
When WCS solutions are available, the solutions indicate that no data
will be present, raise this exception.
Specific example is for NIRSpec and the NRS2 detector. For various
configurations of the MSA, it is possible that no dispersed spectra will
appear on NRS2. This is not a failure of calibration, but needs to be
called out in order for the calling architecture to be aware of this.
"""
def __init__(self, message=None):
if message is None:
message = 'WCS solution indicate that no science is in the data.'
super(NoDataOnDetectorError, self).__init__(message)
def _domain_to_bounding_box(domain):
# TODO: remove this when domain is completely removed
bb = tuple([(item['lower'], item['upper']) for item in domain])
if len(bb) == 1:
bb = bb[0]
return bb
[docs]def reproject(wcs1, wcs2, origin=0):
"""
Given two WCSs return a function which takes pixel coordinates in
the first WCS and computes their location in the second one.
It performs the forward transformation of ``wcs1`` followed by the
inverse of ``wcs2``.
Parameters
----------
wcs1, wcs2 : `~gwcs.wcs.WCS`
WCS objects.
Returns
-------
_reproject : func
Function to compute the transformations. It takes x, y
positions in ``wcs1`` and returns x, y positions in ``wcs2``.
"""
def _reproject(x, y):
sky = wcs1.forward_transform(x, y)
return wcs2.backward_transform(*sky)
return _reproject
def compute_fiducial(wcslist, bounding_box=None, domain=None):
"""
For a celestial footprint this is the center.
For a spectral footprint, it is the beginning of the range.
This function assumes all WCSs have the same output coordinate frame.
"""
if domain is not None:
warnings.warning("'domain' was deprecated in 0.8 and will be removed from next"
"version. Use 'bounding_box' instead.")
axes_types = wcslist[0].output_frame.axes_type
spatial_axes = np.array(axes_types) == 'SPATIAL'
spectral_axes = np.array(axes_types) == 'SPECTRAL'
footprints = np.hstack([w.footprint(bounding_box=bounding_box).T for w in wcslist])
spatial_footprint = footprints[spatial_axes]
spectral_footprint = footprints[spectral_axes]
fiducial = np.empty(len(axes_types))
if (spatial_footprint).any():
lon, lat = spatial_footprint
lon, lat = np.deg2rad(lon), np.deg2rad(lat)
x_mean = np.mean(np.cos(lat) * np.cos(lon))
y_mean = np.mean(np.cos(lat) * np.sin(lon))
z_mean = np.mean(np.sin(lat))
lon_fiducial = np.rad2deg(np.arctan2(y_mean, x_mean)) % 360.0
lat_fiducial = np.rad2deg(np.arctan2(z_mean, np.sqrt(x_mean ** 2 +
y_mean ** 2)))
fiducial[spatial_axes] = lon_fiducial, lat_fiducial
if (spectral_footprint).any():
fiducial[spectral_axes] = spectral_footprint.min()
return fiducial
def is_fits(input):
"""
Returns
--------
isFits: tuple
An ``(isfits, fitstype)`` tuple. The values of ``isfits`` and
``fitstype`` are specified as:
- ``isfits``: True|False
- ``fitstype``: if True, one of 'waiver', 'mef', 'simple'; if False, None
Notes
-----
Input images which do not have a valid FITS filename will automatically
result in a return of (False, None).
In the case that the input has a valid FITS filename but runs into some
error upon opening, this routine will raise that exception for the calling
routine/user to handle.
"""
isfits = False
fitstype = None
names = ['fits', 'fit', 'FITS', 'FIT']
#determine if input is a fits file based on extension
# Only check type of FITS file if filename ends in valid FITS string
f = None
fileclose = False
if isinstance(input, fits.HDUList):
isfits = True
f = input
else:
isfits = True in [input.endswith(l) for l in names]
# if input is a fits file determine what kind of fits it is
# waiver fits len(shape) == 3
if isfits:
if not f:
try:
f = fits.open(input, mode='readonly')
fileclose = True
except Exception:
if f is not None:
f.close()
raise
data0 = f[0].data
if data0 is not None:
try:
if isinstance(f[1], fits.TableHDU):
fitstype = 'waiver'
except IndexError:
fitstype = 'simple'
else:
fitstype = 'mef'
if fileclose:
f.close()
return isfits, fitstype
def subarray_transform(input_model):
"""
Inputs are in full frame coordinates.
If a subarray observation - shift the inputs.
"""
xstart = input_model.meta.subarray.xstart
ystart = input_model.meta.subarray.ystart
if xstart is None:
xstart = 1
if ystart is None:
ystart = 1
subarray2full = astmodels.Shift(xstart - 1) & astmodels.Shift(ystart - 1)
return subarray2full
def not_implemented_mode(input_model, ref):
"""
Return ``None`` if assign_wcs has not been implemented for a mode.
"""
exp_type = input_model.meta.exposure.type
message = "WCS for EXP_TYPE of {0} is not implemented.".format(exp_type)
log.critical(message)
return None
def get_object_info(catalog_name=None):
"""Return a list of SkyObjects from the direct image
the source_catalog step are read into a list of SkyObjects
which can be referenced by catalog id.
Parameters
----------
catalog_name : str
The name of the photutils catalog
Returns
-------
objects : list[jwst.transforms.models.SkyObject]
A list of SkyObject tuples
Notes
-----
"""
if catalog_name is None:
raise TypeError("Expected name of the catalog file")
objects = []
catalog = QTable.read(catalog_name, format='ascii.ecsv')
# validate that the expected columns are there
# id is just a bad name for a param, but it's used in the catalog
required_fields = list(SkyObject()._fields)
if "sid" in required_fields:
required_fields[required_fields.index("sid")] = "id"
try:
if not set(required_fields).issubset(set(catalog.colnames)):
difference = set(catalog.colnames).difference(required_fields)
raise KeyError("Missing required columns in source catalog ({}): {}"
.format(catalog_name, difference))
except AttributeError as e:
print("Problem validating object catalog columns {0:s}: {1}"
.format(catalog_name, e))
# The columns are named sky_bbox_ll, sky_bbox_ul, sky_bbox_lr, and sky_bbox_ur, each of
# which is a SkyCoord (i.e. RA & Dec & frame) at one corner of the minimal bounding box.
# There will also be a sky_bbox property as a 4-tuple of SkyCoord, but that is not
# serializable (hence, the four separate columns). This is not yet merged in photutils
# -- I discovered some bugs with SkyCoord and serialization, some of which have just been
# fixed in astropy.
for row in catalog:
objects.append(SkyObject(sid=row['id'],
xcentroid=row['xcentroid'],
ycentroid=row['ycentroid'],
sky_centroid=row['sky_centroid'],
abmag=row['abmag'],
abmag_error=row['abmag_error'],
sky_bbox_ll=row['sky_bbox_ll'],
sky_bbox_lr=row['sky_bbox_lr'],
sky_bbox_ul=row['sky_bbox_ul'],
sky_bbox_ur=row['sky_bbox_ur'],
)
)
return objects
def create_grism_bbox(input_model, reference_files,
mmag_extract=99.0):
"""Create bounding boxes for each object in the catalog
The sky coordinates in the catalog image are first related
to the grism image. They need to go through the WCS object
in order to find the "direct image" pixel location, which is
also in a detector pixel coordinate frame. This "direct image"
location can then be sent through the trace polynomials to find
the spectral location on the grism image for that wavelength and order.
Parameters
----------
input_model : `jwst.datamodels.ImagingModel`
Data model which holds the grism image
reference_files : dict
Dictionary of reference files
mmag_extract : float
The faintest magnitude to extract from the catalog
Returns
-------
A list of GrismObject(s) for every source in the catalog
Each grism object contains information about it's
spectral extent
Notes
-----
The wavelengthrange reference file is used to govern the extent of the
bounding box for each object. The name of the catalog has been stored
in the input models meta information under the source_catalog key.
It's left to the calling routine to cut the bounding boxes at the extent
of the detector (for example, extract 2d would only extract the on-detector
portion of the bounding box)
Bounding box dispersion direction is dependent on the filter and module for
NIRCAM and changes for GRISMR, but is consistent for GRISMC, see
https://jwst-docs.stsci.edu/display/JTI/NIRCam+Wide+Field+Slitless+Spectroscopy
NIRISS only has one detector, but GRISMC disperses along rows and GRISMR disperses
along columns.
"""
# figure out the dispersion direction, shouldn't this be in the polynomials already?
disperse_row_right = True # disperse to increasing x
disperse_column = False # column always disperses to increasing y
instr_name = input_model.meta.instrument.name
if instr_name == "NIRCAM":
module, grism, filter_name = (input_model.meta.instrument.module,
input_model.meta.instrument.pupil,
input_model.meta.instrument.filter)
if ((module == "B") and (grism == "GRISMR")):
disperse_row_right = False
elif (grism == "GRISMC"):
disperse_column = True
disperse_row_right = False
elif instr_name == "NIRISS":
grism, filter_name = (input_model.meta.instrument.filter,
input_model.meta.instrument.pupil)
if "R" == grism[-1]:
disperse_column = True
disperse_row_right = False
else:
raise ValueError("Input model is from unexpected instrument")
# get the array extent to exclude boxes not contained on the detector
xsize = input_model.meta.subarray.xsize
ysize = input_model.meta.subarray.ysize
skyobject_list = get_object_info(input_model.meta.source_catalog.filename)
# get the imaging transform to record the center of the object in the image
# here, image is in the imaging reference frame, before going through the
# dispersion coefficients
sky_to_detector = input_model.meta.wcs.get_transform('world', 'detector')
sky_to_grism = input_model.meta.wcs.get_transform('world', 'grism_detector')
# Get the disperser parameters which have the wave limits
with WavelengthrangeModel(reference_files['wavelengthrange']) as f:
if (f.meta.exposure.type == "NRC_TSGRISM"):
raise ValueError("Wavelengthrange reference file not meant for WFSS mode")
waverange_selector = f.waverange_selector
orders = f.order
# All objects in the catalog will use the same filter for translation
# that filter is the one that was used in front of the grism
fselect = waverange_selector.index(filter_name)
grism_objects = [] # the return list of GrismObjects
for obj in skyobject_list:
if obj.abmag < mmag_extract:
# could add logic to ignore object if too far off image,
# save the image frame center of the object
# takes in ra, dec, wavelength, order but wave and order
# don't get used until the detector->grism_detector transform
xcenter, ycenter, _, _ = sky_to_detector(obj.sky_centroid.icrs.ra.value,
obj.sky_centroid.icrs.dec.value,
1, 1)
order_bounding = {}
waverange = {}
for oidx, order in enumerate(orders):
# The orders of the bounding box in the non-dispersed image
# drive the extraction extent. The location of the min and
# max wavelengths for each order are used to get the location
# of the +/- sides of the bounding box in the grism image
lmin, lmax = waverange[oidx][fselect]
# we need to be specific with dispersion direction here?
# I think this should be taken care of in the trace polys
wave_min = lmax
wave_max = lmin
if (disperse_row_right or disperse_column):
wave_min = lmin
wave_max = lmax
xmin, ymin, _, _, _ = sky_to_grism(obj.sky_bbox_ll.ra.value, obj.sky_bbox_ll.dec.value, lmin, order)
xmax, ymax, _, _, _ = sky_to_grism(obj.sky_bbox_ur.ra.value, obj.sky_bbox_ur.dec.value, lmax, order)
# convert to integer pixels, making use of python3 round to integer, 2.7 rounds to float
# if disperse_column:
# cdisp = abs(round(bxmax)-round(bxmin)) // 2
# xmin, ymin, xmax, ymax = map(round,[xmin, ymin-cdisp, xmax, ymax+cdisp])
# else:
# cdisp = abs(round(bymax)-round(bymin)) // 2
# xmin, ymin, xmax, ymax = map(round,[xmin-cdisp, ymin, xmax+cdisp, ymax])
# don't add objects and orders which are entirely off the detector
# this could also live in extract_2d
# partial_order marks partial off-detector objects which are near enough to cause
# spectra to be observed on the detector. This is usefull because the catalog often is
# created from a resampled direct image that is bigger than the detector FOV for a single
# grism exposure.
exclude = False
partial_order = False
if ((ymin < 0) or (ymax > ysize)):
partial_order = True
if ((ymin < 0) and (ymax < 0)):
exclude = True
if (ymin > ysize):
exclude = True
if ((xmin < 0) or (xmax > xsize)):
partial_order = True
if ((xmin < 0) and (xmax < 0)):
exclude = True
if (xmin > xsize):
exclude = True
if partial_order:
log.info("Partial order on detector for obj: {} order: {}".format(obj.sid, order))
if exclude:
log.info("Excluding off-image object: {}, order {}".format(obj.sid, order))
else:
order_bounding[order] = ((round(ymin), round(ymax)), (round(xmin), round(xmax)))
waverange[order] = ((lmin, lmax))
# add lmin and lmax used for the orders here?
# input_model.meta.wcsinfo.waverange_start keys covers the
# full range of all the orders
if len(order_bounding) > 0:
grism_objects.append(GrismObject(sid=obj.sid,
order_bounding=order_bounding,
sky_centroid=obj.sky_centroid,
partial_order=partial_order,
waverange=waverange,
sky_bbox_ll=obj.sky_bbox_ll,
sky_bbox_lr=obj.sky_bbox_lr,
sky_bbox_ul=obj.sky_bbox_ul,
sky_bbox_ur=obj.sky_bbox_ur,
xcentroid=xcenter,
ycentroid=ycenter))
return grism_objects
def get_num_msa_open_shutters(shutter_state):
"""
Return the number of open shutters in a slitlet.
Parameters
----------
shutter_state : str
``Slit.shutter_state`` attribute - a combination of
``1`` - open shutter, ``0`` - closed shutter, ``x`` - main shutter.
"""
num = shutter_state.count('1')
if 'x' in shutter_state:
num += 1
return num
def bounding_box_from_shape(model):
"""Create a bounding box from the shape of the data.
Note: The bounding box of a ``CubeModel`` is the bounding_box of one
of the stacked images.
"""
if isinstance(model, (CubeModel, IFUCubeModel)):
shape = model.data[0].shape
else:
shape = model.data.shape
bbox = ((-0.5, shape[1] - 0.5),
(-0.5, shape[0] - 0.5))
return bbox
def update_s_region_imaging(model):
"""
Update the ``S_REGION`` keyword using ``WCS.footprint``.
"""
bbox = model.meta.wcs.bounding_box
if bbox is None:
bbox = bounding_box_from_shape(model)
# footprint is an array of shape (2, 4) as we
# are interested only in the footprint on the sky
footprint = model.meta.wcs.footprint(bbox, center=True, axis_type="spatial").T
# take only imaging footprint
footprint = footprint[:2, :]
# Make sure RA values are all positive
negative_ind = footprint[0] < 0
if negative_ind.any():
footprint[0][negative_ind] = 360 + footprint[0][negative_ind]
footprint = footprint.T
update_s_region_keyword(model, footprint)
def update_s_region_spectral(model):
swcs = model.meta.wcs
bbox = swcs.bounding_box
if bbox is None:
bbox = bounding_box_from_shape(model)
x, y = grid_from_bounding_box(bbox)
ra, dec, lam = swcs(x, y)
footprint = np.array([[np.nanmin(ra), np.nanmin(dec)],
[np.nanmax(ra), np.nanmin(dec)],
[np.nanmax(ra), np.nanmax(dec)],
[np.nanmin(ra), np.nanmax(dec)]])
update_s_region_keyword(model, footprint)
def update_s_region_keyword(model, footprint):
""" Update the S_REGION keyword.
"""
s_region = (
"POLYGON ICRS "
" {0:.9f} {1:.9f}"
" {2:.9f} {3:.9f}"
" {4:.9f} {5:.9f}"
" {6:.9f} {7:.9f}".format(*footprint.flatten()))
if "nan" in s_region:
# do not update s_region if there are NaNs.
log.info("There are NaNs in s_region, S_REGION not updated.")
else:
model.meta.wcsinfo.s_region = s_region
log.info("Update S_REGION to {}".format(model.meta.wcsinfo.s_region))
def _nanminmax(wcsobj):
x, y = grid_from_bounding_box(wcsobj.bounding_box)
ra, dec, lam = wcsobj(x, y)
return np.nanmin(ra), np.nanmax(ra), np.nanmin(dec), np.nanmax(dec)
def update_s_region_nrs_ifu(output_model, mod):
"""
Update S_REGION for NRS_IFU observations using the instrument model.
Parameters
----------
output_model : `~jwst.datamodels.IFUImageModel`
The output of assign_wcs.
mod : module
The imported ``nirspec`` module.
"""
wcs_list = mod.nrs_ifu_wcs(output_model)
ra_total = []
dec_total = []
for wcsobj in wcs_list:
rmin, rmax, dmin, dmax = _nanminmax(wcsobj)
ra_total.append((rmin, rmax))
dec_total.append((dmin, dmax))
ra_max = np.asarray(ra_total)[:, 1].max()
ra_min = np.asarray(ra_total)[:, 0].min()
dec_max = np.asarray(dec_total)[:, 1].max()
dec_min = np.asarray(dec_total)[:, 0].min()
footprint = np.array([ra_min, dec_min, ra_max, dec_min, ra_max, dec_max, ra_min, dec_max])
update_s_region_keyword(output_model, footprint)
def update_s_region_mrs(output_model):
"""
Update S_REGION for MIRI_MRS observations using the WCS transforms.
Parameters
----------
output_model : `~jwst.datamodels.IFUImageModel`
The output of assign_wcs.
"""
rmin, rmax, dmin, dmax = _nanminmax(output_model.meta.wcs)
footprint = np.array([rmin, dmin, rmax, dmin, rmax, dmax, rmin, dmax])
update_s_region_keyword(output_model, footprint)
[docs]def velocity_correction(velosys):
"""
Compute wavelength correction to Barycentric reference frame.
Parameters
----------
velosys : float
Radial velocity wrt Barycenter [m / s].
"""
correction = (1 / (1 + velosys / c.value))
model = astmodels.Identity(1) * astmodels.Const1D(correction, name="velocity_correction")
model.inverse = astmodels.Identity(1) / astmodels.Const1D(correction, name="inv_vel_correciton")
return model