Skip to content

Commit 277b97e

Browse files
antalszavarmoyard
andauthored
Fix an error in MottonenStatePreparation with qml.math.cast (#1400)
* qml.math.cast * Add test and update changelog. * Typo. * Import torch change. * Update pennylane/templates/state_preparations/mottonen.py Co-authored-by: antalszava <antalszava@gmail.com> * Update tests/templates/test_state_preparations/test_mottonen_state_prep.py Co-authored-by: antalszava <antalszava@gmail.com> * Update tests/templates/test_state_preparations/test_mottonen_state_prep.py Co-authored-by: antalszava <antalszava@gmail.com> * Update tests/templates/test_state_preparations/test_mottonen_state_prep.py Co-authored-by: antalszava <antalszava@gmail.com> * Update .github/CHANGELOG.md Co-authored-by: antalszava <antalszava@gmail.com> * Update tests/templates/test_state_preparations/test_mottonen_state_prep.py Co-authored-by: antalszava <antalszava@gmail.com> * Tensor to array. Co-authored-by: Romain Moyard <rmoyard@gmail.com>
1 parent 08ff413 commit 277b97e

File tree

3 files changed

+42
-4
lines changed

3 files changed

+42
-4
lines changed

.github/CHANGELOG.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ random_mat2 = rng.standard_normal(3, requires_grad=False)
505505
```
506506

507507
```pycon
508-
>>> tape = tape.expand(depth=2)
508+
>>> tape = tape.expand(depth=1)
509509
>>> print(tape.draw(wire_order=Wires(all_wires)))
510510
c0: ──────────────╭C──────────────────────╭C──────────┤
511511
c1: ──────────────├C──────────────────────├C──────────┤
@@ -577,14 +577,18 @@ random_mat2 = rng.standard_normal(3, requires_grad=False)
577577

578578
<h3>Bug fixes</h3>
579579

580+
* Fixes a bug with `qml.math.cast` where the `MottonenStatePreparation` operation expected
581+
a float type instead of double.
582+
[(#1400)](https://github.com/XanaduAI/pennylane/pull/1400)
583+
580584
* Fixes a bug where a copy of `qml.ControlledQubitUnitary` was non-functional as it did not have all the necessary information.
581585
[(#1411)](https://github.com/PennyLaneAI/pennylane/pull/1411)
582586

583587
* Warns when adjoint or reversible differentiation specified or called on a device with finite shots.
584-
[(#1406)](https://github.com/PennyLaneAI/pennylane/pull/1406)
588+
[(#1406)](https://github.com/PennyLaneAI/pennylane/pull/1406)
585589

586590
* Fixes the differentiability of the operations `IsingXX` and `IsingZZ` for Autograd, Jax and Tensorflow.
587-
[(#1390)](https://github.com/PennyLaneAI/pennylane/pull/1390)
591+
[(#1390)](https://github.com/PennyLaneAI/pennylane/pull/1390)
588592

589593
* Fixes a bug where multiple identical Hamiltonian terms will produce a
590594
different result with ``optimize=True`` using ``ExpvalCost``.
@@ -595,7 +599,7 @@ random_mat2 = rng.standard_normal(3, requires_grad=False)
595599
[(#1392)](https://github.com/XanaduAI/pennylane/pull/1392)
596600

597601
* Fixes floating point errors with `diff_method="finite-diff"` and `order=1` when parameters are `float32`.
598-
[(#1381)](https://github.com/PennyLaneAI/pennylane/pull/1381)
602+
[(#1381)](https://github.com/PennyLaneAI/pennylane/pull/1381)
599603

600604
* Fixes a bug where `qml.ctrl` would fail to transform gates that had no
601605
control defined and no decomposition defined.

pennylane/templates/state_preparations/mottonen.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ def _get_alpha_y(a, n, k):
203203
with np.errstate(divide="ignore", invalid="ignore"):
204204
division = numerator / denominator
205205

206+
# Cast the numerator and denominator to ensure compatibility with interfaces
207+
division = qml.math.cast(division, np.float64)
208+
denominator = qml.math.cast(denominator, np.float64)
209+
206210
division = qml.math.where(denominator != 0.0, division, 0.0)
207211

208212
return 2 * qml.math.arcsin(qml.math.sqrt(division))

tests/templates/test_state_preparations/test_mottonen_state_prep.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import numpy as np
1919
import pennylane as qml
2020
from pennylane import numpy as pnp
21+
22+
torch = pytest.importorskip("torch", minversion="1.3")
2123
from pennylane.templates.state_preparations.mottonen import gray_code, _get_alpha_y
2224

2325

@@ -342,3 +344,31 @@ def circuit(state_vector):
342344
return qml.expval(qml.PauliZ(0))
343345

344346
qml.grad(circuit)(state_vector)
347+
348+
349+
class TestCasting:
350+
"""Test that the Mottonen state preparation ensures the compatibility with
351+
interfaces by using casting'"""
352+
353+
@pytest.mark.parametrize(
354+
"inputs, expected",
355+
[
356+
(
357+
torch.tensor([0.0, 0.7, 0.7, 0.0], requires_grad=True),
358+
[0.0, 0.5, 0.5, 0.0],
359+
),
360+
(torch.tensor([0.1, 0.0, 0.0, 0.1], requires_grad=True), [0.5, 0.0, 0.0, 0.5]),
361+
],
362+
)
363+
def test_scalar_torch(self, inputs, expected):
364+
"""Test that MottonenStatePreparation can be correctly used with the Torch interface."""
365+
dev = qml.device("default.qubit", wires=2)
366+
367+
@qml.qnode(dev, interface="torch")
368+
def circuit(inputs):
369+
qml.templates.MottonenStatePreparation(inputs, wires=[0, 1])
370+
return qml.probs(wires=[0, 1])
371+
372+
inputs = inputs / torch.linalg.norm(inputs)
373+
res = circuit(inputs)
374+
assert np.allclose(res.detach().numpy(), expected, atol=1e-6, rtol=0)

0 commit comments

Comments
 (0)