Source code for jdaviz.core.marks

import numpy as np

from astropy import units as u
from bqplot import LinearScale
from bqplot.marks import Lines, Label, Scatter
from copy import deepcopy
from glue.core import HubListener
from specutils import Spectrum1D

from jdaviz.core.events import GlobalDisplayUnitChanged
from jdaviz.core.events import (SliceToolStateMessage, LineIdentifyMessage,
                                SpectralMarksChangedMessage,
                                RedshiftMessage)

__all__ = ['OffscreenLinesMarks', 'BaseSpectrumVerticalLine', 'SpectralLine',
           'SliceIndicatorMarks', 'ShadowMixin', 'ShadowLine', 'ShadowLabelFixedY',
           'PluginMark', 'LinesAutoUnit', 'PluginLine', 'PluginScatter',
           'LineAnalysisContinuum', 'LineAnalysisContinuumCenter',
           'LineAnalysisContinuumLeft', 'LineAnalysisContinuumRight',
           'LineUncertainties', 'ScatterMask', 'SelectedSpaxel', 'MarkersMark', 'FootprintOverlay',
           'ApertureMark', 'SpectralExtractionPreview']

accent_color = "#c75d2c"


[docs] class OffscreenLinesMarks(HubListener): def __init__(self, viewer): self.viewer = viewer viewer.state.add_callback("x_min", lambda x_min: self._update_counts()) viewer.state.add_callback("x_max", lambda x_max: self._update_counts()) viewer.session.hub.subscribe(self, RedshiftMessage, handler=self._update_counts) viewer.session.hub.subscribe(self, SpectralMarksChangedMessage, handler=self._update_counts) self.left = Label(text=[''], x=[0.02], y=[0.8], scales={'x': LinearScale(min=0, max=1), 'y': LinearScale(min=0, max=1)}, colors=['gray'], default_size=12, align='start') self.right = Label(text=[''], x=[0.98], y=[0.8], scales={'x': LinearScale(min=0, max=1), 'y': LinearScale(min=0, max=1)}, colors=['gray'], default_size=12, align='end') self._update_counts() @property def marks(self): return [self.left, self.right] def _update_counts(self, *args): oob_left, oob_right = 0, 0 for m in self.viewer.figure.marks: if isinstance(m, SpectralLine): if m.x[0] < self.viewer.state.x_min: oob_left += 1 elif m.x[0] > self.viewer.state.x_max: oob_right += 1 self.left.text = [f'\u25c0 {oob_left}' if oob_left > 0 else ''] self.right.text = [f'{oob_right} \u25b6' if oob_right > 0 else '']
[docs] class PluginMark: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.xunit = None self.yunit = None # whether to update existing marks when global display units are changed self.auto_update_units = True self.hub.subscribe(self, GlobalDisplayUnitChanged, handler=self._on_global_display_unit_changed) if self.xunit is None: self.set_x_unit() if self.yunit is None: self.set_y_unit() @property def hub(self): return self.viewer.hub
[docs] def update_xy(self, x, y): self.x = np.asarray(x) self.y = np.asarray(y)
[docs] def append_xy(self, x, y): self.x = np.append(self.x, x) self.y = np.append(self.y, y)
[docs] def set_x_unit(self, unit=None): if unit is None: if not hasattr(self.viewer.state, 'x_display_unit'): return unit = self.viewer.state.x_display_unit unit = u.Unit(unit) if self.xunit is not None and not np.all([s == 0 for s in self.x.shape]): x = (self.x * self.xunit).to_value(unit, u.spectral()) self.xunit = unit self.x = x self.xunit = unit
[docs] def set_y_unit(self, unit=None): if unit is None: if not hasattr(self.viewer.state, 'y_display_unit'): return unit = self.viewer.state.y_display_unit unit = u.Unit(unit) if self.yunit is not None and not np.all([s == 0 for s in self.y.shape]): if self.viewer.default_class is Spectrum1D: spec = self.viewer.state.reference_data.get_object(cls=Spectrum1D) eqv = u.spectral_density(spec.spectral_axis) y = (self.y * self.yunit).to_value(unit, equivalencies=eqv) else: y = (self.y * self.yunit).to_value(unit) self.yunit = unit self.y = y self.yunit = unit
def _on_global_display_unit_changed(self, msg): if not self.auto_update_units: return if self.viewer.__class__.__name__ in ['SpecvizProfileView', 'CubevizProfileView']: axis_map = {'spectral': 'x', 'flux': 'y'} elif self.viewer.__class__.__name__ == 'MosvizProfile2DView': axis_map = {'spectral': 'x'} else: return axis = axis_map.get(msg.axis, None) if axis is not None: getattr(self, f'set_{axis}_unit')(msg.unit)
[docs] def clear(self): self.update_xy([], [])
[docs] class BaseSpectrumVerticalLine(Lines, PluginMark, HubListener): def __init__(self, viewer, x, **kwargs): self.viewer = viewer # the location of the marker will need to update automatically if the # underlying data changes (through a unit conversion, for example) if hasattr(viewer.state, 'reference_data'): viewer.state.add_callback("reference_data", self._update_reference_data) scales = viewer.scales # Lines.__init__ will set self.x super().__init__(x=[x, x], y=[0, 1], scales={'x': scales['x'], 'y': LinearScale(min=0, max=1)}, **kwargs) def _update_reference_data(self, reference_data): if reference_data is None: return self._update_unit(reference_data.get_object(cls=Spectrum1D).spectral_axis.unit) def _update_unit(self, new_unit): # the x-units may have changed. We want to convert the internal self.x # from self.xunit to the new units (x_all.unit) if self.xunit is None: self.xunit = new_unit return if new_unit == self.xunit: return old_quant = self.x[0]*self.xunit x = old_quant.to_value(new_unit, equivalencies=u.spectral()) self.x = [x, x] self.xunit = new_unit
[docs] class SpectralLine(BaseSpectrumVerticalLine): """ Subclass on bqplot Lines, mostly so that we can erase spectral lines by eliminating any SpectralLines objects from a figures marks list. Also lets us do wavelength redshifting here on mark creation. """ def __init__(self, viewer, rest_value, redshift=0, name=None, **kwargs): self._rest_value = rest_value self._identify = False self.name = name # table_index is same as name_rest elsewhere self.table_index = kwargs.pop("table_index", None) # setting redshift will set self.x and enable the obs_value property, # but to do that we need x_unit set first (would normally be assigned # in the super init) self.xunit = u.Unit(viewer.state.x_display_unit) self.redshift = redshift viewer.session.hub.subscribe(self, LineIdentifyMessage, handler=self._process_identify_change) super().__init__(viewer=viewer, x=self.obs_value, stroke_width=1, fill='none', close_path=False, **kwargs) @property def name_rest(self): return self.table_index @property def rest_value(self): return self._rest_value @property def obs_value(self): return self.x[0]
[docs] def set_x_unit(self, unit=None): prev_unit = self.xunit super().set_x_unit(unit=unit) self._rest_value = (self._rest_value * prev_unit).to_value(unit, u.spectral())
@property def redshift(self): return self._redshift @redshift.setter def redshift(self, redshift): self._redshift = redshift if str(self.xunit.physical_type) == 'length': obs_value = self._rest_value*(1+redshift) elif str(self.xunit.physical_type) == 'frequency': obs_value = self._rest_value/(1+redshift) else: # catch all for anything else (wavenumber, energy, etc) rest_angstrom = (self._rest_value*self.xunit).to_value(u.Angstrom, equivalencies=u.spectral()) obs_angstrom = rest_angstrom*(1+redshift) obs_value = (obs_angstrom*u.Angstrom).to_value(self.xunit, equivalencies=u.spectral()) self.x = [obs_value, obs_value] @property def identify(self): return self._identify @identify.setter def identify(self, identify): if not isinstance(identify, bool): # pragma: no cover raise TypeError("identify must be of type bool") self._identify = identify self.stroke_width = 3 if identify else 1 def _process_identify_change(self, msg): self.identify = msg.name_rest == self.table_index def _update_unit(self, new_unit): if self.xunit is None: self.xunit = new_unit return if new_unit == self.xunit: return old_quant = self._rest_value*self.xunit self._rest_value = old_quant.to_value(new_unit, equivalencies=u.spectral()) # re-compute self.x from current redshift (instead of converting that as well) self.redshift = self._redshift self.xunit = new_unit
[docs] class SliceIndicatorMarks(BaseSpectrumVerticalLine, HubListener): """Subclass on bqplot Lines to handle slice/wavelength indicator. """ def __init__(self, viewer, value=0, **kwargs): self._viewer = viewer self._value = None self._oob = False # out-of-bounds, either False, 'left', or 'right' self._active = False # TODO: new viewers need to respect plugin settings self._show_if_inactive = True self._show_value = True viewer.state.add_callback("x_min", lambda x_min: self._value_handle_oob(update_label=True)) viewer.state.add_callback("x_max", lambda x_max: self._value_handle_oob(update_label=True)) viewer.session.hub.subscribe(self, SliceToolStateMessage, handler=self._on_change_state) super().__init__(viewer=viewer, x=[value, value], stroke_width=2, marker='diamond', fill='none', close_path=False, labels=['slice'], labels_visibility='none', **kwargs) self.value = value # instead of using the Lines label which is limited, we'll use a Label object which # will follow the x-coordinate of the slice indicator line, with a fixed y-value # (in axes-units) and will flip its alignment depending on whether the line is on the # left or right side of the axes. self.label = ShadowLabelFixedY(viewer, self, shadow_traits=[], default_size=12, y=0.95) # default to the initial state of the tool since we can't control if this will # happen before or after the initialization of the tool tool_active = self.viewer.toolbar.active_tool_id == 'jdaviz:selectslice' self._on_change_state({'active': tool_active}) @property def marks(self): return [self, self.label] def _on_global_display_unit_changed(self, msg): # Updating the value is handled by the plugin itself, need to update unit string. if msg.axis in ["spectral", "x"]: self.xunit = msg.unit self._update_label() def _value_handle_oob(self, x=None, update_label=False): if x is None: x = self.value else: self._value = x x_min, x_max = self._viewer.state.x_min, self._viewer.state.x_max if x_min is None or x_max is None: self.x = [x, x] return x_range = x_max - x_min padding_fig = 0.01 padding = padding_fig * x_range x_min += padding x_max -= padding # ensure y-scale has been set (we'll only be overriding x, but scatter viewers complain # if y-scale is not set) self.scales.setdefault('y', LinearScale(min=0, max=1)) if x < x_min: self.x = [padding_fig, padding_fig] self.scales = {**self.scales, 'x': LinearScale(min=0, max=1)} self.line_style = 'dashed' self._oob = 'left' elif x > x_max: self.x = [1-padding_fig, 1-padding_fig] self.scales = {**self.scales, 'x': LinearScale(min=0, max=1)} self.line_style = 'dashed' self._oob = 'right' else: self.x = [x, x] self.scales = {**self.scales, 'x': self._viewer.scales['x']} self.line_style = 'solid' self._oob = False if update_label: self._update_label() def _update_colors_opacities(self): # orange (accent) if active, import button blue otherwise (see css in main_styles.vue) if not self._show_if_inactive and not self._active: self.label.visible = False self.visible = False return self.visible = True self.label.visible = self._show_value self.colors = ["#c75109" if self._active else "#007BA1"] self.opacities = [1.0 if self._active else 0.9] def _on_change_state(self, msg={}): if isinstance(msg, dict): changes = msg else: if msg.viewer is not None and msg.viewer != self.viewer: return changes = msg.change for k, v in changes.items(): if k == 'active': self._active = v elif k == 'show_indicator': self._show_if_inactive = v elif k == 'show_value': self._show_value = v self._update_colors_opacities() def _update_label(self): def _formatted_value(value): power = abs(np.log10(value)) if power >= 3: # use scientific notation return f'{value:0.4e}' else: return f'{value:0.4f}' valuestr = _formatted_value(self.value) xunit = str(self.xunit) if self.xunit is not None else '' # U+00A0 is a blank space, U+25C0 a left arrow triangle, and U+25B6 a right arrow triangle if self._oob == 'left': self.labels = [f'\u00A0 \u25c0 {valuestr} {xunit} \u00A0'] # noqa elif self._oob == 'right': self.labels = [f'{valuestr} {xunit} \u25b6 \u00A0'] else: self.labels = [f'\u00A0 {valuestr} {xunit} \u00A0'] @property def value(self): return self._value @value.setter def value(self, value): self._value_handle_oob(value, update_label=True)
[docs] class ShadowMixin: """Mixin class to propagate traits from one mark object to another. Anything in ``sync_traits`` will be mirrored directly from ``shadowing`` to the shadowed object. Can manually override ``_on_shadowing_changed`` for more advanced logic cases. """ def _get_id(self, mark): return getattr(mark, '_model_id', None) def _setup_shadowing(self, shadowing, sync_traits=[], other_traits=[]): """ sync_traits: traits to set now, and mirror any changes to shadowing in the future other_trait: traits to set now, but not mirror in the future """ if not hasattr(self, '_shadowing'): self._shadowing = {} self._sync_traits = {} shadowing_id = self._get_id(shadowing) self._shadowing[shadowing_id] = shadowing self._sync_traits[shadowing_id] = sync_traits # sync initial values for attr in sync_traits + other_traits: self._on_shadowing_changed({'name': attr, 'new': getattr(shadowing, attr), 'owner': shadowing}) # subscribe to future changes shadowing.observe(self._on_shadowing_changed) def _on_shadowing_changed(self, change): if change['name'] in self._sync_traits.get(self._get_id(change.get('owner')), []): setattr(self, change['name'], change['new']) return
[docs] class ShadowLine(Lines, HubListener, ShadowMixin): """Create a white shadow line around another line to help make it standout on top of other lines. """ def __init__(self, shadowing, shadow_width=1, **kwargs): self._shadow_width = shadow_width super().__init__(scales=shadowing.scales, stroke_width=shadowing.stroke_width+shadow_width if shadowing.stroke_width else 0, # noqa marker_size=shadowing.marker_size+shadow_width if shadowing.marker_size else 0, # noqa colors=[kwargs.pop('color', 'white')], **kwargs) self._setup_shadowing(shadowing, ['scales', 'x', 'y', 'visible', 'line_style', 'marker'], ['stroke_width', 'marker_size'])
class ShadowSpatialSpectral(Lines, HubListener, ShadowMixin): """ Shadow the mark of a spatial subset collapsed spectrum, with the mask from a spectral subset, and the styling from the spatial subset. """ def __init__(self, spatial_spectrum_mark, spectral_subset_mark, spatial_uuid, spectral_uuid, data_uuid): # spatial_spectrum_mark: Lines mark corresponding to the spatially-collapsed spectrum # from a spatial subset # spectral_subset_mark: Lines mark on the FULL cube corresponding to the glue-highlight # of the spectral subset super().__init__(scales=spatial_spectrum_mark.scales, marker=None) self._spatial_mark_id = self._get_id(spatial_spectrum_mark) self._setup_shadowing(spatial_spectrum_mark, ['scales', 'y', 'visible', 'line_style'], ['x']) self.spatial_uuid = spatial_uuid self._spectral_mark_id = self._get_id(spectral_subset_mark) self._setup_shadowing(spectral_subset_mark, ['stroke_width', 'x', 'y', 'visible', 'opacities', 'colors']) self.spectral_uuid = spectral_uuid self.data_uuid = data_uuid @property def spatial_spectrum_mark(self): return self._shadowing[self._spatial_mark_id] @property def spectral_subset_mark(self): return self._shadowing[self._spectral_mark_id] def _on_shadowing_changed(self, change): if hasattr(self, '_spectral_mark_id'): if change['name'] == 'y': # at initial setup, the arrays may not be populated yet if self.spatial_spectrum_mark.y.shape == self.spectral_subset_mark.y.shape: # force a copy or else we'll overwrite the mask to the spatial mark! change['new'] = deepcopy(self.spatial_spectrum_mark.y) change['new'][np.isnan(self.spectral_subset_mark.y)] = np.nan elif change['name'] == 'visible': # only show if BOTH shadowing marks are set to visible change['new'] = self.spectral_subset_mark.visible and self.spatial_spectrum_mark.visible # noqa return super()._on_shadowing_changed(change)
[docs] class ShadowLabelFixedY(Label, ShadowMixin): """Label whose position shadows that of a parent ``shadowing`` line and will flip alignment based on whether it is left or right of the center of the viewer. """ def __init__(self, viewer, shadowing, shadow_traits=['visible'], y=0.95, point_index=0, **kwargs): super().__init__(**kwargs) self._viewer = viewer self.y = [y] self.scales['y'] = LinearScale(min=0, max=1) self._point_index = point_index self._setup_shadowing(shadowing, shadow_traits, ['x', 'scales', 'labels', 'colors']) viewer.state.add_callback("x_min", lambda x_min: self._update_align()) viewer.state.add_callback("x_max", lambda x_max: self._update_align()) def _force_redraw(self): # TODO: bug in bqplot that change in align/colors traitlet doesn't update immediately, # we'll get around it in the meantime by just forcing the Label to see a change to the # text traitlet text = self.text self.text = [''] self.text = text def _update_align(self): if not isinstance(self.scales.get('x'), LinearScale): return # determine alignment automatically if self.scales['x'].min == 0 and self.scales['x'].max == 1: # then we're in axes units, so just check position compared to 0.5 is_to_right = self.x[0] > 0.5 else: # then we're in data units, so check position compared to the median of the axes limits is_to_right = self.x[0] > (self._viewer.state.x_min + self._viewer.state.x_max) / 2. if is_to_right and self.align != 'end': self.align = 'end' # force redraw by re-updating label self._force_redraw() if not is_to_right and self.align != 'start': self.align = 'start' # force redraw by re-updating label self._force_redraw() def _on_shadowing_changed(self, change): super()._on_shadowing_changed(change) if change['name'] == 'labels': self.text = [change['new'][self._point_index]] elif change['name'] in ('x', 'colors'): setattr(self, change['name'], [change['new'][self._point_index]]) if change['name'] == 'colors': # bqplot bug that won't notice change to colors, manually force re-draw self._force_redraw() elif change['name'] == 'scales': self.scales = {**self.scales, 'x': change['new']['x']} if change['name'] in ('x', 'scales'): # then the position of the label on the plot has changed, so re-determine whether # it should be aligned to the left or right self._update_align()
[docs] class LinesAutoUnit(PluginMark, Lines, HubListener): def __init__(self, viewer, *args, **kwargs): self.viewer = viewer super().__init__(*args, **kwargs)
[docs] class PluginLine(Lines, PluginMark, HubListener): def __init__(self, viewer, x=[], y=[], **kwargs): self.viewer = viewer # color is same blue as import button kwargs.setdefault('colors', [accent_color]) super().__init__(x=x, y=y, scales=kwargs.pop('scales', viewer.scales), **kwargs)
[docs] class PluginScatter(Scatter, PluginMark, HubListener): def __init__(self, viewer, x=[], y=[], **kwargs): self.viewer = viewer # default color is same blue as import button kwargs.setdefault('colors', [accent_color]) super().__init__(x=x, y=y, scales=kwargs.pop('scales', viewer.scales), **kwargs)
[docs] class LineAnalysisContinuum(PluginLine): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # units do not need to be updated because the plugin itself reruns # the computation and automatically changes the arrays themselves self.auto_update_units = False
[docs] class LineAnalysisContinuumCenter(LineAnalysisContinuum): def __init__(self, viewer, x=[], y=[], **kwargs): super().__init__(viewer, x, y, **kwargs) self.stroke_width = 1
[docs] class LineAnalysisContinuumLeft(LineAnalysisContinuum): def __init__(self, viewer, x=[], y=[], **kwargs): super().__init__(viewer, x, y, **kwargs) self.stroke_width = 5
[docs] class LineAnalysisContinuumRight(LineAnalysisContinuumLeft): pass
[docs] class LineUncertainties(LinesAutoUnit): def __init__(self, viewer, *args, **kwargs): super().__init__(viewer, *args, **kwargs)
[docs] class ScatterMask(Scatter): def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] class SelectedSpaxel(Lines): def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] class MarkersMark(PluginScatter): def __init__(self, viewer, **kwargs): kwargs.setdefault('marker', 'circle') super().__init__(viewer, **kwargs)
[docs] class FootprintOverlay(PluginLine): def __init__(self, viewer, overlay, **kwargs): self._overlay = overlay kwargs.setdefault('stroke_width', 2) kwargs.setdefault('close_path', True) kwargs.setdefault('opacities', [0.8]) kwargs.setdefault('fill', 'inside') kwargs.setdefault('fill_opacities', [0.2]) super().__init__(viewer, **kwargs) @property def overlay(self): return self._overlay
[docs] class ApertureMark(PluginLine): def __init__(self, viewer, id, **kwargs): self._id = id super().__init__(viewer, **kwargs)
[docs] class SpectralExtractionPreview(PluginLine): def __init__(self, viewer, **kwargs): super().__init__(viewer, **kwargs)
class HistogramMark(Lines): def __init__(self, min_max_value, scales, **kwargs): # Vertical line in LinearScale y = [0, 1] colors = [accent_color] line_style = "solid" super().__init__(x=min_max_value, y=y, scales=scales, colors=colors, line_style=line_style, **kwargs)