Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -453,11 +453,8 @@ def on_user_function_end(self, info: UserFunctionEndInfo) -> None:
else "Unknown error"
)
)
elif info.outcome is UserFunctionOutcome.SUCCEEDED:
span.set_status(StatusCode.OK)
else:
# PENDING
span.set_status(StatusCode.UNSET, "PENDING")
span.set_status(StatusCode.OK)

end_timestamp = info.end_time
if end_timestamp is not None and end_timestamp == info.start_time:
Expand Down Expand Up @@ -493,9 +490,13 @@ def _extract_attributes(self, info: Any) -> dict[str, str]:
attributes["durable.operation.type"] = info.operation_type.value
if hasattr(info, "name") and info.name is not None:
attributes["durable.operation.name"] = info.name
if hasattr(info, "attempt") and info.attempt is not None:
attributes["durable.attempt.number"] = info.attempt
if hasattr(info, "outcome") and info.outcome is not None:
attributes["durable.attempt.outcome"] = info.outcome.value
# Per-attempt fields are meaningful for STEP (each attempt is retried)
# but not for CONTEXT (a context is entered once per invocation, not
# retried). Omit them on CONTEXT spans for cross-SDK consistency.
if getattr(info, "operation_type", None) is not OperationType.CONTEXT:
if hasattr(info, "attempt") and info.attempt is not None:
attributes["durable.attempt.number"] = info.attempt
if hasattr(info, "outcome") and info.outcome is not None:
attributes["durable.attempt.outcome"] = info.outcome.value

return attributes
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,51 @@ def test_user_function_callbacks_emit_attempt_span_attributes():
)


def test_context_span_omits_attempt_attributes():
"""CONTEXT operations do not carry per-attempt attributes.

durable.attempt.number and durable.attempt.outcome are meaningful for
STEP operations (each retry is an attempt) but not for CONTEXT, so the
plugin omits them on CONTEXT spans for cross-SDK consistency.
"""
plugin, exporter = _create_plugin()
plugin.on_invocation_start(_invocation_start_info())
operation_id = "ctx-1"

plugin.on_user_function_start(
UserFunctionStartInfo(
operation_id=operation_id,
operation_type=OperationType.CONTEXT,
sub_type=None,
name="book-trip",
parent_id=None,
start_time=START_TIME,
is_replay_children=False,
attempt=1,
)
)
plugin.on_user_function_end(
UserFunctionEndInfo(
operation_id=operation_id,
operation_type=OperationType.CONTEXT,
sub_type=None,
name="book-trip",
parent_id=None,
start_time=START_TIME,
is_replay_children=False,
attempt=1,
outcome=UserFunctionOutcome.SUCCEEDED,
end_time=END_TIME,
error=None,
)
)

span = exporter.get_finished_spans()[0]
assert span.attributes["durable.operation.type"] == OperationType.CONTEXT.value
assert "durable.attempt.number" not in span.attributes
assert "durable.attempt.outcome" not in span.attributes


def test_span_registry_helpers_can_be_called_from_multiple_threads():
"""Verify active span registry helpers are safe under concurrent access."""
plugin, _ = _create_plugin()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import contextlib
import datetime
import functools
Expand All @@ -7,7 +9,6 @@
from enum import Enum
from typing import Any, Callable, MutableMapping

from aws_durable_execution_sdk_python.exceptions import SuspendExecution
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
from aws_durable_execution_sdk_python.lambda_service import (
DurableExecutionInvocationOutput,
Expand Down Expand Up @@ -51,16 +52,12 @@ class OperationEndInfo(OperationInfo):
class UserFunctionOutcome(Enum):
SUCCEEDED = "SUCCEEDED"
FAILED = "FAILED"
PENDING = "PENDING"

@classmethod
def from_error(cls, error: ErrorObject | None) -> "UserFunctionOutcome":
def from_error(cls, error: ErrorObject | None) -> UserFunctionOutcome:
if error is None:
return cls(cls.SUCCEEDED)
elif error.type == SuspendExecution.__name__:
return cls(cls.PENDING)
else:
return cls(cls.FAILED)
return cls(cls.FAILED)


@dataclass(frozen=True)
Expand All @@ -86,7 +83,7 @@ class UserFunctionEndInfo(OperationInfo):
@classmethod
def from_start_info(
cls, start_info: UserFunctionStartInfo, error: ErrorObject | None
) -> "UserFunctionEndInfo":
) -> UserFunctionEndInfo:
return UserFunctionEndInfo(
operation_id=start_info.operation_id,
operation_type=start_info.operation_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -952,13 +952,7 @@ def wrapper(*args, **kwargs):
result = user_function(*args, **kwargs)
self._plugin_executor.on_user_function_end(start_info, None)
return result
except SuspendExecution as e:
self._plugin_executor.on_user_function_end(
start_info,
ErrorObject(
type=type(e).__name__, message=None, data=None, stack_trace=None
),
)
except SuspendExecution:
raise
except Exception as e:
self._plugin_executor.on_user_function_end(
Expand Down
24 changes: 24 additions & 0 deletions packages/aws-durable-execution-sdk-python/tests/plugin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,5 +783,29 @@ def on_operation_attempt_end(self, info):
# endregion Helper Classes


# region Suspend Outcome Tests
class TestUserFunctionOutcomeValues(unittest.TestCase):
def test_outcome_values(self):
self.assertEqual(
{o.value for o in UserFunctionOutcome},
{"SUCCEEDED", "FAILED"},
)


class TestUserFunctionOutcomeFromError(unittest.TestCase):
def test_none_error_is_succeeded(self):
self.assertEqual(
UserFunctionOutcome.from_error(None), UserFunctionOutcome.SUCCEEDED
)

def test_error_is_failed(self):
self.assertEqual(
UserFunctionOutcome.from_error(ERROR), UserFunctionOutcome.FAILED
)


# endregion Suspend Outcome Tests


if __name__ == "__main__":
unittest.main()
42 changes: 42 additions & 0 deletions packages/aws-durable-execution-sdk-python/tests/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DurableApiErrorCategory,
GetExecutionStateError,
OrphanedChildException,
TimedSuspendExecution,
)
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
from aws_durable_execution_sdk_python.lambda_service import (
Expand All @@ -41,6 +42,7 @@
from aws_durable_execution_sdk_python.plugin import (
DurableInstrumentationPlugin,
PluginExecutor,
UserFunctionEndInfo,
)
from aws_durable_execution_sdk_python.state import (
CheckpointBatcherConfig,
Expand Down Expand Up @@ -4199,6 +4201,46 @@ def on_operation_end(self, info):
executor.shutdown(wait=True)


def test_wrap_user_function_suspend_does_not_fire_end_hook():
"""A user function that suspends does not fire the end hook.

Regression: a timed suspend (TimedSuspendExecution) raised inside a wrapped
user function (e.g. a child context that waits) must not be surfaced to
plugins as a FAILED outcome. The suspend is normal durable control flow,
and the plugin observes it by absence (no end hook fires), with the
instrumentation plugin's own per-invocation span sweep closing any open
spans cleanly at invocation end.
"""
captured: list[UserFunctionEndInfo] = []

class _CapturingPlugin(DurableInstrumentationPlugin):
def on_user_function_end(self, info: UserFunctionEndInfo) -> None:
captured.append(info)

plugin_executor = PluginExecutor(plugins=[_CapturingPlugin()])
with plugin_executor.run():
state = ExecutionState(
durable_execution_arn="test_arn",
initial_checkpoint_token="token123", # noqa: S106
operations={},
service_client=create_autospec(spec=LambdaClient),
plugin_executor=plugin_executor,
)

def suspends(_: object) -> None:
raise TimedSuspendExecution.from_delay("waiting", 5)

op_id = OperationIdentifier(
operation_id="op-1", sub_type=OperationSubType.STEP, name="step"
)
wrapped = state.wrap_user_function(suspends, op_id, attempt=1)

with pytest.raises(TimedSuspendExecution):
wrapped(None)

assert captured == []


def test_plugin_executor_not_called_for_pending_operations():
"""Test that plugin_executor.on_operation_update fires on_user_function_end for PENDING operations."""
mock_client = create_autospec(spec=LambdaClient)
Expand Down
Loading