import os.path
import logging
import numpy as np
from astropy.modeling import models
from astropy import coordinates as coord
from astropy import units as u
from astropy.io import fits
import gwcs.coordinate_frames as cf
from gwcs import selector
from gwcs.utils import _toindex
from . import pointing
from ..transforms import models as jwmodels
from .util import (not_implemented_mode, subarray_transform,
velocity_correction)
from ..datamodels import (DistortionModel, FilteroffsetModel,
DistortionMRSModel, WavelengthrangeModel,
RegionsModel, SpecwcsModel)
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
__all__ = ["create_pipeline", "imaging", "lrs", "ifu"]
[docs]def create_pipeline(input_model, reference_files):
"""
Create the WCS pipeline for MIRI modes.
Parameters
----------
input_model : `jwst.datamodels.ImagingModel`, `~jwst.datamodels.IFUImageModel`, `~jwst.datamodels.CubeModel`
Data model.
reference_files : dict
{reftype: reference file name} mapping.
"""
exp_type = input_model.meta.exposure.type.lower()
pipeline = exp_type2transform[exp_type](input_model, reference_files)
if pipeline:
log.info("Created a MIRI {0} pipeline with references {1}".format(
exp_type, reference_files))
return pipeline
[docs]def imaging(input_model, reference_files):
"""
The MIRI Imaging WCS pipeline.
It includes three coordinate frames -
"detector", "v2v3" and "world".
Parameters
----------
input_model : `jwst.datamodels.ImagingModel`
Data model.
reference_files : dict
Dictionary {reftype: reference file name}.
Uses "distortion" and "filteroffset" reference files.
"""
# Create the Frames
detector = cf.Frame2D(name='detector', axes_order=(0, 1), unit=(u.pix, u.pix))
v2v3 = cf.Frame2D(name='v2v3', axes_order=(0, 1), unit=(u.arcsec, u.arcsec))
world = cf.CelestialFrame(reference_frame=coord.ICRS(), name='world')
# Create the transforms
subarray2full = subarray_transform(input_model)
imdistortion = imaging_distortion(input_model, reference_files)
distortion = subarray2full | imdistortion
tel2sky = pointing.v23tosky(input_model)
# TODO: remove setting the bounding box when it is set in the new ref file.
try:
bb = distortion.bounding_box
except NotImplementedError:
shape = input_model.data.shape
# Note: Since bounding_box is attached to the model here it's in reverse order.
bb = ((-0.5, shape[0] - 0.5), (3.5, shape[1] - 4.5))
distortion.bounding_box = bb
# Create the pipeline
pipeline = [(detector, distortion),
(v2v3, tel2sky),
(world, None)
]
return pipeline
def imaging_distortion(input_model, reference_files):
"""
Create the "detector" to "v2v3" transform for the MIRI Imager.
1. Filter dependent shift in (x,y) (!with an oposite
sign to that delivered by the IT) (uses the "filteroffset" ref file)
2. Apply MI (uses "distortion" ref file)
3. Apply Ai and BI matrices (uses "distortion" ref file)
4. Apply the TI matrix (this gives Xan/Yan coordinates) (uses "distortion" ref file)
5. Aply the XanYan --> V2V3 transform (uses "distortion" ref file)
6. Apply V2V3 --> sky transform
"""
# Read in the distortion.
with DistortionModel(reference_files['distortion']) as dist:
distortion = dist.model
obsfilter = input_model.meta.instrument.filter
# Add an offset for the filter
with FilteroffsetModel(reference_files['filteroffset']) as filter_offset:
filters = filter_offset.filters
col_offset = None
row_offset = None
for f in filters:
if f.name == obsfilter:
col_offset = f.column_offset
row_offset = f.row_offset
break
if (col_offset is not None) and (row_offset is not None):
distortion = models.Shift(col_offset) & models.Shift(row_offset) | distortion
return distortion
[docs]def lrs(input_model, reference_files):
"""
The LRS-FIXEDSLIT and LRS-SLITLESS WCS pipeline.
It has two coordinate frames: "detecor" and "world".
Uses the "specwcs" and "distortion" reference files.
"""
# Setup the frames.
detector = cf.Frame2D(name='detector', axes_order=(0, 1), unit=(u.pix, u.pix))
spec = cf.SpectralFrame(name='wavelength', axes_order=(2,), unit=(u.micron,),
axes_names=('lambda',))
sky = cf.CelestialFrame(reference_frame=coord.ICRS(), name='sky')
world = cf.CompositeFrame(name="world", frames=[sky, spec])
# Determine the distortion model.
subarray2full = subarray_transform(input_model)
with DistortionModel(reference_files['distortion']) as dist:
distortion = dist.model
full_distortion = subarray2full | distortion
# Load and process the reference data.
with fits.open(reference_files['specwcs']) as ref:
lrsdata = np.array([l for l in ref[1].data])
# Get the zero point from the reference data.
# The zero_point is X, Y (which should be COLUMN, ROW)
# TODO: Are imx, imy 0- or 1-indexed? We are treating them here as
# 0-indexed. Since they are FITS, they are probably 1-indexed.
if input_model.meta.exposure.type.lower() == 'mir_lrs-fixedslit':
zero_point = ref[1].header['imx'], ref[1].header['imy']
elif input_model.meta.exposure.type.lower() == 'mir_lrs-slitless':
#zero_point = ref[1].header['imxsltl'], ref[1].header['imysltl']
zero_point = [35, 442] # [35, 763] # account for subarray
# Create the bounding_box
x0 = lrsdata[:, 3]
y0 = lrsdata[:, 4]
x1 = lrsdata[:, 5]
bb = ((x0.min() - 0.5 + zero_point[0], x1.max() + 0.5 + zero_point[0]),
(y0.min() - 0.5 + zero_point[1], y0.max() + 0.5 + zero_point[1]))
# Find the ROW of the zero point which should be the [1] of zero_point
row_zero_point = zero_point[1]
# Compute the v2v3 to sky.
tel2sky = pointing.v23tosky(input_model)
# Compute the V2/V3 for each pixel in this row
# x.shape will be something like (1, 388)
y, x = np.mgrid[row_zero_point:row_zero_point + 1, 0:input_model.data.shape[1]]
spatial_transform = full_distortion | tel2sky
radec = np.array(spatial_transform(x, y))[:, 0, :]
ra_full = np.matlib.repmat(radec[0], _toindex(bb[1][1]) + 1 - _toindex(bb[1][0]), 1)
dec_full = np.matlib.repmat(radec[1], _toindex(bb[1][1]) + 1 - _toindex(bb[1][0]), 1)
ra_t2d = models.Tabular2D(lookup_table=ra_full, name='xtable',
bounds_error=False, fill_value=np.nan)
dec_t2d = models.Tabular2D(lookup_table=dec_full, name='ytable',
bounds_error=False, fill_value=np.nan)
# Create the model transforms.
lrs_wav_model = jwmodels.LRSWavelength(lrsdata, zero_point)
try:
velosys = input_model.meta.wcsinfo.velosys
except AttributeError:
pass
else:
if velosys is not None:
velocity_corr = velocity_correction(input_model.meta.wcsinfo.velosys)
lrs_wav_model = lrs_wav_model | velocity_corr
log.info("Applied Barycentric velocity correction : {}".format(velocity_corr[1].amplitude.value))
# Incorporate the small rotation
angle = np.arctan(0.00421924)
rot = models.Rotation2D(angle)
radec_t2d = ra_t2d & dec_t2d | rot
# Account for the subarray when computing spatial coordinates.
xshift = -bb[0][0]
yshift = -bb[1][0]
det2world = models.Mapping((1, 0, 1, 0, 0, 1)) | models.Shift(yshift, name='yshift1') & \
models.Shift(xshift, name='xshift1') & \
models.Shift(yshift, name='yshift2') & models.Shift(xshift, name='xshift2') & \
models.Identity(2) | radec_t2d & lrs_wav_model
det2world.bounding_box = bb[::-1]
# Now the actual pipeline.
pipeline = [(detector, det2world),
(world, None)
]
return pipeline
[docs]def ifu(input_model, reference_files):
"""
The MIRI MRS WCS pipeline.
It has the following coordinate frames:
"detector", "alpha_beta", "v2v3", "world".
It uses the "distortion", "regions", "specwcs"
and "wavelengthrange" reference files.
"""
# Define coordinate frames.
detector = cf.Frame2D(name='detector', axes_order=(0, 1), unit=(u.pix, u.pix))
alpha_beta = cf.Frame2D(name='alpha_beta_spatial', axes_order=(0, 1), unit=(u.arcsec, u.arcsec), axes_names=('alpha', 'beta'))
spec_local = cf.SpectralFrame(name='alpha_beta_spectral', axes_order=(2,), unit=(u.micron,), axes_names=('lambda',))
miri_focal = cf.CompositeFrame([alpha_beta, spec_local], name='alpha_beta')
v23_spatial = cf.Frame2D(name='V2_V3_spatial', axes_order=(0, 1), unit=(u.arcsec, u.arcsec), axes_names=('v2', 'v3'))
spec = cf.SpectralFrame(name='spectral', axes_order=(2,), unit=(u.micron,), axes_names=('lambda',))
v2v3 = cf.CompositeFrame([v23_spatial, spec], name='v2v3')
icrs = cf.CelestialFrame(name='icrs', reference_frame=coord.ICRS(),
axes_order=(0, 1), unit=(u.deg, u.deg), axes_names=('RA', 'DEC'))
world = cf.CompositeFrame([icrs, spec], name='world')
# Define the actual transforms
det2abl = (detector_to_abl(input_model, reference_files)).rename(
"detector_to_abl")
abl2v2v3l = (abl_to_v2v3l(input_model, reference_files)).rename("abl_to_v2v3l")
tel2sky = pointing.v23tosky(input_model) & models.Identity(1)
# Put the transforms together into a single transform
shape = input_model.data.shape
det2abl.bounding_box = ((-0.5, shape[0] - 0.5), (-0.5, shape[1] - 0.5))
pipeline = [(detector, det2abl),
(miri_focal, abl2v2v3l),
(v2v3, tel2sky),
(world, None)]
return pipeline
def detector_to_abl(input_model, reference_files):
"""
Create the transform from "detector" to "alpha_beta" frame.
Transform description:
forward transform
RegionsSelector
label_mapper is the regions array
selector is {slice_number: alpha_model & beta_model & lambda_model}
backward transform
RegionsSelector
label_mapper is LabelMapperDict
{channel_wave_range (): LabelMapperDict}
{beta: slice_number}
selector is {slice_number: x_transform & y_transform}
"""
band = input_model.meta.instrument.band
channel = input_model.meta.instrument.channel
# used to read the wavelength range
with DistortionMRSModel(reference_files['distortion']) as dist:
alpha_model = dist.alpha_model
beta_model = dist.beta_model
x_model = dist.x_model
y_model = dist.y_model
bzero = dict(zip(dist.bzero.channel_band, dist.bzero.beta_zero))
bdel = dict(zip(dist.bdel.channel_band, dist.bdel.delta_beta))
slices = dist.slices
with SpecwcsModel(reference_files['specwcs']) as f:
lambda_model = f.model
try:
velosys = input_model.meta.wcsinfo.velosys
except AttributeError:
pass
else:
if velosys is not None:
velocity_corr = velocity_correction(input_model.meta.wcsinfo.velosys)
lambda_model = [m | velocity_corr for m in lambda_model]
log.info("Applied Barycentric velocity correction : {}".format(velocity_corr[1].amplitude.value))
with RegionsModel(reference_files['regions']) as f:
regions = f.regions.copy()
label_mapper = selector.LabelMapperArray(regions)
transforms = {}
for i, sl in enumerate(slices):
forward = models.Mapping([1, 0, 0, 1, 0]) | \
alpha_model[i] & beta_model[i] & lambda_model[i]
inv = models.Mapping([2, 0, 2, 0]) | x_model[i] & y_model[i]
forward.inverse = inv
transforms[sl] = forward
with WavelengthrangeModel(reference_files['wavelengthrange']) as f:
wr = dict(zip(f.waverange_selector, f.wavelengthrange))
ch_dict = {}
for c in channel:
cb = c + band
mapper = jwmodels.MIRI_AB2Slice(bzero[cb], bdel[cb], c)
lm = selector.LabelMapper(inputs=('alpha', 'beta', 'lam'),
mapper=mapper, inputs_mapping=models.Mapping((1,), n_inputs=3))
ch_dict[tuple(wr[cb])] = lm
alpha_beta_mapper = selector.LabelMapperRange(('alpha', 'beta', 'lam'), ch_dict,
models.Mapping((2,)))
label_mapper.inverse = alpha_beta_mapper
det2alpha_beta = selector.RegionsSelector(('x', 'y'), ('alpha', 'beta', 'lam'),
label_mapper=label_mapper, selector=transforms)
return det2alpha_beta
def abl_to_v2v3l(input_model, reference_files):
"""
Create the transform from "alpha_beta" to "v2v3" frame.
Transform description:
forward transform
RegionsSelector
label_mapper is LabelMapperDict()
{channel_wave_range (): channel_number}
selector is {channel_number: ab2v2 & ab2v3}
bacward_transform
RegionsSelector
label_mapper is LabelMapperDict()
{channel_wave_range (): channel_number}
selector is {channel_number: v22ab & v32ab}
"""
band = input_model.meta.instrument.band
channel = input_model.meta.instrument.channel
# used to read the wavelength range
channels = [c + band for c in channel]
with DistortionMRSModel(reference_files['distortion']) as dist:
v23 = dict(zip(dist.abv2v3_model.channel_band, dist.abv2v3_model.model))
with WavelengthrangeModel(reference_files['wavelengthrange']) as f:
wr = dict(zip(f.waverange_selector, f.wavelengthrange))
dict_mapper = {}
sel = {}
# Since there are two channels in each reference file we need to loop over them
for c in channels:
ch = int(c[0])
dict_mapper[tuple(wr[c])] = models.Mapping((2,), name="mapping_lam") | \
models.Const1D(ch, name="channel #")
ident1 = models.Identity(1, name='identity_lam')
ident1._inputs = ('lam',)
chan_v23 = v23[c]
v23chan_backward = chan_v23.inverse
del chan_v23.inverse
v23_spatial = chan_v23
v23_spatial.inverse = v23chan_backward
# Tack on passing the third wavelength component
v23c = v23_spatial & ident1
sel[ch] = v23c
wave_range_mapper = selector.LabelMapperRange(('alpha', 'beta', 'lam'), dict_mapper,
inputs_mapping=models.Mapping([2, ]))
wave_range_mapper.inverse = wave_range_mapper.copy()
abl2v2v3l = selector.RegionsSelector(('alpha', 'beta', 'lam'), ('v2', 'v3', 'lam'),
label_mapper=wave_range_mapper,
selector=sel)
return abl2v2v3l
exp_type2transform = {'mir_image': imaging,
'mir_tacq': imaging,
'mir_lyot': imaging,
'mir_4qpm': imaging,
'mir_coroncal': imaging,
'mir_lrs-fixedslit': lrs,
'mir_lrs-slitless': lrs,
'mir_mrs': ifu,
'mir_flatmrs': not_implemented_mode,
'mir_flatimage': not_implemented_mode,
'mir_flat-mrs': not_implemented_mode,
'mir_flat-image': not_implemented_mode,
'mir_dark': not_implemented_mode,
}
def get_wavelength_range(input_model, path=None):
"""
Return the wavelength range used for computing the WCS.
Needs access to the reference file used to construct the WCS object.
Parameters
----------
input_model : `jwst.datamodels.ImagingModel`
Data model after assign_wcs has been run.
path : str
Directory where the reference file is. (optional)
"""
fname = input_model.meta.ref_file.wavelengthrange.name.split('/')[-1]
if path is None and not os.path.exists(fname):
raise IOError("Reference file {0} not found. Please specify a path.".format(fname))
else:
fname = os.path.join(path, fname)
f = WavelengthrangeModel(fname)
wave_range = f.tree['wavelengthrange'].copy()
wave_channels = f.tree['channels']
f.close()
wr = dict(zip(wave_channels, wave_range))
channel = input_model.meta.instrument.channel
band = input_model.meta.instrument.band
return dict([(ch + band, wr[ch + band]) for ch in channel])