Source code for jdaviz.configs.imviz.wcs_utils

# This is adapted from Ginga (ginga.util.wcs, ginga.trcalc, and ginga.Bindings.ImageViewBindings).
# Please see the file licenses/GINGA_LICENSE.txt for details.
#
"""This module handles calculations based on world coordinate system (WCS)."""

import base64
import math
from io import BytesIO

import matplotlib.pyplot as plt
import numpy as np
from astropy import units as u
from astropy import coordinates as coord
from astropy.coordinates import SkyCoord
from astropy.modeling import models
from astropy.nddata import NDData
from astropy.wcs import WCS
from astropy.wcs.utils import proj_plane_pixel_scales

from gwcs import coordinate_frames as cf
from gwcs.wcs import WCS as GWCS

from matplotlib.patches import Polygon
from jdaviz.utils import _wcs_only_label

__all__ = ['get_compass_info', 'draw_compass_mpl']


def rotate_pt(x_arr, y_arr, theta_deg, xoff=0, yoff=0):
    """
    Rotate an array of points (x_arr, y_arr) by theta_deg offsetted
    from a center point by (xoff, yoff).
    """
    a_arr = x_arr - xoff
    b_arr = y_arr - yoff
    cos_t = np.cos(np.radians(theta_deg))
    sin_t = np.sin(np.radians(theta_deg))
    ap = (a_arr * cos_t) - (b_arr * sin_t)
    bp = (a_arr * sin_t) + (b_arr * cos_t)
    return np.asarray((ap + xoff, bp + yoff))


def add_offset_radec(ra_deg, dec_deg, delta_deg_ra, delta_deg_dec):
    """
    Algorithm to compute RA/Dec from RA/Dec base position plus tangent
    plane offsets.
    """
    # To radians
    x = math.radians(delta_deg_ra)
    y = math.radians(delta_deg_dec)
    raz = math.radians(ra_deg)
    decz = math.radians(dec_deg)

    sdecz = math.sin(decz)
    cdecz = math.cos(decz)

    d = cdecz - y * sdecz

    ra2 = math.atan2(x, d) + raz
    # Normalize ra into the range 0 to 2*pi
    twopi = math.pi * 2
    ra2 = math.fmod(ra2, twopi)
    if ra2 < 0.0:
        ra2 += twopi
    dec2 = math.atan2(sdecz + y * cdecz, math.sqrt(x * x + d * d))

    # back to degrees
    ra2_deg = math.degrees(ra2)
    dec2_deg = math.degrees(dec2)

    return (ra2_deg, dec2_deg)


def add_offset_xy(image_wcs, x, y, delta_deg_x, delta_deg_y):
    # calculate ra/dec of x,y pixel
    c = image_wcs.pixel_to_world(x, y)
    if isinstance(c, SkyCoord):
        ra_deg = c.ra.deg
        dec_deg = c.dec.deg
    else:  # list of Quantity (e.g., from FITS primary header)
        ra_deg = c[0].value
        dec_deg = c[1].value

    # add offsets
    ra2_deg, dec2_deg = add_offset_radec(ra_deg, dec_deg, delta_deg_x, delta_deg_y)

    # then back to new pixel coords
    return image_wcs.world_to_pixel_values(ra2_deg, dec2_deg)  # x2, y2


def calc_compass(image_wcs, x, y, len_deg_e, len_deg_n):

    # Get east and north coordinates
    xe, ye = list(map(float, add_offset_xy(image_wcs, x, y, len_deg_e, 0.0)))
    xn, yn = list(map(float, add_offset_xy(image_wcs, x, y, 0.0, len_deg_n)))

    return (x, y, xn, yn, xe, ye)


def calc_compass_radius(image_wcs, x, y, radius_px):
    xe, ye = add_offset_xy(image_wcs, x, y, 1.0, 0.0)
    xn, yn = add_offset_xy(image_wcs, x, y, 0.0, 1.0)

    # now calculate the length in pixels of those arcs
    # (planar geometry is good enough here)
    px_per_deg_e = math.sqrt(math.fabs(ye - y) ** 2 + math.fabs(xe - x) ** 2)
    px_per_deg_n = math.sqrt(math.fabs(yn - y) ** 2 + math.fabs(xn - x) ** 2)

    # now calculate the arm length in degrees for each arm
    # (this produces same-length arms)
    len_deg_e = radius_px / px_per_deg_e
    len_deg_n = radius_px / px_per_deg_n

    return calc_compass(image_wcs, x, y, len_deg_e, len_deg_n)


def calc_compass_center(image_wcs, image_shape, r_fac=0.5):
    # calculate center of data
    x = image_shape[1] * 0.5
    y = image_shape[0] * 0.5

    # radius we want the arms to be
    radius_px = min(image_shape) * r_fac

    return calc_compass_radius(image_wcs, x, y, radius_px)


[docs] def get_compass_info(image_wcs, image_shape, r_fac=0.4): """Calculate WCS compass parameters. North (N) is up and East (E) is left. Parameters ---------- image_wcs : obj WCS that is compatible with APE 14. image_shape : tuple of int Shape of the image in the form of ``(ny, nx)``. r_fac : float Scale factor for compass arrow length. Returns ------- x, y : float Pixel positions for the center of the compass. xn, yn : float Pixel positions for N of the compass. xe, ye : float Pixel positions for E of the compass. degn, dege : float Rotation angles for N and E, in degrees, for the compass, respectively. xflip : bool Should display flip on X? """ x, y, xn, yn, xe, ye = calc_compass_center(image_wcs, image_shape, r_fac=r_fac) degn = math.degrees(math.atan2(xn - x, yn - y)) # rotate east point also by degn xe2, ye2 = rotate_pt(xe, ye, degn, xoff=x, yoff=y) dege = math.degrees(math.atan2(xe2 - x, ye2 - y)) # if right-hand image, flip it to make left hand xflip = False if dege > 0.0: xflip = not xflip if xflip and not np.isclose(degn, 0): degn = -degn return x, y, xn, yn, xe, ye, degn, dege, xflip
[docs] def draw_compass_mpl(image, orig_shape=None, wcs=None, show=True, zoom_limits=None, **kwargs): """Visualize the compass using Matplotlib. Parameters ---------- image : ndarray 2D Numpy array (can be resampled). orig_shape : tuple of int or `None` The original (non-resampled) array shape in ``(ny, nx)``, if different. wcs : obj or `None` Associated original image WCS that is compatible with APE 14. If `None` given, compass is not drawn. show : bool Display the plot. zoom_limits : ndarray or None If not `None`, also draw a rectangle to represent the current zoom limits in the form of list of ``(x, y)`` representing the four corners of the zoom box. kwargs : dict Keywords for ``matplotlib.pyplot.imshow``. Returns ------- image_base64 : str Decoded buffer for Compass plugin. """ if orig_shape is None: orig_shape = image.shape if not show: plt.ioff() fig, ax = plt.subplots() ax.imshow(image, extent=[-0.5, orig_shape[1] - 0.5, -0.5, orig_shape[0] - 0.5], origin='lower', cmap='gray', **kwargs) if wcs is not None: try: x, y, xn, yn, xe, ye, degn, dege, xflip = get_compass_info(wcs, orig_shape) except Exception: wcs = None else: # TODO: Not sure what xflip really do, ask Eric Jeschke later. # if xflip: # plt.imshow(np.fliplr(image), origin='lower') # Positive here is counter-clockwise, hence the minus sign in comment. ax.plot(x, y, marker='o', color='cyan', markersize=5) ax.annotate('N', xy=(x, y), xytext=(xn, yn), arrowprops={'arrowstyle': '<-', 'color': 'cyan', 'lw': 1.5}, color='cyan', fontsize=16, va='center', ha='center') # rotation=-degn ax.annotate('E', xy=(x, y), xytext=(xe, ye), arrowprops={'arrowstyle': '<-', 'color': 'cyan', 'lw': 1.5}, color='cyan', fontsize=16, va='center', ha='center') # rotation=-dege if wcs is None: x = orig_shape[1] * 0.5 y = orig_shape[0] * 0.5 ax.plot(x, y, marker='o', color='yellow', markersize=5) # Also draw X/Y compass. r_xy = float(min(orig_shape)) * 0.25 ax.annotate('X', xy=(x, y), xytext=(x + r_xy, y), arrowprops={'arrowstyle': '<-', 'color': 'yellow', 'lw': 1.5}, color='yellow', fontsize=16, va='center', ha='center') ax.annotate('Y', xy=(x, y), xytext=(x, y + r_xy), arrowprops={'arrowstyle': '<-', 'color': 'yellow', 'lw': 1.5}, color='yellow', fontsize=16, va='center', ha='center') if zoom_limits is not None: ax.add_patch(Polygon( zoom_limits, closed=True, linewidth=1.5, edgecolor='r', facecolor='none')) if show: plt.draw() plt.show() buff = BytesIO() plt.savefig(buff) plt.style.use('default') plt.close() return base64.b64encode(buff.getvalue()).decode('utf-8')
def data_outside_gwcs_bounding_box(data, x, y): """This is for internal use by Imviz coordinates transformation only.""" outside_bounding_box = False if hasattr(data.coords, '_orig_bounding_box'): # then coords is a GWCS object and had its bounding box cleared # by the Imviz parser ints = data.coords._orig_bounding_box.intervals if isinstance(ints[0].lower, u.Quantity): bb_xmin = ints[0].lower.value bb_xmax = ints[0].upper.value bb_ymin = ints[1].lower.value bb_ymax = ints[1].upper.value else: # pragma: no cover bb_xmin = ints[0].lower bb_xmax = ints[0].upper bb_ymin = ints[1].lower bb_ymax = ints[1].upper if not (bb_xmin <= x <= bb_xmax and bb_ymin <= y <= bb_ymax): outside_bounding_box = True # Has to be Python bool, not Numpy bool_ return outside_bounding_box def _rotated_gwcs( center_world_coord, rotation_angle, pixel_scales, cdelt_signs, refdata_shape=(10, 10), image_shape=None ): # based on ``gwcs_simple_imaging_units`` in gwcs: # https://github.com/spacetelescope/gwcs/blob/ # eec9a2b6de8356495f405de3dc6531538589ce5d/gwcs/tests/conftest.py#L165 image_extent = u.Quantity(image_shape, u.pix) * u.Quantity(pixel_scales) refdata_extent = image_extent.max() pixel_scales = refdata_extent / u.Quantity(refdata_shape, u.pix) # multiplying by +/-1 can flip north/south or east/west: flip_direction = ( models.Multiply(cdelt_signs[0]) & models.Multiply(cdelt_signs[1]) ) # shift to compensate for the difference between the center and corner: shift = ( models.Shift((0.5 - refdata_shape[0])/2 * u.pix) & models.Shift((0.5 - refdata_shape[1])/2 * u.pix) ) # rotate field of view: rho = rotation_angle sin_rho = np.sin(rho.to_value(u.rad)) cos_rho = np.cos(rho.to_value(u.rad)) rotation_matrix = np.array([[cos_rho, -sin_rho], [sin_rho, cos_rho]]) rotation = models.AffineTransformation2D( rotation_matrix * u.deg, translation=[0, 0] * u.deg ) rotation.input_units_equivalencies = { "x": u.pixel_scale(pixel_scales[0]), "y": u.pixel_scale(pixel_scales[1]) } rotation.inverse = models.AffineTransformation2D( np.linalg.inv(rotation_matrix) * u.pix, translation=[0, 0] * u.pix ) rotation.inverse.input_units_equivalencies = { "x": u.pixel_scale(1 / pixel_scales[0]), "y": u.pixel_scale(1 / pixel_scales[1]) } tan = models.Pix2Sky_TAN() celestial_rotation = models.RotateNative2Celestial( center_world_coord.ra, center_world_coord.dec, 180 * u.deg ) det2sky = shift | flip_direction | rotation | tan | celestial_rotation det2sky.name = "linear_transform" detector_frame = cf.Frame2D( name="detector", axes_names=("x", "y"), unit=(u.pix, u.pix) ) sky_frame = cf.CelestialFrame( reference_frame=coord.ICRS(), name='icrs', unit=(u.deg, u.deg) ) pipeline = [ (detector_frame, det2sky), (sky_frame, None) ] return GWCS(pipeline) def _prepare_rotated_nddata(real_image_shape, wcs, rotation_angle, refdata_shape, wcs_only_key="_WCS_ONLY", data=None, cdelt_signs=None): cdelt = None # compute the x/y pixel scales from the WCS: if hasattr(wcs, 'pixel_scale_matrix'): pixel_scales = u.Quantity([ value * (unit / u.pix) for value, unit in zip( proj_plane_pixel_scales(wcs), wcs.wcs.cunit ) ]) if getattr(wcs.wcs, 'cd', None) is not None: cdelt = np.diag(wcs.wcs.cd) else: cdelt = wcs.wcs.cdelt elif data.meta.get(wcs_only_key, False): # WCS-only layers have pixel scales in meta: pixel_scales = u.Quantity(data.meta['_pixel_scales']) elif 'wcsinfo' in data.meta and 'wcs' in data.meta and 'ra_ref' in data.meta['wcsinfo']: # GWCS doesn't yet have a pixel scale attr, so approximate # its behavior using the pixel scale method from jwst: pixel_scales = (2 * [compute_scale( data.meta['wcs'], (data.meta['wcsinfo']['ra_ref'], data.meta['wcsinfo']['dec_ref']), 1 )]) * u.deg / u.pix else: # fall back on CRVAL cards if they're available wcsinfo = ( data.meta.get('_primary_header', None) or data.meta.get('wcsinfo', None) or data.meta.get('wcs', None) ) if wcsinfo is not None and not isinstance(wcsinfo, GWCS): crval1 = float(wcsinfo.get('CRVAL1', wcsinfo.get('crval1'))) crval2 = float(wcsinfo.get('CRVAL2', wcsinfo.get('crval2'))) cdelt = [ float(wcsinfo.get('CDELT1', wcsinfo.get('cdelt1'))), float(wcsinfo.get('CDELT2', wcsinfo.get('cdelt2'))) ] unit = u.Unit(wcsinfo.get('CUNIT1', wcsinfo.get('cunit1'))) fiducial = [crval1, crval2] * unit pixel_scales = (2 * [compute_scale( WCS(data.meta['_primary_header']) if 'wcs' not in data.meta else data.meta['wcs'], fiducial, None, 1 )]) * u.deg / u.pix else: # fall back on simple approximation: compare_pixel_coords = [[0, 0], [0, 1]] * u.pix compare_sky_coords = data.coords.pixel_to_world(*compare_pixel_coords) separation = compare_sky_coords[0].separation(compare_sky_coords[1]) pixel_scales = u.Quantity([separation, separation]) / u.pix # flip e.g. RA or Dec axes? if cdelt_signs is None and cdelt is not None: cdelt_signs = np.sign(cdelt) # get the world coordinates of the pixel origin center_pixel_coord = np.array(real_image_shape) / 2 * u.pix center_world_coord = wcs.pixel_to_world(*center_pixel_coord[::-1]) rotation_angle = coord.Angle(rotation_angle).wrap_at(360 * u.deg) # create a GWCS centered on ``filename``, # and rotated by ``rotation_angle``: new_rotated_gwcs = _rotated_gwcs( center_world_coord, rotation_angle, pixel_scales, cdelt_signs, refdata_shape=refdata_shape, image_shape=real_image_shape ) # create a fake NDData (we use arange so data boundaries show up in Imviz # if it ever is accidentally exposed) with the rotated GWCS: placeholder_data = np.nan * np.ones(refdata_shape) ndd = NDData( data=placeholder_data, wcs=new_rotated_gwcs, meta={wcs_only_key: True, '_pixel_scales': pixel_scales} ) return ndd def _get_rotated_nddata_from_label( app, data_label, rotation_angle, refdata_shape=(10, 10), cdelt_signs=None, target_wcs_east_left=True, target_wcs_north_up=True ): """ Create a synthetic NDData which stores GWCS that approximate the WCS in the coords attr of the Data object with label ``data_label`` loaded into ``app``. This method is useful for rotating pre-loaded datasets when combined with ``app._change_reference_data(data_label)``. Parameters ---------- app : `~jdaviz.Application` App instance containing ``data_label``. data_label : str Data label for the Data to rotate. rotation_angle : `~astropy.units.Quantity` Angle to rotate the image counter-clockwise from its original orientation. refdata_shape : tuple Shape of the reference data array. Returns ------- ndd : `~astropy.nddata.NDData` Contains rotated WCS and meaningless data. Raises ------ ValueError Data has no WCS. """ data = app.data_collection[data_label] if data.coords is None: raise ValueError(f"{data_label} has no WCS for rotation.") # transform WCS relative to the first loaded data entry: wcs = data.coords degn, dege, flip = get_compass_info(data.coords, data.shape)[-3:] has_east_left = flip has_north_up = True # assumed if isinstance(wcs, GWCS): lat_axis = wcs.world_axis_names.index("lat") lon_axis = wcs.world_axis_names.index("lon") else: # FITS WCS lat_axis = wcs.wcs.lat lon_axis = wcs.wcs.lng if ( not has_east_left and target_wcs_east_left and 'imviz-compass' in [item['name'] for item in app.state.tray_items] ): # if an east/west flip is necessary, pass that along to the compass: compass_plugin = app.get_tray_item_from_name('imviz-compass') compass_plugin.canvas_flip_horizontal = not compass_plugin.canvas_flip_horizontal if cdelt_signs is None: cdelt_signs = [None, None] cdelt_signs[lon_axis] = ( 1 if ((has_east_left and target_wcs_east_left) or (not has_east_left and not target_wcs_east_left)) else -1 ) cdelt_signs[lat_axis] = ( 1 if ((has_north_up and target_wcs_north_up) or (not has_north_up and not target_wcs_north_up)) else -1 ) else: if has_east_left != target_wcs_east_left: cdelt_signs[lon_axis] = -1 if has_north_up != target_wcs_north_up: cdelt_signs[lat_axis] = -1 return _prepare_rotated_nddata( data.shape, data.coords, rotation_angle, refdata_shape, wcs_only_key=_wcs_only_label, data=data, cdelt_signs=cdelt_signs ) # This method comes from the jwst package: # https://github.com/spacetelescope/jwst/blob/95467186aca9784ece9451b33d437d80d550a795/jwst/assign_wcs/util.py#L103 def compute_scale(wcs, fiducial, disp_axis, pscale_ratio=1): """Compute scaling transform. Parameters ---------- wcs : `~gwcs.wcs.WCS` Reference WCS object from which to compute a scaling factor. fiducial : tuple Input fiducial of (RA, DEC) or (RA, DEC, Wavelength) used in calculating reference points. disp_axis : int Dispersion axis integer. Assumes the same convention as `wcsinfo.dispersion_direction` pscale_ratio : int Ratio of input to output pixel scale Returns ------- scale : float Scaling factor for x and y or cross-dispersion direction. """ spectral = 'SPECTRAL' in wcs.output_frame.axes_type if spectral and disp_axis is None: # pragma: no cover raise ValueError('If input WCS is spectral, a disp_axis must be given') crpix = np.array(wcs.invert(*fiducial)) delta = np.zeros_like(crpix) spatial_idx = np.where(np.array(wcs.output_frame.axes_type) == 'SPATIAL')[0] delta[spatial_idx[0]] = 1 crpix_with_offsets = np.vstack((crpix, crpix + delta, crpix + np.roll(delta, 1))).T crval_with_offsets = wcs(*crpix_with_offsets, with_bounding_box=False) coords = SkyCoord( ra=crval_with_offsets[spatial_idx[0]], dec=crval_with_offsets[spatial_idx[1]], unit="deg" ) xscale = np.abs(coords[0].separation(coords[1]).value) yscale = np.abs(coords[0].separation(coords[2]).value) if pscale_ratio is not None: xscale *= pscale_ratio yscale *= pscale_ratio if spectral: # pragma: no cover # Assuming scale doesn't change with wavelength # Assuming disp_axis is consistent with DataModel.meta.wcsinfo.dispersion.direction return yscale if disp_axis == 1 else xscale return np.sqrt(xscale * yscale)