import os
import pickle
import re
import astropy.modeling.models as models
import astropy.units as u
import numpy as np
from astropy.wcs import WCSSUB_SPECTRAL
from glue.core.message import (SubsetCreateMessage,
SubsetDeleteMessage,
SubsetUpdateMessage)
from specutils import Spectrum1D
from traitlets import Bool, Int, List, Unicode
from glue.core.data import Data
from jdaviz.core.events import AddDataMessage, RemoveDataMessage, SnackbarMessage
from jdaviz.core.registries import tray_registry
from jdaviz.core.template_mixin import TemplateMixin
from jdaviz.utils import load_template
from .fitting_backend import fit_model_to_spectrum
from .initializers import initialize, model_parameters
__all__ = ['ModelFitting']
MODELS = {
'Const1D': models.Const1D,
'Linear1D': models.Linear1D,
'Polynomial1D': models.Polynomial1D,
'Gaussian1D': models.Gaussian1D,
'Voigt1D': models.Voigt1D,
'Lorentz1D': models.Lorentz1D
}
[docs]@tray_registry('g-model-fitting', label="Model Fitting")
class ModelFitting(TemplateMixin):
dialog = Bool(False).tag(sync=True)
template = load_template("model_fitting.vue", __file__).tag(sync=True)
dc_items = List([]).tag(sync=True)
save_enabled = Bool(False).tag(sync=True)
model_label = Unicode().tag(sync=True)
model_save_path = Unicode().tag(sync=True)
temp_name = Unicode().tag(sync=True)
temp_model = Unicode().tag(sync=True)
model_equation = Unicode().tag(sync=True)
eq_error = Bool(False).tag(sync=True)
component_models = List([]).tag(sync=True)
display_order = Bool(False).tag(sync=True)
poly_order = Int(0).tag(sync=True)
available_models = List(list(MODELS.keys())).tag(sync=True)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._viewer_spectra = None
self._spectrum1d = None
self._units = {}
self.n_models = 0
self._fitted_model = None
self._fitted_spectrum = None
self.component_models = []
self._initialized_models = {}
self._display_order = False
self.model_save_path = os.getcwd()
self.model_label = "Model"
self._selected_data_label = None
self.hub.subscribe(self, AddDataMessage,
handler=self._on_viewer_data_changed)
self.hub.subscribe(self, RemoveDataMessage,
handler=self._on_viewer_data_changed)
self.hub.subscribe(self, SubsetCreateMessage,
handler=lambda x: self._on_viewer_data_changed())
self.hub.subscribe(self, SubsetDeleteMessage,
handler=lambda x: self._on_viewer_data_changed())
self.hub.subscribe(self, SubsetUpdateMessage,
handler=lambda x: self._on_viewer_data_changed())
def _on_viewer_data_changed(self, msg=None):
"""
Callback method for when data is added or removed from a viewer, or
when a subset is created, deleted, or updated. This method receives
a glue message containing viewer information in the case of the former
set of events, and updates the available data list displayed to the
user.
Notes
-----
We do not attempt to parse any data at this point, at it can cause
visible lag in the application.
Parameters
----------
msg : `glue.core.Message`
The glue message passed to this callback method.
"""
self._viewer_id = self.app._viewer_item_by_reference(
'spectrum-viewer').get('id')
# Subsets are global and are not linked to specific viewer instances,
# so it's not required that we match any specific ids for that case.
# However, if the msg is not none, check to make sure that it's the
# viewer we care about.
if msg is not None and msg.viewer_id != self._viewer_id:
return
viewer = self.app.get_viewer('spectrum-viewer')
self.dc_items = [layer_state.layer.label
for layer_state in viewer.state.layers]
def _param_units(self, param, order = 0):
"""Helper function to handle units that depend on x and y"""
y_params = ["amplitude", "amplitude_L", "intercept"]
if param == "slope":
return str(u.Unit(self._units["y"]) / u.Unit(self._units["x"]))
elif param == "poly":
return str(u.Unit(self._units["y"]) / u.Unit(self._units["x"])**order)
return self._units["y"] if param in y_params else self._units["x"]
def _update_parameters_from_fit(self):
"""Insert the results of the model fit into the component_models"""
for m in self.component_models:
name = m["id"]
if len(self.component_models) > 1:
m_fit = self._fitted_model[name]
else:
m_fit = self._fitted_model
temp_params = []
for i in range(0, len(m_fit.parameters)):
temp_param = [x for x in m["parameters"] if x["name"] ==
m_fit.param_names[i]]
temp_param[0]["value"] = m_fit.parameters[i]
temp_params += temp_param
m["parameters"] = temp_params
# Trick traitlets into updating the displayed values
component_models = self.component_models
self.component_models = []
self.component_models = component_models
def _update_initialized_parameters(self):
# If the user changes a parameter value, we need to change it in the
# initialized model
for m in self.component_models:
name = m["id"]
for param in m["parameters"]:
quant_param = u.Quantity(param["value"], param["unit"])
setattr(self._initialized_models[name], param["name"],
quant_param)
[docs] def vue_data_selected(self, event):
"""
Callback method for when the user has selected data from the drop down
in the front-end. It is here that we actually parse and create a new
data object from the selected data. From this data object, unit
information is scraped, and the selected spectrum is stored for later
use in fitting.
Parameters
----------
event : str
IPyWidget callback event object. In this case, represents the data
label of the data collection object selected by the user.
"""
selected_spec = self.app.get_data_from_viewer("spectrum-viewer",
data_label=event)
self._selected_data_label = event
if self._units == {}:
self._units["x"] = str(
selected_spec.spectral_axis.unit)
self._units["y"] = str(
selected_spec.flux.unit)
self._spectrum1d = selected_spec
[docs] def vue_model_selected(self, event):
# Add the model selected to the list of models
self.temp_model = event
if event == "Polynomial1D":
self.display_order = True
else:
self.display_order = False
def _initialize_polynomial(self, new_model):
initialized_model = initialize(
MODELS[self.temp_model](name=self.temp_name, degree=self.poly_order),
self._spectrum1d.spectral_axis,
self._spectrum1d.flux)
self._initialized_models[self.temp_name] = initialized_model
for i in range(self.poly_order + 1):
param = "c{}".format(i)
initial_val = getattr(initialized_model, param).value
new_model["parameters"].append({"name": param,
"value": initial_val,
"unit": self._param_units("poly", i),
"fixed": False})
self._update_initialized_parameters()
return new_model
[docs] def vue_add_model(self, event):
"""Add the selected model and input string ID to the list of models"""
new_model = {"id": self.temp_name, "model_type": self.temp_model,
"parameters": []}
# Need to do things differently for polynomials, since the order varies
if self.temp_model == "Polynomial1D":
new_model = self._initialize_polynomial(new_model)
else:
# Have a separate private dict with the initialized models, since
# they don't play well with JSON for widget interaction
initialized_model = initialize(
MODELS[self.temp_model](name=self.temp_name),
self._spectrum1d.spectral_axis,
self._spectrum1d.flux)
self._initialized_models[self.temp_name] = initialized_model
for param in model_parameters[new_model["model_type"]]:
initial_val = getattr(initialized_model, param).value
new_model["parameters"].append({"name": param,
"value": initial_val,
"unit": self._param_units(param),
"fixed": False})
new_model["Initialized"] = True
self.component_models = self.component_models + [new_model]
self._update_initialized_parameters()
[docs] def vue_remove_model(self, event):
self.component_models = [x for x in self.component_models
if x["id"] != event]
del(self._initialized_models[event])
[docs] def vue_save_model(self, event):
if self.model_save_path[-1] == "/":
connector = ""
else:
connector = "/"
full_path = self.model_save_path + connector + self.model_label + ".pkl"
with open(full_path, 'wb') as f:
pickle.dump(self._fitted_model, f)
[docs] def vue_equation_changed(self, event):
# Length is a dummy check to test the infrastructure
if len(self.model_equation) > 20:
self.eq_error = True
[docs] def vue_model_fitting(self, *args, **kwargs):
"""
Run fitting on the initialized models, fixing any parameters marked
as such by the user, then update the displayed parameters with fit
values
"""
fitted_model, fitted_spectrum = fit_model_to_spectrum(
self._spectrum1d,
self._initialized_models.values(),
self.model_equation,
run_fitter=True)
self._fitted_model = fitted_model
self._fitted_spectrum = fitted_spectrum
# Update component model parameters with fitted values
self._update_parameters_from_fit()
self.save_enabled = True
[docs] def vue_fit_model_to_cube(self, *args, **kwargs):
data = self.app.data_collection[self._selected_data_label]
# First, ensure that the selected data is cube-like. It is possible
# that the user has selected a pre-existing 1d data object.
if data.ndim != 3:
snackbar_message = SnackbarMessage(
f"Selected data {self._selected_data_label} is not cube-like",
color='error',
sender=self)
self.hub.broadcast(snackbar_message)
return
# Get the primary data component
attribute = data.main_components[0]
component = data.get_component(attribute)
temp_values = data.get_data(attribute)
# Transpose the axis order
values = np.moveaxis(temp_values, 0, -1) * u.Unit(component.units)
# We manually create a Spectrum1D object from the flux information
# in the cube we select
wcs = data.coords.sub([WCSSUB_SPECTRAL])
spec = Spectrum1D(flux=values, wcs=wcs)
# TODO: in vuetify >2.3, timeout should be set to -1 to keep open
# indefinitely
snackbar_message = SnackbarMessage(
"Fitting model to cube...",
loading=True, timeout=0, sender=self)
self.hub.broadcast(snackbar_message)
fitted_model, fitted_spectrum = fit_model_to_spectrum(
spec,
self._initialized_models.values(),
self.model_equation,
run_fitter=True)
# Transpose the axis order back
values = np.moveaxis(fitted_spectrum.flux.value, -1, 0)
count = max(map(lambda s: int(next(iter(re.findall("\d$", s)), 0)),
self.data_collection.labels)) + 1
label = f"{self.model_label} [Cube] {count}"
# Create new glue data object
output_cube = Data(label=label,
coords=data.coords)
output_cube['flux'] = values
output_cube.get_component('flux').units = \
fitted_spectrum.flux.unit.to_string()
# Add to data collection
self.app.data_collection.append(output_cube)
snackbar_message = SnackbarMessage(
"Finished cube fitting",
color='success', loading=False, sender=self)
self.hub.broadcast(snackbar_message)
[docs] def vue_register_spectrum(self, event):
"""
Add a spectrum to the data collection based on the currently displayed
parameters (these could be user input or fit values).
"""
# Make sure the initialized models are updated with any user-specified
# parameters
self._update_initialized_parameters()
# Need to run the model fitter with run_fitter=False to get spectrum
model, spectrum = fit_model_to_spectrum(self._spectrum1d,
self._initialized_models.values(),
self.model_equation)
self.n_models += 1
label = self.model_label
if label in self.data_collection:
self.app.remove_data_from_viewer('spectrum-viewer', label)
# Some hacky code to remove the label from the data dropdown
temp_items = []
for data_item in self.app.state.data_items:
if data_item['name'] != label:
temp_items.append(data_item)
self.app.state.data_items = temp_items
# Remove the actual Glue data object from the data_collection
self.data_collection.remove(self.data_collection[label])
self.data_collection[label] = spectrum
self.save_enabled = True
#sleep(1)
#self.app.add_data_to_viewer('spectrum-viewer', label)