Source code for rydopt.gates.four_qubit_gate_asym

from __future__ import annotations

import warnings
from copy import deepcopy
from functools import partial
from math import isinf

import jax
import jax.numpy as jnp
from typing_extensions import Self

from rydopt.gates.subsystem_hamiltonians_general import (
    H_1_atom_general,
    H_2_atoms_general,
    H_3_atoms_general,
    H_4_atoms_general,  # must exist in your package, analogous to H_3_atoms_general
)
from rydopt.protocols import PulseAnsatzLike
from rydopt.simulation.fidelity import average_gate_fidelity, process_fidelity
from rydopt.types import FidelityType, HamiltonianFunction, ParamsFloatLike


[docs] class FourQubitGateAsym: r"""Class that describes a gate on four atoms in an asymmetric setup. The physical setting is described by the interaction strengths between atoms, :math:`V_{12}, V_{13}, V_{14}, V_{23}, V_{24}, V_{34}`, and the decay strength from Rydberg states, :math:`\gamma`. In addition, each atom can optionally have a different Rabi frequency scaling factor. The target gate is specified by phases :math:`\phi_1, \phi_2, \phi_3, \phi_4`, :math:`\theta_{12}, \theta_{13}, \theta_{14}, \theta_{23}, \theta_{24}, \theta_{34}`, :math:`\lambda_{123}, \lambda_{124}, \lambda_{134}, \lambda_{234}`, and :math:`\mu`. Some phases can remain unspecified if they may take on arbitrary values. Args: phi1: target phase of the single-qubit gate contribution on atom 1. phi2: target phase of the single-qubit gate contribution on atom 2. phi3: target phase of the single-qubit gate contribution on atom 3. phi4: target phase of the single-qubit gate contribution on atom 4. theta12: target phase of the two-qubit gate contribution on atoms 1, 2. theta13: target phase of the two-qubit gate contribution on atoms 1, 3. theta14: target phase of the two-qubit gate contribution on atoms 1, 4. theta23: target phase of the two-qubit gate contribution on atoms 2, 3. theta24: target phase of the two-qubit gate contribution on atoms 2, 4. theta34: target phase of the two-qubit gate contribution on atoms 3, 4. lamb123: target phase of the three-qubit gate contribution on atoms 1, 2, 3. lamb124: target phase of the three-qubit gate contribution on atoms 1, 2, 4. lamb134: target phase of the three-qubit gate contribution on atoms 1, 3, 4. lamb234: target phase of the three-qubit gate contribution on atoms 2, 3, 4. mu: target phase of the four-qubit gate contribution. V12: interaction strength between atoms 1 and 2, :math:`V_{12}/(\hbar\Omega_0)`. V13: interaction strength between atoms 1 and 3, :math:`V_{13}/(\hbar\Omega_0)`. V14: interaction strength between atoms 1 and 4, :math:`V_{14}/(\hbar\Omega_0)`. V23: interaction strength between atoms 2 and 3, :math:`V_{23}/(\hbar\Omega_0)`. V24: interaction strength between atoms 2 and 4, :math:`V_{24}/(\hbar\Omega_0)`. V34: interaction strength between atoms 3 and 4, :math:`V_{34}/(\hbar\Omega_0)`. decay: Rydberg decay strength :math:`\gamma/\Omega_0`, default is 0. s1: Rabi frequency scaling factor for atom 1, default is 1. s2: Rabi frequency scaling factor for atom 2, default is 1. s3: Rabi frequency scaling factor for atom 3, default is 1. s4: Rabi frequency scaling factor for atom 4, default is 1. """ def __init__( self, phi1: float | None, phi2: float | None, phi3: float | None, phi4: float | None, theta12: float | None, theta13: float | None, theta14: float | None, theta23: float | None, theta24: float | None, theta34: float | None, lamb123: float | None, lamb124: float | None, lamb134: float | None, lamb234: float | None, mu: float | None, V12: float, V13: float, V14: float, V23: float, V24: float, V34: float, decay: float = 0.0, s1: float = 1.0, s2: float = 1.0, s3: float = 1.0, s4: float = 1.0, fidelity_type: FidelityType = "process", ) -> None: for name, val in [("V12", V12), ("V13", V13), ("V14", V14), ("V23", V23), ("V24", V24), ("V34", V34)]: if isinf(float(val)): raise ValueError( f"{name} must be finite. If the setup is symmetric, use `FourQubitGatePyramidal` " "for infinite interaction strengths." ) warnings.warn( "This gate implementation does not use any symmetries. If your setup is a pyramidal arrangement of atoms, " "consider using `FourQubitGatePyramidal` for better performance.", stacklevel=2, ) self._phi1 = phi1 self._phi2 = phi2 self._phi3 = phi3 self._phi4 = phi4 self._theta12 = theta12 self._theta13 = theta13 self._theta14 = theta14 self._theta23 = theta23 self._theta24 = theta24 self._theta34 = theta34 self._lamb123 = lamb123 self._lamb124 = lamb124 self._lamb134 = lamb134 self._lamb234 = lamb234 self._mu = mu self._V12 = V12 self._V13 = V13 self._V14 = V14 self._V23 = V23 self._V24 = V24 self._V34 = V34 self._decay = decay self._s1 = s1 self._s2 = s2 self._s3 = s3 self._s4 = s4 self._fidelity_type = fidelity_type
[docs] def with_decay(self, decay: float) -> Self: r"""Creates a copy of the gate with a new decay strength. Args: decay: New decay strength :math:`\gamma/\Omega_0`. Returns: A copy of the gate object with the new decay strength. """ new = deepcopy(self) new._decay = decay return new
[docs] def dim(self) -> int: r"""Hilbert space dimension. Returns: 16 """ return 16
[docs] def hamiltonian_functions_for_basis_states(self) -> tuple[HamiltonianFunction, ...]: r"""The full gate Hamiltonian can be split into distinct blocks that describe the time evolution of basis states. Returns: Tuple of Hamiltonian functions. """ return ( # |0001> partial(H_1_atom_general, decay=self._decay, s1=self._s4), # |0010> partial(H_1_atom_general, decay=self._decay, s1=self._s3), # |0011> partial(H_2_atoms_general, decay=self._decay, V12=self._V34, s1=self._s3, s2=self._s4), # |0100> partial(H_1_atom_general, decay=self._decay, s1=self._s2), # |0101> partial(H_2_atoms_general, decay=self._decay, V12=self._V24, s1=self._s2, s2=self._s4), # |0110> partial(H_2_atoms_general, decay=self._decay, V12=self._V23, s1=self._s2, s2=self._s3), # |0111> partial( H_3_atoms_general, decay=self._decay, V12=self._V23, V13=self._V24, V23=self._V34, s1=self._s2, s2=self._s3, s3=self._s4, ), # |1000> partial(H_1_atom_general, decay=self._decay, s1=self._s1), # |1001> partial(H_2_atoms_general, decay=self._decay, V12=self._V14, s1=self._s1, s2=self._s4), # |1010> partial(H_2_atoms_general, decay=self._decay, V12=self._V13, s1=self._s1, s2=self._s3), # |1011> partial( H_3_atoms_general, decay=self._decay, V12=self._V13, V13=self._V14, V23=self._V34, s1=self._s1, s2=self._s3, s3=self._s4, ), # |1100> partial(H_2_atoms_general, decay=self._decay, V12=self._V12, s1=self._s1, s2=self._s2), # |1101> partial( H_3_atoms_general, decay=self._decay, V12=self._V12, V13=self._V14, V23=self._V24, s1=self._s1, s2=self._s2, s3=self._s4, ), # |1110> partial( H_3_atoms_general, decay=self._decay, V12=self._V12, V13=self._V13, V23=self._V23, s1=self._s1, s2=self._s2, s3=self._s3, ), # |1111> partial( H_4_atoms_general, decay=self._decay, V12=self._V12, V13=self._V13, V14=self._V14, V23=self._V23, V24=self._V24, V34=self._V34, s1=self._s1, s2=self._s2, s3=self._s3, s4=self._s4, ), )
[docs] def rydberg_population_operators_for_basis_states(self) -> tuple[jax.Array, ...]: r"""For each basis state, the Rydberg population operators count the number of Rydberg excitations on the diagonal. Returns: Tuple of operators. """ return ( H_1_atom_general(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0), H_1_atom_general(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0), H_2_atoms_general(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, V12=0.0), H_1_atom_general(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0), H_2_atoms_general(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, V12=0.0), H_2_atoms_general(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, V12=0.0), H_3_atoms_general(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, V12=0.0, V13=0.0, V23=0.0), H_1_atom_general(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0), H_2_atoms_general(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, V12=0.0), H_2_atoms_general(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, V12=0.0), H_3_atoms_general(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, V12=0.0, V13=0.0, V23=0.0), H_2_atoms_general(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, V12=0.0), H_3_atoms_general(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, V12=0.0, V13=0.0, V23=0.0), H_3_atoms_general(Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, V12=0.0, V13=0.0, V23=0.0), H_4_atoms_general( Delta_1=0.0, Delta_r=-1.0, Xi=0.0, Omega=0.0, decay=0.0, V12=0.0, V13=0.0, V14=0.0, V23=0.0, V24=0.0, V34=0.0, ), )
[docs] def initial_basis_states(self) -> tuple[jax.Array, ...]: r"""The initial basis states :math:`(1, 0, ...)` of appropriate dimension are provided. Returns: Tuple of arrays. """ z2 = jnp.array([1.0 + 0.0j, 0.0 + 0.0j]) z4 = jnp.array([1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j]) z8 = jnp.array([1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j]) z16 = jnp.array( [ 1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, ] ) return (z2, z2, z4, z2, z4, z4, z8, z2, z4, z4, z8, z4, z8, z8, z16)
def process_fidelity_helper(self, final_basis_states: tuple[jax.Array, ...]) -> jax.Array: r"""Given the basis states evolved under the pulse, this function calculates the fidelity with respect to the gate's target state, specified by the gate angles :math:`\phi, \, \theta, \, \ldots` Args: final_basis_states: Time-evolved basis states. Returns: Fidelity with respect to the target state. """ # Obtained diagonal gate matrix obtained_gate = jnp.array( [ 1, # 0: |0000> final_basis_states[0][0], # 1: |0001> final_basis_states[1][0], # 2: |0010> final_basis_states[2][0], # 3: |0011> final_basis_states[3][0], # 4: |0100> final_basis_states[4][0], # 5: |0101> final_basis_states[5][0], # 6: |0110> final_basis_states[6][0], # 7: |0111> final_basis_states[7][0], # 8: |1000> final_basis_states[8][0], # 9: |1001> final_basis_states[9][0], # 10: |1010> final_basis_states[10][0], # 11: |1011> final_basis_states[11][0], # 12: |1100> final_basis_states[12][0], # 13: |1101> final_basis_states[13][0], # 14: |1110> final_basis_states[14][0], # 15: |1111> ] ) # Single-qubit phase (averaged) p1 = jnp.angle(obtained_gate[8]) if self._phi1 is None else self._phi1 p2 = jnp.angle(obtained_gate[4]) if self._phi2 is None else self._phi2 p3 = jnp.angle(obtained_gate[2]) if self._phi3 is None else self._phi3 p4 = jnp.angle(obtained_gate[1]) if self._phi4 is None else self._phi4 # Two-qubit phases t12 = jnp.angle(obtained_gate[12]) - p1 - p2 if self._theta12 is None else self._theta12 t13 = jnp.angle(obtained_gate[10]) - p1 - p3 if self._theta13 is None else self._theta13 t14 = jnp.angle(obtained_gate[9]) - p1 - p4 if self._theta14 is None else self._theta14 t23 = jnp.angle(obtained_gate[6]) - p2 - p3 if self._theta23 is None else self._theta23 t24 = jnp.angle(obtained_gate[5]) - p2 - p4 if self._theta24 is None else self._theta24 t34 = jnp.angle(obtained_gate[3]) - p3 - p4 if self._theta34 is None else self._theta34 # Three-qubit phases l234 = jnp.angle(obtained_gate[7]) - p2 - p3 - p4 - t23 - t24 - t34 if self._lamb234 is None else self._lamb234 l134 = jnp.angle(obtained_gate[11]) - p1 - p3 - p4 - t13 - t14 - t34 if self._lamb134 is None else self._lamb134 l124 = jnp.angle(obtained_gate[13]) - p1 - p2 - p4 - t12 - t14 - t24 if self._lamb124 is None else self._lamb124 l123 = jnp.angle(obtained_gate[14]) - p1 - p2 - p3 - t12 - t13 - t23 if self._lamb123 is None else self._lamb123 # Four-qubit phase mu = ( jnp.angle(obtained_gate[15]) - p1 - p2 - p3 - p4 - t12 - t13 - t14 - t23 - t24 - t34 - l123 - l124 - l134 - l234 if self._mu is None else self._mu ) # Targeted diagonal gate matrix targeted_gate = jnp.stack( [ 1, jnp.exp(1j * p4), jnp.exp(1j * p3), jnp.exp(1j * (p3 + p4 + t34)), jnp.exp(1j * p2), jnp.exp(1j * (p2 + p4 + t24)), jnp.exp(1j * (p2 + p3 + t23)), jnp.exp(1j * (p2 + p3 + p4 + t23 + t24 + t34 + l234)), jnp.exp(1j * p1), jnp.exp(1j * (p1 + p4 + t14)), jnp.exp(1j * (p1 + p3 + t13)), jnp.exp(1j * (p1 + p3 + p4 + t13 + t14 + t34 + l134)), jnp.exp(1j * (p1 + p2 + t12)), jnp.exp(1j * (p1 + p2 + p4 + t12 + t14 + t24 + l124)), jnp.exp(1j * (p1 + p2 + p3 + t12 + t13 + t23 + l123)), jnp.exp(1j * (p1 + p2 + p3 + p4 + t12 + t13 + t14 + t23 + t24 + t34 + l123 + l124 + l134 + l234 + mu)), ] ) return jnp.abs(jnp.vdot(targeted_gate, obtained_gate)) ** 2 / len(targeted_gate) ** 2 def cost(self, pulse: PulseAnsatzLike, params: ParamsFloatLike, tol: float = 1e-7) -> jax.Array: """Evaluate the cost function from the configured fidelity metric.""" if self._fidelity_type == "process": return jnp.abs(1 - process_fidelity(self, pulse, params, tol)) if self._fidelity_type == "average_gate": return jnp.abs(1 - average_gate_fidelity(self, pulse, params, tol)) raise ValueError(f"Unsupported fidelity type: {self._fidelity_type}")
[docs] def rydberg_time(self, expectation_values_of_basis_states: tuple[jax.Array, ...]) -> jax.Array: r"""Given the expectation values of Rydberg populations for each basis state, integrated over the full pulse, this function calculates the average time spent in Rydberg states during the gate. Args: expectation_values_of_basis_states: Expected Rydberg times for each basis state. Returns: Averaged Rydberg time :math:`T_R`. """ return (1 / 16) * jnp.squeeze( expectation_values_of_basis_states[0] + expectation_values_of_basis_states[1] + expectation_values_of_basis_states[2] + expectation_values_of_basis_states[3] + expectation_values_of_basis_states[4] + expectation_values_of_basis_states[5] + expectation_values_of_basis_states[6] + expectation_values_of_basis_states[7] + expectation_values_of_basis_states[8] + expectation_values_of_basis_states[9] + expectation_values_of_basis_states[10] + expectation_values_of_basis_states[11] + expectation_values_of_basis_states[12] + expectation_values_of_basis_states[13] + expectation_values_of_basis_states[14] )