Skip to content

Commit

Permalink
Merge pull request #250 from hamidelmaazouz/fix/he/analysis_and_codegen
Browse files Browse the repository at this point in the history
[EXPERIMENTAL] Fixes to Analysis and Codegen
  • Loading branch information
hamidelmaazouz authored Nov 18, 2024
2 parents 09429f9 + 305e385 commit e9d2b16
Show file tree
Hide file tree
Showing 10 changed files with 637 additions and 523 deletions.
330 changes: 165 additions & 165 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ scipy = ">=1.13.1"
pyqir = ">=0.8.0a1"
regex = ">=2022.6.2"
jsonpickle = ">=2.2.0"
qblox-instruments = "^0.14.0"
qblox-instruments = "0.14.1"
lark-parser = "^0.12.0"
pydantic-settings = ">=2.5.2"
compiler-config = "0.1.0"
Expand Down
301 changes: 171 additions & 130 deletions src/qat/backend/analysis_passes.py

Large diffs are not rendered by default.

117 changes: 63 additions & 54 deletions src/qat/purr/backends/qblox/analysis_passes.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,19 @@
from collections import defaultdict
from copy import deepcopy
from typing import Dict, Set

import numpy as np

from qat.backend.analysis_passes import BindingResult, IterBound
from qat.backend.analysis_passes import BindingResult, IterBound, TriageResult
from qat.ir.pass_base import AnalysisPass
from qat.ir.result_base import ResultManager
from qat.purr.backends.qblox.constants import Constants
from qat.purr.compiler.builders import InstructionBuilder
from qat.purr.compiler.devices import PulseChannel, PulseShapeType
from qat.purr.compiler.devices import PulseShapeType
from qat.purr.compiler.instructions import DeviceUpdate, Instruction, Pulse, Variable


class QbloxLegalisationPass(AnalysisPass):
"""
Performs target-dependent legalisation for QBlox.
A) A repeat instruction with a very high repetition count is illegal because acquisition memory
on a QBlox sequencer is limited. This requires optimal batching of the repeat instruction into maximally
supported batches of smaller repeat counts.
This pass does not do any batching. More features and adjustments will follow in future iterations.
B) Previously processed variables such as frequencies, phases, and amplitudes still need digital conversion
to a representation that's required by the QBlox ISA.
+ NCO's 1GHz frequency range by 4e9 steps:
+ [-500, 500] Mhz <=> [-2e9, 2e9] steps
+ 1 Hz <=> 4 steps
+ NCO's 360° phase range by 1e9 steps:
+ 1e9 steps <=> 2*pi rad
+ 125e6 steps <=> pi/4 rad
+ Time and samples are quantified in nanoseconds
+ Amplitude depends on the type of the module:
+ [-1, 1] <=> [-2.5, 2.5] V (for QCM)
+ [-1, 1] <=> [-0.5, 0.5] V (for QRM)
+ AWG offset:
+ [-1, 1] <=> [-32 768, 32 767]
The last point is interesting as it requires knowledge of physical configuration of qubits and the modules
they are wired to. This knowledge is typically found during execution and involving it early on would upset
the rest of the compilation flow. In fact, it complicates this pass in particular, involves allocation
concepts that should not be treated here, and promotes a monolithic compilation style. A temporary workaround
is to simply assume the legality of amplitudes from the start whereby users are required to convert
the desired voltages to the equivalent ratio AOT.
This pass performs target-dependent conversion as described in part (B). More features and adjustments
will follow in future iterations.
"""

@staticmethod
def phase_as_steps(phase_rad: float) -> int:
phase_deg = np.rad2deg(phase_rad)
Expand Down Expand Up @@ -130,7 +95,7 @@ def _legalise_bound(self, name: str, bound: IterBound, inst: Instruction):
legal_bound.start,
legal_bound.step,
legal_bound.end,
legal_bound.end,
legal_bound.count,
]
):
raise ValueError(
Expand All @@ -146,20 +111,64 @@ def _legalise_bound(self, name: str, bound: IterBound, inst: Instruction):
)

def run(self, builder: InstructionBuilder, res_mgr: ResultManager, *args, **kwargs):
"""
Performs target-dependent legalisation for QBlox.
A) A repeat instruction with a very high repetition count is illegal because acquisition memory
on a QBlox sequencer is limited. This requires optimal batching of the repeat instruction into maximally
supported batches of smaller repeat counts.
This pass does not do any batching. More features and adjustments will follow in future iterations.
B) Previously processed variables such as frequencies, phases, and amplitudes still need digital conversion
to a representation that's required by the QBlox ISA.
+ NCO's 1GHz frequency range by 4e9 steps:
+ [-500, 500] Mhz <=> [-2e9, 2e9] steps
+ 1 Hz <=> 4 steps
+ NCO's 360° phase range by 1e9 steps:
+ 1e9 steps <=> 2*pi rad
+ 125e6 steps <=> pi/4 rad
+ Time and samples are quantified in nanoseconds
+ Amplitude depends on the type of the module:
+ [-1, 1] <=> [-2.5, 2.5] V (for QCM)
+ [-1, 1] <=> [-0.5, 0.5] V (for QRM)
+ AWG offset:
+ [-1, 1] <=> [-32 768, 32 767]
The last point is interesting as it requires knowledge of physical configuration of qubits and the modules
they are wired to. This knowledge is typically found during execution and involving it early on would upset
the rest of the compilation flow. In fact, it complicates this pass in particular, involves allocation
concepts that should not be treated here, and promotes a monolithic compilation style. A temporary workaround
is to simply assume the legality of amplitudes from the start whereby users are required to convert
the desired voltages to the equivalent ratio AOT.
This pass performs target-dependent conversion as described in part (B). More features and adjustments
will follow in future iterations.
"""

triage_result: TriageResult = res_mgr.lookup_by_type(TriageResult)
binding_result: BindingResult = res_mgr.lookup_by_type(BindingResult)

qblox_iter_bounds: Dict[PulseChannel, Dict[str, Set[IterBound]]] = defaultdict(
lambda: defaultdict(set)
)
for name, instructions in binding_result.reads.items():
for inst in instructions:
for target in inst.quantum_targets:
bound = binding_result.iter_bounds[target][name]
legal_bound = self._legalise_bound(name, bound, inst)
qblox_iter_bounds[target][name].add(legal_bound)

for target, symbol2iter_bounds in qblox_iter_bounds.items():
for name, iter_bounds in symbol2iter_bounds.items():
if len(iter_bounds) > 1:
raise ValueError(f"Ambiguous Qblox bounds for variable {name}")
binding_result.iter_bounds[target][name] = next(iter(iter_bounds))
for target in triage_result.target_map:
rw_result = binding_result.rw_results[target]
bound_result = binding_result.iter_bound_results[target]
legal_bound_result: Dict[str, IterBound] = deepcopy(bound_result)

qblox_bounds: Dict[str, Set[IterBound]] = defaultdict(set)
for name, instructions in rw_result.reads.items():
for inst in instructions:
legal_bound = self._legalise_bound(name, bound_result[name], inst)
qblox_bounds[name].add(legal_bound)

for name, bound in bound_result.items():
if name in qblox_bounds:
bound_set = qblox_bounds[name]
if len(bound_set) > 1:
raise ValueError(
f"Ambiguous Qblox bounds for variable {name} in target {target}"
)
legal_bound_result[name] = next(iter(bound_set))

# TODO: the proper way is to produce a new result and invalidate the old one
binding_result.iter_bound_results[target] = legal_bound_result
76 changes: 53 additions & 23 deletions src/qat/purr/backends/qblox/codegen.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from collections import defaultdict
from contextlib import ExitStack, contextmanager
from dataclasses import dataclass, field
from itertools import chain
from typing import Dict, List

import numpy as np

from qat.backend.analysis_passes import BindingResult, CFGResult, IterBound, TriageResult
from qat.backend.analysis_passes import (
BindingResult,
CFGPass,
CFGResult,
IterBound,
ReadWriteResult,
TriageResult,
)
from qat.backend.codegen_base import DfsTraversal
from qat.backend.graph import ControlFlowGraph
from qat.ir.pass_base import AnalysisPass, InvokerMixin, PassManager
Expand Down Expand Up @@ -632,9 +640,15 @@ def gen_label(self, name: str):


class NewQbloxContext:
def __init__(self, alloc_mgr: AllocationManager, iter_bounds: Dict[str, IterBound]):
self.alloc_mgr = alloc_mgr
def __init__(
self,
rw_result: ReadWriteResult,
iter_bounds: Dict[str, IterBound],
alloc_mgr: AllocationManager,
):
self.rw_result = rw_result
self.iter_bounds = iter_bounds
self.alloc_mgr = alloc_mgr

self.sequence_builder = SequenceBuilder()
self.sequencer_config = SequencerConfig()
Expand Down Expand Up @@ -1020,14 +1034,18 @@ def exit_repeat(inst: Repeat, contexts: Dict):

context._wait_seconds(inst.repetition_period)
context.sequence_builder.add(register, bound.step, register)
context.sequence_builder.nop()
context.sequence_builder.jlt(
register, bound.end + bound.step, label
) # extra step for jlt

@staticmethod
def enter_sweep(inst: Sweep, contexts: Dict):
for name in inst.variables.keys():
for context in contexts.values():
for context in contexts.values():
names = [n for n in inst.variables if n in context.rw_result.writes]

# TODO - multiple variable definition in a single target
for name in names:
register = context.alloc_mgr.registers[name]
label = context.alloc_mgr.labels[name]
bound = context.iter_bounds[name]
Expand All @@ -1037,8 +1055,11 @@ def enter_sweep(inst: Sweep, contexts: Dict):

@staticmethod
def exit_sweep(inst: Sweep, contexts: Dict):
for name in inst.variables.keys():
for context in contexts.values():
for context in contexts.values():
names = [n for n in inst.variables if n in context.rw_result.writes]

# TODO - multiple variable definition in a single target
for name in names:
register = context.alloc_mgr.registers[name]
label = context.alloc_mgr.labels[name]
bound = context.iter_bounds[name]
Expand Down Expand Up @@ -1082,45 +1103,51 @@ def run(self, builder: InstructionBuilder, res_mgr: ResultManager, *args, **kwar
binding_result: BindingResult = res_mgr.lookup_by_type(BindingResult)
result = PreCodegenResult()

for name, instructions in binding_result.writes.items():
for inst in instructions:
targets = (
inst.quantum_targets
if isinstance(inst, QuantumInstruction)
else triage_result.target_map.keys()
)
for target in targets:
result.alloc_mgrs[target].reg_alloc(name)
result.alloc_mgrs[target].gen_label(name)
for target in triage_result.target_map:
alloc_mgr = result.alloc_mgrs[target]
iter_bound_result = binding_result.iter_bound_results[target]
reads = binding_result.rw_results[target].reads
writes = binding_result.rw_results[target].writes

names = set(chain(*[iter_bound_result.keys(), reads.keys(), writes.keys()]))
for name in names:
alloc_mgr.reg_alloc(name)
alloc_mgr.gen_label(name)

res_mgr.add(result)


class NewQbloxEmitter(InvokerMixin):
def build_pass_pipeline(self, *args, **kwargs):
return PassManager() | PreCodegenPass()
return PassManager() | PreCodegenPass() | CFGPass()

def emit_packages(
self, builder: InstructionBuilder, res_mgr: ResultManager, *args, **kwargs
) -> List[QbloxPackage]:
self.run_pass_pipeline(builder, res_mgr, *args, **kwargs)

cfg_result: CFGResult = res_mgr.lookup_by_type(CFGResult)
triage_result: TriageResult = res_mgr.lookup_by_type(TriageResult)
binding_result: BindingResult = res_mgr.lookup_by_type(BindingResult)
precodegen_result: PreCodegenResult = res_mgr.lookup_by_type(PreCodegenResult)

rw_results: Dict[PulseChannel, ReadWriteResult] = binding_result.rw_results
iter_bounds: Dict[PulseChannel, Dict[str, IterBound]] = (
binding_result.iter_bound_results
)
alloc_mgrs: Dict[PulseChannel, AllocationManager] = precodegen_result.alloc_mgrs
iter_bounds: Dict[PulseChannel, Dict[str, IterBound]] = binding_result.iter_bounds

contexts = {
t: NewQbloxContext(alloc_mgr=alloc_mgrs[t], iter_bounds=iter_bounds[t])
t: NewQbloxContext(
rw_result=rw_results[t], iter_bounds=iter_bounds[t], alloc_mgr=alloc_mgrs[t]
)
for t in triage_result.target_map
}

cfg_result: CFGResult = res_mgr.lookup_by_type(CFGResult)
cfg_walker = QbloxCFGWalker(contexts)
cfg_walker.walk(cfg_result.cfg)

# Digard empty contexts
# Discard empty contexts
return [
context.create_package(target)
for target, context in contexts.items()
Expand Down Expand Up @@ -1166,7 +1193,10 @@ def enter(self, block):
for block in self._entered:
head = block.head()
if isinstance(head, Sweep):
name = next(iter(head.variables.keys()))
name = next(
(n for n in context.iter_bounds if n in head.variables),
f"sweep_{hash(head)}",
)
iter_bound = context.iter_bounds[name]
num_bins *= iter_bound.count
elif isinstance(head, Repeat):
Expand Down
2 changes: 0 additions & 2 deletions src/qat/purr/backends/qblox/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from qat.backend.analysis_passes import (
BindingPass,
CFGPass,
TILegalisationPass,
TriagePass,
TriageResult,
Expand Down Expand Up @@ -161,7 +160,6 @@ def build_pass_pipeline(self, *args, **kwargs):
| ReturnSanitisation()
| TriagePass()
| BindingPass()
| CFGPass()
| TILegalisationPass()
| QbloxLegalisationPass()
)
Expand Down
35 changes: 22 additions & 13 deletions tests/qat/backend/qblox/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,33 @@ def test_qblox_legalisation_pass(self):
| TILegalisationPass()
).run(builder, res_mgr)

binding_result: BindingResult = res_mgr.lookup_by_type(BindingResult)
bounds = deepcopy(binding_result.iter_bounds)
triage_result: TriageResult = res_mgr.lookup_by_type(TriageResult)
binding_result: BindingResult = deepcopy(res_mgr.lookup_by_type(BindingResult))

QbloxLegalisationPass().run(builder, res_mgr)
legal_iter_bounds = binding_result.iter_bounds

assert set(legal_iter_bounds.keys()) == set(bounds.keys())
for target, symbol2bound in legal_iter_bounds.items():
for name in binding_result.symbol2scopes:
assert set(symbol2bound.keys()) == set(bounds[target].keys())
bound = bounds[target][name]
legal_bound = symbol2bound[name]
if name in binding_result.reads:

legal_binding_result: BindingResult = res_mgr.lookup_by_type(BindingResult)

for target, instructions in triage_result.target_map.items():
scoping_result = binding_result.scoping_results[target]
rw_result = binding_result.rw_results[target]

bounds = binding_result.iter_bound_results[target]
legal_bounds = legal_binding_result.iter_bound_results[target]

assert set(legal_bounds.keys()) == set(bounds.keys())

for name in scoping_result.symbol2scopes:
bound = bounds[name]
legal_bound = legal_bounds[name]
if name in rw_result.reads:
device_updates = [
inst
for inst in binding_result.reads[name]
if isinstance(inst, DeviceUpdate) and inst.target == target
for inst in rw_result.reads[name]
if isinstance(inst, DeviceUpdate)
]
for du in device_updates:
assert du.target == target
if du.attribute == "frequency":
assert legal_bound != bound
assert legal_bound == IterBound(
Expand Down
Loading

0 comments on commit e9d2b16

Please sign in to comment.