Skip to content

Commit

Permalink
Merge pull request #256 from hamidelmaazouz/fix/he/analysis_and_codegen
Browse files Browse the repository at this point in the history
[EXPERIMENTAL] Skeptical static `DeviceUpdate` instructions
  • Loading branch information
hamidelmaazouz authored Nov 19, 2024
2 parents 11349b3 + d10d438 commit 2fcde80
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 116 deletions.
6 changes: 3 additions & 3 deletions src/qat/backend/analysis_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def run(self, builder: InstructionBuilder, res_mgr: ResultManager, *args, **kwar
else:
targets.update(inst.quantum_targets)

if isinstance(inst, DeviceUpdate):
if isinstance(inst, DeviceUpdate) and isinstance(inst.value, Variable):
reads[inst.target].add(inst.value.name)

result = TriageResult()
Expand Down Expand Up @@ -303,7 +303,7 @@ def run(self, builder: InstructionBuilder, res_mgr: ResultManager, *args, **kwar
]
elif isinstance(inst, Acquire):
rw_result.writes[inst.output_variable].append(inst)
elif isinstance(inst, DeviceUpdate):
elif isinstance(inst, DeviceUpdate) and isinstance(inst.value, Variable):
if not (
inst.value.name in scoping_result.symbol2scopes
and [
Expand Down Expand Up @@ -337,7 +337,7 @@ def decompose_freq(frequency: float, target: PulseChannel):
return lo_freq, nco_freq

def _legalise_bound(self, name: str, bound: IterBound, inst: Instruction):
if isinstance(inst, DeviceUpdate):
if isinstance(inst, DeviceUpdate) and isinstance(inst.value, Variable):
if inst.attribute == "frequency":
if inst.target.fixed_if:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion src/qat/purr/backends/qblox/analysis_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def freq_as_steps(freq_hz: float) -> int:
return steps

def _legalise_bound(self, name: str, bound: IterBound, inst: Instruction):
if isinstance(inst, DeviceUpdate):
if isinstance(inst, DeviceUpdate) and isinstance(inst.value, Variable):
if inst.attribute == "frequency":
legal_bound = IterBound(
start=self.freq_as_steps(bound.start),
Expand Down
139 changes: 82 additions & 57 deletions src/qat/purr/backends/qblox/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,23 @@
from qat.purr.backends.qblox.codegen import NewQbloxEmitter, QbloxEmitter
from qat.purr.backends.utilities import get_axis_map
from qat.purr.compiler.emitter import QatFile
from qat.purr.compiler.execution import SweepIterator, _binary_average, _numpy_array_to_list
from qat.purr.compiler.instructions import AcquireMode, IndexAccessor, Instruction, Variable
from qat.purr.compiler.execution import (
DeviceInjectors,
SweepIterator,
_binary_average,
_numpy_array_to_list,
)
from qat.purr.compiler.instructions import (
AcquireMode,
DeviceUpdate,
IndexAccessor,
Instruction,
Variable,
)
from qat.purr.compiler.interrupt import Interrupt, NullInterrupt
from qat.purr.compiler.runtime import NewQuantumRuntime
from qat.purr.utils.logging_utils import log_duration
from qat.utils.algorithm import stable_partition


class QbloxLiveHardwareModel(LiveHardwareModel):
Expand Down Expand Up @@ -167,62 +179,75 @@ def build_pass_pipeline(self, *args, **kwargs):
def _common_execute(self, builder, interrupt: Interrupt = NullInterrupt()):
self._model_exists()

with log_duration("Codegen run in {} seconds."):
res_mgr = ResultManager()
self.run_pass_pipeline(builder, res_mgr, self.model)
packages = NewQbloxEmitter().emit_packages(builder, res_mgr, self.model)

with log_duration("QPU returned results in {} seconds."):
self.model.control_hardware.set_data(packages)
playback_results: Dict[str, np.ndarray] = (
self.model.control_hardware.start_playback(None, None)
)
# TODO - A skeptical usage of DeviceInjectors on static device updates
# TODO - Figure out what they mean w/r to scopes and control flow
static_dus, builder.instructions = stable_partition(
builder.instructions,
lambda inst: isinstance(inst, DeviceUpdate)
and not isinstance(inst.value, Variable),
)
injectors = DeviceInjectors(static_dus)

try:
injectors.inject()
with log_duration("Codegen run in {} seconds."):
res_mgr = ResultManager()
self.run_pass_pipeline(builder, res_mgr, self.model)
packages = NewQbloxEmitter().emit_packages(builder, res_mgr, self.model)

with log_duration("QPU returned results in {} seconds."):
self.model.control_hardware.set_data(packages)
playback_results: Dict[str, np.ndarray] = (
self.model.control_hardware.start_playback(None, None)
)

# Post execution step needs a lot of work
# TODO - Robust batching analysis (as a pass !)
# TODO - Lowerability analysis pass

triage_result: TriageResult = res_mgr.lookup_by_type(TriageResult)
acquire_map = triage_result.acquire_map
pp_map = triage_result.pp_map
sweeps = triage_result.sweeps

def create_sweep_iterator():
switerator = SweepIterator()
for sweep in sweeps:
switerator.add_sweep(sweep)
return switerator

results = {}
for t, acquires in acquire_map.items():
switerator = create_sweep_iterator()
for acq in acquires:
big_response = playback_results[acq.output_variable]
sweep_splits = np.split(big_response, switerator.length)
switerator.reset_iteration()
while not switerator.is_finished():
# just to advance iteration, no need for injection
# TODO - A generic loop nest model. Sth similar to SweepIterator but does not mixin injection stuff
switerator.do_sweep([])
response = sweep_splits[switerator.current_iteration]
response_axis = get_axis_map(acq.mode, response)
for pp in pp_map[acq.output_variable]:
response, response_axis = self.run_post_processing(
pp, response, response_axis
)
handle = results.setdefault(
acq.output_variable,
np.empty(
switerator.get_results_shape(response.shape),
response.dtype,
),
)
switerator.insert_result_at_sweep_position(handle, response)

results = self._process_results(results, triage_result)
results = self._process_assigns(results, triage_result)

return results
# Post execution step needs a lot of work
# TODO - Robust batching analysis (as a pass !)
# TODO - Lowerability analysis pass

triage_result: TriageResult = res_mgr.lookup_by_type(TriageResult)
acquire_map = triage_result.acquire_map
pp_map = triage_result.pp_map
sweeps = triage_result.sweeps

def create_sweep_iterator():
switerator = SweepIterator()
for sweep in sweeps:
switerator.add_sweep(sweep)
return switerator

results = {}
for t, acquires in acquire_map.items():
switerator = create_sweep_iterator()
for acq in acquires:
big_response = playback_results[acq.output_variable]
sweep_splits = np.split(big_response, switerator.length)
switerator.reset_iteration()
while not switerator.is_finished():
# just to advance iteration, no need for injection
# TODO - A generic loop nest model. Sth similar to SweepIterator but does not mixin injection stuff
switerator.do_sweep([])
response = sweep_splits[switerator.current_iteration]
response_axis = get_axis_map(acq.mode, response)
for pp in pp_map[acq.output_variable]:
response, response_axis = self.run_post_processing(
pp, response, response_axis
)
handle = results.setdefault(
acq.output_variable,
np.empty(
switerator.get_results_shape(response.shape),
response.dtype,
),
)
switerator.insert_result_at_sweep_position(handle, response)

results = self._process_results(results, triage_result)
results = self._process_assigns(results, triage_result)

return results
finally:
injectors.revert()

def _process_results(self, results, triage_result: TriageResult):
"""
Expand Down
127 changes: 76 additions & 51 deletions tests/qat/backend/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,17 @@
from qat.purr.backends.qblox.device import QbloxPhysicalBaseband, QbloxPhysicalChannel
from qat.purr.compiler.devices import PulseShapeType
from qat.purr.compiler.emitter import InstructionEmitter
from qat.purr.compiler.instructions import Acquire, MeasurePulse, Pulse
from qat.purr.compiler.execution import DeviceInjectors
from qat.purr.compiler.instructions import (
Acquire,
DeviceUpdate,
MeasurePulse,
Pulse,
Variable,
)
from qat.purr.compiler.runtime import get_builder
from qat.purr.utils.logger import get_default_logger
from qat.utils.algorithm import stable_partition

from tests.qat.utils.builder_nuggets import qubit_spect, resonator_spect

Expand Down Expand Up @@ -391,58 +399,75 @@ def test_qubit_spect(self, model, num_points, qubit_indices):
runtime = model.create_runtime()
runtime.run_pass_pipeline(builder, res_mgr, model, engine)

self.run_pass_pipeline(builder, res_mgr, model)
packages = NewQbloxEmitter().emit_packages(builder, res_mgr, model)
assert len(packages) == 2 * len(qubit_indices)

for index in qubit_indices:
qubit = model.get_qubit(index)
drive_channel = qubit.get_drive_channel()
acquire_channel = qubit.get_acquire_channel()
# TODO - A skeptical usage of DeviceInjectors on static device updates
# TODO - Figure out what they mean w/r to scopes and control flow
static_dus, builder.instructions = stable_partition(
builder.instructions,
lambda inst: isinstance(inst, DeviceUpdate)
and not isinstance(inst.value, Variable),
)

# Drive
drive_pkg = next((pkg for pkg in packages if pkg.target == drive_channel))
drive_pulse = next(
(
inst
for inst in builder.instructions
if isinstance(inst, Pulse) and drive_channel in inst.quantum_targets
assert len(static_dus) == len(qubit_indices)

injectors = DeviceInjectors(static_dus)
try:
injectors.inject()
self.run_pass_pipeline(builder, res_mgr, model)
packages = NewQbloxEmitter().emit_packages(builder, res_mgr, model)
assert len(packages) == 2 * len(qubit_indices)

for index in qubit_indices:
qubit = model.get_qubit(index)
drive_channel = qubit.get_drive_channel()
acquire_channel = qubit.get_acquire_channel()

# Drive
drive_pkg = next((pkg for pkg in packages if pkg.target == drive_channel))
drive_pulse = next(
(
inst
for inst in builder.instructions
if isinstance(inst, Pulse) and drive_channel in inst.quantum_targets
)
)
)

assert not drive_pkg.sequence.acquisitions

if drive_pulse.shape == PulseShapeType.SQUARE:
assert not drive_pkg.sequence.waveforms
assert "play" not in drive_pkg.sequence.program
assert "set_awg_offs" in drive_pkg.sequence.program
assert "upd_param" in drive_pkg.sequence.program
else:
assert drive_pkg.sequence.waveforms
assert "play" in drive_pkg.sequence.program
assert "set_awg_offs" not in drive_pkg.sequence.program
assert "upd_param" not in drive_pkg.sequence.program

# Readout
acquire_pkg = next((pkg for pkg in packages if pkg.target == acquire_channel))
measure_pulse = next(
(
inst
for inst in builder.instructions
if isinstance(inst, MeasurePulse)
and acquire_channel in inst.quantum_targets
assert not drive_pkg.sequence.acquisitions

if drive_pulse.shape == PulseShapeType.SQUARE:
assert not drive_pkg.sequence.waveforms
assert "play" not in drive_pkg.sequence.program
assert "set_awg_offs" in drive_pkg.sequence.program
assert "upd_param" in drive_pkg.sequence.program
else:
assert drive_pkg.sequence.waveforms
assert "play" in drive_pkg.sequence.program
assert "set_awg_offs" not in drive_pkg.sequence.program
assert "upd_param" not in drive_pkg.sequence.program

# Readout
acquire_pkg = next(
(pkg for pkg in packages if pkg.target == acquire_channel)
)
measure_pulse = next(
(
inst
for inst in builder.instructions
if isinstance(inst, MeasurePulse)
and acquire_channel in inst.quantum_targets
)
)
)

assert acquire_pkg.sequence.acquisitions

if measure_pulse.shape == PulseShapeType.SQUARE:
assert not acquire_pkg.sequence.waveforms
assert "play" not in acquire_pkg.sequence.program
assert "set_awg_offs" in acquire_pkg.sequence.program
assert "upd_param" in acquire_pkg.sequence.program
else:
assert acquire_pkg.sequence.waveforms
assert "play" in acquire_pkg.sequence.program
assert "set_awg_offs" not in acquire_pkg.sequence.program
assert "upd_param" not in acquire_pkg.sequence.program
assert acquire_pkg.sequence.acquisitions

if measure_pulse.shape == PulseShapeType.SQUARE:
assert not acquire_pkg.sequence.waveforms
assert "play" not in acquire_pkg.sequence.program
assert "set_awg_offs" in acquire_pkg.sequence.program
assert "upd_param" in acquire_pkg.sequence.program
else:
assert acquire_pkg.sequence.waveforms
assert "play" in acquire_pkg.sequence.program
assert "set_awg_offs" not in acquire_pkg.sequence.program
assert "upd_param" not in acquire_pkg.sequence.program
finally:
injectors.revert()
11 changes: 10 additions & 1 deletion tests/qat/backend/test_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,22 @@ def test_qubit_spect(self, model):

@pytest.mark.parametrize("model", [None], indirect=True)
class TestQbloxLiveEngineAdapter:
def test_resonator_spect(self, model):
def test_engine_adapter(self, model):
runtime = model.create_runtime()
assert isinstance(runtime.engine, QbloxLiveEngineAdapter)
assert isinstance(runtime.engine._legacy_engine, QbloxLiveEngine)
assert isinstance(runtime.engine._new_engine, NewQbloxLiveEngine)

def test_resonator_spect(self, model):
runtime = model.create_runtime()
runtime.engine.enable_hax = True
builder = resonator_spect(model)
results = runtime.execute(builder)
assert results is not None

def test_qubit_spect(self, model):
runtime = model.create_runtime()
runtime.engine.enable_hax = True
builder = qubit_spect(model)
results = runtime.execute(builder)
assert results is not None
6 changes: 3 additions & 3 deletions tests/qat/utils/builder_nuggets.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def qubit_spect(model, qubit_indices=None, num_points=None):
builder = get_builder(model)
builder.synchronize([model.get_qubit(index) for index in qubit_indices])
for index in qubit_indices:
# TODO - Provide better processing of static DeviceUpdates
model.get_qubit(index).get_drive_channel().scale = 1
# builder.device_assign(model.get_qubit(index).get_drive_channel(), "scale", 1)
# TODO - A skeptical usage of DeviceInjectors on static device updates
# TODO - Figure out what they mean w/r to scopes and control flow
builder.device_assign(model.get_qubit(index).get_drive_channel(), "scale", 1)
builder.sweep(
[SweepValue(f"freq{index}", scan_freqs[f"Q{index}"]) for index in qubit_indices]
)
Expand Down

0 comments on commit 2fcde80

Please sign in to comment.