Source code for jwst.datamodels.container

import copy
from collections import OrderedDict
import os.path as op

from asdf import AsdfFile

from ..associations import (
    AssociationError,
    AssociationNotValidError, load_asn)
from . import model_base
from .util import open as datamodel_open

import logging
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)


__all__ = ['ModelContainer']


[docs]class ModelContainer(model_base.DataModel): """ A container for holding DataModels. This functions like a list for holding DataModel objects. It can be iterated through like a list, DataModels within the container can be addressed by index, and the datamodels can be grouped into a list of lists for grouped looping, useful for NIRCam where grouping together all detectors of a given exposure is useful for some pipeline steps. Parameters ---------- init : file path, list of DataModels, or None - file path: initialize from an association table - list: a list of DataModels of any type - None: initializes an empty `ModelContainer` instance, to which DataModels can be added via the ``append()`` method. persist: boolean. If True, do not close model after opening it Examples -------- >>> container = datamodels.ModelContainer('example_asn.json') >>> for dm in container: ... print(dm.meta.filename) Say the association was a NIRCam dithered dataset. The `models_grouped` attribute is a list of lists, the first index giving the list of exposure groups, with the second giving the individual datamodels representing each detector in the exposure (2 or 8 in the case of NIRCam). >>> total_exposure_time = 0.0 >>> for group in container.models_grouped: ... total_exposure_time += group[0].meta.exposure.exposure_time >>> c = datamodels.ModelContainer() >>> m = datamodels.open('myfile.fits') >>> c.append(m) """ # This schema merely extends the 'meta' part of the datamodel, and # does not describe the data contents of the container. schema_url = "container.schema.yaml" def __init__(self, init=None, persist=True, **kwargs): super(ModelContainer, self).__init__(init=None, **kwargs) self._persist = persist if init is None: self._models = [] elif isinstance(init, list): self._validate_model(init) self._models = init[:] elif isinstance(init, self.__class__): instance = copy.deepcopy(init._instance) self._schema = init._schema self._shape = init._shape self._asdf = AsdfFile(instance, extensions=self._extensions) self._instance = instance self._ctx = self self.__class__ = init.__class__ self._models = init._models elif isinstance(init, str): try: self.from_asn(init, **kwargs) except (IOError): raise IOError('Cannot open files.') except AssociationError: raise AssociationError('{0} must be an ASN file'.format(init)) else: raise TypeError('Input {0!r} is not a list of DataModels or ' 'an ASN file'.format(init)) def _open_model(self, index): model = self._models[index] if isinstance(model, str): model = datamodel_open( model, extensions=self._extensions, pass_invalid_values=self._pass_invalid_values ) self._models[index] = model return model def _close_model(self, filename, index): if not self._persist: self._models[index].close() self._models[index] = filename def _validate_model(self, models): if not isinstance(models, list): models = [models] for model in models: if isinstance(model, ModelContainer): raise ValueError( "ModelContainer cannot contain ModelContainer" ) if not isinstance(model, (str, model_base.DataModel)): raise ValueError('model must be string or DataModel') def __len__(self): return len(self._models) def __getitem__(self, index): return self._open_model(index) def __setitem__(self, index, model): self._validate_model(model) self._models[index] = model def __delitem__(self, index): del self._models[index] def __iter__(self): return ModelContainerIterator(self)
[docs] def insert(self, index, model): self._validate_model(model) self._models.insert(index, model)
[docs] def append(self, model): self._validate_model(model) self._models.append(model)
[docs] def extend(self, models): self._validate_model(models) self._models.extend(models)
[docs] def pop(self, index=-1): self._open_model(index) return self._models.pop(index)
[docs] def copy(self, memo=None): """ Returns a deep copy of the models in this model container. """ result = self.__class__(init=None, extensions=self._extensions, pass_invalid_values=self._pass_invalid_values, strict_validation=self._strict_validation) instance = copy.deepcopy(self._instance, memo=memo) result._asdf = AsdfFile(instance, extensions=self._extensions) result._instance = instance result._iscopy = self._iscopy result._schema = result._schema result._ctx = result for m in self._models: if isinstance(m, model_base.DataModel): result.append(m.copy()) else: result.append(m) return result
[docs] def from_asn(self, filepath, **kwargs): """ Load fits files from a JWST association file. Parameters ---------- filepath : str The path to an association file. """ filepath = op.abspath(op.expanduser(op.expandvars(filepath))) basedir = op.dirname(filepath) filename = op.basename(filepath) try: with open(filepath) as asn_file: asn_data = load_asn(asn_file) except AssociationNotValidError: raise IOError("Cannot read ASN file.") # make a list of all the input files infiles = [op.join(basedir, member['expname']) for member in asn_data['products'][0]['members']] self._models = infiles # Pull the whole association table into meta.asn_table self.meta.asn_table = {} model_base.properties.merge_tree( self.meta.asn_table._instance, asn_data ) self.meta.resample.output = asn_data['products'][0]['name'] self.meta.table_name = filename self.meta.pool_name = asn_data['asn_pool']
[docs] def save(self, path=None, dir_path=None, save_model_func=None, *args, **kwargs): """ Write out models in container to FITS or ASDF. Parameters ---------- path : str or func or None - If None, the `meta.filename` is used for each model. - If a string, the string is used as a root and an index is appended. - If a function, the function takes the two arguments: the value of model.meta.filename and the `idx` index, returning constructed file name. dir_path : str Directory to write out files. Defaults to current working dir. If directory does not exist, it creates it. Filenames are pulled from `.meta.filename` of each datamodel in the container. save_model_func: func or None Alternate function to save each model instead of the models `save` method. Takes one argument, the model, and keyword argument `idx` for an index. Returns ------- output_paths: [str[, ...]] List of output file paths of where the models were saved. """ output_paths = [] if path is None: path = lambda filename, idx: filename elif not callable(path): path = make_file_with_index for idx, model in enumerate(self): if len(self) <= 1: idx = None if save_model_func is None: outpath, filename = op.split( path(model.meta.filename, idx=idx) ) if dir_path: outpath = dir_path save_path = op.join(outpath, filename) try: output_paths.append( model.save(save_path, *args, **kwargs) ) except IOError as err: raise err else: output_paths.append(save_model_func(model, idx=idx)) return output_paths
@property def models_grouped(self): """ Returns a list of a list of datamodels grouped by exposure. Data from different detectors of the same exposure will have the same group id, which allows grouping by exposure. The following metadata is used for grouping: meta.observation.program_number meta.observation.observation_number meta.observation.visit_number meta.observation.visit_group meta.observation.sequence_id meta.observation.activity_id meta.observation.exposure_number """ unique_exposure_parameters = [ 'program_number', 'observation_number', 'visit_number', 'visit_group', 'sequence_id', 'activity_id', 'exposure_number' ] group_dict = OrderedDict() for i in range(len(self)): model = self._open_model(i) params = [] for param in unique_exposure_parameters: params.append(getattr(model.meta.observation, param)) try: group_id = ('jw' + '_'.join([''.join(params[:3]), ''.join(params[3:6]), params[6]])) model.meta.group_id = group_id except TypeError: params_dict = dict(zip(unique_exposure_parameters, params)) bad_params = {'meta.observation.'+k:v for k, v in params_dict.items() if not v} log.warn( 'Cannot determine grouping of exposures: ' '{}'.format(bad_params) ) model.meta.group_id = 'exposure{0:04d}'.format(i + 1) group_id = model.meta.group_id if group_id in group_dict: group_dict[group_id].append(model) else: group_dict[group_id] = [model] return group_dict.values() @property def group_names(self): """ Return list of names for the DataModel groups by exposure. """ result = [] for group in self.models_grouped: result.append(group[0].meta.group_id) return result def __get_recursively(self, field, search_dict): """ Takes a dict with nested lists and dicts, and searches all dicts for a key of the field provided. """ values_found = [] for key, value in search_dict.items(): if key == field: values_found.append(value) elif isinstance(value, dict): results = self.__get_recursively(field, value) for result in results: values_found.append(result) elif isinstance(value, list): for item in value: if isinstance(item, dict): more_results = self.__get_recursively(field, item) for another_result in more_results: values_found.append(another_result) return values_found
[docs] def get_recursively(self, field): """ Returns a list of values of the specified field from meta. """ return self.__get_recursively(field, self.meta._instance)
class ModelContainerIterator: """ An iterator for model containers that opens one model at a time """ def __init__(self, container): self.index = -1 self.open_filename = None self.container = container def __iter__(self): return self def __next__(self): if self.open_filename is not None: self.container._close_model(self.open_filename, self.index) self.open_filename = None self.index += 1 if self.index < len(self.container._models): model = self.container._models[self.index] if isinstance(model, str): name = model model = self.container._open_model(self.index) self.open_filename = name return model else: raise StopIteration # ######### # Utilities # ######### def make_file_with_index(file_path, idx): """Append an index to a filename Parameters ---------- file_path: str The file to append the index to. idx: int An index to append Returns ------- file_path: str Path with index appended """ # Decompose path path_head, path_tail = op.split(file_path) base, ext = op.splitext(path_tail) if idx is not None: base = base + str(idx) return op.join(path_head, base + ext)