import warnings
import numpy as np
from astropy import table
from astropy import units as u
from astropy.nddata import StdDevUncertainty, VarianceUncertainty, InverseVariance
from echo import delay_callback
from glue.config import data_translator
from glue.core import BaseData
from glue.core.exceptions import IncompatibleAttribute
from glue.core.units import UnitConverter
from glue.core.subset import Subset
from glue.core.subset_group import GroupedSubset
from glue_astronomy.spectral_coordinates import SpectralCoordinates
from glue_jupyter.bqplot.profile import BqplotProfileView
from matplotlib.colors import cnames
from specutils import Spectrum1D
from jdaviz.core.events import SpectralMarksChangedMessage, LineIdentifyMessage, SnackbarMessage
from jdaviz.core.registries import viewer_registry
from jdaviz.core.marks import SpectralLine, LineUncertainties, ScatterMask, OffscreenLinesMarks
from jdaviz.core.linelists import load_preset_linelist, get_available_linelists
from jdaviz.core.freezable_state import FreezableProfileViewerState
from jdaviz.configs.default.plugins.viewers import JdavizViewerMixin
from jdaviz.utils import get_subset_type
__all__ = ['SpecvizProfileView']
uc = UnitConverter()
uncertainty_str_to_cls_mapping = {
"std": StdDevUncertainty,
"var": VarianceUncertainty,
"ivar": InverseVariance
}
[docs]
@viewer_registry("specviz-profile-viewer", label="Profile 1D (Specviz)")
class SpecvizProfileView(JdavizViewerMixin, BqplotProfileView):
# categories: zoom resets, zoom, pan, subset, select tools, shortcuts
tools_nested = [
['jdaviz:homezoom', 'jdaviz:prevzoom'],
['jdaviz:boxzoom', 'jdaviz:xrangezoom', 'jdaviz:yrangezoom'],
['jdaviz:panzoom', 'jdaviz:panzoom_x', 'jdaviz:panzoom_y'],
['bqplot:xrange'],
['jdaviz:selectline'],
['jdaviz:sidebar_plot', 'jdaviz:sidebar_export']
]
default_class = Spectrum1D
spectral_lines = None
_state_cls = FreezableProfileViewerState
def __init__(self, *args, **kwargs):
default_tool_priority = kwargs.pop('default_tool_priority', [])
super().__init__(*args, **kwargs)
self._subscribe_to_layers_update()
self.initialize_toolbar(default_tool_priority=default_tool_priority)
self._offscreen_lines_marks = OffscreenLinesMarks(self)
self.figure.marks = self.figure.marks + self._offscreen_lines_marks.marks
self.state.add_callback('show_uncertainty', self._show_uncertainty_changed)
self.display_mask = False
# Change collapse function to sum
self.state.function = 'sum'
def _expected_subset_layer_default(self, layer_state):
super()._expected_subset_layer_default(layer_state)
layer_state.linewidth = 3
[docs]
def data(self, cls=None):
# Grab the user's chosen statistic for collapsing data
statistic = getattr(self.state, 'function', None)
data = []
for layer_state in self.state.layers:
if hasattr(layer_state, 'layer'):
lyr = layer_state.layer
# For raw data, just include the data itself
if isinstance(lyr, BaseData):
_class = cls or self.default_class
if _class is not None:
cache_key = (lyr.label, statistic)
if cache_key in self.jdaviz_app._get_object_cache:
layer_data = self.jdaviz_app._get_object_cache[cache_key]
else:
# If spectrum, collapse via the defined statistic
if _class == Spectrum1D:
layer_data = lyr.get_object(cls=_class, statistic=statistic)
else:
layer_data = lyr.get_object(cls=_class)
self.jdaviz_app._get_object_cache[cache_key] = layer_data
data.append(layer_data)
# For subsets, make sure to apply the subset mask to the layer data first
elif isinstance(lyr, Subset):
layer_data = lyr
if _class is not None:
handler, _ = data_translator.get_handler_for(_class)
try:
layer_data = handler.to_object(layer_data, statistic=statistic)
except IncompatibleAttribute:
continue
data.append(layer_data)
return data
@property
def redshift(self):
return self.jdaviz_helper._redshift
[docs]
def load_line_list(self, line_table, replace=False, return_table=False, show=True):
# If string, load the named preset list and don't show by default
# since there might be too many lines
if isinstance(line_table, str):
self.load_line_list(load_preset_linelist(line_table),
replace=replace, return_table=return_table,
show=False)
return
elif not isinstance(line_table, table.QTable):
raise TypeError("Line list must be an astropy QTable with\
(minimally) 'linename' and 'rest' columns")
if "linename" not in line_table.columns:
raise ValueError("Line table must have a 'linename' column'")
if "rest" not in line_table.columns:
raise ValueError("Line table must have a 'rest' column'")
if np.any(line_table['rest'] <= 0):
raise ValueError("all rest values must be positive")
# Use the redshift of the displayed spectrum if no redshifts are specified
if "redshift" in line_table.colnames:
warnings.warn("per line/list redshifts not supported, use viz.set_redshift")
# Set whether to show all of the lines on the plot by default on load
# We convert bool to int to work around ipywidgets json serialization
line_table["show"] = int(show)
# If there is already a loaded table, convert units to match. This
# attempts to do some sane rounding after the unit conversion.
# TODO: Fix this so that things don't get rounded to 0 in some cases
"""
if self.spectral_lines is not None:
sig_figs = []
for row in line_table:
rest_str = str(row["rest"].value).replace(".", "").split("e")[0]
sig_figs.append(len(rest_str))
line_table["rest"] = line_table["rest"].to(self.spectral_lines["rest"].unit)
line_table["sig_figs"] = sig_figs
for row in line_table:
row["rest"] = row["rest"].round(row["sig_figs"])
del line_table["sig_figs"]
"""
# Combine name and rest value for indexing
if "name_rest" not in line_table.colnames:
line_table["name_rest"] = None
for row in line_table:
row["name_rest"] = "{} {}".format(row["linename"], row["rest"].value)
# If no name was given to this list, consider it part of the "Custom" list
if "listname" not in line_table.colnames:
line_table["listname"] = "Custom"
else:
for row in line_table:
if row["listname"] is None:
row["listname"] = "Custom"
# Convert colors to hexa values, or set to default (red)
if "colors" not in line_table.colnames:
line_table["colors"] = "#FF0000FF"
else:
for row in line_table:
if row["colors"][0] == "#":
if len(row["colors"]) == 6:
row["colors"] += "FF"
else:
row["colors"] = cnames[row["colors"]] + "FF"
# Create or update the main spectral_lines astropy table
if self.spectral_lines is None or replace:
self.spectral_lines = line_table
else:
self.spectral_lines = table.vstack([self.spectral_lines, line_table])
self.spectral_lines = table.unique(self.spectral_lines, keys='name_rest')
# It seems that we need to recreate this index after v-stacking.
self.spectral_lines.add_index("name_rest")
self.spectral_lines.add_index("linename")
self.spectral_lines.add_index("listname")
self._broadcast_plotted_lines()
if return_table:
return line_table
def _broadcast_plotted_lines(self, marks=None):
if marks is None:
marks = [x for x in self.figure.marks if isinstance(x, SpectralLine)]
msg = SpectralMarksChangedMessage(marks, sender=self)
self.session.hub.broadcast(msg)
if not np.any([mark.identify for mark in marks]):
# then clear the identified entry
msg = LineIdentifyMessage(name_rest='', sender=self)
self.session.hub.broadcast(msg)
[docs]
def erase_spectral_lines(self, name=None, name_rest=None, show_none=True):
"""
Erase either all spectral lines, all spectral lines sharing the same
name (e.g. 'He II') or a specific name-rest value combination (e.g.
'HE II 1640.5', stored in SpectralLine as 'table_index').
"""
fig = self.figure
if name is None and name_rest is None:
fig.marks = [x for x in fig.marks if not isinstance(x, SpectralLine)]
if show_none:
self.spectral_lines["show"] = False
self._broadcast_plotted_lines([])
else:
temp_marks = []
# Toggle "show" value in main astropy table. The astropy table
# machinery only allows updating a single row at a time.
if name_rest is not None:
if isinstance(name_rest, str):
self.spectral_lines.loc[name_rest]["show"] = False
elif isinstance(name_rest, list):
for nr in name_rest:
self.spectral_lines.loc[nr]["show"] = False
# Get rid of the marks we no longer want
for x in fig.marks:
if isinstance(x, SpectralLine):
if name is not None:
self.spectral_lines.loc[name]["show"] = False
if x.name == name:
continue
else:
if isinstance(name_rest, str):
if x.table_index == name_rest:
continue
elif isinstance(name_rest, list):
if x.table_index in name_rest:
continue
temp_marks.append(x)
fig.marks = temp_marks
self._broadcast_plotted_lines()
[docs]
def get_scales(self):
fig = self.figure
# Deselect any pan/zoom or subsetting tools so they don't interfere
# with the scale retrieval
if self.toolbar.active_tool is not None:
self.toolbar.active_tool = None
return {'x': fig.interaction.x_scale, 'y': fig.interaction.y_scale}
[docs]
def plot_spectral_line(self, line, global_redshift=None, plot_units=None, **kwargs):
if isinstance(line, str):
# Try the full index first (for backend calls), otherwise name only
try:
line = self.spectral_lines.loc[line]
except KeyError:
line = self.spectral_lines.loc["linename", line]
if plot_units is None:
plot_units = self.data()[0].spectral_axis.unit
if global_redshift is None:
redshift = self.redshift
else:
redshift = global_redshift
line_mark = SpectralLine(self,
line['rest'].to_value(plot_units),
redshift,
name=line["linename"],
table_index=line["name_rest"],
colors=[line["colors"]], **kwargs)
# Erase this line if it already existed, to avoid duplication
self.erase_spectral_lines(name_rest=line["name_rest"])
self.figure.marks = self.figure.marks + [line_mark]
line["show"] = True
self._broadcast_plotted_lines()
[docs]
def plot_spectral_lines(self, colors=["blue"], global_redshift=None, **kwargs):
"""
Plots a user-provided astropy table of spectral lines in the viewer.
"""
fig = self.figure
self.erase_spectral_lines(show_none=False)
# Check to see if colors were defined for each line
if "colors" in self.spectral_lines.columns:
colors = self.spectral_lines["colors"]
elif len(colors) != len(self.spectral_lines):
colors = colors*len(self.spectral_lines)
lines = self.spectral_lines
plot_units = self.data()[0].spectral_axis.unit
if global_redshift is None:
redshift = self.redshift
else:
redshift = global_redshift
marks = []
for line, color in zip(lines, colors):
if not line["show"]:
continue
line = SpectralLine(self,
line['rest'].to_value(plot_units),
redshift,
name=line["linename"],
table_index=line["name_rest"],
colors=[color], **kwargs)
marks.append(line)
fig.marks = fig.marks + marks
self._broadcast_plotted_lines()
[docs]
def available_linelists(self):
return get_available_linelists()
def _show_uncertainty_changed(self, msg=None):
# this is subscribed in init to watch for changes to the state
# object since uncertainty handling is in jdaviz instead of glue/glue-jupyter
if self.state.show_uncertainty:
self._plot_uncertainties()
else:
self._clean_error()
[docs]
def show_mask(self):
self.display_mask = True
self._plot_mask()
[docs]
def clean(self):
# Remove extra traces, in case they exist.
self.display_mask = False
self._clean_mask()
# this will automatically call _clean_error via _show_uncertainty_changed
self.state.show_uncertainty = False
def _clean_mask(self):
fig = self.figure
fig.marks = [x for x in fig.marks if not isinstance(x, ScatterMask)]
def _clean_error(self):
fig = self.figure
fig.marks = [x for x in fig.marks if not isinstance(x, LineUncertainties)]
[docs]
def add_data(self, data, color=None, alpha=None, **layer_state):
"""
Overrides the base class to add markers for plotting
uncertainties and data quality flags.
Parameters
----------
spectrum : :class:`glue.core.data.Data`
Data object with the spectrum.
color : obj
Color value for plotting.
alpha : float
Alpha value for plotting.
Returns
-------
result : bool
`True` if successful, `False` otherwise.
"""
# If this is the first loaded data, set things up for unit conversion.
if len(self.layers) == 0:
reset_plot_axes = True
else:
# Check if the new data flux unit is actually compatible since flux not linked.
try:
uc.to_unit(data, data.find_component_id("flux"), [1, 1],
u.Unit(self.state.y_display_unit)) # Error if incompatible
except Exception as err:
# Raising exception here introduces a dirty state that messes up next load_data
# but not raising exception also causes weird behavior unless we remove the data
# completely.
self.session.hub.broadcast(SnackbarMessage(
f"Failed to load {data.label}, so removed it: {repr(err)}",
sender=self, color='error'))
self.jdaviz_app.data_collection.remove(data)
return False
reset_plot_axes = False
# The base class handles the plotting of the main
# trace representing the spectrum itself.
result = super().add_data(data, color, alpha, **layer_state)
if reset_plot_axes:
x_units = data.get_component(self.state.x_att.label).units
y_units = data.get_component("flux").units
with delay_callback(self.state, "x_display_unit", "y_display_unit"):
self.state.x_display_unit = x_units if len(x_units) else None
self.state.y_display_unit = y_units if len(y_units) else None
self.set_plot_axes()
self._plot_uncertainties()
self._plot_mask()
# Set default linewidth on any created spectral subset layers
# NOTE: this logic will need updating if we add support for multiple cubes as this assumes
# that new data entries (from model fitting or gaussian smooth, etc) will only be spectra
# and all subsets affected will be spectral
for layer in self.state.layers:
if (isinstance(layer.layer, GroupedSubset)
and get_subset_type(layer.layer) == 'spectral'
and layer.layer.data.label == data.label):
layer.linewidth = 3
return result
def _plot_mask(self):
if not self.display_mask:
return
# Remove existing mask marks
self._clean_mask()
# Loop through all active data in the viewer
for index, layer_state in enumerate(self.state.layers):
lyr = layer_state.layer
comps = [str(component) for component in lyr.components]
# Skip subsets
if hasattr(lyr, "subset_state"):
continue
# Ignore data that does not have a mask component
if "mask" in comps:
mask = np.array(lyr['mask'].data)
data_obj = lyr.data.get_object()
data_x = data_obj.spectral_axis.value
data_y = data_obj.flux.value
# For plotting markers only for the masked data
# points, erase un-masked data from trace.
y = np.where(np.asarray(mask) == 0, np.nan, data_y)
# A subclass of the bqplot Scatter object, ScatterMask places
# 'X' marks where there is masked data in the viewer.
color = layer_state.color
alpha_shade = layer_state.alpha / 3
mask_line_mark = ScatterMask(scales=self.scales,
marker='cross',
x=data_x,
y=y,
stroke_width=0.5,
colors=[color],
default_size=25,
default_opacities=[alpha_shade]
)
# Add mask marks to viewer
self.figure.marks = list(self.figure.marks) + [mask_line_mark]
def _plot_uncertainties(self):
if not self.state.show_uncertainty:
return
# Remove existing error bars
self._clean_error()
# Loop through all active data in the viewer
for index, layer_state in enumerate(self.state.layers):
lyr = layer_state.layer
# Skip subsets
if hasattr(lyr, "subset_state"):
continue
comps = [str(component) for component in lyr.components]
# Ignore data that does not have an uncertainty component
if "uncertainty" in comps: # noqa
error = np.array(lyr['uncertainty'].data)
# ensure that the uncertainties are represented as stddev:
uncertainty_type_str = lyr.meta.get('uncertainty_type', 'stddev')
uncert_cls = uncertainty_str_to_cls_mapping[uncertainty_type_str]
error = uncert_cls(error).represent_as(StdDevUncertainty).array
# Then we assume that last axis is always wavelength.
# This may need adjustment after the following
# specutils PR is merged: https://github.com/astropy/specutils/pull/1033
spectral_axis = -1
data_obj = lyr.data.get_object(cls=Spectrum1D, statistic=None)
if isinstance(lyr.data.coords, SpectralCoordinates):
spectral_wcs = lyr.data.coords
data_x = spectral_wcs.pixel_to_world_values(
np.arange(lyr.data.shape[spectral_axis])
)
if isinstance(data_x, tuple):
data_x = data_x[0]
else:
if hasattr(lyr.data.coords, 'spectral_wcs'):
spectral_wcs = lyr.data.coords.spectral_wcs
elif hasattr(lyr.data.coords, 'spectral'):
spectral_wcs = lyr.data.coords.spectral
data_x = spectral_wcs.pixel_to_world(
np.arange(lyr.data.shape[spectral_axis])
)
data_y = data_obj.data
# The shaded band around the spectrum trace is bounded by
# two lines, above and below the spectrum trace itself.
data_x_list = np.ndarray.tolist(data_x)
x = [data_x_list, data_x_list]
y = [np.ndarray.tolist(data_y - error),
np.ndarray.tolist(data_y + error)]
if layer_state.as_steps:
for i in (0, 1):
a = np.insert(x[i], 0, 2*x[i][0] - x[i][1])
b = np.append(x[i], 2*x[i][-1] - x[i][-2])
edges = (a + b) / 2
x[i] = np.concatenate((edges[:1], np.repeat(edges[1:-1], 2), edges[-1:]))
y[i] = np.repeat(y[i], 2)
x, y = np.asarray(x), np.asarray(y)
# A subclass of the bqplot Lines object, LineUncertainties keeps
# track of uncertainties plotted in the viewer. LineUncertainties
# appear with two lines and shaded area in between.
color = layer_state.color
alpha_shade = layer_state.alpha / 3
error_line_mark = LineUncertainties(viewer=self,
x=[x],
y=[y],
scales=self.scales,
stroke_width=1,
colors=[color, color],
fill_colors=[color, color],
opacities=[0.0, 0.0],
fill_opacities=[alpha_shade,
alpha_shade],
fill='between',
close_path=False
)
# Add error lines to viewer
self.figure.marks = list(self.figure.marks) + [error_line_mark]
[docs]
def set_plot_axes(self):
# Set y axes labels for the spectrum viewer
y_display_unit = self.state.y_display_unit
y_unit = u.Unit(y_display_unit) if y_display_unit else u.dimensionless_unscaled
if y_unit.is_equivalent(u.Jy / u.sr):
flux_unit_type = "Surface brightness"
elif y_unit.is_equivalent(u.erg / (u.s * u.cm**2)):
flux_unit_type = 'Flux'
elif y_unit.is_equivalent(u.electron / u.s) or y_unit.physical_type == 'dimensionless':
# electron / s or 'dimensionless_unscaled' should be labeled counts
flux_unit_type = "Counts"
elif y_unit.is_equivalent(u.W):
flux_unit_type = "Luminosity"
else:
# default to Flux Density for flux density or uncaught types
flux_unit_type = "Flux density"
# Set x axes labels for the spectrum viewer
x_disp_unit = self.state.x_display_unit
x_unit = u.Unit(x_disp_unit) if x_disp_unit else u.dimensionless_unscaled
if x_unit.is_equivalent(u.m):
spectral_axis_unit_type = "Wavelength"
elif x_unit.is_equivalent(u.Hz):
spectral_axis_unit_type = "Frequency"
elif x_unit.is_equivalent(u.pixel):
spectral_axis_unit_type = "Pixel"
else:
spectral_axis_unit_type = str(x_unit.physical_type).title()
with self.figure.hold_sync():
self.figure.axes[0].label = f"{spectral_axis_unit_type} [{self.state.x_display_unit}]"
self.figure.axes[1].label = f"{flux_unit_type} [{self.state.y_display_unit}]"
# Make it so axis labels are not covering tick numbers.
self.figure.fig_margin["left"] = 95
self.figure.fig_margin["bottom"] = 60
self.figure.send_state('fig_margin') # Force update
self.figure.axes[0].label_offset = "40"
self.figure.axes[1].label_offset = "-70"
# NOTE: with tick_style changed below, the default responsive ticks in bqplot result
# in overlapping tick labels. For now we'll hardcode at 8, but this could be removed
# (default to None) if/when bqplot auto ticks react to styling options.
self.figure.axes[1].num_ticks = 8
# Set Y-axis to scientific notation
self.figure.axes[1].tick_format = '0.1e'
for i in (0, 1):
self.figure.axes[i].tick_style = {'font-size': 15, 'font-weight': 600}