from __future__ import annotations
from copy import deepcopy
from functools import partial
from math import isinf
import jax
import jax.numpy as jnp
from typing_extensions import Self
from rydopt.gates.subsystem_hamiltonians import (
H_2_atoms,
H_3_atoms,
H_3_atoms_inf_V,
H_3_atoms_symmetric,
H_k_atoms_perfect_blockade,
)
from rydopt.types import HamiltonianFunction
[docs]
class ThreeQubitGateIsosceles:
r"""Class that describes a gate on three atoms arranged in an isosceles triangle.
The physical setting is described by the interaction strengths between atoms, :math:`V_{\mathrm{nn}}` and
:math:`V_{\mathrm{nnn}}`, and the decay strength from Rydberg states, :math:`\gamma`.
The target gate is specified by the phases :math:`\phi, \theta, \theta', \lambda`.
Some phases can remain unspecified if they may take on arbitrary values.
In the figure, we use the notation :math:`\mathrm{C}_n\mathrm{Z}(\alpha) = \mathrm{diag}(1, ..., 1, e^{i\alpha})`
on :math:`n+1` qubits, and
:math:`\mathrm{Z}(\alpha) = \mathrm{C}_0\mathrm{Z}(\alpha) = \mathrm{diag}(1, e^{i\alpha})`.
.. image:: ../_static/ThreeQubitGateIsosceles.png
Example:
>>> import rydopt as ro
>>> import numpy as np
>>> gate = ro.gates.ThreeQubitGateIsosceles(
... phi=None,
... theta=np.pi,
... theta_prime=None,
... lamb=0.0,
... Vnn=float("inf"),
... Vnnn=1.0,
... decay=0.0001,
... )
Args:
phi: target phase :math:`\phi` of single-qubit gate contribution.
theta: target phase :math:`\theta` of nearest-neighbour two-qubit gate contribution.
theta_prime: target phase :math:`\theta'` of next-nearest-neighbour two-qubit gate contribution.
lamb: target phase :math:`\lambda` of three-qubit gate contribution.
Vnn: nearest-neighbour interaction strength :math:`V_{\mathrm{nn}}/(\hbar\Omega_0)`.
Vnnn: next-nearest-neighbour interaction strength :math:`V_{\mathrm{nnn}}/(\hbar\Omega_0)`.
decay: Rydberg decay strength :math:`\gamma/\Omega_0`, default is 0.
"""
def __init__(
self,
phi: float | None,
theta: float | None,
theta_prime: float | None,
lamb: float | None,
Vnn: float,
Vnnn: float,
decay: float = 0.0,
) -> None:
if (Vnn == Vnnn) and (theta != theta_prime):
raise ValueError("For Vnn=Vnnn, theta=theta_prime is required")
if (Vnnn == 0) and (theta_prime != 0.0):
raise ValueError("For Vnnn=0, theta_prime=0 is required")
self._phi = phi
self._theta = theta
self._theta_prime = theta_prime
self._lamb = lamb
self._Vnn = Vnn
self._Vnnn = Vnnn
self._decay = decay
[docs]
def with_decay(self, decay: float) -> Self:
r"""Creates a copy of the gate with a new decay strength.
Args:
decay: New decay strength :math:`\gamma/\Omega_0`.
Returns:
A copy of the gate object with the new decay strength.
"""
new = deepcopy(self)
new._decay = decay
return new
[docs]
def dim(self) -> int:
r"""Hilbert space dimension.
Returns:
8
"""
return 8
[docs]
def hamiltonian_functions_for_basis_states(self) -> tuple[HamiltonianFunction, ...]:
r"""The full gate Hamiltonian can be split into distinct blocks that describe the time evolution
of basis states. The number of blocks and their dimensionality depends on the interaction strengths.
Returns:
Tuple of Hamiltonian functions.
"""
if isinf(float(self._Vnn)) and isinf(float(self._Vnnn)):
return (
partial(H_k_atoms_perfect_blockade, decay=self._decay, k=1),
partial(H_k_atoms_perfect_blockade, decay=self._decay, k=2),
partial(H_k_atoms_perfect_blockade, decay=self._decay, k=3),
)
if float(self._Vnn) == float(self._Vnnn):
return (
partial(H_k_atoms_perfect_blockade, decay=self._decay, k=1),
partial(H_2_atoms, decay=self._decay, V=self._Vnn),
partial(H_3_atoms_symmetric, decay=self._decay, V=self._Vnn),
)
if isinf(float(self._Vnn)) and float(self._Vnnn) == 0.0:
return (
partial(H_k_atoms_perfect_blockade, decay=self._decay, k=1),
partial(H_k_atoms_perfect_blockade, decay=self._decay, k=2),
partial(H_3_atoms_inf_V, decay=self._decay, V=self._Vnnn),
)
if isinf(float(self._Vnn)):
return (
partial(H_k_atoms_perfect_blockade, decay=self._decay, k=1),
partial(H_k_atoms_perfect_blockade, decay=self._decay, k=2),
partial(H_2_atoms, decay=self._decay, V=self._Vnnn),
partial(H_3_atoms_inf_V, decay=self._decay, V=self._Vnnn),
)
return (
partial(H_k_atoms_perfect_blockade, decay=self._decay, k=1),
partial(H_2_atoms, decay=self._decay, V=self._Vnn),
partial(H_2_atoms, decay=self._decay, V=self._Vnnn),
partial(H_3_atoms, decay=self._decay, Vnn=self._Vnn, Vnnn=self._Vnnn),
)
[docs]
def rydberg_population_operators_for_basis_states(self) -> tuple[jax.Array, ...]:
r"""For each basis state, the Rydberg population operators count the number of Rydberg excitations on
the diagonal.
Returns:
Tuple of operators.
"""
if isinf(float(self._Vnn)) and isinf(float(self._Vnnn)):
return (
H_k_atoms_perfect_blockade(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, k=1),
H_k_atoms_perfect_blockade(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, k=1),
H_k_atoms_perfect_blockade(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, k=1),
)
if float(self._Vnn) == float(self._Vnnn):
return (
H_k_atoms_perfect_blockade(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, k=1),
H_2_atoms(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, V=0.0),
H_3_atoms_symmetric(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, V=0.0),
)
if isinf(float(self._Vnn)) and float(self._Vnnn) == 0.0:
return (
H_k_atoms_perfect_blockade(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, k=1),
H_k_atoms_perfect_blockade(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, k=1),
H_3_atoms_inf_V(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, V=0.0),
)
if isinf(float(self._Vnn)):
return (
H_k_atoms_perfect_blockade(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, k=1),
H_k_atoms_perfect_blockade(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, k=1),
H_2_atoms(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, V=0.0),
H_3_atoms_inf_V(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, V=0.0),
)
return (
H_k_atoms_perfect_blockade(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, k=1),
H_2_atoms(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, V=0.0),
H_2_atoms(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, V=0.0),
H_3_atoms(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, Vnn=0.0, Vnnn=0.0),
)
[docs]
def initial_basis_states(self) -> tuple[jax.Array, ...]:
r"""The initial basis states :math:`(1, 0, ...)` of appropriate dimension are
provided.
Returns:
Tuple of arrays.
"""
if isinf(float(self._Vnn)) and isinf(float(self._Vnnn)):
return (
jnp.array([1.0 + 0.0j, 0.0 + 0.0j]),
jnp.array([1.0 + 0.0j, 0.0 + 0.0j]),
jnp.array([1.0 + 0.0j, 0.0 + 0.0j]),
)
if float(self._Vnn) == float(self._Vnnn):
return (
jnp.array([1.0 + 0.0j, 0.0 + 0.0j]),
jnp.array([1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j]),
jnp.array([1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j]),
)
if isinf(float(self._Vnn)) and float(self._Vnnn) == 0.0:
return (
jnp.array([1.0 + 0.0j, 0.0 + 0.0j]),
jnp.array([1.0 + 0.0j, 0.0 + 0.0j]),
jnp.array([1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j]),
)
if isinf(float(self._Vnn)):
return (
jnp.array([1.0 + 0.0j, 0.0 + 0.0j]),
jnp.array([1.0 + 0.0j, 0.0 + 0.0j]),
jnp.array([1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j]),
jnp.array([1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j]),
)
return (
jnp.array([1.0 + 0.0j, 0.0 + 0.0j]),
jnp.array([1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j]),
jnp.array([1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j]),
jnp.array([1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j]),
)
[docs]
def process_fidelity(self, final_basis_states: tuple[jax.Array, ...]) -> jax.Array:
r"""Given the basis states evolved under the pulse,
this function calculates the fidelity with respect to the gate's target state, specified by the gate angles
:math:`\phi, \, \theta, \, \ldots`
Args:
final_basis_states: Time-evolved basis states.
Returns:
Fidelity with respect to the target state.
"""
# Obtained diagonal gate matrix
if float(self._Vnn) == float(self._Vnnn):
obtained_gate = jnp.array(
[
1,
final_basis_states[0][0],
final_basis_states[0][0],
final_basis_states[1][0],
final_basis_states[0][0],
final_basis_states[1][0],
final_basis_states[1][0],
final_basis_states[2][0],
]
)
elif isinf(float(self._Vnn)) and float(self._Vnnn) == 0.0:
obtained_gate = jnp.array(
[
1,
final_basis_states[0][0],
final_basis_states[0][0],
final_basis_states[1][0],
final_basis_states[0][0],
final_basis_states[0][0] ** 2,
final_basis_states[1][0],
final_basis_states[2][0],
]
)
else:
obtained_gate = jnp.array(
[
1,
final_basis_states[0][0],
final_basis_states[0][0],
final_basis_states[1][0],
final_basis_states[0][0],
final_basis_states[2][0],
final_basis_states[1][0],
final_basis_states[3][0],
]
)
# Targeted diagonal gate matrix
p = jnp.angle(obtained_gate[1]) if self._phi is None else self._phi
t = jnp.angle(obtained_gate[3]) - 2 * p if self._theta is None else self._theta
e = jnp.angle(obtained_gate[5]) - 2 * p if self._theta_prime is None else self._theta_prime
l = jnp.angle(obtained_gate[7]) - 3 * p - 2 * t - e if self._lamb is None else self._lamb
targeted_gate = jnp.stack(
[
1,
jnp.exp(1j * p),
jnp.exp(1j * p),
jnp.exp(1j * (2 * p + t)),
jnp.exp(1j * p),
jnp.exp(1j * (2 * p + e)),
jnp.exp(1j * (2 * p + t)),
jnp.exp(1j * (3 * p + 2 * t + e + l)),
]
)
return jnp.abs(jnp.vdot(targeted_gate, obtained_gate)) ** 2 / len(targeted_gate) ** 2
[docs]
def rydberg_time(self, expectation_values_of_basis_states: tuple[jax.Array, ...]) -> jax.Array:
r"""Given the expectation values of Rydberg populations for each basis state, integrated over the full
pulse, this function calculates the average time spent in Rydberg states during the gate.
Args:
expectation_values_of_basis_states: Expected Rydberg times for each basis state.
Returns:
Averaged Rydberg time :math:`T_R`.
"""
if float(self._Vnn) == float(self._Vnnn):
return (1 / 8) * jnp.squeeze(
3 * expectation_values_of_basis_states[0]
+ 3 * expectation_values_of_basis_states[1]
+ expectation_values_of_basis_states[2]
)
if isinf(float(self._Vnn)) and float(self._Vnnn) == 0.0:
return (1 / 8) * jnp.squeeze(
5 * expectation_values_of_basis_states[0]
+ 2 * expectation_values_of_basis_states[1]
+ expectation_values_of_basis_states[2]
)
return (1 / 8) * jnp.squeeze(
3 * expectation_values_of_basis_states[0]
+ 2 * expectation_values_of_basis_states[1]
+ expectation_values_of_basis_states[2]
+ expectation_values_of_basis_states[3]
)