Skip to content

Commit

Permalink
Fixes failing quantum_rel_entr tests (cvxpy#2621)
Browse files Browse the repository at this point in the history
* Fixes failing tests

* Applying fix to marimo NB && tests
  • Loading branch information
aryamanjeendgar authored Nov 9, 2024
1 parent f3de1e8 commit c756be3
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
5 changes: 3 additions & 2 deletions cvxpy/atoms/quantum_cond_entr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from cvxpy.atoms.affine.kron import kron
from cvxpy.atoms.affine.partial_trace import partial_trace
from cvxpy.atoms.affine.wraps import hermitian_wrap
from cvxpy.atoms.quantum_rel_entr import quantum_rel_entr
from cvxpy.expressions.expression import Expression

Expand All @@ -12,8 +13,8 @@ def quantum_cond_entr(rho: Expression , dim: list[int], sys: Optional[int]=0):
if sys == 0:
composite_arg = kron(np.eye(dim[0]),
partial_trace(rho, dim, sys))
return -quantum_rel_entr(rho, composite_arg)
return -quantum_rel_entr(rho, hermitian_wrap(composite_arg))
elif sys == 1:
composite_arg = kron(partial_trace(rho, dim, sys),
np.eye(dim[1]))
return -quantum_rel_entr(rho, composite_arg)
return -quantum_rel_entr(rho, hermitian_wrap(composite_arg))
3 changes: 2 additions & 1 deletion cvxpy/tests/test_quantum_rel_entr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import cvxpy as cp
from cvxpy.atoms.affine.kron import kron
from cvxpy.atoms.affine.partial_trace import partial_trace
from cvxpy.atoms.affine.wraps import hermitian_wrap
from cvxpy.tests import solver_test_helpers as STH


Expand Down Expand Up @@ -125,7 +126,7 @@ def AD(gamma: float):

def Ic(rho: cp.Variable):
return cp.quantum_cond_entr(
W @ applychan(U, rho, 'isom', (na, nb)) @ W.conj().T,
hermitian_wrap(W @ applychan(U, rho, 'isom', (na, nb)) @ W.conj().T),
[ne, nf], 1
)/np.log(2)

Expand Down
13 changes: 7 additions & 6 deletions examples/QI_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __():
import numpy as np

from cvxpy.atoms.affine.partial_trace import partial_trace
return cp, np, partial_trace
from cvxpy.atoms.affine.wraps import hermitian_wrap
return cp, hermitian_wrap, np, partial_trace


@app.cell
Expand Down Expand Up @@ -181,14 +182,14 @@ def __(mo):


@app.cell
def __(cp, np):
def __(cp, hermitian_wrap, np):
na_en, nb_en, ne_en = (2, 2, 2)
AD_en = lambda gamma: np.array([[1, 0], [0, np.sqrt(gamma)], [0, np.sqrt(1-gamma)], [0, 0]])
U_en = AD_en(0.2)

rho_en = cp.Variable(shape=(na_en, na_en), hermitian=True)
obj_en = cp.Maximize((cp.quantum_cond_entr(U_en @ rho_en @ U_en.conj().T, [nb_en, ne_en]) +
cp.von_neumann_entr(cp.partial_trace(U_en @ rho_en @ U_en.conj().T, [nb_en, ne_en], 1)))/np.log(2))
obj_en = cp.Maximize(cp.quantum_cond_entr(hermitian_wrap(U_en @ rho_en @ U_en.conj().T), [nb_en, ne_en]) +
cp.von_neumann_entr(cp.partial_trace(U_en @ rho_en @ U_en.conj().T, [nb_en, ne_en], 1)))/np.log(2)
cons_en = [
rho_en >> 0,
cp.trace(rho_en) == 1
Expand Down Expand Up @@ -274,7 +275,7 @@ def applychan(chan: np.array, rho: cp.Variable, rep: str, dim: tuple[int, int]):


@app.cell
def __(applychan, cp, np):
def __(applychan, cp, hermitian_wrap, np):
na_cc, nb_cc, ne_cc, nf_cc = (2, 2, 2, 2)
AD_cc = lambda gamma: np.array([[1, 0],[0, np.sqrt(gamma)],[0, np.sqrt(1-gamma)],[0, 0]])
gamma = 0.2
Expand All @@ -283,7 +284,7 @@ def __(applychan, cp, np):
W_cc = AD_cc((1-2*gamma)/(1-gamma))

Ic_cc = lambda rho: cp.quantum_cond_entr(
W_cc @ applychan(U_cc, rho, 'isom', (na_cc, nb_cc)) @ W_cc.conj().T,
hermitian_wrap(W_cc @ applychan(U_cc, rho, 'isom', (na_cc, nb_cc)) @ W_cc.conj().T),
[ne_cc, nf_cc], 1
)/np.log(2)

Expand Down

0 comments on commit c756be3

Please sign in to comment.