Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _user_function_start_info(operation_id: str) -> UserFunctionStartInfo:
name="fetch-user",
parent_id=None,
start_time=START_TIME,
is_replayed=False,
is_replay_children=False,
attempt=1,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def _user_function_start_info(
name=f"step-{operation_id}",
parent_id=parent_id,
start_time=START_TIME,
is_replayed=False,
is_replay_children=False,
attempt=attempt,
)
Expand All @@ -121,6 +122,7 @@ def _user_function_end_info(
name=f"step-{operation_id}",
parent_id=parent_id,
start_time=START_TIME,
is_replayed=False,
is_replay_children=False,
attempt=attempt,
outcome=outcome,
Expand Down Expand Up @@ -158,6 +160,7 @@ def test_operation_callbacks_emit_child_span_with_deterministic_span_id():
name="wait-for-signal",
parent_id=None,
start_time=START_TIME,
is_replayed=False,
)
)
plugin.on_operation_end(
Expand All @@ -168,6 +171,7 @@ def test_operation_callbacks_emit_child_span_with_deterministic_span_id():
name="wait-for-signal",
parent_id=None,
start_time=START_TIME,
is_replayed=False,
status=OperationStatus.SUCCEEDED,
end_time=END_TIME,
error=None,
Expand Down Expand Up @@ -202,6 +206,7 @@ def test_operation_end_without_start_emits_continuation_span_with_link():
name="existing-wait",
parent_id=None,
start_time=START_TIME,
is_replayed=False,
status=OperationStatus.SUCCEEDED,
end_time=END_TIME,
error=None,
Expand All @@ -227,6 +232,7 @@ def test_user_function_callbacks_emit_attempt_span_attributes():
name="fetch-user",
parent_id=None,
start_time=START_TIME,
is_replayed=False,
is_replay_children=False,
attempt=1,
)
Expand All @@ -247,6 +253,7 @@ def test_user_function_callbacks_emit_attempt_span_attributes():
name="fetch-user",
parent_id=None,
start_time=START_TIME,
is_replayed=False,
is_replay_children=False,
attempt=1,
outcome=UserFunctionOutcome.SUCCEEDED,
Expand Down Expand Up @@ -288,6 +295,7 @@ def test_context_span_omits_attempt_attributes():
name="book-trip",
parent_id=None,
start_time=START_TIME,
is_replayed=False,
is_replay_children=False,
attempt=1,
)
Expand All @@ -300,6 +308,7 @@ def test_context_span_omits_attempt_attributes():
name="book-trip",
parent_id=None,
start_time=START_TIME,
is_replayed=False,
is_replay_children=False,
attempt=1,
outcome=UserFunctionOutcome.SUCCEEDED,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,16 +498,17 @@ def test_event_timeout_handling():
scheduler = Scheduler()
scheduler.start()

event = scheduler.create_event()

start_time = time.time()
result = event.wait(timeout=0.05)
end_time = time.time()
try:
event = scheduler.create_event()

assert result is False
assert 0.04 <= (end_time - start_time) <= 0.1
start_time = time.monotonic()
result = event.wait(timeout=0.05)
end_time = time.monotonic()

scheduler.stop()
assert result is False
assert 0.04 <= (end_time - start_time) <= 0.25
finally:
scheduler.stop()


def test_scheduler_call_later_zero_delay():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class OperationInfo:
name: str | None
parent_id: str | None
start_time: datetime.datetime | None
is_replayed: bool


@dataclass(frozen=True)
Expand Down Expand Up @@ -91,6 +92,7 @@ def from_start_info(
name=start_info.name,
parent_id=start_info.parent_id,
start_time=start_info.start_time,
is_replayed=start_info.is_replayed,
is_replay_children=start_info.is_replay_children,
attempt=start_info.attempt,
outcome=UserFunctionOutcome.from_error(error),
Expand Down Expand Up @@ -156,7 +158,8 @@ def on_invocation_end(self, info: InvocationEndInfo) -> None:

def on_operation_start(self, info: OperationStartInfo) -> None:
"""
Called when an operation checkpoints STARTED status. This is called NOT within the thread that runs operation.
Called when an operation checkpoints STARTED status, or when a prior
operation is replayed. This is called NOT within the thread that runs operation.

Args:
info: Information about the operation.
Expand All @@ -166,7 +169,8 @@ def on_operation_start(self, info: OperationStartInfo) -> None:

def on_operation_end(self, info: OperationEndInfo) -> None:
"""
Called when an operation checkpoints a terminal status. This is called NOT within the thread that runs operation.
Called when an operation checkpoints a terminal status, or when a prior
terminal operation is replayed. This is called NOT within the thread that runs operation.

Args:
info: Information about the operation.
Expand Down Expand Up @@ -295,6 +299,7 @@ def on_user_function_start(
name=operation_identifier.name,
parent_id=operation_identifier.parent_id,
start_time=datetime.datetime.now(datetime.UTC),
is_replayed=False,
is_replay_children=is_replay_children,
attempt=attempt,
)
Expand Down Expand Up @@ -325,6 +330,37 @@ def on_operation_action(self, update: OperationUpdate):
name=update.name,
parent_id=update.parent_id,
start_time=datetime.datetime.now(datetime.UTC),
is_replayed=False,
),
sync=True,
)

def on_operation_replay(self, operation: Operation) -> None:
"""Execute plugins for a checkpointed operation observed during replay."""
start_info = OperationStartInfo(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently, is this being executed on every single replay on every invocation?

@zhongkechen zhongkechen Jun 30, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

called once per operation per invocation if the operation's context is replayed. If the context is completed and the result is cached, its child operations will not be replayed and thus this will not be called.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this affect the Otel Plugin behaviour?

operation_id=operation.operation_id,
operation_type=operation.operation_type,
sub_type=operation.sub_type,
name=operation.name,
parent_id=operation.parent_id,
start_time=operation.start_timestamp,
is_replayed=True,
)
self.execute_plugins(start_info, sync=True)

if self._is_terminal_status(operation.status):
self.execute_plugins(
Comment thread
SilanHe marked this conversation as resolved.
OperationEndInfo(
operation_id=operation.operation_id,
operation_type=operation.operation_type,
sub_type=operation.sub_type,
name=operation.name,
parent_id=operation.parent_id,
start_time=operation.start_timestamp,
end_time=operation.end_timestamp,
status=operation.status,
error=self._extract_error(operation),
is_replayed=True,
),
sync=True,
)
Expand Down Expand Up @@ -352,6 +388,7 @@ def on_operation_update(self, operation: Operation | None):
end_time=operation.end_timestamp,
status=operation.status,
error=self._extract_error(operation),
is_replayed=False,
),
sync=True,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def __init__(
self._replay_status: ReplayStatus = replay_status
self._replay_status_lock: Lock = Lock()
self._visited_operations: set[str] = set()
self._replayed_operation_hooks: set[str] = set()

@property
def operations(self) -> dict[str, Operation]:
Expand Down Expand Up @@ -444,11 +445,30 @@ def get_checkpoint_result(self, checkpoint_id: str) -> CheckpointedResult:
"""
# checking status are deliberately under a lighter non-serialized lock
with self._operations_lock:
if checkpoint := self._operations.get(checkpoint_id):
return CheckpointedResult.create_from_operation(checkpoint)
checkpoint = self._operations.get(checkpoint_id)

if checkpoint:
self._emit_operation_replay_hooks(checkpoint)
return CheckpointedResult.create_from_operation(checkpoint)

return CHECKPOINT_NOT_FOUND

def _emit_operation_replay_hooks(self, operation: Operation) -> None:
"""Emit operation hooks once for each checkpointed operation during replay."""
if operation.operation_type is OperationType.EXECUTION:
return
if operation.status is OperationStatus.READY:
return

with self._replay_status_lock:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is still using the state level replay status, which will be removed in #488, do we want to keep this ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be updated to use the new per context replay status

if self._replay_status is not ReplayStatus.REPLAY:
return
if operation.operation_id in self._replayed_operation_hooks:
return
self._replayed_operation_hooks.add(operation.operation_id)
Comment on lines +466 to +468

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused. What is this _replayed_operation_hooks variable used for? Wouldn't this reset on each invocation?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on_operation_replay method has been called for operations in this dict. It's reset per invocation.


self._plugin_executor.on_operation_replay(operation)

def create_checkpoint(
self,
operation_update: OperationUpdate | None = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
name="my-op",
parent_id="parent-1",
start_time=START_TS,
is_replayed=False,
)
OPERATION_END_INFO = OperationEndInfo(
operation_id="op-1",
Expand All @@ -48,6 +49,7 @@
name="my-op",
parent_id="parent-1",
start_time=START_TS,
is_replayed=False,
status=OperationStatus.FAILED,
end_time=END_TS,
error=ERROR,
Expand Down Expand Up @@ -76,6 +78,7 @@
name="func",
parent_id="parent-1",
start_time=START_TS,
is_replayed=False,
)

USER_FUNCTION_END_INFO = UserFunctionEndInfo(
Expand All @@ -85,6 +88,7 @@
name="func",
parent_id="parent-1",
start_time=START_TS,
is_replayed=False,
is_replay_children=False,
attempt=1,
outcome=UserFunctionOutcome.FAILED,
Expand All @@ -99,6 +103,7 @@ def test_operation_start_info(self):
self.assertEqual(OPERATION_START_INFO.name, "my-op")
self.assertEqual(OPERATION_START_INFO.parent_id, "parent-1")
self.assertEqual(OPERATION_START_INFO.start_time, START_TS)
self.assertFalse(OPERATION_START_INFO.is_replayed)

def test_operation_end_info(self):
self.assertEqual(OPERATION_END_INFO.status, OperationStatus.FAILED)
Expand All @@ -111,6 +116,7 @@ def test_operation_end_info(self):
self.assertEqual(OPERATION_END_INFO.operation_id, "op-1")
self.assertEqual(OPERATION_END_INFO.status, OperationStatus.FAILED)
self.assertEqual(OPERATION_END_INFO.operation_id, "op-1")
self.assertFalse(OPERATION_END_INFO.is_replayed)

def test_invocation_start_info(self):
self.assertEqual(INVOCATION_START_INFO.request_id, "req-1")
Expand Down
78 changes: 78 additions & 0 deletions packages/aws-durable-execution-sdk-python/tests/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4301,6 +4301,84 @@ def test_plugin_executor_not_called_for_pending_operations():
assert len(operation_end_calls) == 0


def test_plugin_executor_emits_start_and_end_for_replayed_terminal_operation():
"""Replay reads emit operation lifecycle hooks with is_replayed=True."""
start_time = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC)
end_time = datetime.datetime(2025, 1, 2, tzinfo=datetime.UTC)
operation = Operation(
operation_id="step-1",
operation_type=OperationType.STEP,
status=OperationStatus.SUCCEEDED,
parent_id="parent-1",
name="my-step",
start_timestamp=start_time,
end_timestamp=end_time,
sub_type=OperationSubType.STEP,
step_details=StepDetails(attempt=1, result='"done"'),
)
captured: list[tuple[str, str, bool]] = []

class _CapturingPlugin(DurableInstrumentationPlugin):
def on_operation_start(self, info):
captured.append(("start", info.operation_id, info.is_replayed))

def on_operation_end(self, info):
captured.append(("end", info.operation_id, info.is_replayed))

plugin_executor = PluginExecutor(plugins=[_CapturingPlugin()])
with plugin_executor.run():
state = ExecutionState(
durable_execution_arn="test_arn",
initial_checkpoint_token="token123", # noqa: S106
operations={"step-1": operation},
service_client=create_autospec(spec=LambdaClient),
plugin_executor=plugin_executor,
replay_status=ReplayStatus.REPLAY,
)

assert state.get_checkpoint_result("step-1").is_succeeded()
assert state.get_checkpoint_result("step-1").is_succeeded()

assert captured == [
("start", "step-1", True),
("end", "step-1", True),
]


def test_plugin_executor_emits_only_start_for_replayed_non_terminal_operation():
"""Replay reads for in-flight operations emit start but not end."""
operation = Operation(
operation_id="wait-1",
operation_type=OperationType.WAIT,
status=OperationStatus.STARTED,
name="my-wait",
sub_type=OperationSubType.WAIT,
)
captured: list[tuple[str, str, bool]] = []

class _CapturingPlugin(DurableInstrumentationPlugin):
def on_operation_start(self, info):
captured.append(("start", info.operation_id, info.is_replayed))

def on_operation_end(self, info):
captured.append(("end", info.operation_id, info.is_replayed))

plugin_executor = PluginExecutor(plugins=[_CapturingPlugin()])
with plugin_executor.run():
state = ExecutionState(
durable_execution_arn="test_arn",
initial_checkpoint_token="token123", # noqa: S106
operations={"wait-1": operation},
service_client=create_autospec(spec=LambdaClient),
plugin_executor=plugin_executor,
replay_status=ReplayStatus.REPLAY,
)

assert state.get_checkpoint_result("wait-1").is_started()

assert captured == [("start", "wait-1", True)]


# endregion Plugin Executor Integration Tests


Expand Down