Source code for rydopt.pulses.pulse_ansatz

from __future__ import annotations

from dataclasses import dataclass

import jax
import jax.numpy as jnp
from numpy.typing import ArrayLike

from rydopt.pulses.general_pulse_ansatz_functions import const
from rydopt.types import PulseAnsatzFunction, PulseParams


def _const_zero(t: jax.Array | float, _duration: float, _ansatz_params: jax.Array) -> jax.Array:
    return const(t, _duration, jnp.array([0.0]))


def _const_one(t: jax.Array | float, _duration: float, _ansatz_params: jax.Array) -> jax.Array:
    return const(t, _duration, jnp.array([1.0]))


[docs] @dataclass class PulseAnsatz: r"""Data class that stores ansatz functions for the laser pulse that couples the qubit state :math:`|1\rangle` to the Rydberg state :math:`|r\rangle`. RydOpt models the atom-light interaction in the rotating frame, using the rotating wave approximation. The Hamiltonian of the driven two-level ladder system :math:`|1\rangle \leftrightarrow |r\rangle` is .. math:: H_\mathrm{drive}(t)=\begin{pmatrix} 0 & \frac{\Omega(t)}{2} e^{-i\xi(t)} \\ \frac{\Omega(t)}{2} e^{i\xi(t)} & -\Delta(t) \end{pmatrix}. For available ansatz functions for the detuning :math:`\Delta(t)`, phase :math:`\xi(t)`, and Rabi frequency :math:`\Omega(t)` sweeps, see below. The function :func:`optimize <rydopt.optimization.optimize>` allows optimizing the parameters of the ansatz functions and duration of the laser pulse to maximize the gate fidelity. Initial parameters can be provided to the function as :class:`PulseParams`, i.e., as a tuple ``(duration, detuning_params, phase_params, rabi_params)``. Example: >>> import rydopt as ro >>> pulse = ro.pulses.PulseAnsatz( ... detuning_ansatz=ro.pulses.const, ... phase_ansatz=ro.pulses.sin_crab, ... ) Attributes: detuning_ansatz: Detuning sweep :math:`\Delta(t)`, default is zero. phase_ansatz: Phase sweep :math:`\xi(t)`, default is zero. rabi_ansatz: Rabi frequency amplitude sweep :math:`\Omega(t)`, default is one. """ detuning_ansatz: PulseAnsatzFunction = _const_zero phase_ansatz: PulseAnsatzFunction = _const_zero rabi_ansatz: PulseAnsatzFunction = _const_one
[docs] def evaluate_pulse_functions( self, t: jax.Array | float, params: PulseParams ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: r"""Evaluate the detuning, phase, and the rabi sweeps for fixed parameters at the given times. Args: t: Time samples at which the functions are evaluated params: Pulse parameters Returns: Tuple ``(detuning_1, detuning_r, phase, rabi)`` """ duration, detuning_ansatz_params, phase_ansatz_params, rabi_ansatz_params = params detuning_ansatz_params = jnp.asarray(detuning_ansatz_params) phase_ansatz_params = jnp.asarray(phase_ansatz_params) rabi_ansatz_params = jnp.asarray(rabi_ansatz_params) return ( jnp.zeros_like(t), self.detuning_ansatz(t, duration, detuning_ansatz_params), self.phase_ansatz(t, duration, phase_ansatz_params), self.rabi_ansatz(t, duration, rabi_ansatz_params), )
[docs] @dataclass class TwoPhotonPulseAnsatz: r"""Data class that stores an effective two-photon pulse ansatz that couples the qubit state :math:`|1\rangle` to the Rydberg state :math:`|r\rangle` via the intermediate state :math:`|e\rangle`. RydOpt models the atom-light interaction in the rotating frame, using the rotating wave approximation. The Hamiltonian of the driven three-level ladder system :math:`|1\rangle \leftrightarrow |e\rangle \leftrightarrow |r\rangle` is taken as .. math:: H_\mathrm{3lvl}(t)= \begin{pmatrix} 0 & \frac{\Omega_\ell(t)}{2}\,e^{-i\xi_\ell(t)} & 0 \\[6pt] \frac{\Omega_\ell(t)}{2}\,e^{i\xi_\ell(t)} & -\Delta_\ell(t) - i \frac{\gamma}{2}& \frac{\Omega_u(t)}{2}\,e^{-i\xi_u(t)} \\[6pt] 0 & \frac{\Omega_u(t)}{2}\,e^{i\xi_u(t)} & -\Delta_\ell(t)-\Delta_u(t) \end{pmatrix}, where the lower/upper laser couples :math:`|1\rangle \leftrightarrow |e\rangle` / :math:`|e\rangle \leftrightarrow |r\rangle` with Rabi frequency amplitudes :math:`\Omega_{\ell/u}(t)`, phases :math:`\xi_{\ell/u}(t)`, detunings :math:`\Delta_{\ell/u}(t)`. :math:`\gamma` is the decay rate of the intermediate state. The implementation is restricted to the adiabatic-elimination regime (:math:`|\Delta_\ell| \gg |\Omega_\ell|, |\Omega_u|, |\delta|` and :math:`|\Delta_\ell|^2 \gg |\dot{\Omega}_\ell|, |\dot{\Omega}_u|, |\dot{\delta}|` with :math:`\delta = \Delta_\ell+\Delta_u`), where the system can be treated by an effective two-level Hamiltonian on the subspace :math:`\{|1\rangle,|r\rangle\}`: .. math:: H_\mathrm{drive}(t)= \begin{pmatrix} -\Delta_{1,\mathrm{eff}}(t) & \frac{\Omega_\mathrm{eff}(t)}{2} e^{-i\xi_\mathrm{eff}(t)} \\ \frac{\Omega_\mathrm{eff}(t)}{2} e^{i\xi_\mathrm{eff}(t)} & -\Delta_{r,\mathrm{eff}}(t) \end{pmatrix}. The effective controls are computed as .. math:: \Omega_\mathrm{eff}(t)&=\frac{\Omega_\ell(t)\Omega_u(t)}{2(\Delta_\ell(t)+i\gamma/2)}, \\ \xi_\mathrm{eff}(t)&=\xi_\ell(t)+\xi_u(t), \\ \Delta_{1,\mathrm{eff}}(t)&=- \frac{\Omega_\ell(t)^2}{4(\Delta_\ell(t)+i\gamma/2)} \\ \Delta_{r,\mathrm{eff}}(t)&=\Delta_\ell(t)+\Delta_u(t)- \frac{\Omega_u(t)^2}{4(\Delta_\ell(t)+i\gamma/2)}. For available ansatz functions for the detuning, phase, and Rabi frequency sweeps, see below. The function :func:`optimize <rydopt.optimization.optimize>` allows optimizing the parameters of the ansatz functions and duration of the laser pulse to maximize the gate fidelity. Initial parameters can be provided to the function as :class:`PulseParams`, i.e., as a tuple ``(duration, detuning_params, phase_params, rabi_params)``. Each parameter array within the tuple is packed as ``[*lower_transition_params, *upper_transition_params]``. The split positions are set by ``lower_param_counts=(n_detuning, n_phase, n_rabi)``. Example: >>> import rydopt as ro >>> lower = ro.pulses.PulseAnsatz( ... detuning_ansatz=ro.pulses.const, ... phase_ansatz=ro.pulses.sin_crab, ... ) >>> upper = ro.pulses.PulseAnsatz( ... detuning_ansatz=ro.pulses.const, ... rabi_ansatz=ro.pulses.const, ... ) >>> pulse = ro.pulses.TwoPhotonPulseAnsatz( ... lower_transition=lower, ... upper_transition=upper, ... lower_param_counts=(1, 4, 0) ... ) Attributes: lower_transition: Ansatz for the lower transition :math:`|1\rangle \leftrightarrow |e\rangle`. upper_transition: Ansatz for the upper transition :math:`|e\rangle \leftrightarrow |r\rangle`. lower_param_counts: Tuple ``(n_detuning, n_phase, n_rabi)`` specifying how many entries per parameter array belong to the lower transition. decay: Decay rate of the intermediate state, default is zero. """ lower_transition: PulseAnsatz upper_transition: PulseAnsatz lower_param_counts: tuple[int, int, int] decay: float = 0.0 @staticmethod def _split_1d(packed_params: ArrayLike, lower_count: int) -> tuple[jax.Array, jax.Array]: packed_params = jnp.asarray(packed_params) return packed_params[:lower_count], packed_params[lower_count:] def _unpack_transition_params(self, params: PulseParams) -> tuple[PulseParams, PulseParams]: duration, detuning_params, phase_params, rabi_params = params lower_detuning_count, lower_phase_count, lower_rabi_count = self.lower_param_counts lower_detuning_params, upper_detuning_params = self._split_1d(detuning_params, lower_detuning_count) lower_phase_params, upper_phase_params = self._split_1d(phase_params, lower_phase_count) lower_rabi_params, upper_rabi_params = self._split_1d(rabi_params, lower_rabi_count) lower_params: PulseParams = (duration, lower_detuning_params, lower_phase_params, lower_rabi_params) upper_params: PulseParams = (duration, upper_detuning_params, upper_phase_params, upper_rabi_params) return lower_params, upper_params
[docs] def evaluate_pulse_functions( self, t: jax.Array | float, params: PulseParams ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: r"""Evaluate the effective two-photon detuning, phase, and the rabi sweeps for fixed parameters at the given times. Args: t: Time samples at which the functions are evaluated params: Pulse parameters Returns: Tuple ``(detuning_1, detuning_r, phase, rabi)`` """ lower_params, upper_params = self._unpack_transition_params(params) _, lower_detuning, lower_phase, lower_rabi = self.lower_transition.evaluate_pulse_functions(t, lower_params) _, upper_detuning, upper_phase, upper_rabi = self.upper_transition.evaluate_pulse_functions(t, upper_params) effective_rabi = lower_rabi * upper_rabi / (2.0 * (lower_detuning + 0.5j * self.decay)) effective_phase = lower_phase + upper_phase effective_detuning_1 = -(lower_rabi**2) / (4.0 * (lower_detuning + 0.5j * self.decay)) effective_detuning_r = (lower_detuning + upper_detuning) - upper_rabi**2 / ( 4.0 * (lower_detuning + 0.5j * self.decay) ) return effective_detuning_1, effective_detuning_r, effective_phase, effective_rabi