Source code for rydopt.simulation.evolve

from __future__ import annotations

from functools import partial

import jax
import jax.numpy as jnp

from rydopt.protocols import Evolvable, PulseAnsatzLike
from rydopt.types import HamiltonianFunction, PulseParams


[docs] def evolve(gate: Evolvable, pulse: PulseAnsatzLike, params: PulseParams, tol: float = 1e-7) -> tuple[jax.Array, ...]: r"""The function performs the time evolution of all initial states :math:`|\psi_i(0)\rangle` (specified in the gate object), under the pulse Hamiltonian :math:`H`. .. math:: |\psi_i(T)\rangle = U(T)|\psi_i(0)\rangle = \mathcal{T} e^{-\frac{i}{\hbar} \int_0^T H(t)dt} |\psi_i(0)\rangle Example: >>> import rydopt as ro >>> import numpy as np >>> gate = ro.gates.TwoQubitGate( ... phi=None, ... theta=np.pi, ... Vnn=float("inf"), ... decay=0, ... ) >>> pulse = ro.pulses.PulseAnsatz( ... detuning_ansatz=ro.pulses.const, ... phase_ansatz=ro.pulses.sin_crab, ... ) >>> params = (7.61140652, [0.07842706], [1.80300902, -0.61792703], []) >>> time_evolved_basis_states = ro.simulation.evolve(gate, pulse, params) Args: gate: RydOpt Gate object. pulse: RydOpt PulseAnsatz object. params: Pulse parameters. tol: Precision of the ODE solver, default is 1e-7. Returns: Time-evolved basis states :math:`\{|\psi_i(T)\rangle\}`. """ # When we import diffrax, at least one jnp array is allocated (see optimistix/_misc.py, line 138). Thus, # if we change the default device after we have imported diffrax, some memory is allocated on the # wrong device. Hence, we defer the import of diffrax to the latest time possible. import diffrax # If we are on a GPU, dispatch to a GPU-optimized evolve. On GPUs, it is more efficient to solve one # large differential equation instead of many small ones because it reduced overheads with kernels. if jax.devices()[0].platform == "gpu": return _evolve_optimized_for_gpus(gate, pulse, params, tol) # Collect initial states and pad them to a common dimension so we can stack initial_states = gate.initial_basis_states() dims = tuple(len(psi) for psi in initial_states) max_dim = max(dims) initial_states_padded = jnp.stack([jnp.pad(psi, (0, max_dim - dim)) for psi, dim in zip(initial_states, dims)]) # Schrödinger equation for the basis states. The Hamiltonian is chosen via lax.switch # based on the index of the basis state, with padding to max_dim × max_dim. def apply_hamiltonian( t: jax.Array | float, params: PulseParams, psi: jax.Array, hamiltonian: HamiltonianFunction, dim: int, ) -> jax.Array: values = pulse.evaluate_pulse_functions(t, params) dpsi_small = -1j * hamiltonian(*values) @ psi[:dim] return jnp.pad(dpsi_small, (0, psi.shape[0] - dim)) branches = tuple( partial(apply_hamiltonian, hamiltonian=h, dim=d) for h, d in zip(gate.hamiltonian_functions_for_basis_states(), dims) ) def schroedinger_eq(t: jax.Array | float, psi: jax.Array, args: tuple[PulseParams, int]) -> jax.Array: params, idx = args return jax.lax.switch(idx, branches, t, params, psi) # Propagator term = diffrax.ODETerm(schroedinger_eq) # type: ignore[arg-type] solver = diffrax.Tsit5() stepsize_controller = diffrax.PIDController(rtol=0.1 * tol, atol=0.1 * tol) saveat = diffrax.SaveAt(t1=True) def propagate(args: tuple[jax.Array, int]) -> jax.Array: psi_initial, idx = args sol = diffrax.diffeqsolve( term, solver, t0=0.0, t1=params[0], dt0=None, y0=psi_initial, args=(params, idx), stepsize_controller=stepsize_controller, saveat=saveat, max_steps=200_000, ) return sol.ys[0] # Run the propagator for each basis state final_states_padded = jax.lax.map( propagate, (initial_states_padded, jnp.arange(len(branches))), ) # Remove padding and return original state sizes return tuple(s[:d] for s, d in zip(final_states_padded, dims))
def _evolve_optimized_for_gpus( gate: Evolvable, pulse: PulseAnsatzLike, params: PulseParams, tol: float = 1e-7 ) -> tuple[jax.Array, ...]: # When we import diffrax, at least one jnp array is allocated (see optimistix/_misc.py, line 138). Thus, # if we change the default device after we have imported diffrax, some memory is allocated on the # wrong device. Hence, we defer the import of diffrax to the latest time possible. import diffrax def schroedinger_eq( t: jax.Array | float, psi_tuple: tuple[jax.Array, ...], _: object, ) -> tuple[jax.Array, ...]: values = pulse.evaluate_pulse_functions(t, params) return tuple( -1j * (h(*values) @ psi) for h, psi in zip(gate.hamiltonian_functions_for_basis_states(), psi_tuple) ) solver = diffrax.Dopri8() stepsize_controller = diffrax.PIDController(rtol=0.1 * tol, atol=0.1 * tol) saveat = diffrax.SaveAt(t1=True) term = diffrax.ODETerm(schroedinger_eq) # type: ignore[arg-type] sol = diffrax.diffeqsolve( term, solver, t0=0.0, t1=params[0], dt0=None, y0=gate.initial_basis_states(), args=None, stepsize_controller=stepsize_controller, saveat=saveat, max_steps=200_000, ) return tuple(psi_t1[0] for psi_t1 in sol.ys)