Source code for gwcs.coordinate_frames._base

from __future__ import annotations

import warnings
from abc import abstractmethod
from collections.abc import Callable
from itertools import zip_longest
from numbers import Number
from typing import Any, NamedTuple, Protocol, Self, TypeAlias, runtime_checkable

import numpy as np
from astropy import units as u
from astropy.coordinates import BaseCoordinateFrame as _AstropyBaseCoordinateFrame
from astropy.time import Time, TimeDelta
from astropy.wcs.wcsapi.high_level_api import (
    high_level_objects_to_values,
    values_to_high_level_objects,
)
from numpy import typing as npt

from ._axis import AxesType

__all__ = [
    "AstropyBuiltInFrame",
    "BaseCoordinateFrame",
    "CoordinateFrameProtocol",
    "LowLevelArray",
    "LowLevelInput",
    "WorldAxisObjectClass",
    "WorldAxisObjectClassConverter",
    "WorldAxisObjectComponent",
]

AstropyBuiltInFrame: TypeAlias = Time | _AstropyBaseCoordinateFrame
LowLevelArray: TypeAlias = npt.NDArray[np.generic]
LowLevelInput: TypeAlias = LowLevelArray | u.Quantity


class WorldAxisObjectClass(NamedTuple):
    """
    A tuple to document the individual elements of the ``world_axis_object_classes``
    of the WCS API.

    Notes
    -----
    - ``world_axis_object_classes`` will return a dictionary with the key being
        the name from ``world_axis_object_components`` and the value being an
        instance of this class.

    - To stay consistent with the APE 14 API, users should not access the elements
        of this tuple via their names, but instead should access them via their
        position in the tuple.

    Attributes
    ----------
    class_object : type | str
        The High-Level Object class for the axis or a string that is the fully
        qualified name of the class.
    arguments : tuple
        The positional arguments to be passed to the class when instantiating an
        object of this class. Note if ``world_axis_object_components`` specifies that
        the world coordinates should be passed as a positional argument, then this
        tuple will include `None` as a place holder for each of the world coordinates.
    keyword_arguments : dict
        The keyword arguments to be passed to the class when instantiating an object of
        this class.
    """

    class_object: type | str
    arguments: tuple[Any, ...]
    keyword_arguments: dict[str, Any]


class WorldAxisObjectClassConverter(NamedTuple):
    """
    Same as the `WorldAxisObjectClass` but with an additional converter field.

    Attributes
    ----------
    converter : Callable[..., Any]
        A callable that will convert the input values into the desired output
    """

    class_object: type | str
    arguments: tuple[Any, ...]
    keyword_arguments: dict[str, Any]
    converter: Callable[..., Any]


WorldAxisObjectClasses: TypeAlias = (
    dict[str, WorldAxisObjectClass]
    | dict[str, WorldAxisObjectClassConverter]
    | dict[str, WorldAxisObjectClass | WorldAxisObjectClassConverter]
)


[docs] class WorldAxisObjectComponent(NamedTuple): """ A tuple to document the individual elements of the ``world_axis_object_components`` of the WCS API. Notes ----- - ``world_axis_object_components`` will return a list of tuples with each tuple being an instance of this class. - To stay consistent with the APE 14 API, users should not access the elements of this tuple via their names, but instead should access them via their position in the tuple. Attributes ---------- name : str Name for the world object this world array corresponds to, which *must* match the string names used in ``world_axis_object_classes``. Note that names might appear twice because two world arrays might correspond to a single world object (e.g. a celestial coordinate might have both “ra” and “dec” arrays, which correspond to a single sky coordinate object. position : str | int This is either a string keyword argument name or a positional index for the corresponding class from ``world_axis_object_classes``. property: str | Callable[[Any], str] This is a string giving the name of the property to access on the corresponding class from ``world_axis_object_classes`` in order to get numerical values. """ name: str position: str | int property: str | Callable[[Any], str]
[docs] @classmethod def from_tuple(cls, tup: tuple[str, str | int, str]) -> Self: return cls(*tup)
[docs] @runtime_checkable class CoordinateFrameProtocol(Protocol): """ API Definition for a Coordinate frame """ @property @abstractmethod def naxes(self) -> int: """ The number of axes described by this frame. """ @property @abstractmethod def name(self) -> str: """ The name of the coordinate frame. """ @property @abstractmethod def unit(self) -> tuple[u.Unit | None, ...]: """ The units of the axes in this frame. """ @property @abstractmethod def axes_names(self) -> tuple[str, ...]: """ Names describing the axes of the frame. """ @property @abstractmethod def axes_order(self) -> tuple[int, ...]: """ The position of the axes in the frame in the transform. """ @property @abstractmethod def reference_frame(self) -> AstropyBuiltInFrame | None: """ The reference frame of the coordinates described by this frame. This is usually an Astropy object such as ``SkyCoord`` or ``Time``. """ @property @abstractmethod def axes_type(self) -> AxesType: """ An upcase string (or tuple of strings) describing the type of the axis. See AxisType for the known values, but you can also use your own custom one. """ @property @abstractmethod def axis_physical_types(self) -> tuple[str | None, ...]: """ The UCD 1+ physical types for the axes, in frame order. """ @property @abstractmethod def world_axis_object_classes(self) -> WorldAxisObjectClasses: """ The APE 14 object classes for this frame. See Also -------- astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_classes """ @property @abstractmethod def world_axis_object_components(self) -> list[WorldAxisObjectComponent]: """ The APE 14 object components for this frame. See Also -------- astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_components """
[docs] def add_units( self, arrays: tuple[LowLevelInput, ...] | LowLevelInput ) -> tuple[LowLevelInput, ...]: """ Add units to the arrays """ # Handle the case where we have a single axis input which maybe passed as a # scalar rather than a tuple of length 1. if self.naxes == 1 and (np.isscalar(arrays) or isinstance(arrays, u.Quantity)): return ( arrays if self.unit[0] is None else u.Quantity(arrays, self.unit[0]), ) return tuple( # Add units to the array if there is a unit for the axis, otherwise # just pass it through. array if unit is None or array is None else ( array.to(unit) if isinstance(array, TimeDelta) else u.Quantity(array, unit=unit) ) # zip_longest is used here to support "non-coordinate" inputs/outputs # This implicitly assumes that the "non-coordinate" inputs/outputs # are tacked onto the end of the tuple of "coordinate" inputs/outputs. for array, unit in zip_longest(arrays, self.unit) )
[docs] def remove_units( self, arrays: tuple[LowLevelInput, ...] | LowLevelInput ) -> tuple[LowLevelArray, ...]: """ Remove units from the input arrays """ return tuple( # Strip the unit off an axis if the axis is a quantity, # otherwise just pass it through. array.value if isinstance(array, u.Quantity) else array # self.add_units is used first because: # 1. If something is a Quantity, then it will be converted to the # unit of the frame. # 2. If something is not a Quantity, but the frame has a unit for that # axis, then we treat that as the correct magnitude but just missing # the unit, so we get a Quantity with the correct unit. # 3. If there is no unit for the axis, then we just pass whatever it is # through and hope for the best. # Now we have an array with the correct units, so we can safely strip # the units off by accessing the .value (magnitude) of the attribute. for array in self.add_units(arrays) )
[docs] def to_high_level_coordinates(self, *values): """ Convert "values" to high level coordinate objects described by this frame. "values" are the coordinates in array or scalar form, and high level objects are things such as ``SkyCoord`` or ``Quantity``. See :ref:`wcsapi` for details. Parameters ---------- values : `numbers.Number`, `numpy.ndarray`, or `~astropy.units.Quantity` ``naxis`` number of coordinates as scalars or arrays. Returns ------- high_level_coordinates One (or more) high level object describing the coordinate. """ # We allow Quantity-like objects here which values_to_high_level_objects # does not. values = self.remove_units(values) if not all(isinstance(v, Number) or type(v) is np.ndarray for v in values): msg = "All values should be a scalar number or a numpy array." raise TypeError(msg) high_level = values_to_high_level_objects(*values, low_level_wcs=self) if len(high_level) == 1: high_level = high_level[0] return high_level
[docs] def from_high_level_coordinates(self, *high_level_coords): """ Convert high level coordinate objects to "values" as described by this frame. "values" are the coordinates in array or scalar form, and high level objects are things such as ``SkyCoord`` or ``Quantity``. See :ref:`wcsapi` for details. Parameters ---------- high_level_coordinates One (or more) high level object describing the coordinate. Returns ------- values : `numbers.Number` or `numpy.ndarray` ``naxis`` number of coordinates as scalars or arrays. """ values = high_level_objects_to_values(*high_level_coords, low_level_wcs=self) if len(values) == 1: values = values[0] return values
[docs] class BaseCoordinateFrame(CoordinateFrameProtocol): """ Legacy base class for coordinate frames. """ def __init_subclass__(cls, *args, **kwargs): msg = ( "BaseCoordinateFrame has been deprecated and will be removed in a" "future release. Please implement or inherit from CoordinateFrameProtocol " "instead." ) warnings.warn(msg, DeprecationWarning, stacklevel=2) super().__init_subclass__(*args, **kwargs) @property def world_axis_object_components(self): """ The APE 14 object components for this frame. See Also -------- astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_components """ if self.naxes == 1: return self._native_world_axis_object_components # If we have more than one axis then we should sort the native # components by the axes_order. ordered = np.array(self._native_world_axis_object_components, dtype=object)[ np.argsort(self.axes_order) ] return list(map(tuple, ordered)) @property @abstractmethod def _native_world_axis_object_components(self): """ This property holds the "native" frame order of the components. The native order of the components is the order the frame assumes the input arrays are in. This is not necessarily the same as the order of the axes in the frame, which is given by axes_order. """