Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@
from test.conftest import deserialize_operation_payload

from aws_durable_execution_sdk_python.execution import InvocationStatus
from aws_durable_execution_sdk_python.lambda_service import ErrorObject
from aws_durable_execution_sdk_python.lambda_service import ErrorObject, OperationStatus


def get_callback_status(result, name: str) -> OperationStatus:
callbacks = [
operation for operation in result.get_all_operations() if operation.name == name
]
assert len(callbacks) == 1
return callbacks[0].status


@pytest.mark.example
Expand Down Expand Up @@ -67,3 +75,16 @@ def test_with_retry_callback_fails_twice_then_succeeds(durable_runner):
"success": True,
"result": "approval granted",
}

assert (
get_callback_status(result, "external-call-attempt-1 create callback id")
is OperationStatus.FAILED
)
assert (
get_callback_status(result, "external-call-attempt-2 create callback id")
is OperationStatus.FAILED
)
assert (
get_callback_status(result, "external-call-attempt-3 create callback id")
is OperationStatus.SUCCEEDED
)
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@


if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from collections.abc import Callable
from concurrent.futures import Future

from aws_durable_execution_sdk_python_testing.checkpoint.processor import (
Expand Down Expand Up @@ -779,10 +779,10 @@ def _validate_invocation_response_and_store(
)
raise IllegalStateException(msg_unexpected_status)

def _invoke_handler(self, execution_arn: str) -> Callable[[], Awaitable[None]]:
def _invoke_handler(self, execution_arn: str) -> Callable[[], None]:
"""Create a parameterless callable that captures execution arn for the scheduler."""

async def invoke() -> None:
def invoke() -> None:
execution: Execution = self._store.load(execution_arn)

# Early exit if execution is already completed - like Java's COMPLETED check
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,16 @@ def create_operation(
return operation_class.from_svc_operation(svc_operation, all_operations)


def _is_callback_ready_for_response(
events: list[Event], callback_started_event: Event
) -> bool:
"""Return True once the invocation that created the callback has settled."""
for event in events[events.index(callback_started_event) + 1 :]:
if event.event_type == "InvocationCompleted":
return True
return False


def _get_callback_id_from_events(
events: list[Event], name: str | None = None
) -> str | None:
Expand All @@ -409,7 +419,8 @@ def _get_callback_id_from_events(
name: Optional callback name to search for. If not provided, returns the latest callback.

Returns:
The callback ID string for a non-completed callback, or None if not found.
The callback ID string for a non-completed callback whose creating
invocation has completed, or None if not found.

Raises:
DurableFunctionsTestError: If the named callback has already succeeded/failed/timed out.
Expand All @@ -436,6 +447,8 @@ def _get_callback_id_from_events(
raise DurableFunctionsTestError(
f"Callback {name} has already completed (succeeded/failed/timed out)"
)
if not _is_callback_ready_for_response(events, event):
return None
return (
event.callback_started_details.callback_id
if event.callback_started_details
Expand All @@ -448,6 +461,7 @@ def _get_callback_id_from_events(
event
for event in callback_started_events
if event.event_id not in completed_callback_ids
and _is_callback_ready_for_response(events, event)
]

if not active_callbacks:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
from __future__ import annotations

import asyncio
import inspect
import itertools
import logging
import threading
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FutureTimeoutError
from typing import TYPE_CHECKING, Any


Expand Down Expand Up @@ -64,6 +67,10 @@ class Scheduler:

def __init__(self) -> None:
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
self._thread_executor = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="durable-scheduler"
)
self._loop.set_default_executor(self._thread_executor)
self._ready_event: threading.Event = threading.Event()
self._thread: threading.Thread = threading.Thread(
target=self._start_loop, daemon=True
Expand Down Expand Up @@ -99,7 +106,8 @@ def stop(self):
return

self._running = False
self._loop.call_soon_threadsafe(self._cleanup_and_stop)
future = asyncio.run_coroutine_threadsafe(self._cleanup_and_stop(), self._loop)
future.result()
self._thread.join()

def is_started(self) -> bool:
Expand All @@ -116,16 +124,24 @@ def task_count(self) -> int:
return 0
return len(asyncio.all_tasks(self._loop))

def _cleanup_and_stop(self):
"""Cancel all tasks and clear all events. Stop the event-loop."""
# Cancel all tasks
for task in asyncio.all_tasks(self._loop):
async def _cleanup_and_stop(self):
"""Cancel all tasks, clear events, and stop the event loop."""
current_task = asyncio.current_task(self._loop)
tasks = [
task for task in asyncio.all_tasks(self._loop) if task is not current_task
]

for task in tasks:
task.cancel()

# Clear events (don't set them)
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)

await self._loop.shutdown_default_executor()

self._events.clear()

self._loop.stop()
self._loop.call_soon(self._loop.stop)

def _start_loop(self):
"""Initialize the event-loop. The ready event notifies that the loop is started."""
Expand All @@ -146,8 +162,8 @@ def call_later(
) -> Future[Any]:
"""Call func after the delay.

If func is async it runs inside a thread-safe coroutine. If func is sync it runs in its own
threadpool, so it won't block the event loop.
If func is async it runs inside a thread-safe coroutine. If func is sync it runs in the
scheduler worker thread, so it won't block the event loop.

Args:
func (Callable[[], Any]): The function to call later. This can be an async or a standard
Expand All @@ -170,7 +186,7 @@ async def delayed_func() -> Any:
await asyncio.sleep(delay)

try:
if asyncio.iscoroutinefunction(func):
if inspect.iscoroutinefunction(func):
result = await func()
else:
result = await asyncio.to_thread(func)
Expand Down Expand Up @@ -215,13 +231,22 @@ def wait_for_event(
if event not in self._events:
return False

async def wait_for_event_with_timeout() -> bool:
return await asyncio.wait_for(event.wait(), timeout)

future: Future[bool] = asyncio.run_coroutine_threadsafe(
asyncio.wait_for(event.wait(), timeout), self._loop
wait_for_event_with_timeout(), self._loop
)

try:
return future.result()
except TimeoutError:
if timeout is None:
return future.result()
# Enforce the timeout from the waiting thread too. If the scheduler
# loop is blocked, the inner asyncio.wait_for timeout cannot fire.
margin = min(1.0, max(0.01, timeout * 0.1))
return future.result(timeout=timeout + margin)
except (TimeoutError, FutureTimeoutError):
future.cancel()
return False

def set_event(self, event: asyncio.Event):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import json
import sqlite3
from collections.abc import Iterator
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import Any, cast
Expand Down Expand Up @@ -43,10 +45,20 @@ def _get_connection(self) -> sqlite3.Connection:
conn.execute("PRAGMA synchronous=NORMAL;")
return conn

@contextmanager
def _connection(self) -> Iterator[sqlite3.Connection]:
"""Open a connection, preserve transaction handling, and always close it."""
conn = self._get_connection()
try:
with conn:
yield conn
finally:
conn.close()

def _init_db(self) -> None:
"""Initialize database schema."""
try:
with self._get_connection() as conn:
with self._connection() as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS executions (
durable_execution_arn TEXT PRIMARY KEY,
Expand Down Expand Up @@ -80,7 +92,7 @@ def save(self, execution: Execution) -> None:
execution_op = execution.get_operation_execution_started()
status: str = execution.current_status().value

with self._get_connection() as conn:
with self._connection() as conn:
conn.execute(
"""
INSERT OR REPLACE INTO executions
Expand Down Expand Up @@ -111,7 +123,7 @@ def save(self, execution: Execution) -> None:
def load(self, execution_arn: str) -> Execution:
"""Load execution from SQLite."""
try:
with self._get_connection() as conn:
with self._connection() as conn:
cursor: sqlite3.Cursor = conn.execute(
"SELECT data FROM executions WHERE durable_execution_arn = ?",
(execution_arn,),
Expand Down Expand Up @@ -204,7 +216,7 @@ def query(
)
params_with_limit = params

with self._get_connection() as conn:
with self._connection() as conn:
# Get total count for pagination
total_count: int = int(conn.execute(count_query, params).fetchone()[0])

Expand Down Expand Up @@ -247,7 +259,7 @@ def list_all(self) -> list[Execution]:
def get_execution_metadata(self, execution_arn: str) -> dict[str, Any] | None:
"""Get just the metadata without full deserialization for performance."""
try:
with self._get_connection() as conn:
with self._connection() as conn:
cursor: sqlite3.Cursor = conn.execute(
"SELECT function_name, execution_name, status, start_timestamp, end_timestamp FROM executions WHERE durable_execution_arn = ?",
(execution_arn,),
Expand Down
Loading