diff --git a/packages/aws-durable-execution-sdk-python-otel/tests/test_log_filter.py b/packages/aws-durable-execution-sdk-python-otel/tests/test_log_filter.py index a3a5bcd..13c969f 100644 --- a/packages/aws-durable-execution-sdk-python-otel/tests/test_log_filter.py +++ b/packages/aws-durable-execution-sdk-python-otel/tests/test_log_filter.py @@ -60,6 +60,7 @@ def _user_function_start_info(operation_id: str) -> UserFunctionStartInfo: name="fetch-user", parent_id=None, start_time=START_TIME, + is_replayed=False, is_replay_children=False, attempt=1, ) diff --git a/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py b/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py index 771547f..8286088 100644 --- a/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py +++ b/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py @@ -101,6 +101,7 @@ def _user_function_start_info( name=f"step-{operation_id}", parent_id=parent_id, start_time=START_TIME, + is_replayed=False, is_replay_children=False, attempt=attempt, ) @@ -121,6 +122,7 @@ def _user_function_end_info( name=f"step-{operation_id}", parent_id=parent_id, start_time=START_TIME, + is_replayed=False, is_replay_children=False, attempt=attempt, outcome=outcome, @@ -158,6 +160,7 @@ def test_operation_callbacks_emit_child_span_with_deterministic_span_id(): name="wait-for-signal", parent_id=None, start_time=START_TIME, + is_replayed=False, ) ) plugin.on_operation_end( @@ -168,6 +171,7 @@ def test_operation_callbacks_emit_child_span_with_deterministic_span_id(): name="wait-for-signal", parent_id=None, start_time=START_TIME, + is_replayed=False, status=OperationStatus.SUCCEEDED, end_time=END_TIME, error=None, @@ -202,6 +206,7 @@ def test_operation_end_without_start_emits_continuation_span_with_link(): name="existing-wait", parent_id=None, start_time=START_TIME, + is_replayed=False, status=OperationStatus.SUCCEEDED, end_time=END_TIME, error=None, @@ -227,6 +232,7 @@ def test_user_function_callbacks_emit_attempt_span_attributes(): name="fetch-user", parent_id=None, start_time=START_TIME, + is_replayed=False, is_replay_children=False, attempt=1, ) @@ -247,6 +253,7 @@ def test_user_function_callbacks_emit_attempt_span_attributes(): name="fetch-user", parent_id=None, start_time=START_TIME, + is_replayed=False, is_replay_children=False, attempt=1, outcome=UserFunctionOutcome.SUCCEEDED, @@ -288,6 +295,7 @@ def test_context_span_omits_attempt_attributes(): name="book-trip", parent_id=None, start_time=START_TIME, + is_replayed=False, is_replay_children=False, attempt=1, ) @@ -300,6 +308,7 @@ def test_context_span_omits_attempt_attributes(): name="book-trip", parent_id=None, start_time=START_TIME, + is_replayed=False, is_replay_children=False, attempt=1, outcome=UserFunctionOutcome.SUCCEEDED, diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/scheduler_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/scheduler_test.py index 65db932..0f86726 100644 --- a/packages/aws-durable-execution-sdk-python-testing/tests/scheduler_test.py +++ b/packages/aws-durable-execution-sdk-python-testing/tests/scheduler_test.py @@ -498,16 +498,17 @@ def test_event_timeout_handling(): scheduler = Scheduler() scheduler.start() - event = scheduler.create_event() - - start_time = time.time() - result = event.wait(timeout=0.05) - end_time = time.time() + try: + event = scheduler.create_event() - assert result is False - assert 0.04 <= (end_time - start_time) <= 0.1 + start_time = time.monotonic() + result = event.wait(timeout=0.05) + end_time = time.monotonic() - scheduler.stop() + assert result is False + assert 0.04 <= (end_time - start_time) <= 0.25 + finally: + scheduler.stop() def test_scheduler_call_later_zero_delay(): diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py index 34b7529..f786053 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py @@ -35,6 +35,7 @@ class OperationInfo: name: str | None parent_id: str | None start_time: datetime.datetime | None + is_replayed: bool @dataclass(frozen=True) @@ -91,6 +92,7 @@ def from_start_info( name=start_info.name, parent_id=start_info.parent_id, start_time=start_info.start_time, + is_replayed=start_info.is_replayed, is_replay_children=start_info.is_replay_children, attempt=start_info.attempt, outcome=UserFunctionOutcome.from_error(error), @@ -156,7 +158,8 @@ def on_invocation_end(self, info: InvocationEndInfo) -> None: def on_operation_start(self, info: OperationStartInfo) -> None: """ - Called when an operation checkpoints STARTED status. This is called NOT within the thread that runs operation. + Called when an operation checkpoints STARTED status, or when a prior + operation is replayed. This is called NOT within the thread that runs operation. Args: info: Information about the operation. @@ -166,7 +169,8 @@ def on_operation_start(self, info: OperationStartInfo) -> None: def on_operation_end(self, info: OperationEndInfo) -> None: """ - Called when an operation checkpoints a terminal status. This is called NOT within the thread that runs operation. + Called when an operation checkpoints a terminal status, or when a prior + terminal operation is replayed. This is called NOT within the thread that runs operation. Args: info: Information about the operation. @@ -295,6 +299,7 @@ def on_user_function_start( name=operation_identifier.name, parent_id=operation_identifier.parent_id, start_time=datetime.datetime.now(datetime.UTC), + is_replayed=False, is_replay_children=is_replay_children, attempt=attempt, ) @@ -325,6 +330,37 @@ def on_operation_action(self, update: OperationUpdate): name=update.name, parent_id=update.parent_id, start_time=datetime.datetime.now(datetime.UTC), + is_replayed=False, + ), + sync=True, + ) + + def on_operation_replay(self, operation: Operation) -> None: + """Execute plugins for a checkpointed operation observed during replay.""" + start_info = OperationStartInfo( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + sub_type=operation.sub_type, + name=operation.name, + parent_id=operation.parent_id, + start_time=operation.start_timestamp, + is_replayed=True, + ) + self.execute_plugins(start_info, sync=True) + + if self._is_terminal_status(operation.status): + self.execute_plugins( + OperationEndInfo( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + sub_type=operation.sub_type, + name=operation.name, + parent_id=operation.parent_id, + start_time=operation.start_timestamp, + end_time=operation.end_timestamp, + status=operation.status, + error=self._extract_error(operation), + is_replayed=True, ), sync=True, ) @@ -352,6 +388,7 @@ def on_operation_update(self, operation: Operation | None): end_time=operation.end_timestamp, status=operation.status, error=self._extract_error(operation), + is_replayed=False, ), sync=True, ) diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py index 9996591..0f146e3 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py @@ -278,6 +278,7 @@ def __init__( self._replay_status: ReplayStatus = replay_status self._replay_status_lock: Lock = Lock() self._visited_operations: set[str] = set() + self._replayed_operation_hooks: set[str] = set() @property def operations(self) -> dict[str, Operation]: @@ -444,11 +445,30 @@ def get_checkpoint_result(self, checkpoint_id: str) -> CheckpointedResult: """ # checking status are deliberately under a lighter non-serialized lock with self._operations_lock: - if checkpoint := self._operations.get(checkpoint_id): - return CheckpointedResult.create_from_operation(checkpoint) + checkpoint = self._operations.get(checkpoint_id) + + if checkpoint: + self._emit_operation_replay_hooks(checkpoint) + return CheckpointedResult.create_from_operation(checkpoint) return CHECKPOINT_NOT_FOUND + def _emit_operation_replay_hooks(self, operation: Operation) -> None: + """Emit operation hooks once for each checkpointed operation during replay.""" + if operation.operation_type is OperationType.EXECUTION: + return + if operation.status is OperationStatus.READY: + return + + with self._replay_status_lock: + if self._replay_status is not ReplayStatus.REPLAY: + return + if operation.operation_id in self._replayed_operation_hooks: + return + self._replayed_operation_hooks.add(operation.operation_id) + + self._plugin_executor.on_operation_replay(operation) + def create_checkpoint( self, operation_update: OperationUpdate | None = None, diff --git a/packages/aws-durable-execution-sdk-python/tests/plugin_test.py b/packages/aws-durable-execution-sdk-python/tests/plugin_test.py index 86379c0..479997f 100644 --- a/packages/aws-durable-execution-sdk-python/tests/plugin_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/plugin_test.py @@ -40,6 +40,7 @@ name="my-op", parent_id="parent-1", start_time=START_TS, + is_replayed=False, ) OPERATION_END_INFO = OperationEndInfo( operation_id="op-1", @@ -48,6 +49,7 @@ name="my-op", parent_id="parent-1", start_time=START_TS, + is_replayed=False, status=OperationStatus.FAILED, end_time=END_TS, error=ERROR, @@ -76,6 +78,7 @@ name="func", parent_id="parent-1", start_time=START_TS, + is_replayed=False, ) USER_FUNCTION_END_INFO = UserFunctionEndInfo( @@ -85,6 +88,7 @@ name="func", parent_id="parent-1", start_time=START_TS, + is_replayed=False, is_replay_children=False, attempt=1, outcome=UserFunctionOutcome.FAILED, @@ -99,6 +103,7 @@ def test_operation_start_info(self): self.assertEqual(OPERATION_START_INFO.name, "my-op") self.assertEqual(OPERATION_START_INFO.parent_id, "parent-1") self.assertEqual(OPERATION_START_INFO.start_time, START_TS) + self.assertFalse(OPERATION_START_INFO.is_replayed) def test_operation_end_info(self): self.assertEqual(OPERATION_END_INFO.status, OperationStatus.FAILED) @@ -111,6 +116,7 @@ def test_operation_end_info(self): self.assertEqual(OPERATION_END_INFO.operation_id, "op-1") self.assertEqual(OPERATION_END_INFO.status, OperationStatus.FAILED) self.assertEqual(OPERATION_END_INFO.operation_id, "op-1") + self.assertFalse(OPERATION_END_INFO.is_replayed) def test_invocation_start_info(self): self.assertEqual(INVOCATION_START_INFO.request_id, "req-1") diff --git a/packages/aws-durable-execution-sdk-python/tests/state_test.py b/packages/aws-durable-execution-sdk-python/tests/state_test.py index 9362f7c..7523ddd 100644 --- a/packages/aws-durable-execution-sdk-python/tests/state_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/state_test.py @@ -4301,6 +4301,84 @@ def test_plugin_executor_not_called_for_pending_operations(): assert len(operation_end_calls) == 0 +def test_plugin_executor_emits_start_and_end_for_replayed_terminal_operation(): + """Replay reads emit operation lifecycle hooks with is_replayed=True.""" + start_time = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + end_time = datetime.datetime(2025, 1, 2, tzinfo=datetime.UTC) + operation = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + parent_id="parent-1", + name="my-step", + start_timestamp=start_time, + end_timestamp=end_time, + sub_type=OperationSubType.STEP, + step_details=StepDetails(attempt=1, result='"done"'), + ) + captured: list[tuple[str, str, bool]] = [] + + class _CapturingPlugin(DurableInstrumentationPlugin): + def on_operation_start(self, info): + captured.append(("start", info.operation_id, info.is_replayed)) + + def on_operation_end(self, info): + captured.append(("end", info.operation_id, info.is_replayed)) + + plugin_executor = PluginExecutor(plugins=[_CapturingPlugin()]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={"step-1": operation}, + service_client=create_autospec(spec=LambdaClient), + plugin_executor=plugin_executor, + replay_status=ReplayStatus.REPLAY, + ) + + assert state.get_checkpoint_result("step-1").is_succeeded() + assert state.get_checkpoint_result("step-1").is_succeeded() + + assert captured == [ + ("start", "step-1", True), + ("end", "step-1", True), + ] + + +def test_plugin_executor_emits_only_start_for_replayed_non_terminal_operation(): + """Replay reads for in-flight operations emit start but not end.""" + operation = Operation( + operation_id="wait-1", + operation_type=OperationType.WAIT, + status=OperationStatus.STARTED, + name="my-wait", + sub_type=OperationSubType.WAIT, + ) + captured: list[tuple[str, str, bool]] = [] + + class _CapturingPlugin(DurableInstrumentationPlugin): + def on_operation_start(self, info): + captured.append(("start", info.operation_id, info.is_replayed)) + + def on_operation_end(self, info): + captured.append(("end", info.operation_id, info.is_replayed)) + + plugin_executor = PluginExecutor(plugins=[_CapturingPlugin()]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={"wait-1": operation}, + service_client=create_autospec(spec=LambdaClient), + plugin_executor=plugin_executor, + replay_status=ReplayStatus.REPLAY, + ) + + assert state.get_checkpoint_result("wait-1").is_started() + + assert captured == [("start", "wait-1", True)] + + # endregion Plugin Executor Integration Tests