Source code for jdaviz.configs.default.plugins.model_fitting.model_fitting

import os
import pickle
import re
import numpy as np

import astropy.modeling.models as models
import astropy.units as u
import numpy as np
from astropy.wcs import WCSSUB_SPECTRAL
from glue.core.message import (SubsetCreateMessage,
                               SubsetDeleteMessage,
                               SubsetUpdateMessage)
from specutils import Spectrum1D
from specutils.utils import QuantityModel
from traitlets import Bool, Int, List, Unicode
from glue.core.data import Data

from jdaviz.core.events import AddDataMessage, RemoveDataMessage, SnackbarMessage
from jdaviz.core.registries import tray_registry
from jdaviz.core.template_mixin import TemplateMixin
from jdaviz.utils import load_template
from .fitting_backend import fit_model_to_spectrum
from .initializers import initialize, model_parameters

__all__ = ['ModelFitting']

MODELS = {
     'Const1D': models.Const1D,
     'Linear1D': models.Linear1D,
     'Polynomial1D': models.Polynomial1D,
     'Gaussian1D': models.Gaussian1D,
     'Voigt1D': models.Voigt1D,
     'Lorentz1D': models.Lorentz1D
     }


[docs]@tray_registry('g-model-fitting', label="Model Fitting") class ModelFitting(TemplateMixin): dialog = Bool(False).tag(sync=True) template = load_template("model_fitting.vue", __file__).tag(sync=True) dc_items = List([]).tag(sync=True) save_enabled = Bool(False).tag(sync=True) model_label = Unicode().tag(sync=True) model_save_path = Unicode().tag(sync=True) temp_name = Unicode().tag(sync=True) temp_model = Unicode().tag(sync=True) model_equation = Unicode().tag(sync=True) eq_error = Bool(False).tag(sync=True) component_models = List([]).tag(sync=True) display_order = Bool(False).tag(sync=True) poly_order = Int(0).tag(sync=True) available_models = List(list(MODELS.keys())).tag(sync=True) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._viewer_spectra = None self._spectrum1d = None self._units = {} self.n_models = 0 self._fitted_model = None self._fitted_spectrum = None self.component_models = [] self._initialized_models = {} self._display_order = False self.model_save_path = os.getcwd() self.model_label = "Model" self._selected_data_label = None self.hub.subscribe(self, AddDataMessage, handler=self._on_viewer_data_changed) self.hub.subscribe(self, RemoveDataMessage, handler=self._on_viewer_data_changed) self.hub.subscribe(self, SubsetCreateMessage, handler=lambda x: self._on_viewer_data_changed()) self.hub.subscribe(self, SubsetDeleteMessage, handler=lambda x: self._on_viewer_data_changed()) self.hub.subscribe(self, SubsetUpdateMessage, handler=lambda x: self._on_viewer_data_changed()) def _on_viewer_data_changed(self, msg=None): """ Callback method for when data is added or removed from a viewer, or when a subset is created, deleted, or updated. This method receives a glue message containing viewer information in the case of the former set of events, and updates the available data list displayed to the user. Notes ----- We do not attempt to parse any data at this point, at it can cause visible lag in the application. Parameters ---------- msg : `glue.core.Message` The glue message passed to this callback method. """ self._viewer_id = self.app._viewer_item_by_reference( 'spectrum-viewer').get('id') # Subsets are global and are not linked to specific viewer instances, # so it's not required that we match any specific ids for that case. # However, if the msg is not none, check to make sure that it's the # viewer we care about. if msg is not None and msg.viewer_id != self._viewer_id: return viewer = self.app.get_viewer('spectrum-viewer') self.dc_items = [layer_state.layer.label for layer_state in viewer.state.layers] def _param_units(self, param, order = 0): """Helper function to handle units that depend on x and y""" y_params = ["amplitude", "amplitude_L", "intercept"] if param == "slope": return str(u.Unit(self._units["y"]) / u.Unit(self._units["x"])) elif param == "poly": return str(u.Unit(self._units["y"]) / u.Unit(self._units["x"])**order) return self._units["y"] if param in y_params else self._units["x"] def _update_parameters_from_fit(self): """Insert the results of the model fit into the component_models""" for m in self.component_models: name = m["id"] if len(self.component_models) > 1: m_fit = self._fitted_model[name] else: m_fit = self._fitted_model temp_params = [] for i in range(0, len(m_fit.parameters)): temp_param = [x for x in m["parameters"] if x["name"] == m_fit.param_names[i]] temp_param[0]["value"] = m_fit.parameters[i] temp_params += temp_param m["parameters"] = temp_params # Trick traitlets into updating the displayed values component_models = self.component_models self.component_models = [] self.component_models = component_models def _update_parameters_from_QM(self): """ Parse out result parameters from a QuantityModel, which isn't subscriptable with model name """ if hasattr(self._fitted_model, "submodel_names"): submodel_names = self._fitted_model.submodel_names submodels = True else: submodel_names = [self._fitted_model.name] submodels = False fit_params = self._fitted_model.parameters param_names = self._fitted_model.param_names for i in range(len(submodel_names)): name = submodel_names[i] m = [x for x in self.component_models if x["id"] == name][0] temp_params = [] if submodels: idxs = [j for j in range(len(param_names)) if int(param_names[j][-1]) == i] else: idxs = [j for j in range(len(param_names))] # This is complicated by needing to handle parameter names that # have underscores in them, since QuantityModel adds an underscore # and integer to indicate to which model a parameter belongs for idx in idxs: if submodels: temp_param = [x for x in m["parameters"] if x["name"] == "_".join(param_names[idx].split("_")[0:-1])] else: temp_param = [x for x in m["parameters"] if x["name"] == param_names[idx]] temp_param[0]["value"] = fit_params[idx] temp_params += temp_param m["parameters"] = temp_params # Trick traitlets into updating the displayed values component_models = self.component_models self.component_models = [] self.component_models = component_models def _update_initialized_parameters(self): # If the user changes a parameter value, we need to change it in the # initialized model for m in self.component_models: name = m["id"] for param in m["parameters"]: quant_param = u.Quantity(param["value"], param["unit"]) setattr(self._initialized_models[name], param["name"], quant_param) def _warn_if_no_equation(self): if self.model_equation == "" or self.model_equation is None: example = "+".join([m["id"] for m in self.component_models]) snackbar_message = SnackbarMessage( f"Error: a model equation must be defined, e.g. {example}", color='error', sender=self) self.hub.broadcast(snackbar_message) return True else: return False
[docs] def vue_data_selected(self, event): """ Callback method for when the user has selected data from the drop down in the front-end. It is here that we actually parse and create a new data object from the selected data. From this data object, unit information is scraped, and the selected spectrum is stored for later use in fitting. Parameters ---------- event : str IPyWidget callback event object. In this case, represents the data label of the data collection object selected by the user. """ selected_spec = self.app.get_data_from_viewer("spectrum-viewer", data_label=event) # Replace NaNs from collapsed SpectralCube in Cubeviz # (won't affect calculations because these locations are masked) selected_spec.flux[np.isnan(selected_spec.flux)] = 0.0 self._selected_data_label = event if self._units == {}: self._units["x"] = str( selected_spec.spectral_axis.unit) self._units["y"] = str( selected_spec.flux.unit) self._spectrum1d = selected_spec
[docs] def vue_model_selected(self, event): # Add the model selected to the list of models self.temp_model = event if event == "Polynomial1D": self.display_order = True else: self.display_order = False
def _initialize_polynomial(self, new_model): initialized_model = initialize( MODELS[self.temp_model](name=self.temp_name, degree=self.poly_order), self._spectrum1d.spectral_axis, self._spectrum1d.flux) self._initialized_models[self.temp_name] = initialized_model new_model["order"] = self.poly_order for i in range(self.poly_order + 1): param = "c{}".format(i) initial_val = getattr(initialized_model, param).value new_model["parameters"].append({"name": param, "value": initial_val, "unit": self._param_units("poly", i), "fixed": False}) self._update_initialized_parameters() return new_model def _reinitialize_with_fixed(self): """ Reinitialize all component models with current values and the specified parameters fixed (can't easily update fixed dictionary in an existing model) """ temp_models = [] for m in self.component_models: fixed = {} for p in m["parameters"]: fixed[p["name"]] = p["fixed"] # Have to initialize with fixed dictionary if m["model_type"] == "Polynomial1D": temp_model = MODELS[m["model_type"]](name=m["id"], degree=m["order"], fixed=fixed) else: temp_model = MODELS[m["model_type"]](name=m["id"], fixed=fixed) # Now we can set the parameter values for p in m["parameters"]: setattr(temp_model, p["name"], p["value"]) temp_models.append(temp_model) return temp_models
[docs] def vue_add_model(self, event): """Add the selected model and input string ID to the list of models""" new_model = {"id": self.temp_name, "model_type": self.temp_model, "parameters": []} # Need to do things differently for polynomials, since the order varies if self.temp_model == "Polynomial1D": new_model = self._initialize_polynomial(new_model) else: # Have a separate private dict with the initialized models, since # they don't play well with JSON for widget interaction initialized_model = initialize( MODELS[self.temp_model](name=self.temp_name), self._spectrum1d.spectral_axis, self._spectrum1d.flux) self._initialized_models[self.temp_name] = initialized_model for param in model_parameters[new_model["model_type"]]: initial_val = getattr(initialized_model, param).value new_model["parameters"].append({"name": param, "value": initial_val, "unit": self._param_units(param), "fixed": False}) new_model["Initialized"] = True self.component_models = self.component_models + [new_model] self._update_initialized_parameters()
[docs] def vue_remove_model(self, event): self.component_models = [x for x in self.component_models if x["id"] != event] del(self._initialized_models[event])
[docs] def vue_save_model(self, event): if self.model_save_path[-1] == "/": connector = "" else: connector = "/" full_path = self.model_save_path + connector + self.model_label + ".pkl" with open(full_path, 'wb') as f: pickle.dump(self._fitted_model, f)
[docs] def vue_equation_changed(self, event): # Length is a dummy check to test the infrastructure if len(self.model_equation) > 20: self.eq_error = True
[docs] def vue_model_fitting(self, *args, **kwargs): """ Run fitting on the initialized models, fixing any parameters marked as such by the user, then update the displayed parameters with fit values """ if self._warn_if_no_equation(): return models_to_fit = self._reinitialize_with_fixed() try: fitted_model, fitted_spectrum = fit_model_to_spectrum( self._spectrum1d, models_to_fit, self.model_equation, run_fitter=True) except AttributeError: msg = SnackbarMessage("Unable to fit: model equation may be invalid", color="error", sender=self) self.hub.broadcast(msg) return self._fitted_model = fitted_model self._fitted_spectrum = fitted_spectrum self.vue_register_spectrum({"spectrum": fitted_spectrum}) if not hasattr(self.app, "_fitted_1d_models"): self.app._fitted_1d_models = {} self.app._fitted_1d_models[self.model_label] = fitted_model # Update component model parameters with fitted values if type(self._fitted_model) == QuantityModel: self._update_parameters_from_QM() else: self._update_parameters_from_fit() self.save_enabled = True
[docs] def vue_fit_model_to_cube(self, *args, **kwargs): if self._warn_if_no_equation(): return data = self.app.data_collection[self._selected_data_label] # First, ensure that the selected data is cube-like. It is possible # that the user has selected a pre-existing 1d data object. if data.ndim != 3: snackbar_message = SnackbarMessage( f"Selected data {self._selected_data_label} is not cube-like", color='error', sender=self) self.hub.broadcast(snackbar_message) return # Get the primary data component attribute = data.main_components[0] component = data.get_component(attribute) temp_values = data.get_data(attribute) # Transpose the axis order values = np.moveaxis(temp_values, 0, -1) * u.Unit(component.units) # We manually create a Spectrum1D object from the flux information # in the cube we select wcs = data.coords.sub([WCSSUB_SPECTRAL]) spec = Spectrum1D(flux=values, wcs=wcs) # TODO: in vuetify >2.3, timeout should be set to -1 to keep open # indefinitely snackbar_message = SnackbarMessage( "Fitting model to cube...", loading=True, timeout=0, sender=self) self.hub.broadcast(snackbar_message) # Retrieve copy of the models with proper "fixed" dictionaries # TODO: figure out why this was causing the parallel fitting to fail #models_to_fit = self._reinitialize_with_fixed() models_to_fit = self._initialized_models.values() fitted_model, fitted_spectrum = fit_model_to_spectrum( spec, models_to_fit, self.model_equation, run_fitter=True) # Save fitted 3D model in a way that the cubeviz # helper can access it. self.app._fitted_3d_model = fitted_model # Transpose the axis order back values = np.moveaxis(fitted_spectrum.flux.value, -1, 0) count = max(map(lambda s: int(next(iter(re.findall("\d$", s)), 0)), self.data_collection.labels)) + 1 label = f"{self.model_label} [Cube] {count}" # Create new glue data object output_cube = Data(label=label, coords=data.coords) output_cube['flux'] = values output_cube.get_component('flux').units = \ fitted_spectrum.flux.unit.to_string() # Add to data collection self.app.data_collection.append(output_cube) snackbar_message = SnackbarMessage( "Finished cube fitting", color='success', loading=False, sender=self) self.hub.broadcast(snackbar_message)
[docs] def vue_register_spectrum(self, event): """ Add a spectrum to the data collection based on the currently displayed parameters (these could be user input or fit values). """ if self._warn_if_no_equation(): return # Make sure the initialized models are updated with any user-specified # parameters self._update_initialized_parameters() # Need to run the model fitter with run_fitter=False to get spectrum if "spectrum" in event: spectrum = event["spectrum"] else: model, spectrum = fit_model_to_spectrum(self._spectrum1d, self._initialized_models.values(), self.model_equation) self.n_models += 1 label = self.model_label if label in self.data_collection: self.app.remove_data_from_viewer('spectrum-viewer', label) # Remove the actual Glue data object from the data_collection self.data_collection.remove(self.data_collection[label]) self.data_collection[label] = spectrum self.save_enabled = True
#self.app.add_data_to_viewer('spectrum-viewer', label)