Source code for jdaviz.configs.mosviz.plugins.viewers

import numpy as np

from astropy.coordinates import SkyCoord
from astropy.table import QTable
from functools import cached_property
from glue.core import BaseData
from glue_jupyter.bqplot.image import BqplotImageView
from glue_jupyter.table import TableViewer
from scipy.interpolate import interp1d
from specutils import Spectrum1D

from jdaviz.core.events import (AddDataToViewerMessage,
                                RemoveDataFromViewerMessage,
                                TableClickMessage)
from jdaviz.core.astrowidgets_api import AstrowidgetsImageViewerMixin
from jdaviz.core.registries import viewer_registry
from jdaviz.core.freezable_state import FreezableBqplotImageViewerState
from jdaviz.configs.default.plugins.viewers import JdavizViewerMixin
from jdaviz.configs.specviz.plugins.viewers import SpecvizProfileView

__all__ = ['MosvizImageView', 'MosvizProfile2DView',
           'MosvizProfileView', 'MosvizTableViewer']


[docs] @viewer_registry("mosviz-image-viewer", label="Image 2D (Mosviz)") class MosvizImageView(JdavizViewerMixin, BqplotImageView, AstrowidgetsImageViewerMixin): # categories: zoom resets, zoom, pan, subset, select tools, shortcuts tools_nested = [ ['jdaviz:homezoom', 'jdaviz:prevzoom'], ['jdaviz:boxzoom'], ['jdaviz:panzoom'], ['jdaviz:sidebar_plot', 'jdaviz:sidebar_export'] ] default_class = None _state_cls = FreezableBqplotImageViewerState def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.init_astrowidgets_api() self._subscribe_to_layers_update() self.state.show_axes = False # Axes are wrong anyway self.figure.fig_margin = {'left': 0, 'bottom': 0, 'top': 0, 'right': 0}
[docs] def data(self, cls=None): return [layer_state.layer.get_object(cls=cls or self.default_class) for layer_state in self.state.layers if hasattr(layer_state, 'layer') and isinstance(layer_state.layer, BaseData)]
# NOTE: This is currently only for debugging. It is not used in app. def _mark_targets(self): table_data = self.jdaviz_app.data_collection['MOS Table'] if ("R.A." not in table_data.component_ids() or "Dec." not in table_data.component_ids()): return ra = table_data["R.A."] dec = table_data["Dec."] sky = SkyCoord(ra, dec, unit='deg') t = QTable({'coord': sky}) self.add_markers(t, use_skycoord=True, marker_name='Targets')
[docs] @viewer_registry("mosviz-profile-2d-viewer", label="Spectrum 2D (Mosviz)") class MosvizProfile2DView(JdavizViewerMixin, BqplotImageView): # Due to limitations in CCDData and 2D data that has spectral and spatial # axes, the default conversion class must handle cubes default_class = Spectrum1D # categories: zoom resets, zoom, pan, subset, select tools, shortcuts tools_nested = [ ['mosviz:homezoom'], ['mosviz:boxzoom', 'mosviz:xrangezoom', 'jdaviz:yrangezoom'], ['mosviz:panzoom', 'mosviz:panzoom_x', 'jdaviz:panzoom_y'], ['bqplot:xrange'], ['jdaviz:sidebar_plot', 'jdaviz:sidebar_export'] ] _state_cls = FreezableBqplotImageViewerState def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._subscribe_to_layers_update() # Setup viewer option defaults self.state.aspect = 'auto' self.session.hub.subscribe(self, AddDataToViewerMessage, handler=self._on_viewer_data_changed) self.session.hub.subscribe(self, RemoveDataFromViewerMessage, handler=self._on_viewer_data_changed) for k in ('x_min', 'x_max'): self.state.add_callback(k, self._handle_x_axis_orientation) @cached_property def reference_spectral_axis(self): return self.state.reference_data.get_object().spectral_axis.value @cached_property def pixel_to_world_interp(self): pixels = range(len(self.reference_spectral_axis)) return interp1d(pixels, self.reference_spectral_axis)
[docs] def pixel_to_world_limits(self, limits): if not len(limits) == 2: raise ValueError("limits must be length 2") pixels = np.arange(0, len(self.reference_spectral_axis)) # we'll use interpolation when possible, but also want to fit a line between # the outermost edge of the data within the limits line_edges_pix = np.array([max((min(pixels), min(limits))), min((max(pixels), max(limits)))]) if line_edges_pix[0] > line_edges_pix[1]: # then the limits are entirely out of range, so use the whole range # when fitting the linear approximation line_edges_pix = np.array([min(pixels), max(pixels)]) line_edges_world = self.pixel_to_world_interp(line_edges_pix) line_coeffs = np.polyfit(line_edges_pix, line_edges_world, deg=1) def pixel_to_world_line(pixel): return line_coeffs[0] * pixel + line_coeffs[1] def map_pixel_to_world(pixel): if pixels[0] <= pixel <= pixels[-1]: # interpolate directly return float(self.pixel_to_world_interp(pixel)) else: # use the line model to extrapolate return pixel_to_world_line(pixel) invert = (-1) ** sum((self.inverted_x_axis, limits[0] > limits[1])) out_lims = list(map(map_pixel_to_world, limits))[::invert] return out_lims
@cached_property def world_to_pixel_interp(self): pixels = range(len(self.reference_spectral_axis)) return interp1d(self.reference_spectral_axis, pixels)
[docs] def world_to_pixel_limits(self, limits): if not len(limits) == 2: raise ValueError("limits must be length 2") # we'll use interpolation when possible, but also want to fit a line between # the outermost edge of the data within the limits line_edges_world = np.array([max((min(self.reference_spectral_axis), min(limits))), min((max(self.reference_spectral_axis), max(limits)))]) if line_edges_world[0] > line_edges_world[1]: # then the limits are entirely out of range, so use the whole range # when fitting the linear approximation line_edges_world = np.array([min(self.reference_spectral_axis), max(self.reference_spectral_axis)]) line_edges_pixels = self.world_to_pixel_interp(line_edges_world) line_coeffs = np.polyfit(line_edges_world, line_edges_pixels, deg=1) def world_to_pixel_line(world): return line_coeffs[0] * world + line_coeffs[1] def map_world_to_pixel(world): if min(self.reference_spectral_axis) <= world <= max(self.reference_spectral_axis): # interpolate directly return float(self.world_to_pixel_interp(world)) else: # use the line model to extrapolate return world_to_pixel_line(world) invert = (-1) ** sum((self.inverted_x_axis, limits[0] > limits[1])) out_lims = list(map(map_world_to_pixel, limits))[::invert] return out_lims
def _on_viewer_data_changed(self, msg): if msg.viewer_reference != self.reference: return # clear cached properties that are based on reference data - this is probably # overly-conservative and we might be able to limit the clearing for only when # reference data is changed (perhaps with a callback on the state for reference_data) for attr in ('reference_spectral_axis', 'inverted_x_axis', 'pixel_to_world_interp', 'world_to_pixel_interp'): if attr in self.__dict__: del self.__dict__[attr] if len(self.data()): self._handle_x_axis_orientation() @cached_property def inverted_x_axis(self): return self.reference_spectral_axis[0] > self.reference_spectral_axis[-1] def _handle_x_axis_orientation(self, *args): x_scales = self.scales['x'] limits = [x_scales.min, x_scales.max] limits_inverted = limits[0] > limits[1] if limits_inverted == self.inverted_x_axis: return with x_scales.hold_sync(): x_scales.min = max(limits) if self.inverted_x_axis else min(limits) x_scales.max = min(limits) if self.inverted_x_axis else max(limits)
[docs] def data(self, cls=None, statistic=None): return [layer_state.layer.get_object(statistic=statistic, cls=cls or self.default_class) for layer_state in self.state.layers if hasattr(layer_state, 'layer') and isinstance(layer_state.layer, BaseData)]
[docs] def set_plot_axes(self): self.figure.axes[0].label = "x: pixels" self.figure.axes[1].label = "y: pixels" self.figure.axes[1].tick_format = None self.figure.axes[1].label_location = "middle" # Sync with Spectrum1D viewer (that is also used by other viz). # Make it so y axis label is not covering tick numbers. self.figure.fig_margin["left"] = 95 self.figure.fig_margin["bottom"] = 60 self.figure.send_state('fig_margin') # Force update self.figure.axes[0].label_offset = "40" self.figure.axes[1].label_offset = "-70" # NOTE: with tick_style changed below, the default responsive ticks in bqplot result # in overlapping tick labels. For now we'll hardcode at 8, but this could be removed # (default to None) if/when bqplot auto ticks react to styling options. self.figure.axes[1].num_ticks = 8 for i in (0, 1): self.figure.axes[i].tick_style = {'font-size': 15, 'font-weight': 600}
[docs] @viewer_registry("mosviz-profile-viewer", label="Profile 1D") class MosvizProfileView(SpecvizProfileView): # categories: zoom resets, zoom, pan, subset, select tools, shortcuts tools_nested = [ ['mosviz:homezoom'], ['mosviz:boxzoom', 'mosviz:xrangezoom', 'jdaviz:yrangezoom'], # noqa ['mosviz:panzoom', 'mosviz:panzoom_x', 'jdaviz:panzoom_y'], # noqa ['bqplot:xrange'], ['jdaviz:selectline'], ['jdaviz:sidebar_plot', 'jdaviz:sidebar_export'] ]
[docs] def set_plot_axes(self): super().set_plot_axes() self.figure.axes[1].num_ticks = 5
[docs] @viewer_registry("mosviz-table-viewer", label="Table (Mosviz)") class MosvizTableViewer(TableViewer, JdavizViewerMixin): def __init__(self, session, *args, **kwargs): super().__init__(session, *args, **kwargs) self.figure_widget.observe(self._on_row_selected, names=['highlighted']) # enable scrolling: # https://github.com/glue-viz/glue-jupyter/pull/287 self.widget_table.scrollable = True self._selected_data = {} self._shared_image = False self.row_selection_in_progress = False self._on_row_selected_begin = None self._on_row_selected_end = None @property def _default_table_viewer_reference_name(self): return self.jdaviz_helper._default_table_viewer_reference_name @property def _default_spectrum_viewer_reference_name(self): return self.jdaviz_helper._default_spectrum_viewer_reference_name @property def _default_spectrum_2d_viewer_reference_name(self): return self.jdaviz_helper._default_spectrum_2d_viewer_reference_name @property def _default_image_viewer_reference_name(self): return self.jdaviz_helper._default_image_viewer_reference_name
[docs] def redraw(self): # Overload to hide components - we do this via overloading instead of # checking for changes in self.figure_widget.data because some components # might be added inplace to the dataset. if self.figure_widget.data is None: self.figure_widget.hidden_components = [] else: components_str = [cid.label for cid in self.figure_widget.data.main_components] hidden = [] for colname in ('Images', '1D Spectra', '2D Spectra'): if colname in components_str: hidden.append(self.figure_widget.data.id[colname]) self.figure_widget.hidden_components = hidden super().redraw()
@property def nrows(self): return self.widget_table.get_state()['total_length'] @property def current_row(self): return self.widget_table.highlighted
[docs] def select_row(self, n): if n < 0 or n >= self.nrows: raise ValueError("n must be between 0 and {}".format(self.nrows-1)) # compute and set the appropriate page # NOTE: traitlets won't detect internal changes to a dict options = self.widget_table.get_state()['options'] page = int(n / options['itemsPerPage']) + 1 if options['page'] != page: self.widget_table.set_state({'options': {**options, 'page': page}}) self.widget_table.send_state() # select and highlight the row self.widget_table.highlighted = n
[docs] def next_row(self): current_row = self.current_row new_row = 0 if current_row == self.nrows - 1 else current_row + 1 self.select_row(new_row)
[docs] def prev_row(self): current_row = self.current_row new_row = self.nrows - 1 if current_row == 0 else current_row - 1 self.select_row(new_row)
def _on_row_selected(self, event): if self._on_row_selected_begin: self._on_row_selected_begin(event) self.row_selection_in_progress = True # Grab the index of the latest selected row selected_index = event['new'] mos_data = self.session.data_collection['MOS Table'] # plugin data entries: select all in new row, deselect all others for data_item in self.jdaviz_app.data_collection: if data_item.meta.get('Plugin') is not None: if data_item.meta.get('mosviz_row') == selected_index: self.session.hub.broadcast(AddDataToViewerMessage( self._default_spectrum_viewer_reference_name, data_item.label, sender=self)) else: self.session.hub.broadcast(RemoveDataFromViewerMessage( self._default_spectrum_viewer_reference_name, data_item.label, sender=self)) for component in mos_data.components: comp_data = mos_data.get_component(component).data selected_data = comp_data[selected_index] if component.label == '1D Spectra': prev_data = self._selected_data.get(self._default_spectrum_viewer_reference_name) if prev_data != selected_data: if prev_data: # This covers the cases where data is unit converted # and the name is modified all_prev_data = [x for x in self.session.data_collection.labels if prev_data in x] for modified_prev_data in all_prev_data: if modified_prev_data: remove_data_from_viewer_message = RemoveDataFromViewerMessage( self._default_spectrum_viewer_reference_name, modified_prev_data, sender=self ) # reset the counter in the spectrum viewer's color cycler # so that the newly selected row is displayed in gray and # future additions will have other colors: spectrum_viewer = self.jdaviz_app.get_viewer( self._default_spectrum_viewer_reference_name ) spectrum_viewer.color_cycler.reset() self.session.hub.broadcast(remove_data_from_viewer_message) add_data_to_viewer_message = AddDataToViewerMessage( self._default_spectrum_viewer_reference_name, selected_data, sender=self ) self.session.hub.broadcast(add_data_to_viewer_message) self._selected_data[ self._default_spectrum_viewer_reference_name ] = selected_data if component.label == '2D Spectra': prev_data = self._selected_data.get(self._default_spectrum_2d_viewer_reference_name) if prev_data != selected_data: if prev_data: remove_data_from_viewer_message = RemoveDataFromViewerMessage( self._default_spectrum_2d_viewer_reference_name, prev_data, sender=self ) self.session.hub.broadcast(remove_data_from_viewer_message) add_data_to_viewer_message = AddDataToViewerMessage( self._default_spectrum_2d_viewer_reference_name, selected_data, sender=self ) self.session.hub.broadcast(add_data_to_viewer_message) self._selected_data[ self._default_spectrum_2d_viewer_reference_name ] = selected_data if component.label == 'Images': prev_data = self._selected_data.get(self._default_image_viewer_reference_name) if prev_data != selected_data: if prev_data: remove_data_from_viewer_message = RemoveDataFromViewerMessage( self._default_image_viewer_reference_name, prev_data, sender=self) self.session.hub.broadcast(remove_data_from_viewer_message) add_data_to_viewer_message = AddDataToViewerMessage( self._default_image_viewer_reference_name, selected_data, sender=self) self.session.hub.broadcast(add_data_to_viewer_message) self._selected_data[self._default_image_viewer_reference_name] = selected_data message = TableClickMessage(selected_index=selected_index, shared_image=self._shared_image, sender=self) self.session.hub.broadcast(message) self.row_selection_in_progress = False if self._on_row_selected_end: self._on_row_selected_end(event)
[docs] def set_plot_axes(self, *args, **kwargs): return