Source code for rydopt.pulses.pulse_ansatz

from __future__ import annotations

from dataclasses import dataclass, field

import jax
import jax.numpy as jnp
from numpy.typing import ArrayLike

from rydopt.pulses.ansatz_functions import PulseAnsatzFunction
from rydopt.types import ParamsFloatLike


class _FixedConstant(PulseAnsatzFunction):
    def __init__(self, value: float) -> None:
        super().__init__(0)
        self._value = value

    def __call__(self, t: float | jax.Array, duration: float | jax.Array, ansatz_params: jax.Array) -> jax.Array:
        del duration, ansatz_params
        return self._value + jnp.zeros_like(t)


[docs] @dataclass class PulseAnsatz: r"""Data class that stores ansatz functions for the laser pulse that couples the qubit state :math:`|1\rangle` to the Rydberg state :math:`|r\rangle`. RydOpt models the atom-light interaction in the rotating frame, using the rotating wave approximation. The Hamiltonian of the driven two-level ladder system :math:`|1\rangle \leftrightarrow |r\rangle` is .. math:: H_\mathrm{drive}(t)=\begin{pmatrix} 0 & \frac{\Omega(t)}{2} e^{-i\xi(t)} \\ \frac{\Omega(t)}{2} e^{i\xi(t)} & -\Delta(t) \end{pmatrix}. For available ansatz functions for the detuning :math:`\Delta(t)`, phase :math:`\xi(t)`, and Rabi frequency :math:`\Omega(t)` sweeps, see below. The function :func:`optimize <rydopt.optimization.optimize>` allows optimizing the parameters of the ansatz functions and duration of the laser pulse to maximize the gate fidelity. Initial parameters can be provided to the function as ``PulseParams(duration, detuning_params, phase_params, rabi_params)``. Example: >>> import rydopt as ro >>> pulse = ro.pulses.PulseAnsatz( ... detuning_ansatz=ro.pulses.Const(), ... phase_ansatz=ro.pulses.SinCrab(2), ... ) Attributes: detuning_ansatz: Detuning sweep :math:`\Delta(t)`, default is zero. phase_ansatz: Phase sweep :math:`\xi(t)`, default is zero. rabi_ansatz: Rabi frequency amplitude sweep :math:`\Omega(t)`, default is one. """ detuning_ansatz: PulseAnsatzFunction = field(default_factory=lambda: _FixedConstant(0.0)) phase_ansatz: PulseAnsatzFunction = field(default_factory=lambda: _FixedConstant(0.0)) rabi_ansatz: PulseAnsatzFunction = field(default_factory=lambda: _FixedConstant(1.0)) @property def param_counts(self) -> tuple[int, int, int]: return self.detuning_ansatz.num_params, self.phase_ansatz.num_params, self.rabi_ansatz.num_params def _unpack_params(self, params: ParamsFloatLike) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: flat_params = jnp.asarray(params, dtype=jnp.float64) detuning_count, phase_count, rabi_count = self.param_counts expected_size = 1 + detuning_count + phase_count + rabi_count if int(flat_params.shape[-1]) != expected_size: raise ValueError(f"PulseAnsatz expects {expected_size} packed parameters, got {int(flat_params.shape[-1])}") duration, detuning_params, phase_params, rabi_params = jnp.split( flat_params, (1, 1 + detuning_count, 1 + detuning_count + phase_count), axis=-1, ) return duration[..., 0], detuning_params, phase_params, rabi_params
[docs] def evaluate_pulse_functions( self, t: float | jax.Array, params: ParamsFloatLike ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: r"""Evaluate the detuning, phase, and the rabi sweeps for fixed parameters at the given times. Args: t: Time samples at which the functions are evaluated params: Pulse parameters Returns: Tuple ``(detuning_1, detuning_r, phase, rabi)`` """ duration, detuning_ansatz_params, phase_ansatz_params, rabi_ansatz_params = self._unpack_params(params) return ( jnp.zeros_like(t), self.detuning_ansatz(t, duration, detuning_ansatz_params), self.phase_ansatz(t, duration, phase_ansatz_params), self.rabi_ansatz(t, duration, rabi_ansatz_params), )
[docs] @dataclass class TwoPhotonPulseAnsatz: r"""Data class that stores an effective two-photon pulse ansatz that couples the qubit state :math:`|1\rangle` to the Rydberg state :math:`|r\rangle` via the intermediate state :math:`|e\rangle`. RydOpt models the atom-light interaction in the rotating frame, using the rotating wave approximation. The Hamiltonian of the driven three-level ladder system :math:`|1\rangle \leftrightarrow |e\rangle \leftrightarrow |r\rangle` is taken as .. math:: H_\mathrm{3lvl}(t)= \begin{pmatrix} 0 & \frac{\Omega_\ell(t)}{2}\,e^{-i\xi_\ell(t)} & 0 \\[6pt] \frac{\Omega_\ell(t)}{2}\,e^{i\xi_\ell(t)} & -\Delta_\ell(t) - i \frac{\gamma}{2}& \frac{\Omega_u(t)}{2}\,e^{-i\xi_u(t)} \\[6pt] 0 & \frac{\Omega_u(t)}{2}\,e^{i\xi_u(t)} & -\Delta_\ell(t)-\Delta_u(t) \end{pmatrix}, where the lower/upper laser couples :math:`|1\rangle \leftrightarrow |e\rangle` / :math:`|e\rangle \leftrightarrow |r\rangle` with Rabi frequency amplitudes :math:`\Omega_{\ell/u}(t)`, phases :math:`\xi_{\ell/u}(t)`, detunings :math:`\Delta_{\ell/u}(t)`. :math:`\gamma` is the decay rate of the intermediate state. The implementation is restricted to the adiabatic-elimination regime (:math:`|\Delta_\ell| \gg |\Omega_\ell|, |\Omega_u|, |\delta|` and :math:`|\Delta_\ell|^2 \gg |\dot{\Omega}_\ell|, |\dot{\Omega}_u|, |\dot{\delta}|` with :math:`\delta = \Delta_\ell+\Delta_u`), where the system can be treated by an effective two-level Hamiltonian on the subspace :math:`\{|1\rangle,|r\rangle\}`: .. math:: H_\mathrm{drive}(t)= \begin{pmatrix} -\Delta_{1,\mathrm{eff}}(t) & \frac{\Omega_\mathrm{eff}(t)}{2} e^{-i\xi_\mathrm{eff}(t)} \\ \frac{\Omega_\mathrm{eff}(t)}{2} e^{i\xi_\mathrm{eff}(t)} & -\Delta_{r,\mathrm{eff}}(t) \end{pmatrix}. The effective controls are computed as .. math:: \Omega_\mathrm{eff}(t)&=\frac{\Omega_\ell(t)\Omega_u(t)}{2(\Delta_\ell(t)+i\gamma/2)}, \\ \xi_\mathrm{eff}(t)&=\xi_\ell(t)+\xi_u(t), \\ \Delta_{1,\mathrm{eff}}(t)&=- \frac{\Omega_\ell(t)^2}{4(\Delta_\ell(t)+i\gamma/2)} \\ \Delta_{r,\mathrm{eff}}(t)&=\Delta_\ell(t)+\Delta_u(t)- \frac{\Omega_u(t)^2}{4(\Delta_\ell(t)+i\gamma/2)}. For available ansatz functions for the detuning, phase, and Rabi frequency sweeps, see below. The function :func:`optimize <rydopt.optimization.optimize>` allows optimizing the parameters of the ansatz functions and duration of the laser pulse to maximize the gate fidelity. Initial parameters can be provided to the function as ``PulseParams(duration, detuning_params, phase_params, rabi_params)``. Each parameter array within the tuple is packed as ``[*lower_transition_params, *upper_transition_params]``. The split positions are inferred from the ansatz parameter counts of ``lower_transition``. Example: >>> import rydopt as ro >>> lower = ro.pulses.PulseAnsatz( ... detuning_ansatz=ro.pulses.Const(), ... phase_ansatz=ro.pulses.SinCrab(4), ... ) >>> upper = ro.pulses.PulseAnsatz( ... detuning_ansatz=ro.pulses.Const(), ... rabi_ansatz=ro.pulses.Const(), ... ) >>> pulse = ro.pulses.TwoPhotonPulseAnsatz( ... lower_transition=lower, ... upper_transition=upper, ... ) Attributes: lower_transition: Ansatz for the lower transition :math:`|1\rangle \leftrightarrow |e\rangle`. upper_transition: Ansatz for the upper transition :math:`|e\rangle \leftrightarrow |r\rangle`. decay: Decay rate of the intermediate state, default is zero. """ lower_transition: PulseAnsatz upper_transition: PulseAnsatz decay: float = 0.0 @property def lower_param_counts(self) -> tuple[int, int, int]: return self.lower_transition.param_counts @property def upper_param_counts(self) -> tuple[int, int, int]: return self.upper_transition.param_counts def _unpack_params(self, params: ParamsFloatLike) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: flat_params = jnp.asarray(params, dtype=jnp.float64) lower_detuning_count, lower_phase_count, lower_rabi_count = self.lower_param_counts upper_detuning_count, upper_phase_count, upper_rabi_count = self.upper_param_counts detuning_count = lower_detuning_count + upper_detuning_count phase_count = lower_phase_count + upper_phase_count rabi_count = lower_rabi_count + upper_rabi_count expected_size = 1 + detuning_count + phase_count + rabi_count if int(flat_params.shape[-1]) != expected_size: raise ValueError( f"TwoPhotonPulseAnsatz expects {expected_size} packed parameters, got {int(flat_params.shape[-1])}" ) duration, detuning_params, phase_params, rabi_params = jnp.split( flat_params, (1, 1 + detuning_count, 1 + detuning_count + phase_count), axis=-1, ) return duration[..., 0], detuning_params, phase_params, rabi_params @staticmethod def _split_1d(packed_params: ArrayLike, lower_count: int) -> tuple[jax.Array, jax.Array]: packed_params = jnp.asarray(packed_params) return packed_params[..., :lower_count], packed_params[..., lower_count:]
[docs] def evaluate_pulse_functions( self, t: float | jax.Array, params: ParamsFloatLike ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: r"""Evaluate the effective two-photon detuning, phase, and the rabi sweeps for fixed parameters at the given times. Args: t: Time samples at which the functions are evaluated params: Pulse parameters Returns: Tuple ``(detuning_1, detuning_r, phase, rabi)`` """ duration, detuning_params, phase_params, rabi_params = self._unpack_params(params) lower_detuning_count, lower_phase_count, lower_rabi_count = self.lower_param_counts lower_detuning_params, upper_detuning_params = self._split_1d(detuning_params, lower_detuning_count) lower_phase_params, upper_phase_params = self._split_1d(phase_params, lower_phase_count) lower_rabi_params, upper_rabi_params = self._split_1d(rabi_params, lower_rabi_count) lower_detuning = self.lower_transition.detuning_ansatz(t, duration, lower_detuning_params) lower_phase = self.lower_transition.phase_ansatz(t, duration, lower_phase_params) lower_rabi = self.lower_transition.rabi_ansatz(t, duration, lower_rabi_params) upper_detuning = self.upper_transition.detuning_ansatz(t, duration, upper_detuning_params) upper_phase = self.upper_transition.phase_ansatz(t, duration, upper_phase_params) upper_rabi = self.upper_transition.rabi_ansatz(t, duration, upper_rabi_params) effective_rabi = lower_rabi * upper_rabi / (2.0 * (lower_detuning + 0.5j * self.decay)) effective_phase = lower_phase + upper_phase effective_detuning_1 = -(lower_rabi**2) / (4.0 * (lower_detuning + 0.5j * self.decay)) effective_detuning_r = (lower_detuning + upper_detuning) - upper_rabi**2 / ( 4.0 * (lower_detuning + 0.5j * self.decay) ) return effective_detuning_1, effective_detuning_r, effective_phase, effective_rabi