Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unitary-Hack][Sub-ir-braket] Doc in submission/ir/braket.py #978

Merged
100 changes: 93 additions & 7 deletions src/bloqade/submission/ir/braket.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""Helper functions related to IR submission
co-ordinations between Bloqade and Braket"""

import braket.ir.ahs as braket_ir
from braket.ahs.pattern import Pattern
from braket.timings import TimeSeries
Expand All @@ -6,31 +9,49 @@
from braket.ahs.driving_field import DrivingField
from braket.ahs.shifting_field import ShiftingField
from braket.ahs.field import Field

weinbe58 marked this conversation as resolved.
Show resolved Hide resolved
from braket.task_result import AnalogHamiltonianSimulationTaskResult

import bloqade.submission.ir.capabilities as cp
weinbe58 marked this conversation as resolved.
Show resolved Hide resolved
from bloqade.submission.ir.task_results import (
QuEraTaskResults,
QuEraTaskStatusCode,
QuEraShotResult,
QuEraShotStatusCode,
)

from bloqade.submission.ir.task_specification import (
QuEraTaskSpecification,
GlobalField,
LocalField,
)

from typing import Tuple, Union, List
from pydantic.v1 import BaseModel
from decimal import Decimal


class BraketTaskSpecification(BaseModel):
"""Class representing geometry of an atom arrangement.

Attributes:
nshots (int): Number of shots
program (braket_ir.Program): IR(Intermediate Representation)
program suitable for braket
"""

nshots: int
program: braket_ir.Program


def to_braket_time_series(times: List[Decimal], values: List[Decimal]) -> TimeSeries:
weinbe58 marked this conversation as resolved.
Show resolved Hide resolved
"""Converts to `TimeSeries` object supported by Braket.

Args:
times (List[Decimal]): Times of the value.
values (List[Decimal]): Corresponding values to add to the time series

Returns:
An object of the type `braket.timings.TimeSeries`
"""
time_series = TimeSeries()
for time, value in zip(times, values):
time_series.put(time, value)
Expand All @@ -39,6 +60,18 @@ def to_braket_time_series(times: List[Decimal], values: List[Decimal]) -> TimeSe


def to_braket_field(quera_field: Union[GlobalField, LocalField]) -> Field:
"""Converts to `TimeSeries` object supported by Braket.

Args:
quera_field (Union[GlobalField, LocalField)]:
Field supported by Quera

Returns:
An object of the type `braket.ahs.field.Field`

Raises:
TypeError: If field is not of the type `GlobalField` or `LocalField`.
"""
if isinstance(quera_field, GlobalField):
times = quera_field.times
values = quera_field.values
Expand All @@ -56,6 +89,12 @@ def to_braket_field(quera_field: Union[GlobalField, LocalField]) -> Field:


def extract_braket_program(quera_task_ir: QuEraTaskSpecification):
"""Extracts the Braket program.

Args:
quera_task_ir (QuEraTaskSpecification):
Quera IR(Intermediate representation) of the task.
"""
lattice = quera_task_ir.lattice

rabi_amplitude = (
Expand Down Expand Up @@ -90,18 +129,47 @@ def extract_braket_program(quera_task_ir: QuEraTaskSpecification):
def to_braket_task(
quera_task_ir: QuEraTaskSpecification,
) -> Tuple[int, AnalogHamiltonianSimulation]:
"""Converts to `Tuple[int, AnalogHamiltonianSimulation]` object supported by Braket.

Args:
quera_task_ir (QuEraTaskSpecification):
Quera IR(Intermediate representation) of the task.

Returns:
An tuple of the type `Tuple[int, AnalogHamiltonianSimulation]`.
"""
braket_ahs_program = extract_braket_program(quera_task_ir)
return quera_task_ir.nshots, braket_ahs_program


def to_braket_task_ir(quera_task_ir: QuEraTaskSpecification) -> BraketTaskSpecification:
"""Converts quera IR(Intermendiate Representation) to
to `BraketTaskSpecification` object.

Args:
quera_task_ir (QuEraTaskSpecification):
Quera IR(Intermediate representation) of the task.

Returns:
An object of the type `BraketTaskSpecification` in Braket SDK

"""
nshots, braket_ahs_program = to_braket_task(quera_task_ir)
return BraketTaskSpecification(nshots=nshots, program=braket_ahs_program.to_ir())


def from_braket_task_results(
braket_task_results: AnalogHamiltonianSimulationTaskResult,
) -> QuEraTaskResults:
"""Get the `QuEraTaskResults` object for working with Bloqade SDK.

Args:
braket_task_results: AnalogHamiltonianSimulationTaskResult
Quantum task result of braket system

Returns:
An object of the type `Field` in Braket SDK.
"""
shot_outputs = []
for measurement in braket_task_results.measurements:
shot_outputs.append(
Expand All @@ -117,16 +185,34 @@ def from_braket_task_results(
)


def from_braket_status_codes(braket_message: str) -> QuEraTaskStatusCode:
if braket_message == str("QUEUED"):
def from_braket_status_codes(braket_status: str) -> QuEraTaskStatusCode:
weinbe58 marked this conversation as resolved.
Show resolved Hide resolved
"""Gets the `QuEraTaskStatusCode` object for working with Bloqade SDK.

Args:
braket_status: str
The value of status in metadata() in the Amazon Braket.
`GetQuantumTask` operation. If use_cached_value is True,
Manvi-Agrawal marked this conversation as resolved.
Show resolved Hide resolved
the value most recently returned from
`GetQuantumTask` operation is used

Returns:
An object of the type `Field` in Braket SDK
"""
if braket_status == str("QUEUED"):
return QuEraTaskStatusCode.Enqueued
else:
return QuEraTaskStatusCode(braket_message.lower().capitalize())
return QuEraTaskStatusCode(braket_status.lower().capitalize())


def to_quera_capabilities(paradigm) -> cp.QuEraCapabilities:
"""Converts to `QuEraCapabilities` object supported by Braket.

def to_quera_capabilities(paradigm):
import bloqade.submission.ir.capabilities as cp
Args:
paradigm: Bracket paradigm
Manvi-Agrawal marked this conversation as resolved.
Show resolved Hide resolved

Returns:
An object of the type `QuEraCapabilities` in Bloqade SDK.
"""
rydberg_global = paradigm.rydberg.rydbergGlobal

return cp.QuEraCapabilities(
Expand Down