Skip to content

qml.execute no longer accepts mcm_config argument #6807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
4ec4b38
fix: Initial code addition w/ test fixing
andrijapau Jan 10, 2025
23bdaaf
Merge branch 'master' into fix-execute-signature
andrijapau Jan 10, 2025
5d7565f
fix: Move mcm_config kwarg to the bottom
andrijapau Jan 10, 2025
a187285
doc: Update changelog-dev.md
andrijapau Jan 10, 2025
2d5ef5e
fix: Update capture to use new execute_kwargs
andrijapau Jan 10, 2025
0bda24f
doc: Update deprecations.rst
andrijapau Jan 13, 2025
549931b
doc: Update changelog-dev.md
andrijapau Jan 13, 2025
266438f
Update doc/development/deprecations.rst
andrijapau Jan 13, 2025
6ac3995
Update doc/releases/changelog-dev.md
andrijapau Jan 13, 2025
fe8e295
doc: Update execute doc string
andrijapau Jan 13, 2025
432faaf
Merge branch 'fix-execute-signature' of github.com:PennyLaneAI/pennyl…
andrijapau Jan 13, 2025
231e3c0
Merge branch 'master' into fix-execute-signature
andrijapau Jan 13, 2025
7d9376e
Update doc/releases/changelog-dev.md
andrijapau Jan 14, 2025
a63e01f
Merge branch 'master' into fix-execute-signature
andrijapau Jan 14, 2025
391cbba
fix: Update _capture_qnode.py
andrijapau Jan 14, 2025
8d61d2e
fix: Add xfails
andrijapau Jan 14, 2025
df93fdc
Revert "fix: Add xfails"
andrijapau Jan 14, 2025
ea4a20c
fix: Add xfails
andrijapau Jan 14, 2025
d717c0e
fix: Add remaining xfails
andrijapau Jan 14, 2025
c9a1f26
fix: Move xfail in TestTransformDispatcher
andrijapau Jan 14, 2025
b02f0f0
fix: Update test_execute.py to test for deprecation of mcm_config
andrijapau Jan 14, 2025
154d03b
Update doc/releases/changelog-dev.md
andrijapau Jan 15, 2025
34761e1
Merge branch 'master' into fix-execute-signature
andrijapau Jan 22, 2025
b305840
Trigger CI
andrijapau Jan 22, 2025
e4ee93b
Merge branch 'master' into fix-execute-signature
andrijapau Jan 22, 2025
069c4ad
Trigger CI
andrijapau Jan 22, 2025
fef5a4d
fix: Update test_capture_qnode.py for unnecessary code
andrijapau Jan 22, 2025
dc5ba37
Merge branch 'master' into fix-execute-signature
andrijapau Jan 22, 2025
a7a73b7
Merge branch 'master' into fix-execute-signature
andrijapau Jan 23, 2025
f242d76
Merge branch 'master' into fix-execute-signature
andrijapau Jan 23, 2025
361efcf
Merge branch 'master' into fix-execute-signature
andrijapau Jan 24, 2025
040d068
fix: Promote to valueerror to get all deprecated code
andrijapau Jan 24, 2025
6760c3e
Revert "fix: Promote to valueerror to get all deprecated code"
andrijapau Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/development/deprecations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ deprecations are listed below.
Pending deprecations
--------------------

* The ``mcm_config`` argument to ``qml.execute`` has been deprecated.
Instead, use the ``mcm_method`` and ``postselect_mode`` arguments.

- Deprecated in v0.41
- Will be removed in v0.42

* Specifying gradient keyword arguments as any additional keyword argument to the qnode is deprecated
and will be removed in v0.42. The gradient keyword arguments should be passed to the new
keyword argument ``gradient_kwargs`` via an explicit dictionary, like ``gradient_kwargs={"h": 1e-4}``.
Expand Down
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@

<h3>Deprecations 👋</h3>

* The `mcm_method` keyword in `qml.execute` is deprecated. Instead, use the ``mcm_method`` and ``postselect_mode`` arguments.
[(#6807)](https://github.com/PennyLaneAI/pennylane/pull/6807)

* Specifying gradient keyword arguments as any additional keyword argument to the qnode is deprecated
and will be removed in v0.42. The gradient keyword arguments should be passed to the new
keyword argument `gradient_kwargs` via an explicit dictionary. This change will improve qnode argument
Expand Down
4 changes: 1 addition & 3 deletions pennylane/workflow/_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@

"""
from copy import copy
from dataclasses import asdict
from functools import partial
from numbers import Number
from warnings import warn
Expand Down Expand Up @@ -404,8 +403,7 @@ def f(x):
qfunc_jaxpr = jax.make_jaxpr(flat_fn)(*args)

execute_kwargs = copy(qnode.execute_kwargs)
mcm_config = asdict(execute_kwargs.pop("mcm_config"))
qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config}
qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs}

flat_args = jax.tree_util.tree_leaves(args)

Expand Down
28 changes: 23 additions & 5 deletions pennylane/workflow/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,16 @@ def execute(
diff_method: Optional[Union[Callable, SupportedDiffMethods, TransformDispatcher]] = None,
interface: Optional[InterfaceLike] = Interface.AUTO,
*,
transform_program: TransformProgram = None,
grad_on_execution: Literal[bool, "best"] = "best",
cache: Union[None, bool, dict, Cache] = True,
cachesize: int = 10000,
max_diff: int = 1,
device_vjp: Union[bool, None] = False,
postselect_mode=None,
mcm_method=None,
gradient_kwargs: dict = None,
transform_program: TransformProgram = None,
mcm_config: "qml.devices.MCMConfig" = None,
mcm_config: "qml.devices.MCMConfig" = "unset",
config="unset",
inner_transform="unset",
) -> ResultBatch:
Expand Down Expand Up @@ -86,10 +88,18 @@ def execute(
(classical) computational overhead during the backward pass.
device_vjp=False (Optional[bool]): whether or not to use the device-provided Jacobian
product if it is available.
mcm_config (dict): Dictionary containing configuration options for handling
mid-circuit measurements.
postselect_mode (str): Configuration for handling shots with mid-circuit measurement
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
keep the same number of shots. Default is ``None``.
mcm_method (str): Strategy to use when executing circuits with mid-circuit measurements.
``"deferred"`` is ignored. If mid-circuit measurements are found in the circuit,
the device will use ``"tree-traversal"`` if specified and the ``"one-shot"`` method
otherwise. For usage details, please refer to the
:doc:`dynamic quantum circuits page </introduction/dynamic_quantum_circuits>`.
gradient_kwargs (dict): dictionary of keyword arguments to pass when
determining the gradients of tapes.
mcm_config="unset": **DEPRECATED**. This keyword argument has been replaced by ``postselect_mode``
and ``mcm_method`` and will be removed in v0.42.
config="unset": **DEPRECATED**. This keyword argument has been deprecated and
will be removed in v0.42.
inner_transform="unset": **DEPRECATED**. This keyword argument has been deprecated
Expand Down Expand Up @@ -173,6 +183,14 @@ def cost_fn(params, x):
qml.PennyLaneDeprecationWarning,
)

if mcm_config != "unset":
warn(
"The mcm_config argument is deprecated and will be removed in v0.42, use mcm_method and postselect_mode instead.",
qml.PennyLaneDeprecationWarning,
)
mcm_method = mcm_config.mcm_method
postselect_mode = mcm_config.postselect_mode

if logger.isEnabledFor(logging.DEBUG):
logger.debug(
(
Expand Down Expand Up @@ -209,7 +227,7 @@ def cost_fn(params, x):
gradient_method=diff_method,
grad_on_execution=None if grad_on_execution == "best" else grad_on_execution,
use_device_jacobian_product=device_vjp,
mcm_config=mcm_config or {},
mcm_config=qml.devices.MCMConfig(postselect_mode=postselect_mode, mcm_method=mcm_method),
gradient_keyword_arguments=gradient_kwargs or {},
derivative_order=max_diff,
)
Expand Down
9 changes: 7 additions & 2 deletions pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ def _to_qfunc_output_type(
return qml.pytrees.unflatten(results, qfunc_output_structure)


def _validate_mcm_config(postselect_mode: str, mcm_method: str) -> None:
qml.devices.MCMConfig(postselect_mode=postselect_mode, mcm_method=mcm_method)


def _validate_gradient_kwargs(gradient_kwargs: dict) -> None:
for kwarg in gradient_kwargs:
if kwarg == "expansion_strategy":
Expand Down Expand Up @@ -563,17 +567,18 @@ def __init__(
self.device = device
self._interface = get_canonical_interface_name(interface)
self.diff_method = diff_method
mcm_config = qml.devices.MCMConfig(mcm_method=mcm_method, postselect_mode=postselect_mode)
cache = (max_diff > 1) if cache == "auto" else cache

# execution keyword arguments
_validate_mcm_config(postselect_mode, mcm_method)
self.execute_kwargs = {
"grad_on_execution": grad_on_execution,
"cache": cache,
"cachesize": cachesize,
"max_diff": max_diff,
"device_vjp": device_vjp,
"mcm_config": mcm_config,
"postselect_mode": postselect_mode,
"mcm_method": mcm_method,
}

# internal data attributes
Expand Down
2 changes: 0 additions & 2 deletions tests/capture/test_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""
Tests for capturing a qnode into jaxpr.
"""
from dataclasses import asdict
from functools import partial

# pylint: disable=protected-access
Expand Down Expand Up @@ -130,7 +129,6 @@ def circuit(x):
assert eqn0.params["shots"] == qml.measurements.Shots(None)
expected_kwargs = {"diff_method": "best"}
expected_kwargs.update(circuit.execute_kwargs)
expected_kwargs.update(asdict(expected_kwargs.pop("mcm_config")))
assert eqn0.params["qnode_kwargs"] == expected_kwargs

qfunc_jaxpr = eqn0.params["qfunc_jaxpr"]
Expand Down
1 change: 1 addition & 0 deletions tests/ops/functions/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,7 @@ def circuit(theta):

assert np.allclose(matrix, expected_matrix)

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
@pytest.mark.catalyst
@pytest.mark.external
def test_catalyst(self):
Expand Down
1 change: 1 addition & 0 deletions tests/ops/op_math/test_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,7 @@ def circ(phi):
grad = jax.grad(circ)(phi)
assert qml.math.allclose(grad, -jnp.sin(phi))

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
@pytest.mark.catalyst
@pytest.mark.external
def test_catalyst_qnode(self):
Expand Down
28 changes: 28 additions & 0 deletions tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def test_compiler(self):
assert qml.compiler.available("catalyst")
assert qml.compiler.available_compilers() == ["catalyst", "cuda_quantum"]

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_active_compiler(self):
"""Test `qml.compiler.active_compiler` inside a simple circuit"""
dev = qml.device("lightning.qubit", wires=2)
Expand All @@ -91,6 +92,7 @@ def circuit(phi, theta):
assert jnp.allclose(circuit(jnp.pi, jnp.pi / 2), 1.0)
assert jnp.allclose(qml.qjit(circuit)(jnp.pi, jnp.pi / 2), -1.0)

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_active(self):
"""Test `qml.compiler.active` inside a simple circuit"""
dev = qml.device("lightning.qubit", wires=2)
Expand All @@ -114,6 +116,7 @@ def test_jax_enable_x64(self, jax_enable_x64):
qml.compiler.active()
assert jax.config.jax_enable_x64 is jax_enable_x64

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_qjit_circuit(self):
"""Test JIT compilation of a circuit with 2-qubit"""
dev = qml.device("lightning.qubit", wires=2)
Expand All @@ -128,6 +131,7 @@ def circuit(theta):

assert jnp.allclose(circuit(0.5), 0.0)

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_qjit_aot(self):
"""Test AOT compilation of a circuit with 2-qubit"""

Expand All @@ -152,6 +156,7 @@ def circuit(x: complex, z: ShapedArray(shape=(3,), dtype=jnp.float64)):
)
assert jnp.allclose(result, expected)

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
@pytest.mark.parametrize(
"_in,_out",
[
Expand Down Expand Up @@ -196,6 +201,7 @@ def workflow1(params1, params2):
result = workflow1(params1, params2)
assert jnp.allclose(result, expected)

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_return_value_dict(self):
"""Test pytree return values."""
dev = qml.device("lightning.qubit", wires=2)
Expand All @@ -218,6 +224,7 @@ def circuit1(params):
assert jnp.allclose(result["w0"], expected["w0"])
assert jnp.allclose(result["w1"], expected["w1"])

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_qjit_python_if(self):
"""Test JIT compilation with the autograph support"""
dev = qml.device("lightning.qubit", wires=2)
Expand All @@ -235,6 +242,7 @@ def circuit(x: int):
assert jnp.allclose(circuit(3), 0.0)
assert jnp.allclose(circuit(5), 1.0)

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_compilation_opt(self):
"""Test user-configurable compilation options"""
dev = qml.device("lightning.qubit", wires=2)
Expand All @@ -250,6 +258,7 @@ def circuit(x: float):
result_header = "func.func public @circuit(%arg0: tensor<f64>) -> tensor<f64>"
assert result_header in mlir_str

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_qjit_adjoint(self):
"""Test JIT compilation with adjoint support"""
dev = qml.device("lightning.qubit", wires=2)
Expand All @@ -273,6 +282,7 @@ def func():

assert jnp.allclose(workflow_cl(0.1, [1]), workflow_pl(0.1, [1]))

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_qjit_adjoint_lazy(self):
"""Test that the lazy kwarg is supported."""
dev = qml.device("lightning.qubit", wires=2)
Expand All @@ -287,6 +297,7 @@ def workflow_pl(theta, wires):

assert jnp.allclose(workflow_cl(0.1, [1]), workflow_pl(0.1, [1]))

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_control(self):
"""Test that control works with qjit."""
dev = qml.device("lightning.qubit", wires=2)
Expand Down Expand Up @@ -317,6 +328,7 @@ def cond_fn():
class TestCatalystControlFlow:
"""Test ``qml.qjit`` with Catalyst's control-flow operations"""

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_alternating_while_loop(self):
"""Test simple while loop."""
dev = qml.device("lightning.qubit", wires=1)
Expand All @@ -334,6 +346,7 @@ def loop(v):

assert jnp.allclose(circuit(1), -1.0)

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_nested_while_loops(self):
"""Test nested while loops."""
dev = qml.device("lightning.qubit", wires=1)
Expand Down Expand Up @@ -393,6 +406,7 @@ def loop(v):
expected = [qml.PauliX(0) for i in range(4)]
_ = [qml.assert_equal(i, j) for i, j in zip(tape.operations, expected)]

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_dynamic_wires_for_loops(self):
"""Test for loops with iteration index-dependant wires."""
dev = qml.device("lightning.qubit", wires=6)
Expand All @@ -414,6 +428,7 @@ def loop_fn(i):

assert jnp.allclose(circuit(6), expected)

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_nested_for_loops(self):
"""Test nested for loops."""
dev = qml.device("lightning.qubit", wires=4)
Expand Down Expand Up @@ -445,6 +460,7 @@ def inner(j):

assert jnp.allclose(circuit(4), jnp.eye(2**4)[0])

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_for_loop_python_fallback(self):
"""Test that qml.for_loop fallsback to Python
interpretation if Catalyst is not available"""
Expand Down Expand Up @@ -496,6 +512,7 @@ def inner(j):

_ = [qml.assert_equal(i, j) for i, j in zip(res, expected)]

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_cond(self):
"""Test condition with simple true_fn"""
dev = qml.device("lightning.qubit", wires=1)
Expand All @@ -514,6 +531,7 @@ def ansatz_true():
assert jnp.allclose(circuit(1.4), 1.0)
assert jnp.allclose(circuit(1.6), 0.0)

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_cond_with_else(self):
"""Test condition with simple true_fn and false_fn"""
dev = qml.device("lightning.qubit", wires=1)
Expand All @@ -535,6 +553,7 @@ def ansatz_false():
assert jnp.allclose(circuit(1.4), 0.16996714)
assert jnp.allclose(circuit(1.6), 0.0)

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_cond_with_elif(self):
"""Test condition with a simple elif branch"""
dev = qml.device("lightning.qubit", wires=1)
Expand All @@ -558,6 +577,7 @@ def false_fn():
assert jnp.allclose(circuit(1.2), 0.13042371)
assert jnp.allclose(circuit(jnp.pi), -1.0)

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_cond_with_elifs(self):
"""Test condition with multiple elif branches"""
dev = qml.device("lightning.qubit", wires=1)
Expand Down Expand Up @@ -630,6 +650,7 @@ def conditional_false_fn(): # pylint: disable=unused-variable
class TestCatalystGrad:
"""Test ``qml.qjit`` with Catalyst's grad operations"""

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_grad_classical_preprocessing(self):
"""Test the grad transformation with classical preprocessing."""

Expand All @@ -647,6 +668,7 @@ def circuit(x):

assert jnp.allclose(workflow(2.0), -jnp.pi)

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_grad_with_postprocessing(self):
"""Test the grad transformation with classical preprocessing and postprocessing."""
dev = qml.device("lightning.qubit", wires=1)
Expand All @@ -665,6 +687,7 @@ def loss(theta):

assert jnp.allclose(workflow(1.0), 5.04324559)

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_grad_with_multiple_qnodes(self):
"""Test the grad transformation with multiple QNodes with their own differentiation methods."""
dev = qml.device("lightning.qubit", wires=1)
Expand Down Expand Up @@ -703,6 +726,7 @@ def dsquare(x: float):

assert jnp.allclose(dsquare(2.3), 4.6)

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_jacobian_diff_method(self):
"""Test the Jacobian transformation with the device diff_method."""
dev = qml.device("lightning.qubit", wires=1)
Expand All @@ -721,6 +745,7 @@ def workflow(p: float):

assert jnp.allclose(result, reference)

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_jacobian_auto(self):
"""Test the Jacobian transformation with 'auto'."""
dev = qml.device("lightning.qubit", wires=1)
Expand All @@ -740,6 +765,7 @@ def circuit(x):

assert jnp.allclose(result, reference)

@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_jacobian_fd(self):
"""Test the Jacobian transformation with 'fd'."""
dev = qml.device("lightning.qubit", wires=1)
Expand Down Expand Up @@ -838,6 +864,7 @@ def f(x):
vjp(x, dy)


@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
class TestCatalystSample:
"""Test qml.sample with Catalyst."""

Expand All @@ -858,6 +885,7 @@ def circuit(x):
assert circuit(jnp.pi) == 1


@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
class TestCatalystMCMs:
"""Test dynamic_one_shot with Catalyst."""

Expand Down
Loading
Loading