Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/packages/core/agent_framework/_workflows/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ async def _run_stream_impl(
"""
# Determine the event stream based on whether we have function responses
if bool(self.pending_requests):
# This is a continuation - use send_responses_streaming to send function responses back
# This is a continuation - use run_stream with responses to send function responses back
logger.info(f"Continuing workflow to address {len(self.pending_requests)} requests")

# Extract function responses from input messages, and ensure that
Expand All @@ -212,7 +212,7 @@ async def _run_stream_impl(
# NOTE: It is possible that some pending requests are not fulfilled,
# and we will let the workflow to handle this -- the agent does not
# have an opinion on this.
event_stream = self.workflow.send_responses_streaming(function_responses)
event_stream = self.workflow.run_stream(responses=function_responses)
else:
# Execute workflow with streaming (initial run or no function responses)
# Pass the new input messages directly to the workflow
Expand Down
38 changes: 0 additions & 38 deletions python/packages/core/agent_framework/_workflows/_magentic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2458,17 +2458,6 @@ async def _validate_checkpoint_participants(
f"Missing names: {missing}; unexpected names: {unexpected}."
)

async def run_stream_from_checkpoint(
self,
checkpoint_id: str,
checkpoint_storage: CheckpointStorage | None = None,
responses: dict[str, Any] | None = None,
) -> AsyncIterable[WorkflowEvent]:
"""Resume orchestration from a checkpoint and stream resulting events."""
await self._validate_checkpoint_participants(checkpoint_id, checkpoint_storage)
async for event in self._workflow.run_stream_from_checkpoint(checkpoint_id, checkpoint_storage, responses):
yield event

async def run_with_string(self, task_text: str) -> WorkflowRunResult:
"""Run the workflow with a task string and return all events.

Expand Down Expand Up @@ -2512,33 +2501,6 @@ async def run(self, message: Any | None = None) -> WorkflowRunResult:
events.append(event)
return WorkflowRunResult(events)

async def run_from_checkpoint(
self,
checkpoint_id: str,
checkpoint_storage: CheckpointStorage | None = None,
responses: dict[str, Any] | None = None,
) -> WorkflowRunResult:
"""Resume orchestration from a checkpoint and collect all resulting events."""
events: list[WorkflowEvent] = []
async for event in self.run_stream_from_checkpoint(checkpoint_id, checkpoint_storage, responses):
events.append(event)
return WorkflowRunResult(events)

async def send_responses_streaming(self, responses: dict[str, Any]) -> AsyncIterable[WorkflowEvent]:
"""Forward responses to pending requests and stream resulting events.

This delegates to the underlying Workflow implementation.
"""
async for event in self._workflow.send_responses_streaming(responses):
yield event

async def send_responses(self, responses: dict[str, Any]) -> WorkflowRunResult:
"""Forward responses to pending requests and return all resulting events.

This delegates to the underlying Workflow implementation.
"""
return await self._workflow.send_responses(responses)

def __getattr__(self, name: str) -> Any:
"""Delegate unknown attributes to the underlying workflow."""
return getattr(self._workflow, name)
Expand Down
41 changes: 36 additions & 5 deletions python/packages/core/agent_framework/_workflows/_runner_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,18 @@ def has_checkpointing(self) -> bool:
"""
...

def set_runtime_checkpoint_storage(self, storage: CheckpointStorage) -> None:
"""Set runtime checkpoint storage to override build-time configuration.

Args:
storage: The checkpoint storage to use for this run.
"""
...

def clear_runtime_checkpoint_storage(self) -> None:
"""Clear runtime checkpoint storage override."""
...

# Checkpointing APIs (optional, enabled by storage)
def set_workflow_id(self, workflow_id: str) -> None:
"""Set the workflow ID for the context."""
Expand Down Expand Up @@ -202,6 +214,7 @@ def __init__(self, checkpoint_storage: CheckpointStorage | None = None):

# Checkpointing configuration/state
self._checkpoint_storage = checkpoint_storage
self._runtime_checkpoint_storage: CheckpointStorage | None = None
self._workflow_id: str | None = None

# Streaming flag - set by workflow's run_stream() vs run()
Expand Down Expand Up @@ -252,16 +265,33 @@ async def next_event(self) -> WorkflowEvent:

# region Checkpointing

def _get_effective_checkpoint_storage(self) -> CheckpointStorage | None:
"""Get the effective checkpoint storage (runtime override or build-time)."""
return self._runtime_checkpoint_storage or self._checkpoint_storage

def set_runtime_checkpoint_storage(self, storage: CheckpointStorage) -> None:
"""Set runtime checkpoint storage to override build-time configuration.

Args:
storage: The checkpoint storage to use for this run.
"""
self._runtime_checkpoint_storage = storage

def clear_runtime_checkpoint_storage(self) -> None:
"""Clear runtime checkpoint storage override."""
self._runtime_checkpoint_storage = None

def has_checkpointing(self) -> bool:
return self._checkpoint_storage is not None
return self._get_effective_checkpoint_storage() is not None

async def create_checkpoint(
self,
shared_state: SharedState,
iteration_count: int,
metadata: dict[str, Any] | None = None,
) -> str:
if not self._checkpoint_storage:
storage = self._get_effective_checkpoint_storage()
if not storage:
raise ValueError("Checkpoint storage not configured")

self._workflow_id = self._workflow_id or str(uuid.uuid4())
Expand All @@ -274,14 +304,15 @@ async def create_checkpoint(
iteration_count=state["iteration_count"],
metadata=metadata or {},
)
checkpoint_id = await self._checkpoint_storage.save_checkpoint(checkpoint)
checkpoint_id = await storage.save_checkpoint(checkpoint)
logger.info(f"Created checkpoint {checkpoint_id} for workflow {self._workflow_id}")
return checkpoint_id

async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None:
if not self._checkpoint_storage:
storage = self._get_effective_checkpoint_storage()
if not storage:
raise ValueError("Checkpoint storage not configured")
return await self._checkpoint_storage.load_checkpoint(checkpoint_id)
return await storage.load_checkpoint(checkpoint_id)

def reset_for_new_run(self) -> None:
"""Reset the context for a new workflow run.
Expand Down
Loading