Skip to content

Commit

Permalink
Remove record node with Assignment of variables. (#499)
Browse files Browse the repository at this point in the history
* in assign, if record variable is in the assign, remove record node in AST

* adding specific visitor for waveforms.

* fixing bugs adding unit test.

* fixing test

* removing assignment scan inside of schema code gen.
  • Loading branch information
weinbe58 committed Sep 1, 2023
1 parent a68fd8c commit 94d00dd
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 51 deletions.
5 changes: 4 additions & 1 deletion src/bloqade/codegen/common/assign_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ def visit_negative(self, ast: waveform.Negative) -> Any:
return waveform.Negative(self.visit(ast.waveform))

def visit_record(self, ast: waveform.Record) -> Any:
return waveform.Record(self.visit(ast.waveform), ast.var)
if ast.var.name in self.mapping:
return self.visit(ast.waveform)
else:
return waveform.Record(self.visit(ast.waveform), ast.var)

def visit_sample(self, ast: waveform.Sample) -> Any:
dt = self.scalar_visitor.emit(ast.dt)
Expand Down
95 changes: 65 additions & 30 deletions src/bloqade/codegen/common/assignment_scan.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,81 @@
from bloqade.ir.analog_circuit import AnalogCircuit
from bloqade.ir.control.waveform import (
AlignedWaveform,
Constant,
Linear,
Poly,
PythonFn,
Sample,
)
from bloqade.ir.visitor.analog_circuit import AnalogCircuitVisitor
from bloqade.ir.visitor.waveform import WaveformVisitor
from bloqade.ir.analog_circuit import AnalogCircuit
import bloqade.ir.control.sequence as sequence
import bloqade.ir.control.pulse as pulse
import bloqade.ir.control.field as field
import bloqade.ir.control.waveform as waveform
import bloqade.ir.scalar as scalar
import bloqade.ir.analog_circuit as analog_circuit

import numbers
from typing import Any, Dict


class AssignmentScanRecord(WaveformVisitor):
def __init__(self, assignments: Dict[str, numbers.Real] = {}):
self.assignments = dict(assignments)

def visit_record(self, ast: waveform.Record):
duration = ast.waveform.duration(**self.assignments)
var = ast.var
value = ast.waveform(duration, **self.assignments)
self.assignments[var.name] = value
self.visit(ast.waveform)

def visit_append(self, ast: waveform.Append):
list(map(self.visit, ast.waveforms))

def visit_slice(self, ast: waveform.Slice):
self.visit(ast.waveform)

def visit_add(self, ast: waveform.Add):
self.visit(ast.left)
self.visit(ast.right)

def visit_negative(self, ast: waveform.Negative):
self.visit(ast.waveform)

def visit_scale(self, ast: waveform.Scale):
self.visit(ast.waveform)

def visit_smooth(self, ast: waveform.Smooth):
self.visit(ast.waveform)

def visit_sample(self, ast: Sample) -> Any:
return self.visit(ast.waveform)

def visit_alligned(self, ast: AlignedWaveform) -> Any:
return super().visit_alligned(ast)

def visit_constant(self, ast: Constant) -> Any:
pass

def visit_linear(self, ast: Linear) -> Any:
pass

def visit_poly(self, ast: Poly) -> Any:
pass

def visit_python_fn(self, ast: PythonFn) -> Any:
pass

def emit(self, ast: waveform.Waveform) -> Dict[str, numbers.Real]:
self.visit(ast)
return self.assignments


class AssignmentScan(AnalogCircuitVisitor):
def __init__(self, assignments: Dict[str, numbers.Real] = {}):
self.assignments = dict(assignments)
self.waveform_visitor = AssignmentScanRecord(self.assignments)

def visit_analog_circuit(self, ast: AnalogCircuit) -> Any:
self.visit(ast.sequence)
Expand Down Expand Up @@ -50,34 +112,7 @@ def visit_spatial_modulation(self, ast: field.SpatialModulation):
pass

def visit_waveform(self, ast: waveform.Waveform):
match ast:
case waveform.Record(sub_waveform, scalar.Variable(name)):
duration = sub_waveform.duration(**self.assignments)
value = sub_waveform.eval_decimal(duration, **self.assignments)
self.assignments[name] = value
self.visit(sub_waveform)

case waveform.Append(waveforms):
list(map(self.visit, waveforms))

case waveform.Slice(sub_waveform, _):
self.visit(sub_waveform)

case waveform.Add(lhs, rhs):
self.visit(lhs)
self.visit(rhs)

case waveform.Negative(sub_waveform):
self.visit(sub_waveform)

case waveform.Scale(_, sub_waveform):
self.visit(sub_waveform)

case waveform.AlignedWaveform(waveform=sub_waveform):
self.visit(sub_waveform)

case waveform.Smooth(_, sub_waveform):
self.visit(sub_waveform)
self.assignments.update(self.waveform_visitor.emit(ast))

def emit(self, ast: analog_circuit.AnalogCircuit) -> Dict[str, numbers.Real]:
self.visit(ast)
Expand Down
4 changes: 0 additions & 4 deletions src/bloqade/codegen/hardware/quera.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

from bloqade.ir.visitor.analog_circuit import AnalogCircuitVisitor
from bloqade.ir.visitor.waveform import WaveformVisitor
from bloqade.codegen.common.assignment_scan import AssignmentScan

import bloqade.submission.ir.task_specification as task_spec
from bloqade.submission.ir.parallel import ParallelDecoder, ClusterLocationInfo
Expand Down Expand Up @@ -727,9 +726,6 @@ def visit_analog_circuit(self, ast: AnalogCircuit) -> Any:
def emit(
self, nshots: int, analog_circuit: AnalogCircuit
) -> Tuple[task_spec.QuEraTaskSpecification, Optional[ParallelDecoder]]:
self.assignments = AssignmentScan(self.assignments).emit(
analog_circuit.sequence
)
self.visit(analog_circuit)

task_ir = task_spec.QuEraTaskSpecification(
Expand Down
16 changes: 8 additions & 8 deletions src/bloqade/ir/routine/braket.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def compile(
tasks = OrderedDict()

for task_number, batch_params in enumerate(params.batch_assignments(*args)):
final_circuit = AssignAnalogCircuit(batch_params).visit(circuit)
record_params = AssignmentScan().emit(final_circuit)
task_ir, parallel_decoder = QuEraCodeGen(
record_params, capabilities=capabilities
).emit(shots, final_circuit)
record_params = AssignmentScan(batch_params).emit(circuit)
final_circuit = AssignAnalogCircuit(record_params).visit(circuit)
task_ir, parallel_decoder = QuEraCodeGen(capabilities=capabilities).emit(
shots, final_circuit
)

task_ir = task_ir.discretize(capabilities)
tasks[task_number] = BraketTask(
Expand Down Expand Up @@ -207,9 +207,9 @@ def compile(
tasks = OrderedDict()

for task_number, batch_params in enumerate(params.batch_assignments(*args)):
final_circuit = AssignAnalogCircuit(batch_params).visit(circuit)
record_params = AssignmentScan().emit(final_circuit)
quera_task_ir, _ = QuEraCodeGen(record_params).emit(shots, final_circuit)
record_params = AssignmentScan(batch_params).emit(circuit)
final_circuit = AssignAnalogCircuit(record_params).visit(circuit)
quera_task_ir, _ = QuEraCodeGen().emit(shots, final_circuit)

task_ir = to_braket_task_ir(quera_task_ir)

Expand Down
10 changes: 5 additions & 5 deletions src/bloqade/ir/routine/quera.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ def compile(
tasks = OrderedDict()

for task_number, batch_params in enumerate(params.batch_assignments(*args)):
final_circuit = AssignAnalogCircuit(batch_params).visit(circuit)
record_params = AssignmentScan().emit(final_circuit)
task_ir, parallel_decoder = QuEraCodeGen(
record_params, capabilities=capabilities
).emit(shots, final_circuit)
record_params = AssignmentScan(batch_params).emit(circuit)
final_circuit = AssignAnalogCircuit(record_params).visit(circuit)
task_ir, parallel_decoder = QuEraCodeGen(capabilities=capabilities).emit(
shots, final_circuit
)

task_ir = task_ir.discretize(capabilities)
tasks[task_number] = QuEraTask(
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/ir/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def children(self):
return []

def print_node(self):
return f"DefaultVariable: {self.name} = {self.value}"
return f"AssignedVariable: {self.name} = {self.value}"

@validator("name")
def name_validator(cls, v):
Expand Down
56 changes: 55 additions & 1 deletion tests/test_assign.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from bloqade import piecewise_linear
from bloqade import piecewise_linear, start, var, cast
from bloqade.atom_arrangement import Chain
from bloqade.ir import (
rydberg,
Expand All @@ -8,8 +8,12 @@
Pulse,
Field,
AssignedRunTimeVector,
Uniform,
)
import bloqade.ir.control.waveform as waveform
import bloqade.ir.scalar as scalar
from bloqade.codegen.common.assign_variables import AssignAnalogCircuit
from bloqade.codegen.common.assignment_scan import AssignmentScan
from decimal import Decimal
import pytest

Expand Down Expand Up @@ -59,3 +63,53 @@ def test_assignment_error():
circuit = AssignAnalogCircuit(dict(amp=amp)).visit(circuit)
with pytest.raises(ValueError):
circuit = AssignAnalogCircuit(dict(amp=amp)).visit(circuit)


def test_scan():
t = var("t")
circuit = (
start.rydberg.detuning.uniform.constant("max", 1.0)
.slice(0, t)
.record("detuning")
.linear("detuning", 0, 1.0 - t)
.parse_sequence()
)

params = dict(max=10, t=0.1)

completed_params = AssignmentScan(params).emit(circuit)
completed_circuit = AssignAnalogCircuit(completed_params).visit(circuit)

t_assigned = scalar.AssignedVariable("t", 0.1)
max_assigned = scalar.AssignedVariable("max", 10)
detuning_assigned = scalar.AssignedVariable("detuning", 10)
dur_assigned = 1 - t_assigned

interval = waveform.Interval(cast(0), t_assigned)

target_circuit = Sequence(
{
rydberg: Pulse(
{
detuning: Field(
value={
Uniform: waveform.Append(
[
waveform.Slice(
waveform.Constant(max_assigned, cast(1.0)),
interval,
),
waveform.Linear(detuning_assigned, 0, dur_assigned),
]
)
}
)
}
)
}
)

print(repr(completed_circuit))
print(repr(target_circuit))

assert completed_circuit == target_circuit
2 changes: 1 addition & 1 deletion tests/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def test_assigned_variable_methods():
assigned_var = scalar.AssignedVariable("a", Decimal("1.0"))

assert assigned_var.children() == []
assert assigned_var.print_node() == "DefaultVariable: a = 1.0"
assert assigned_var.print_node() == "AssignedVariable: a = 1.0"
assert str(assigned_var) == "a"


Expand Down

0 comments on commit 94d00dd

Please sign in to comment.