Source code for rydopt.pulses.pulse_params

from __future__ import annotations

from collections.abc import Sequence
from typing import Any, Generic, Literal, TypeVar, cast, overload

import jax
import jax.numpy as jnp
import numpy as np
import numpy.typing as npt

ParamScalar = TypeVar("ParamScalar", float, bool)


[docs] @jax.tree_util.register_pytree_node_class class PulseParams(Sequence[Any], Generic[ParamScalar]): r"""Pulse-parameter container. The container stores pulse parameters components ``(duration, detuning_params, phase_params, rabi_params)``. """ __slots__ = ("_detuning_params", "_duration", "_phase_params", "_rabi_params") def __init__( self, duration: ParamScalar, detuning_params: npt.ArrayLike = (), phase_params: npt.ArrayLike = (), rabi_params: npt.ArrayLike = (), ) -> None: self._duration = np.asarray(duration).reshape(1) self._detuning_params = np.asarray(detuning_params).reshape(-1) self._phase_params = np.asarray(phase_params).reshape(-1) self._rabi_params = np.asarray(rabi_params).reshape(-1) def __len__(self) -> int: """Return the number of parameter components.""" return 4 @overload def __getitem__(self, index: Literal[0]) -> ParamScalar: ... @overload def __getitem__(self, index: Literal[1, 2, 3]) -> npt.NDArray[Any]: ... @overload def __getitem__(self, index: int) -> ParamScalar | npt.NDArray[Any]: ... @overload def __getitem__(self, index: slice) -> tuple[npt.NDArray[Any], ...]: ... def __getitem__(self, index: int | slice) -> ParamScalar | npt.NDArray[Any] | tuple[npt.NDArray[Any], ...]: """Return one parameter component or a sliced tuple of parameter components.""" if isinstance(index, int) and index == 0: return self._duration[0] return ( self._duration, self._detuning_params, self._phase_params, self._rabi_params, )[index] def __array__( self, dtype: npt.DTypeLike | None = None, copy: bool | None = None, ) -> npt.NDArray[np.float64] | npt.NDArray[np.bool_]: """Return the flattened representation used by ``np.asarray``.""" del dtype array = np.concatenate( ( self._duration, self._detuning_params, self._phase_params, self._rabi_params, ) ) if copy: return array.copy() return array def __jax_array__(self) -> jax.Array: """Return the flattened representation used by ``jnp.asarray``.""" return jnp.concatenate( [ self._duration, self._detuning_params, self._phase_params, self._rabi_params, ], axis=-1, ) def tree_flatten(self) -> tuple[tuple[Any, Any, Any, Any], None]: """Return a flattened representation for JAX tree utilities.""" return (self._duration, self._detuning_params, self._phase_params, self._rabi_params), None @classmethod def tree_unflatten(cls, aux_data: None, children: tuple[Any, Any, Any, Any]) -> PulseParams[Any]: """Reconstruct a PulseParams instance from a flattened representation for JAX tree utilities.""" del aux_data self = cast(PulseParams[Any], object.__new__(cls)) self._duration, self._detuning_params, self._phase_params, self._rabi_params = children return self