From 94d00ddb79a369605fc9240e658acf9f13b53e2c Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Fri, 1 Sep 2023 11:42:02 -0400 Subject: [PATCH] Remove record node with Assignment of variables. (#499) * in assign, if record variable is in the assign, remove record node in AST * adding specific visitor for waveforms. * fixing bugs adding unit test. * fixing test * removing assignment scan inside of schema code gen. --- .../codegen/common/assign_variables.py | 5 +- src/bloqade/codegen/common/assignment_scan.py | 95 +++++++++++++------ src/bloqade/codegen/hardware/quera.py | 4 - src/bloqade/ir/routine/braket.py | 16 ++-- src/bloqade/ir/routine/quera.py | 10 +- src/bloqade/ir/scalar.py | 2 +- tests/test_assign.py | 56 ++++++++++- tests/test_scalar.py | 2 +- 8 files changed, 139 insertions(+), 51 deletions(-) diff --git a/src/bloqade/codegen/common/assign_variables.py b/src/bloqade/codegen/common/assign_variables.py index 9c73b996b..15618682e 100644 --- a/src/bloqade/codegen/common/assign_variables.py +++ b/src/bloqade/codegen/common/assign_variables.py @@ -114,7 +114,10 @@ def visit_negative(self, ast: waveform.Negative) -> Any: return waveform.Negative(self.visit(ast.waveform)) def visit_record(self, ast: waveform.Record) -> Any: - return waveform.Record(self.visit(ast.waveform), ast.var) + if ast.var.name in self.mapping: + return self.visit(ast.waveform) + else: + return waveform.Record(self.visit(ast.waveform), ast.var) def visit_sample(self, ast: waveform.Sample) -> Any: dt = self.scalar_visitor.emit(ast.dt) diff --git a/src/bloqade/codegen/common/assignment_scan.py b/src/bloqade/codegen/common/assignment_scan.py index a91808f21..891a81079 100644 --- a/src/bloqade/codegen/common/assignment_scan.py +++ b/src/bloqade/codegen/common/assignment_scan.py @@ -1,19 +1,81 @@ -from bloqade.ir.analog_circuit import AnalogCircuit +from bloqade.ir.control.waveform import ( + AlignedWaveform, + Constant, + Linear, + Poly, + PythonFn, + Sample, +) from bloqade.ir.visitor.analog_circuit import AnalogCircuitVisitor +from bloqade.ir.visitor.waveform import WaveformVisitor +from bloqade.ir.analog_circuit import AnalogCircuit import bloqade.ir.control.sequence as sequence import bloqade.ir.control.pulse as pulse import bloqade.ir.control.field as field import bloqade.ir.control.waveform as waveform -import bloqade.ir.scalar as scalar import bloqade.ir.analog_circuit as analog_circuit import numbers from typing import Any, Dict +class AssignmentScanRecord(WaveformVisitor): + def __init__(self, assignments: Dict[str, numbers.Real] = {}): + self.assignments = dict(assignments) + + def visit_record(self, ast: waveform.Record): + duration = ast.waveform.duration(**self.assignments) + var = ast.var + value = ast.waveform(duration, **self.assignments) + self.assignments[var.name] = value + self.visit(ast.waveform) + + def visit_append(self, ast: waveform.Append): + list(map(self.visit, ast.waveforms)) + + def visit_slice(self, ast: waveform.Slice): + self.visit(ast.waveform) + + def visit_add(self, ast: waveform.Add): + self.visit(ast.left) + self.visit(ast.right) + + def visit_negative(self, ast: waveform.Negative): + self.visit(ast.waveform) + + def visit_scale(self, ast: waveform.Scale): + self.visit(ast.waveform) + + def visit_smooth(self, ast: waveform.Smooth): + self.visit(ast.waveform) + + def visit_sample(self, ast: Sample) -> Any: + return self.visit(ast.waveform) + + def visit_alligned(self, ast: AlignedWaveform) -> Any: + return super().visit_alligned(ast) + + def visit_constant(self, ast: Constant) -> Any: + pass + + def visit_linear(self, ast: Linear) -> Any: + pass + + def visit_poly(self, ast: Poly) -> Any: + pass + + def visit_python_fn(self, ast: PythonFn) -> Any: + pass + + def emit(self, ast: waveform.Waveform) -> Dict[str, numbers.Real]: + self.visit(ast) + return self.assignments + + class AssignmentScan(AnalogCircuitVisitor): def __init__(self, assignments: Dict[str, numbers.Real] = {}): self.assignments = dict(assignments) + self.waveform_visitor = AssignmentScanRecord(self.assignments) def visit_analog_circuit(self, ast: AnalogCircuit) -> Any: self.visit(ast.sequence) @@ -50,34 +112,7 @@ def visit_spatial_modulation(self, ast: field.SpatialModulation): pass def visit_waveform(self, ast: waveform.Waveform): - match ast: - case waveform.Record(sub_waveform, scalar.Variable(name)): - duration = sub_waveform.duration(**self.assignments) - value = sub_waveform.eval_decimal(duration, **self.assignments) - self.assignments[name] = value - self.visit(sub_waveform) - - case waveform.Append(waveforms): - list(map(self.visit, waveforms)) - - case waveform.Slice(sub_waveform, _): - self.visit(sub_waveform) - - case waveform.Add(lhs, rhs): - self.visit(lhs) - self.visit(rhs) - - case waveform.Negative(sub_waveform): - self.visit(sub_waveform) - - case waveform.Scale(_, sub_waveform): - self.visit(sub_waveform) - - case waveform.AlignedWaveform(waveform=sub_waveform): - self.visit(sub_waveform) - - case waveform.Smooth(_, sub_waveform): - self.visit(sub_waveform) + self.assignments.update(self.waveform_visitor.emit(ast)) def emit(self, ast: analog_circuit.AnalogCircuit) -> Dict[str, numbers.Real]: self.visit(ast) diff --git a/src/bloqade/codegen/hardware/quera.py b/src/bloqade/codegen/hardware/quera.py index 23f45e546..ce49ff491 100644 --- a/src/bloqade/codegen/hardware/quera.py +++ b/src/bloqade/codegen/hardware/quera.py @@ -30,7 +30,6 @@ from bloqade.ir.visitor.analog_circuit import AnalogCircuitVisitor from bloqade.ir.visitor.waveform import WaveformVisitor -from bloqade.codegen.common.assignment_scan import AssignmentScan import bloqade.submission.ir.task_specification as task_spec from bloqade.submission.ir.parallel import ParallelDecoder, ClusterLocationInfo @@ -727,9 +726,6 @@ def visit_analog_circuit(self, ast: AnalogCircuit) -> Any: def emit( self, nshots: int, analog_circuit: AnalogCircuit ) -> Tuple[task_spec.QuEraTaskSpecification, Optional[ParallelDecoder]]: - self.assignments = AssignmentScan(self.assignments).emit( - analog_circuit.sequence - ) self.visit(analog_circuit) task_ir = task_spec.QuEraTaskSpecification( diff --git a/src/bloqade/ir/routine/braket.py b/src/bloqade/ir/routine/braket.py index d4be3f06b..a7db3674b 100644 --- a/src/bloqade/ir/routine/braket.py +++ b/src/bloqade/ir/routine/braket.py @@ -59,11 +59,11 @@ def compile( tasks = OrderedDict() for task_number, batch_params in enumerate(params.batch_assignments(*args)): - final_circuit = AssignAnalogCircuit(batch_params).visit(circuit) - record_params = AssignmentScan().emit(final_circuit) - task_ir, parallel_decoder = QuEraCodeGen( - record_params, capabilities=capabilities - ).emit(shots, final_circuit) + record_params = AssignmentScan(batch_params).emit(circuit) + final_circuit = AssignAnalogCircuit(record_params).visit(circuit) + task_ir, parallel_decoder = QuEraCodeGen(capabilities=capabilities).emit( + shots, final_circuit + ) task_ir = task_ir.discretize(capabilities) tasks[task_number] = BraketTask( @@ -207,9 +207,9 @@ def compile( tasks = OrderedDict() for task_number, batch_params in enumerate(params.batch_assignments(*args)): - final_circuit = AssignAnalogCircuit(batch_params).visit(circuit) - record_params = AssignmentScan().emit(final_circuit) - quera_task_ir, _ = QuEraCodeGen(record_params).emit(shots, final_circuit) + record_params = AssignmentScan(batch_params).emit(circuit) + final_circuit = AssignAnalogCircuit(record_params).visit(circuit) + quera_task_ir, _ = QuEraCodeGen().emit(shots, final_circuit) task_ir = to_braket_task_ir(quera_task_ir) diff --git a/src/bloqade/ir/routine/quera.py b/src/bloqade/ir/routine/quera.py index e9f01931f..2ffc2da94 100644 --- a/src/bloqade/ir/routine/quera.py +++ b/src/bloqade/ir/routine/quera.py @@ -72,11 +72,11 @@ def compile( tasks = OrderedDict() for task_number, batch_params in enumerate(params.batch_assignments(*args)): - final_circuit = AssignAnalogCircuit(batch_params).visit(circuit) - record_params = AssignmentScan().emit(final_circuit) - task_ir, parallel_decoder = QuEraCodeGen( - record_params, capabilities=capabilities - ).emit(shots, final_circuit) + record_params = AssignmentScan(batch_params).emit(circuit) + final_circuit = AssignAnalogCircuit(record_params).visit(circuit) + task_ir, parallel_decoder = QuEraCodeGen(capabilities=capabilities).emit( + shots, final_circuit + ) task_ir = task_ir.discretize(capabilities) tasks[task_number] = QuEraTask( diff --git a/src/bloqade/ir/scalar.py b/src/bloqade/ir/scalar.py index 03441d719..7960e7426 100644 --- a/src/bloqade/ir/scalar.py +++ b/src/bloqade/ir/scalar.py @@ -406,7 +406,7 @@ def children(self): return [] def print_node(self): - return f"DefaultVariable: {self.name} = {self.value}" + return f"AssignedVariable: {self.name} = {self.value}" @validator("name") def name_validator(cls, v): diff --git a/tests/test_assign.py b/tests/test_assign.py index 5b0ee79df..a6faed00b 100644 --- a/tests/test_assign.py +++ b/tests/test_assign.py @@ -1,4 +1,4 @@ -from bloqade import piecewise_linear +from bloqade import piecewise_linear, start, var, cast from bloqade.atom_arrangement import Chain from bloqade.ir import ( rydberg, @@ -8,8 +8,12 @@ Pulse, Field, AssignedRunTimeVector, + Uniform, ) +import bloqade.ir.control.waveform as waveform +import bloqade.ir.scalar as scalar from bloqade.codegen.common.assign_variables import AssignAnalogCircuit +from bloqade.codegen.common.assignment_scan import AssignmentScan from decimal import Decimal import pytest @@ -59,3 +63,53 @@ def test_assignment_error(): circuit = AssignAnalogCircuit(dict(amp=amp)).visit(circuit) with pytest.raises(ValueError): circuit = AssignAnalogCircuit(dict(amp=amp)).visit(circuit) + + +def test_scan(): + t = var("t") + circuit = ( + start.rydberg.detuning.uniform.constant("max", 1.0) + .slice(0, t) + .record("detuning") + .linear("detuning", 0, 1.0 - t) + .parse_sequence() + ) + + params = dict(max=10, t=0.1) + + completed_params = AssignmentScan(params).emit(circuit) + completed_circuit = AssignAnalogCircuit(completed_params).visit(circuit) + + t_assigned = scalar.AssignedVariable("t", 0.1) + max_assigned = scalar.AssignedVariable("max", 10) + detuning_assigned = scalar.AssignedVariable("detuning", 10) + dur_assigned = 1 - t_assigned + + interval = waveform.Interval(cast(0), t_assigned) + + target_circuit = Sequence( + { + rydberg: Pulse( + { + detuning: Field( + value={ + Uniform: waveform.Append( + [ + waveform.Slice( + waveform.Constant(max_assigned, cast(1.0)), + interval, + ), + waveform.Linear(detuning_assigned, 0, dur_assigned), + ] + ) + } + ) + } + ) + } + ) + + print(repr(completed_circuit)) + print(repr(target_circuit)) + + assert completed_circuit == target_circuit diff --git a/tests/test_scalar.py b/tests/test_scalar.py index 0bfe459ca..dd434ae19 100644 --- a/tests/test_scalar.py +++ b/tests/test_scalar.py @@ -407,7 +407,7 @@ def test_assigned_variable_methods(): assigned_var = scalar.AssignedVariable("a", Decimal("1.0")) assert assigned_var.children() == [] - assert assigned_var.print_node() == "DefaultVariable: a = 1.0" + assert assigned_var.print_node() == "AssignedVariable: a = 1.0" assert str(assigned_var) == "a"