from __future__ import annotations
from functools import partial
import jax
import jax.numpy as jnp
from rydopt.protocols import PulseAnsatzLike, RydbergSystem
from rydopt.types import HamiltonianFunction, ParamsFloatLike
[docs]
def rydberg_time(gate: RydbergSystem, pulse: PulseAnsatzLike, params: ParamsFloatLike, tol: float = 1e-7) -> jax.Array:
r"""The function determines the total time spent in Rydberg states during a gate pulse:
.. math::
\Omega_0 T_R = \Omega_0 \int_0^T \sum_{i=1}^{N} \bra{+}^{\otimes N}U(t)^{\dagger}
|r_i\rangle\!\langle r_i| U(t)\ket{+}^{\otimes N} dt .
Example:
>>> import rydopt as ro
>>> import numpy as np
>>> gate = ro.gates.TwoQubitGate(
... phi=None,
... theta=np.pi,
... Vnn=float("inf"),
... decay=0,
... )
>>> pulse = ro.pulses.PulseAnsatz(
... detuning_ansatz=ro.pulses.Const(),
... phase_ansatz=ro.pulses.SinCrab(2),
... )
>>> params = ro.pulses.PulseParams(7.61140652, [0.07842706], [1.80300902, -0.61792703], [])
>>> time_in_rydberg_state = ro.simulation.rydberg_time(gate, pulse, params)
Args:
gate: RydOpt Gate object.
pulse: RydOpt PulseAnsatz object.
params: Pulse parameters.
tol: Precision of the ODE solver, default is 1e-7.
Returns:
Total Rydberg time :math:`\Omega_0 T_R`.
"""
# When we import diffrax, at least one jnp array is allocated (see optimistix/_misc.py, line 138). Thus,
# if we change the default device after we have imported diffrax, some memory is allocated on the
# wrong device. Hence, we defer the import of diffrax to the latest time possible.
import diffrax
# Collect initial states and pad them to a common dimension so we can stack
initial_states = gate.initial_basis_states()
dims = tuple(len(psi) for psi in initial_states)
max_dim = max(dims)
initial_states_padded = jnp.stack([jnp.pad(psi, (0, max_dim - dim)) for psi, dim in zip(initial_states, dims)])
# Schrödinger equation for the basis states. The Hamiltonian is chosen via lax.switch
# based on the index of the basis state, with padding to max_dim × max_dim.
def apply_hamiltonian(
t: float | jax.Array,
params: ParamsFloatLike,
y: tuple[jax.Array, jax.Array],
hamiltonian: HamiltonianFunction,
rydberg_operator: jax.Array,
dim: int,
) -> tuple[jax.Array, jax.Array]:
values = pulse.evaluate_pulse_functions(t, params)
psi, _expectation = y
psi_small = psi[:dim]
dpsi_small = -1j * hamiltonian(*values) @ psi_small
instantaneous_rydberg_population = jnp.vdot(psi_small, rydberg_operator @ psi_small)
return (
jnp.pad(dpsi_small, (0, psi.shape[0] - dim)),
instantaneous_rydberg_population,
)
branches = tuple(
partial(apply_hamiltonian, hamiltonian=h, rydberg_operator=r, dim=d)
for h, r, d in zip(
gate.hamiltonian_functions_for_basis_states(),
gate.rydberg_population_operators_for_basis_states(),
dims,
)
)
def schroedinger_eq(
t: float | jax.Array,
y: tuple[jax.Array, jax.Array],
args: tuple[ParamsFloatLike, int],
) -> tuple[jax.Array, jax.Array]:
params, idx = args
return jax.lax.switch(idx, branches, t, params, y)
# Propagator
term = diffrax.ODETerm(schroedinger_eq) # type: ignore[arg-type]
solver = diffrax.Tsit5()
stepsize_controller = diffrax.PIDController(rtol=0.1 * tol, atol=0.1 * tol)
saveat = diffrax.SaveAt(t1=True)
def propagate(args: tuple[jax.Array, int]) -> jax.Array:
psi_initial, idx = args
sol = diffrax.diffeqsolve(
term,
solver,
t0=0.0,
t1=params[0],
dt0=None,
y0=(psi_initial, jnp.array(0.0, dtype=psi_initial.dtype)),
args=(params, idx),
stepsize_controller=stepsize_controller,
saveat=saveat,
max_steps=200_000,
)
return jnp.real(sol.ys[1])
# Run the propagator for each basis state
expectation_values = jax.lax.map(
propagate,
(initial_states_padded, jnp.arange(len(branches))),
)
return gate.rydberg_time(tuple(expectation_values))