diff --git a/packages/aws-durable-execution-sdk-python-examples/test/with_retry/test_with_retry_callback.py b/packages/aws-durable-execution-sdk-python-examples/test/with_retry/test_with_retry_callback.py index f3b2206..f08bbee 100644 --- a/packages/aws-durable-execution-sdk-python-examples/test/with_retry/test_with_retry_callback.py +++ b/packages/aws-durable-execution-sdk-python-examples/test/with_retry/test_with_retry_callback.py @@ -10,7 +10,15 @@ from test.conftest import deserialize_operation_payload from aws_durable_execution_sdk_python.execution import InvocationStatus -from aws_durable_execution_sdk_python.lambda_service import ErrorObject +from aws_durable_execution_sdk_python.lambda_service import ErrorObject, OperationStatus + + +def get_callback_status(result, name: str) -> OperationStatus: + callbacks = [ + operation for operation in result.get_all_operations() if operation.name == name + ] + assert len(callbacks) == 1 + return callbacks[0].status @pytest.mark.example @@ -67,3 +75,16 @@ def test_with_retry_callback_fails_twice_then_succeeds(durable_runner): "success": True, "result": "approval granted", } + + assert ( + get_callback_status(result, "external-call-attempt-1 create callback id") + is OperationStatus.FAILED + ) + assert ( + get_callback_status(result, "external-call-attempt-2 create callback id") + is OperationStatus.FAILED + ) + assert ( + get_callback_status(result, "external-call-attempt-3 create callback id") + is OperationStatus.SUCCEEDED + ) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/executor.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/executor.py index 02a1504..5c23e24 100644 --- a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/executor.py +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/executor.py @@ -59,7 +59,7 @@ if TYPE_CHECKING: - from collections.abc import Awaitable, Callable + from collections.abc import Callable from concurrent.futures import Future from aws_durable_execution_sdk_python_testing.checkpoint.processor import ( @@ -779,10 +779,10 @@ def _validate_invocation_response_and_store( ) raise IllegalStateException(msg_unexpected_status) - def _invoke_handler(self, execution_arn: str) -> Callable[[], Awaitable[None]]: + def _invoke_handler(self, execution_arn: str) -> Callable[[], None]: """Create a parameterless callable that captures execution arn for the scheduler.""" - async def invoke() -> None: + def invoke() -> None: execution: Execution = self._store.load(execution_arn) # Early exit if execution is already completed - like Java's COMPLETED check diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/runner.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/runner.py index d60e774..c263e8b 100644 --- a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/runner.py +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/runner.py @@ -398,6 +398,16 @@ def create_operation( return operation_class.from_svc_operation(svc_operation, all_operations) +def _is_callback_ready_for_response( + events: list[Event], callback_started_event: Event +) -> bool: + """Return True once the invocation that created the callback has settled.""" + for event in events[events.index(callback_started_event) + 1 :]: + if event.event_type == "InvocationCompleted": + return True + return False + + def _get_callback_id_from_events( events: list[Event], name: str | None = None ) -> str | None: @@ -409,7 +419,8 @@ def _get_callback_id_from_events( name: Optional callback name to search for. If not provided, returns the latest callback. Returns: - The callback ID string for a non-completed callback, or None if not found. + The callback ID string for a non-completed callback whose creating + invocation has completed, or None if not found. Raises: DurableFunctionsTestError: If the named callback has already succeeded/failed/timed out. @@ -436,6 +447,8 @@ def _get_callback_id_from_events( raise DurableFunctionsTestError( f"Callback {name} has already completed (succeeded/failed/timed out)" ) + if not _is_callback_ready_for_response(events, event): + return None return ( event.callback_started_details.callback_id if event.callback_started_details @@ -448,6 +461,7 @@ def _get_callback_id_from_events( event for event in callback_started_events if event.event_id not in completed_callback_ids + and _is_callback_ready_for_response(events, event) ] if not active_callbacks: diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/scheduler.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/scheduler.py index a45b942..04fa712 100644 --- a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/scheduler.py +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/scheduler.py @@ -3,9 +3,12 @@ from __future__ import annotations import asyncio +import inspect import itertools import logging import threading +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import TimeoutError as FutureTimeoutError from typing import TYPE_CHECKING, Any @@ -64,6 +67,10 @@ class Scheduler: def __init__(self) -> None: self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() + self._thread_executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="durable-scheduler" + ) + self._loop.set_default_executor(self._thread_executor) self._ready_event: threading.Event = threading.Event() self._thread: threading.Thread = threading.Thread( target=self._start_loop, daemon=True @@ -99,7 +106,8 @@ def stop(self): return self._running = False - self._loop.call_soon_threadsafe(self._cleanup_and_stop) + future = asyncio.run_coroutine_threadsafe(self._cleanup_and_stop(), self._loop) + future.result() self._thread.join() def is_started(self) -> bool: @@ -116,16 +124,24 @@ def task_count(self) -> int: return 0 return len(asyncio.all_tasks(self._loop)) - def _cleanup_and_stop(self): - """Cancel all tasks and clear all events. Stop the event-loop.""" - # Cancel all tasks - for task in asyncio.all_tasks(self._loop): + async def _cleanup_and_stop(self): + """Cancel all tasks, clear events, and stop the event loop.""" + current_task = asyncio.current_task(self._loop) + tasks = [ + task for task in asyncio.all_tasks(self._loop) if task is not current_task + ] + + for task in tasks: task.cancel() - # Clear events (don't set them) + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + await self._loop.shutdown_default_executor() + self._events.clear() - self._loop.stop() + self._loop.call_soon(self._loop.stop) def _start_loop(self): """Initialize the event-loop. The ready event notifies that the loop is started.""" @@ -146,8 +162,8 @@ def call_later( ) -> Future[Any]: """Call func after the delay. - If func is async it runs inside a thread-safe coroutine. If func is sync it runs in its own - threadpool, so it won't block the event loop. + If func is async it runs inside a thread-safe coroutine. If func is sync it runs in the + scheduler worker thread, so it won't block the event loop. Args: func (Callable[[], Any]): The function to call later. This can be an async or a standard @@ -170,7 +186,7 @@ async def delayed_func() -> Any: await asyncio.sleep(delay) try: - if asyncio.iscoroutinefunction(func): + if inspect.iscoroutinefunction(func): result = await func() else: result = await asyncio.to_thread(func) @@ -215,13 +231,22 @@ def wait_for_event( if event not in self._events: return False + async def wait_for_event_with_timeout() -> bool: + return await asyncio.wait_for(event.wait(), timeout) + future: Future[bool] = asyncio.run_coroutine_threadsafe( - asyncio.wait_for(event.wait(), timeout), self._loop + wait_for_event_with_timeout(), self._loop ) try: - return future.result() - except TimeoutError: + if timeout is None: + return future.result() + # Enforce the timeout from the waiting thread too. If the scheduler + # loop is blocked, the inner asyncio.wait_for timeout cannot fire. + margin = min(1.0, max(0.01, timeout * 0.1)) + return future.result(timeout=timeout + margin) + except (TimeoutError, FutureTimeoutError): + future.cancel() return False def set_event(self, event: asyncio.Event): diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/sqlite.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/sqlite.py index fac1ca4..ce89e83 100644 --- a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/sqlite.py +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/sqlite.py @@ -4,6 +4,8 @@ import json import sqlite3 +from collections.abc import Iterator +from contextlib import contextmanager from datetime import datetime from pathlib import Path from typing import Any, cast @@ -43,10 +45,20 @@ def _get_connection(self) -> sqlite3.Connection: conn.execute("PRAGMA synchronous=NORMAL;") return conn + @contextmanager + def _connection(self) -> Iterator[sqlite3.Connection]: + """Open a connection, preserve transaction handling, and always close it.""" + conn = self._get_connection() + try: + with conn: + yield conn + finally: + conn.close() + def _init_db(self) -> None: """Initialize database schema.""" try: - with self._get_connection() as conn: + with self._connection() as conn: conn.execute(""" CREATE TABLE IF NOT EXISTS executions ( durable_execution_arn TEXT PRIMARY KEY, @@ -80,7 +92,7 @@ def save(self, execution: Execution) -> None: execution_op = execution.get_operation_execution_started() status: str = execution.current_status().value - with self._get_connection() as conn: + with self._connection() as conn: conn.execute( """ INSERT OR REPLACE INTO executions @@ -111,7 +123,7 @@ def save(self, execution: Execution) -> None: def load(self, execution_arn: str) -> Execution: """Load execution from SQLite.""" try: - with self._get_connection() as conn: + with self._connection() as conn: cursor: sqlite3.Cursor = conn.execute( "SELECT data FROM executions WHERE durable_execution_arn = ?", (execution_arn,), @@ -204,7 +216,7 @@ def query( ) params_with_limit = params - with self._get_connection() as conn: + with self._connection() as conn: # Get total count for pagination total_count: int = int(conn.execute(count_query, params).fetchone()[0]) @@ -247,7 +259,7 @@ def list_all(self) -> list[Execution]: def get_execution_metadata(self, execution_arn: str) -> dict[str, Any] | None: """Get just the metadata without full deserialization for performance.""" try: - with self._get_connection() as conn: + with self._connection() as conn: cursor: sqlite3.Cursor = conn.execute( "SELECT function_name, execution_name, status, start_timestamp, end_timestamp FROM executions WHERE durable_execution_arn = ?", (execution_arn,), diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/executor_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/executor_test.py index e228a0d..787e494 100644 --- a/packages/aws-durable-execution-sdk-python-testing/tests/executor_test.py +++ b/packages/aws-durable-execution-sdk-python-testing/tests/executor_test.py @@ -1,6 +1,5 @@ """Unit tests for executor module.""" -import asyncio from datetime import UTC, datetime from unittest.mock import Mock, patch @@ -315,9 +314,8 @@ def test_should_complete_workflow_with_error_when_invocation_fails( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - import asyncio - asyncio.run(handler()) + handler() # Assert - verify workflow was completed with error mock_fail.assert_called_once_with("test-arn", failed_response.error) @@ -361,9 +359,8 @@ def test_should_complete_workflow_with_result_when_invocation_succeeds( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - import asyncio - asyncio.run(handler()) + handler() # Assert - verify workflow was completed with result mock_complete.assert_called_once_with("test-arn", "success result") @@ -404,9 +401,8 @@ def test_should_handle_pending_status_when_operations_exist( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - import asyncio - asyncio.run(handler()) + handler() # Assert - verify pending operations were checked mock_execution.has_pending_operations.assert_called_once_with(mock_execution) @@ -444,9 +440,8 @@ def test_should_ignore_response_when_execution_already_complete( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - import asyncio - asyncio.run(handler()) + handler() # Assert - verify invoker was not called since execution was already complete mock_invoker.create_invocation_input.assert_not_called() @@ -487,7 +482,7 @@ def test_should_retry_when_response_has_no_status( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - asyncio.run(handler()) + handler() # Assert - verify retry was triggered due to validation error assert mock_execution.consecutive_failed_invocation_attempts == 1 @@ -532,7 +527,7 @@ def test_should_retry_when_failed_response_has_result( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - asyncio.run(handler()) + handler() # Assert - verify retry was triggered due to validation error assert mock_execution.consecutive_failed_invocation_attempts == 1 @@ -578,7 +573,7 @@ def test_should_retry_when_success_response_has_error( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - asyncio.run(handler()) + handler() # Assert - verify retry was triggered due to validation error assert mock_execution.consecutive_failed_invocation_attempts == 1 @@ -622,7 +617,7 @@ def test_should_retry_when_pending_response_has_no_operations( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - asyncio.run(handler()) + handler() # Assert - verify retry was triggered due to validation error assert mock_execution.consecutive_failed_invocation_attempts == 1 @@ -665,7 +660,7 @@ def test_invoke_handler_success( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - asyncio.run(handler()) + handler() # Verify the invocation process was executed mock_invoker.create_invocation_input.assert_called_once_with( @@ -701,7 +696,7 @@ def test_invoke_handler_execution_already_complete( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - asyncio.run(handler()) + handler() # Verify store was accessed to check execution status mock_store.load.assert_called_with("test-arn") @@ -747,7 +742,7 @@ def test_invoke_handler_execution_completed_during_invocation( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - asyncio.run(handler()) + handler() # Verify the execution was checked for completion assert mock_store.load.call_count >= 2 @@ -784,7 +779,7 @@ def test_invoke_handler_resource_not_found( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - asyncio.run(handler()) + handler() # Assert - verify workflow failure was triggered through public API mock_fail.assert_called_once() @@ -825,7 +820,7 @@ def test_invoke_handler_general_exception( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - asyncio.run(handler()) + handler() # Assert - verify retry was scheduled through observable behavior assert mock_execution.consecutive_failed_invocation_attempts == 1 @@ -948,9 +943,8 @@ def test_should_fail_execution_when_function_not_found( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - import asyncio - asyncio.run(handler()) + handler() # Assert - verify failure was triggered with correct error mock_fail.assert_called_once() @@ -992,9 +986,8 @@ def test_should_fail_execution_when_retries_exhausted( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - import asyncio - asyncio.run(handler()) + handler() # Assert - verify failure was triggered when retries exhausted mock_fail.assert_called_once() @@ -1038,7 +1031,7 @@ def test_should_prevent_multiple_workflow_failures_on_complete_execution( with pytest.raises( IllegalStateException, match="Cannot make multiple close workflow decisions" ): - asyncio.run(handler()) + handler() def test_should_retry_invocation_when_under_limit_through_public_api( @@ -1086,9 +1079,8 @@ def test_should_retry_invocation_when_under_limit_through_public_api( # Simulate scheduler executing the initial invocation handler initial_handler = mock_scheduler.call_later.call_args_list[-1][0][0] - import asyncio - asyncio.run(initial_handler()) + initial_handler() # Verify retry was scheduled due to validation error assert mock_scheduler.call_later.call_count == 3 # timeout + initial + retry @@ -1099,7 +1091,7 @@ def test_should_retry_invocation_when_under_limit_through_public_api( retry_delay = retry_call[1]["delay"] # Execute the retry handler to complete the scenario - asyncio.run(retry_handler()) + retry_handler() # Assert - verify final outcome after retry sequence assert ( @@ -1141,7 +1133,7 @@ def test_should_fail_workflow_when_retry_limit_exceeded( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - asyncio.run(handler()) + handler() # Assert - verify workflow failed due to retry limit exceeded mock_fail.assert_called_once() @@ -1501,7 +1493,7 @@ def test_should_retry_when_response_has_unexpected_status( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - asyncio.run(handler()) + handler() # Assert - verify retry was triggered due to validation error assert mock_execution.consecutive_failed_invocation_attempts == 1 @@ -1547,7 +1539,7 @@ def test_invoke_handler_execution_completed_during_invocation_async( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - asyncio.run(handler()) + handler() # Verify the execution was loaded multiple times (before and after invocation) assert mock_store.load.call_count >= 2 @@ -1584,7 +1576,7 @@ def test_invoke_handler_resource_not_found_async( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - asyncio.run(handler()) + handler() # Assert - verify workflow failure was triggered through public API mock_fail.assert_called_once() @@ -1636,7 +1628,7 @@ def test_invoke_handler_general_exception_async( handler = mock_scheduler.call_later.call_args_list[-1][0][0] # Execute the handler to trigger the invocation logic - asyncio.run(handler()) + handler() # Assert - verify retry was scheduled through observable behavior assert mock_execution.consecutive_failed_invocation_attempts == 1 diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/runner_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/runner_test.py index 3b81269..b243ae9 100644 --- a/packages/aws-durable-execution-sdk-python-testing/tests/runner_test.py +++ b/packages/aws-durable-execution-sdk-python-testing/tests/runner_test.py @@ -1792,7 +1792,16 @@ def test_cloud_runner_wait_for_callback_success(mock_boto3): "Id": "callback-event-1", "Name": "test-callback", "CallbackStartedDetails": {"CallbackId": "callback-123"}, - } + }, + { + "EventType": "InvocationCompleted", + "EventTimestamp": "2023-01-01T00:00:01Z", + "InvocationCompletedDetails": { + "StartTimestamp": "2023-01-01T00:00:00Z", + "EndTimestamp": "2023-01-01T00:00:01Z", + "RequestId": "request-1", + }, + }, ] } @@ -1834,6 +1843,46 @@ def test_cloud_runner_wait_for_callback_none(mock_boto3): runner.wait_for_callback("test-arn", name="test-callback1", timeout=2) +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_wait_for_callback_waits_for_invocation(mock_boto3): + """Test wait_for_callback waits until the callback creator has settled.""" + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + callback_started = { + "EventType": "CallbackStarted", + "EventTimestamp": "2023-01-01T00:00:00Z", + "Id": "callback-event-1", + "Name": "test-callback", + "CallbackStartedDetails": {"CallbackId": "callback-123"}, + } + invocation_completed = { + "EventType": "InvocationCompleted", + "EventTimestamp": "2023-01-01T00:00:01Z", + "InvocationCompletedDetails": { + "StartTimestamp": "2023-01-01T00:00:00Z", + "EndTimestamp": "2023-01-01T00:00:01Z", + "RequestId": "request-1", + }, + } + mock_client.get_durable_execution_history.side_effect = [ + {"Events": [callback_started]}, + {"Events": [callback_started, invocation_completed]}, + ] + + runner = DurableFunctionCloudTestRunner( + function_name="test-function", poll_interval=0.01 + ) + callback_id = runner.wait_for_callback("test-arn", name="test-callback", timeout=10) + + assert callback_id == "callback-123" + assert mock_client.get_durable_execution_history.call_count == 2 + + @patch("aws_durable_execution_sdk_python_testing.runner.boto3") def test_cloud_runner_wait_for_callback_success_without_name(mock_boto3): """Test DurableFunctionCloudTestRunner.wait_for_callback success.""" @@ -1852,7 +1901,16 @@ def test_cloud_runner_wait_for_callback_success_without_name(mock_boto3): "Id": "callback-event-1", "Name": "test-callback", "CallbackStartedDetails": {"CallbackId": "callback-123"}, - } + }, + { + "EventType": "InvocationCompleted", + "EventTimestamp": "2023-01-01T00:00:01Z", + "InvocationCompletedDetails": { + "StartTimestamp": "2023-01-01T00:00:00Z", + "EndTimestamp": "2023-01-01T00:00:01Z", + "RequestId": "request-1", + }, + }, ] } @@ -2047,7 +2105,16 @@ def test_cloud_runner_wait_for_callback_client_error_retryable(mock_boto3): "Id": "callback-event-1", "Name": "test-callback", "CallbackStartedDetails": {"CallbackId": "callback-123"}, - } + }, + { + "EventType": "InvocationCompleted", + "EventTimestamp": "2023-01-01T00:00:01Z", + "InvocationCompletedDetails": { + "StartTimestamp": "2023-01-01T00:00:00Z", + "EndTimestamp": "2023-01-01T00:00:01Z", + "RequestId": "request-1", + }, + }, ] }, ] diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/scheduler_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/scheduler_test.py index 0f86726..3e687c0 100644 --- a/packages/aws-durable-execution-sdk-python-testing/tests/scheduler_test.py +++ b/packages/aws-durable-execution-sdk-python-testing/tests/scheduler_test.py @@ -456,24 +456,49 @@ def test_scheduler_cleanup_on_stop(): assert not scheduler.is_started() +def test_scheduler_stop_shuts_down_default_executor_threads(): + """Scheduler stop joins threads created for sync scheduled functions.""" + existing_thread_ids = {thread.ident for thread in threading.enumerate()} + scheduler = Scheduler() + scheduler.start() + + future = scheduler.call_later(lambda: None, delay=0) + future.result(timeout=1.0) + + def new_scheduler_threads(): + return [ + thread + for thread in threading.enumerate() + if thread.ident not in existing_thread_ids + and thread.name.startswith("durable-scheduler") + ] + + assert wait_for_condition(lambda: len(new_scheduler_threads()) >= 1) + + scheduler.stop() + + assert wait_for_condition(lambda: len(new_scheduler_threads()) == 0) + + def test_scheduler_multiple_events(): """Test scheduler with multiple events.""" scheduler = Scheduler() scheduler.start() - event1 = scheduler.create_event() - event2 = scheduler.create_event() - - assert scheduler.event_count() == 2 + try: + event1 = scheduler.create_event() + event2 = scheduler.create_event() - event1.set() - result1 = event1.wait(timeout=0.01) - assert result1 is True + assert scheduler.event_count() == 2 - result2 = event2.wait(timeout=0.01) - assert result2 is False + event1.set() + result1 = event1.wait(timeout=1.0) + assert result1 is True - scheduler.stop() + result2 = event2.wait(timeout=0.01) + assert result2 is False + finally: + scheduler.stop() def test_task_properties_after_scheduler_stop(): @@ -511,6 +536,31 @@ def test_event_timeout_handling(): scheduler.stop() +def test_event_timeout_when_scheduler_loop_is_blocked(): + """Event wait timeout is enforced even if the scheduler loop is blocked.""" + scheduler = Scheduler() + scheduler.start() + + event = scheduler.create_event() + blocker_started = threading.Event() + + async def block_scheduler_loop(): + blocker_started.set() + time.sleep(0.2) + + scheduler.call_later(block_scheduler_loop, delay=0) + assert blocker_started.wait(timeout=1.0) + + start_time = time.time() + result = event.wait(timeout=0.02) + elapsed = time.time() - start_time + + assert result is False + assert elapsed < 0.15 + + scheduler.stop() + + def test_scheduler_call_later_zero_delay(): """Test call_later with zero delay.""" scheduler = Scheduler() diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/stores/sqlite_store_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/stores/sqlite_store_test.py index 7c7feb4..800a1f3 100644 --- a/packages/aws-durable-execution-sdk-python-testing/tests/stores/sqlite_store_test.py +++ b/packages/aws-durable-execution-sdk-python-testing/tests/stores/sqlite_store_test.py @@ -2,6 +2,7 @@ import tempfile import time +from contextlib import closing from datetime import datetime, UTC from pathlib import Path @@ -664,23 +665,24 @@ def test_sqlite_execution_store_corrupted_data_handling(store, temp_db_path): import sqlite3 # Insert corrupted JSON data directly - with sqlite3.connect(temp_db_path) as conn: - conn.execute( - """ - INSERT INTO executions - (durable_execution_arn, function_name, execution_name, status, start_timestamp, end_timestamp, data) - VALUES (?, ?, ?, ?, ?, ?, ?) - """, - ( - "corrupted-arn", - "test-function", - "test-execution", - "RUNNING", - 1234567890.0, - None, - "invalid json data {{{", - ), - ) + with closing(sqlite3.connect(temp_db_path)) as conn: + with conn: + conn.execute( + """ + INSERT INTO executions + (durable_execution_arn, function_name, execution_name, status, start_timestamp, end_timestamp, data) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + "corrupted-arn", + "test-function", + "test-execution", + "RUNNING", + 1234567890.0, + None, + "invalid json data {{{", + ), + ) # Loading corrupted data should raise ValueError with pytest.raises(ValueError, match="Corrupted execution data"): @@ -830,23 +832,24 @@ def test_sqlite_execution_store_query_with_corrupted_data_warning( store.save(execution) # Insert corrupted JSON data directly - with sqlite3.connect(temp_db_path) as conn: - conn.execute( - """ - INSERT INTO executions - (durable_execution_arn, function_name, execution_name, status, start_timestamp, end_timestamp, data) - VALUES (?, ?, ?, ?, ?, ?, ?) - """, - ( - "corrupted-arn-2", - "test-function", - "test-execution", - "RUNNING", - 1234567890.0, - None, - "invalid json data {{{", - ), - ) + with closing(sqlite3.connect(temp_db_path)) as conn: + with conn: + conn.execute( + """ + INSERT INTO executions + (durable_execution_arn, function_name, execution_name, status, start_timestamp, end_timestamp, data) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + "corrupted-arn-2", + "test-function", + "test-execution", + "RUNNING", + 1234567890.0, + None, + "invalid json data {{{", + ), + ) # Query should skip corrupted records and print warning executions, _ = store.query()