Skip to content
47 changes: 47 additions & 0 deletions dbos/_dbos.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def __init__(
self.conductor_key: Optional[str] = conductor_key
if config.get("conductor_key"):
self.conductor_key = config.get("conductor_key")
self.enable_patching = config.get("enable_patching") == True
self.conductor_websocket: Optional[ConductorWebsocket] = None
self._background_event_loop: BackgroundEventLoop = BackgroundEventLoop()
self._active_workflows_set: set[str] = set()
Expand All @@ -350,6 +351,8 @@ def __init__(
# Globally set the application version and executor ID.
# In DBOS Cloud, instead use the values supplied through environment variables.
if not os.environ.get("DBOS__CLOUD") == "true":
if self.enable_patching:
GlobalParams.app_version = "PATCHING_ENABLED"
if (
"application_version" in config
and config["application_version"] is not None
Expand Down Expand Up @@ -1524,6 +1527,50 @@ async def read_stream_async(
await asyncio.sleep(1.0)
continue

@classmethod
def patch(cls, patch_name: str) -> bool:
if not _get_dbos_instance().enable_patching:
raise DBOSException("enable_patching must be True in DBOS configuration")
ctx = get_local_dbos_context()
if ctx is None or not ctx.is_workflow():
raise DBOSException("DBOS.patch must be called from a workflow")
workflow_id = ctx.workflow_id
function_id = ctx.function_id
patch_name = f"DBOS.patch-{patch_name}"
patched = _get_dbos_instance()._sys_db.patch(
workflow_id=workflow_id, function_id=function_id + 1, patch_name=patch_name
)
# If the patch was applied, increment function ID
if patched:
ctx.function_id += 1
return patched

@classmethod
def patch_async(cls, patch_name: str) -> Coroutine[Any, Any, bool]:
return asyncio.to_thread(cls.patch, patch_name)

@classmethod
def deprecate_patch(cls, patch_name: str) -> bool:
if not _get_dbos_instance().enable_patching:
raise DBOSException("enable_patching must be True in DBOS configuration")
ctx = get_local_dbos_context()
if ctx is None or not ctx.is_workflow():
raise DBOSException("DBOS.deprecate_patch must be called from a workflow")
workflow_id = ctx.workflow_id
function_id = ctx.function_id
patch_name = f"DBOS.patch-{patch_name}"
patch_exists = _get_dbos_instance()._sys_db.deprecate_patch(
workflow_id=workflow_id, function_id=function_id + 1, patch_name=patch_name
)
# If the patch is already in history, increment function ID
if patch_exists:
ctx.function_id += 1
return True

@classmethod
def deprecate_patch_async(cls, patch_name: str) -> Coroutine[Any, Any, bool]:
return asyncio.to_thread(cls.deprecate_patch, patch_name)

@classproperty
def tracer(self) -> DBOSTracer:
"""Return the DBOS OpenTelemetry tracer."""
Expand Down
1 change: 1 addition & 0 deletions dbos/_dbos_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class DBOSConfig(TypedDict, total=False):
conductor_key: Optional[str]
conductor_url: Optional[str]
serializer: Optional[Serializer]
enable_patching: Optional[bool]


class RuntimeConfig(TypedDict, total=False):
Expand Down
16 changes: 12 additions & 4 deletions dbos/_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(self, msg: str):
self.status_code = 403

def __reduce__(self) -> Any:
# Tell jsonpickle how to reconstruct this object
# Tell pickle how to reconstruct this object
return (self.__class__, (self.msg,))


Expand All @@ -162,7 +162,7 @@ def __init__(
)

def __reduce__(self) -> Any:
# Tell jsonpickle how to reconstruct this object
# Tell pickle how to reconstruct this object
return (self.__class__, (self.step_name, self.max_retries, self.errors))


Expand All @@ -182,11 +182,19 @@ class DBOSUnexpectedStepError(DBOSException):
def __init__(
self, workflow_id: str, step_id: int, expected_name: str, recorded_name: str
) -> None:
self.inputs = (workflow_id, step_id, expected_name, recorded_name)
super().__init__(
f"During execution of workflow {workflow_id} step {step_id}, function {recorded_name} was recorded when {expected_name} was expected. Check that your workflow is deterministic.",
dbos_error_code=DBOSErrorCode.UnexpectedStep.value,
)

def __reduce__(self) -> Any:
# Tell pickle how to reconstruct this object
return (
self.__class__,
self.inputs,
)


class DBOSQueueDeduplicatedError(DBOSException):
"""Exception raised when a workflow is deduplicated in the queue."""
Expand All @@ -203,7 +211,7 @@ def __init__(
)

def __reduce__(self) -> Any:
# Tell jsonpickle how to reconstruct this object
# Tell pickle how to reconstruct this object
return (
self.__class__,
(self.workflow_id, self.queue_name, self.deduplication_id),
Expand All @@ -219,7 +227,7 @@ def __init__(self, workflow_id: str):
)

def __reduce__(self) -> Any:
# Tell jsonpickle how to reconstruct this object
# Tell pickle how to reconstruct this object
return (self.__class__, (self.workflow_id,))


Expand Down
40 changes: 40 additions & 0 deletions dbos/_sys_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2361,3 +2361,43 @@ def get_metrics(self, start_time: str, end_time: str) -> List[MetricData]:
)

return metrics

@db_retry()
def patch(self, *, workflow_id: str, function_id: int, patch_name: str) -> bool:
"""If there is no checkpoint for this point in history,
insert a patch marker and return True.
Otherwise, return whether the checkpoint is this patch marker."""
with self.engine.begin() as c:
checkpoint_name: str | None = c.execute(
sa.select(SystemSchema.operation_outputs.c.function_name).where(
(SystemSchema.operation_outputs.c.workflow_uuid == workflow_id)
& (SystemSchema.operation_outputs.c.function_id == function_id)
)
).scalar()
if checkpoint_name is None:
result: OperationResultInternal = {
"workflow_uuid": workflow_id,
"function_id": function_id,
"function_name": patch_name,
"output": None,
"error": None,
"started_at_epoch_ms": int(time.time() * 1000),
}
self._record_operation_result_txn(result, c)
return True
else:
return checkpoint_name == patch_name

@db_retry()
def deprecate_patch(
self, *, workflow_id: str, function_id: int, patch_name: str
) -> bool:
"""Respect patch markers in history, but do not introduce new patch markers"""
with self.engine.begin() as c:
checkpoint_name: str | None = c.execute(
sa.select(SystemSchema.operation_outputs.c.function_name).where(
(SystemSchema.operation_outputs.c.workflow_uuid == workflow_id)
& (SystemSchema.operation_outputs.c.function_id == function_id)
)
).scalar()
return checkpoint_name == patch_name
2 changes: 1 addition & 1 deletion tests/test_dbos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,7 +2078,7 @@ class JsonSerializer(Serializer):
def serialize(self, data: Any) -> str:
return json.dumps(data)

def deserialize(cls, serialized_data: str) -> Any:
def deserialize(self, serialized_data: str) -> Any:
return json.loads(serialized_data)

# Configure DBOS with a JSON-based custom serializer
Expand Down
Loading