Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unitary hack] Improve Docstrings task/quera.py #971

Closed
wants to merge 5 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 114 additions & 15 deletions src/bloqade/task/quera.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,42 @@
"""
This module contains the definition and serialization logic for the QuEraTask class,
which represents a task to be run on a quantum computing backend. The QuEraTask class
handles task submission, validation, status checking, and result fetching. The module
also includes serialization and deserialization functions for the QuEraTask class,
enabling tasks to be easily saved and loaded.
"""

import warnings
from dataclasses import dataclass, field
from bloqade.serialize import Serializer
from bloqade.submission.mock import MockBackend
from bloqade.task.base import Geometry
from bloqade.task.base import RemoteTask
from bloqade.task.base import Geometry, RemoteTask

from bloqade.submission.base import ValidationError
from bloqade.submission.ir.task_results import QuEraTaskResults, QuEraTaskStatusCode
from bloqade.submission.ir.task_specification import QuEraTaskSpecification
from bloqade.submission.ir.parallel import ParallelDecoder
from bloqade.submission.quera import QuEraBackend
from bloqade.builder.base import ParamType

from beartype.typing import Dict, Optional, Union, Any
from bloqade.builder.base import ParamType
from dataclasses import dataclass, field
import warnings


@dataclass
@Serializer.register
class QuEraTask(RemoteTask):
"""
Represents a task to be run on a quantum computing backend.

Attributes:
task_id (Optional[str]): The ID of the task.
backend (Union[QuEraBackend, MockBackend]): The backend where the task is executed.
task_ir (QuEraTaskSpecification): The task specification.
metadata (Dict[str, ParamType]): Metadata associated with the task.
parallel_decoder (Optional[ParallelDecoder]): Parallel decoder associated with the task.
task_result_ir (QuEraTaskResults): The task result status.
"""

task_id: Optional[str]
backend: Union[QuEraBackend, MockBackend]
task_ir: QuEraTaskSpecification
Expand All @@ -30,19 +49,34 @@ class QuEraTask(RemoteTask):
)

def submit(self, force: bool = False) -> "QuEraTask":
"""
Submits the task to the backend.

Args:
force (bool): If True, force submission even if the task is already submitted.

Returns:
QuEraTask: The submitted task.

Raises:
ValueError: If the task is already submitted and force is False.
"""
if not force:
if self.task_id is not None:
raise ValueError(
"the task is already submitted with %s" % (self.task_id)
)
raise ValueError(f"The task is already submitted with {self.task_id}")

self.task_id = self.backend.submit_task(self.task_ir)

self.task_result_ir = QuEraTaskResults(task_status=QuEraTaskStatusCode.Enqueued)

return self

def validate(self) -> str:
"""
Validates the task specification against the backend.

Returns:
str: An empty string if validation is successful,otherwise the validation error message.
"""
try:
self.backend.validate_task(self.task_ir)
except ValidationError as e:
Expand All @@ -51,7 +85,15 @@ def validate(self) -> str:
return ""

def fetch(self) -> "QuEraTask":
# non-blocking, pull only when its completed
"""
Fetches the task results if the task is completed.

Returns:
QuEraTask: The task with updated results.

Raises:
ValueError: If the task status is unsubmitted.
"""
if self.task_result_ir.task_status is QuEraTaskStatusCode.Unsubmitted:
raise ValueError("Task ID not found.")

Expand All @@ -73,7 +115,15 @@ def fetch(self) -> "QuEraTask":
return self

def pull(self) -> "QuEraTask":
# blocking, force pulling, even its completed
"""
Forcefully pulls the task results from the backend.
shubhusion marked this conversation as resolved.
Show resolved Hide resolved

Returns:
QuEraTask: The task with updated results.

Raises:
ValueError: If the task ID is not found.
"""
if self.task_id is None:
raise ValueError("Task ID not found.")

Expand All @@ -82,8 +132,12 @@ def pull(self) -> "QuEraTask":
return self

def result(self) -> QuEraTaskResults:
# blocking, caching
"""
Retrieves the task results, blocking if necessary.

Returns:
QuEraTaskResults: The results of the task.
"""
if self.task_result_ir is None:
pass
else:
Expand All @@ -96,12 +150,24 @@ def result(self) -> QuEraTaskResults:
return self.task_result_ir

def status(self) -> QuEraTaskStatusCode:
"""
Gets the current status of the task.

Returns:
QuEraTaskStatusCode: The status of the task.
"""
if self.task_id is None:
return QuEraTaskStatusCode.Unsubmitted

return self.backend.task_status(self.task_id)

def cancel(self) -> None:
"""
Cancels the task if it is already submitted.

Raises:
UserWarning: If the task ID is not found.
"""
if self.task_id is None:
warnings.warn("Cannot cancel task, missing task id.")
return
Expand All @@ -110,16 +176,34 @@ def cancel(self) -> None:

@property
def nshots(self):
"""
Returns the number of shots for the task.

Returns:
int: The number of shots.
"""
return self.task_ir.nshots

def _geometry(self) -> Geometry:
"""
Retrieves the geometry of the task.

Returns:
Geometry: The geometry of the task.
"""
return Geometry(
sites=self.task_ir.lattice.sites,
filling=self.task_ir.lattice.filling,
parallel_decoder=self.parallel_decoder,
)

def _result_exists(self) -> bool:
"""
Checks if the task result exists and is completed.

Returns:
bool: True if the task result exists and is completed, False otherwise.
"""
if self.task_result_ir is None:
return False
else:
Expand All @@ -131,12 +215,18 @@ def _result_exists(self) -> bool:
else:
return False

# def submit_no_task_id(self) -> "HardwareTaskShotResults":
# return HardwareTaskShotResults(hardware_task=self)


@QuEraTask.set_serializer
def _serialze(obj: QuEraTask) -> Dict[str, ParamType]:
"""
Serializes the QuEraTask object.

Args:
obj (QuEraTask): The QuEraTask object to serialize.

Returns:
Dict[str, ParamType]: The serialized QuEraTask object.
"""
shubhusion marked this conversation as resolved.
Show resolved Hide resolved
return {
"task_id": obj.task_id if obj.task_id is not None else None,
"task_ir": obj.task_ir.dict(by_alias=True, exclude_none=True),
Expand All @@ -155,6 +245,15 @@ def _serialze(obj: QuEraTask) -> Dict[str, ParamType]:

@QuEraTask.set_deserializer
def _deserializer(d: Dict[str, Any]) -> QuEraTask:
"""
Deserializes a dictionary into a QuEraTask object.

Args:
d (Dict[str, Any]): The dictionary to deserialize.

Returns:
QuEraTask: The deserialized QuEraTask object.
"""
shubhusion marked this conversation as resolved.
Show resolved Hide resolved
d["task_ir"] = QuEraTaskSpecification(**d["task_ir"])
d["task_result_ir"] = (
QuEraTaskResults(**d["task_result_ir"]) if d["task_result_ir"] else None
Expand Down
Loading