Skip to content

Commit

Permalink
Fixing Builder API with multiple Spatial modulations. (#360)
Browse files Browse the repository at this point in the history
* adding termation function for spatial modulation.

* changing from terminate field build from appending waveforms to add waveforms.

* removing print statement.

* update, add testing cases

---------

Co-authored-by: Kai-Hsin Wu <khwu@KHWus-MacBook-Pro.local>
  • Loading branch information
weinbe58 and Kai-Hsin Wu authored Aug 2, 2023
1 parent 89b00d3 commit ebc745e
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 30 deletions.
51 changes: 21 additions & 30 deletions src/bloqade/builder/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class BuildState(BaseModel):

waveform_slice: Optional[Tuple[Optional[ir.Scalar], Optional[ir.Scalar]]] = None
waveform_record: Optional[str] = None
spatial_modulation: Optional[ir.SpatialModulation] = None
scaled_locations: ir.ScaledLocations = ir.ScaledLocations({})
field: ir.Field = ir.Field({})
detuning: ir.Field = ir.Field({})
Expand Down Expand Up @@ -221,6 +222,14 @@ def __terminate_waveform_append(build_state):

build_state.waveform_build = None

@staticmethod
def __terminate_spatial_modulation(build_state: BuildState):
Emit.__terminate_waveform_append(build_state)
build_state.field = build_state.field.add(
ir.Field(value={build_state.spatial_modulation: build_state.waveform})
)
build_state.waveform = None

@staticmethod
def __build_ast(builder: Builder, build_state: BuildState):
import bloqade.builder.waveform as waveform
Expand All @@ -230,7 +239,6 @@ def __build_ast(builder: Builder, build_state: BuildState):
import bloqade.builder.coupling as coupling
import bloqade.builder.start as start

# print(type(build_state.waveform))
match builder:
case (
waveform.Linear()
Expand All @@ -257,14 +265,9 @@ def __build_ast(builder: Builder, build_state: BuildState):
scale = builder._scale
loc = ir.Location(builder.__parent__._label)
build_state.scaled_locations.value[loc] = scale

new_field = ir.Field(
{build_state.scaled_locations: build_state.waveform}
)
build_state.field = build_state.field.add(new_field)

build_state.spatial_modulation = build_state.scaled_locations
build_state.scaled_locations = ir.ScaledLocations({})
build_state.waveform = None
Emit.__terminate_spatial_modulation(build_state)

Emit.__build_ast(builder.__parent__.__parent__, build_state)

Expand Down Expand Up @@ -311,15 +314,14 @@ def __build_ast(builder: Builder, build_state: BuildState):
Emit.__terminate_waveform_append(build_state)
scale = ir.cast(1.0)
loc = ir.Location(builder._label)
# update current list of scaled locations
build_state.scaled_locations.value[loc] = scale

new_field = ir.Field(
{build_state.scaled_locations: build_state.waveform}
)
build_state.field = build_state.field.add(new_field)

# copy scaled locations to the current spatial modulation
build_state.spatial_modulation = build_state.scaled_locations
# reset scaled locations
build_state.scaled_locations = ir.ScaledLocations({})
build_state.waveform = None
# terminate building of a field
Emit.__terminate_spatial_modulation(build_state)

Emit.__build_ast(builder.__parent__, build_state)

Expand All @@ -338,28 +340,17 @@ def __build_ast(builder: Builder, build_state: BuildState):
Emit.__build_ast(builder.__parent__, build_state)

case location.Uniform():
Emit.__terminate_waveform_append(build_state)
new_field = ir.Field({ir.Uniform: build_state.waveform})
build_state.field = build_state.field.add(new_field)

# reset build_state values
build_state.waveform = None
build_state.spatial_modulation = ir.Uniform
Emit.__terminate_spatial_modulation(build_state)
Emit.__build_ast(builder.__parent__, build_state)

case location.Var():
Emit.__terminate_waveform_append(build_state)
new_field = ir.Field(
{ir.RunTimeVector(builder._name): build_state.waveform}
)
build_state.field = build_state.field.add(new_field)

# reset build_state values
build_state.waveform = None
build_state.spatial_modulation = ir.RunTimeVector(builder._name)
Emit.__terminate_spatial_modulation(build_state)
Emit.__build_ast(builder.__parent__, build_state)

case field.Detuning():
build_state.detuning = build_state.detuning.add(build_state.field)

# reset build_state values
build_state.field = ir.Field({})
Emit.__build_ast(builder.__parent__, build_state)
Expand Down
29 changes: 29 additions & 0 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,35 @@ def test_issue_150():
)


def test_303_replicate_channel_should_add():
prog = (
start.rydberg.detuning.uniform.linear(0, 1, 1)
.rabi.amplitude.uniform.linear(1, 2, 1)
.detuning.uniform.linear(0, 2, 3)
)

assert prog.sequence == ir.Sequence(
{
ir.rydberg: ir.Pulse(
{
ir.rabi.amplitude: ir.Field({ir.Uniform: ir.Linear(1, 2, 1)}),
ir.detuning: ir.Field(
{ir.Uniform: ir.Linear(0, 2, 3) + ir.Linear(0, 1, 1)}
),
}
)
}
)

prog1 = (
start.rydberg.detuning.uniform.linear(0, 1, 1)
.rabi.amplitude.uniform.linear(1, 2, 1)
.rydberg.detuning.uniform.linear(0, 2, 3)
)

assert prog1.sequence == prog.sequence


def test_record():
prog = start
prog = (
Expand Down

0 comments on commit ebc745e

Please sign in to comment.