Skip to content

Commit

Permalink
Fix tutorial_pulse_programming101.py to type cast expval (#1317)
Browse files Browse the repository at this point in the history
**Context:**

In PennyLaneAI/pennylane#6939, a fix was made to
`qml.expval` so that it no longer silently converts the result to a real
number. However, in `tutorial_pulse_programming101.py` we use a
Hamiltonian with complex-typed coefficients. This resulted in an
imaginary expectation value (`X+0j`) and broke the `jax.value_and_grad`
function.

**Description of change:**

Wrap the `qnode` function to get the expectation value and convert it to
a real number. This allows `jax.value_and_grad` to be used and fixes the
demo.
  • Loading branch information
andrijapau authored Feb 14, 2025
1 parent 68c6f08 commit 691322d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion demonstrations/tutorial_pulse_programming101.metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
}
],
"dateOfPublication": "2023-03-08T00:00:00+00:00",
"dateOfLastModification": "2024-10-07T00:00:00+00:00",
"dateOfLastModification": "2025-02-14T00:00:00+00:00",
"categories": [
"Quantum Hardware",
"Quantum Computing"
Expand Down
14 changes: 10 additions & 4 deletions demonstrations/tutorial_pulse_programming101.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,15 +322,21 @@ def wrapped(p, t):

##############################################################################
# Now we define the ``qnode`` that computes the expectation value of the molecular Hamiltonian.
# We need to wrap the ``qnode`` in a function so that we can convert the expectation value to a real number.
# This will enable use to make use of gradient descent methods that require real-valued loss functions.

dev = qml.device("default.qubit", wires=range(n_wires))

@qml.qnode(dev, interface="jax")
def qnode(theta, t=duration):
qml.BasisState(jnp.array(data.tapered_hf_state), wires=H_obj.wires)
qml.evolve(H_pulse)(params=(*theta, *theta), t=t)
return qml.expval(H_obj)

@qml.qnode(dev)
def _qnode_inner(theta, t=duration):
qml.BasisState(jnp.array(data.tapered_hf_state), wires=H_obj.wires)
qml.evolve(H_pulse)(params=(*theta, *theta), t=t)
return qml.expval(H_obj)

expectation_value = _qnode_inner(theta, t) # Execute the qnode
return jnp.real(expectation_value) # Typecast to real number

value_and_grad = jax.jit(jax.value_and_grad(qnode))

Expand Down

0 comments on commit 691322d

Please sign in to comment.