Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class TimerScheduler:
"""Manage timed suspend tasks with a background timer thread."""

def __init__(
self, resubmit_callback: Callable[[ExecutableWithState], None]
self, resubmit_callback: Callable[[list[ExecutableWithState]], None]
) -> None:
self.resubmit_callback = resubmit_callback
self._pending_resumes: list[tuple[float, int, ExecutableWithState]] = []
Expand Down Expand Up @@ -114,18 +114,31 @@ def _timer_loop(self) -> None:

current_time = time.time()
if current_time >= next_resume_time:
# Time to resume
# Drain every due resume under the lock, transitioning each to
# PENDING atomically with the pop. Keeping pop+reset_to_pending
# together is required: should_execution_suspend reads branch
# status without this lock, so an item that is removed from the
# heap but still SUSPENDED_WITH_TIMEOUT could trigger a spurious
# parent suspend.
ready: list[ExecutableWithState] = []
with self._lock:
# no branch cover because hard to test reliably - this is a double-safety check if heap mutated
# since the first peek on next_resume_time further up
if ( # pragma: no branch
while (
self._pending_resumes
and self._pending_resumes[0][0] <= current_time
):
_, _, exe_state = heapq.heappop(self._pending_resumes)
if exe_state.can_resume:
exe_state.reset_to_pending()
self.resubmit_callback(exe_state)
ready.append(exe_state)
# Resubmit outside the lock. Only the heap pop and the PENDING
# transition need the lock. The checkpoint refresh is a blocking
# network call and the submit hands work to the pool, so running
# them off the lock keeps timed resumes from serializing behind
# the network round trip and keeps the timer thread from
# re-entering this non-reentrant lock when a submitted future
# completes inline and its done-callback calls schedule_resume.
if ready:
self.resubmit_callback(ready)
else:
# Wait until next resume time
wait_time = min(next_resume_time - current_time, 0.1)
Expand Down Expand Up @@ -169,6 +182,7 @@ def __init__(
# Event-driven state tracking for when the executor is done
self._completion_event = threading.Event()
self._suspend_exception: SuspendExecution | None = None
self._resume_error: Exception | None = None

# ExecutionCounters will keep track of completion criteria and on-going counters
min_successful = self.completion_config.min_successful or len(self.executables)
Expand Down Expand Up @@ -222,11 +236,32 @@ def execute(
]
self._completion_event.clear()
self._suspend_exception = None

def resubmitter(executable_with_state: ExecutableWithState) -> None:
"""Resubmit a timed suspended task."""
execution_state.create_checkpoint()
submit_task(executable_with_state)
self._resume_error = None

def resubmitter(ready: list[ExecutableWithState]) -> None:
"""Resubmit a wave of timed-suspended tasks.

One checkpoint refresh serves the whole due wave: the fetch returns
all operations, so every resumed branch reads fresh state. The
refresh only raises when the background checkpoint subsystem has
failed, which is terminal for the whole execution, so record the
error and wake the parent to re-raise it. Catching here keeps the
single timer thread alive so a failure does not strand the other
pending resumes.
"""
try:
execution_state.create_checkpoint()
except Exception as exc: # noqa: BLE001
# resubmitter runs only on the single timer thread, so this
# check-then-set needs no lock. First error wins: keep the
# earliest failure if several waves fail before execute() reads
# it (they are the same terminal checkpoint failure anyway).
if self._resume_error is None: # pragma: no branch
self._resume_error = exc
self._completion_event.set()
return
for executable_with_state in ready:
submit_task(executable_with_state)

thread_executor = ThreadPoolExecutor(max_workers=max_workers)
try:
Expand Down Expand Up @@ -259,6 +294,12 @@ def on_done(future: Future) -> None:
for future in futures:
future.cancel()

# A timed resume failed to refresh state (terminal checkpoint
# subsystem failure). Re-raise so the invocation fails and the
# backend retries from the last durable checkpoint.
if self._resume_error is not None:
raise self._resume_error

# Suspend execution if everything done and at least one of the tasks raised a suspend exception.
if self._suspend_exception:
raise self._suspend_exception
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def __init__(
):
self.durable_execution_arn: str = durable_execution_arn
self._current_checkpoint_token: str = initial_checkpoint_token
self.operations: MutableMapping[str, Operation] = operations
self._operations: dict[str, Operation] = dict(operations)
self._service_client: DurableServiceClient = service_client
self._plugin_executor: PluginExecutor = plugin_executor
self._ordered_checkpoint_lock: OrderedLock = OrderedLock()
Expand Down Expand Up @@ -279,6 +279,16 @@ def __init__(
self._replay_status_lock: Lock = Lock()
self._visited_operations: set[str] = set()

@property
def operations(self) -> dict[str, Operation]:
"""Return a point-in-time snapshot copy of the operations map.

The returned dict is a copy, so mutating it does not affect execution
state and iterating it is safe against concurrent updates.
"""
with self._operations_lock:
return dict(self._operations)

def fetch_paginated_operations(
self,
initial_operations: list[Operation],
Expand Down Expand Up @@ -324,7 +334,7 @@ def fetch_paginated_operations(
# Always store whatever operations we successfully fetched
if all_operations:
with self._operations_lock:
self.operations.update(
self._operations.update(
{op.operation_id: op for op in all_operations}
)
return all_operations
Expand All @@ -341,7 +351,8 @@ def get_input_payload(self) -> str | None:
def get_execution_operation(self) -> Operation | None:
# invocation id is id of execution operation
invocation_id = self.durable_execution_arn.split("/")[-1]
candidate = self.operations.get(invocation_id)
with self._operations_lock:
candidate = self._operations.get(invocation_id)
if not candidate:
# Due to payload size limitations we may have an empty operations list.
# This will only happen when loading the initial page of results and is
Expand Down Expand Up @@ -370,19 +381,21 @@ def track_replay(self, operation_id: str) -> None:
with self._replay_status_lock:
if self._replay_status == ReplayStatus.REPLAY:
self._visited_operations.add(operation_id)
completed_ops = {
op_id
for op_id, op in self.operations.items()
if op.operation_type != OperationType.EXECUTION
and op.status
in {
OperationStatus.SUCCEEDED,
OperationStatus.FAILED,
OperationStatus.CANCELLED,
OperationStatus.STOPPED,
OperationStatus.TIMED_OUT,
# Lock order: _replay_status_lock then _operations_lock.
with self._operations_lock:
completed_ops = {
op_id
for op_id, op in self._operations.items()
if op.operation_type != OperationType.EXECUTION
and op.status
in {
OperationStatus.SUCCEEDED,
OperationStatus.FAILED,
OperationStatus.CANCELLED,
OperationStatus.STOPPED,
OperationStatus.TIMED_OUT,
}
}
}
if completed_ops.issubset(self._visited_operations):
logger.debug(
"Transitioning from REPLAY to NEW status at operation %s",
Expand All @@ -404,7 +417,7 @@ def mark_replaying_if_prior_operations_exist(self) -> None:
with self._operations_lock:
has_prior_operations: bool = any(
op.operation_type is not OperationType.EXECUTION
for op in self.operations.values()
for op in self._operations.values()
)

if has_prior_operations:
Expand All @@ -431,7 +444,7 @@ def get_checkpoint_result(self, checkpoint_id: str) -> CheckpointedResult:
"""
# checking status are deliberately under a lighter non-serialized lock
with self._operations_lock:
if checkpoint := self.operations.get(checkpoint_id):
if checkpoint := self._operations.get(checkpoint_id):
return CheckpointedResult.create_from_operation(checkpoint)

return CHECKPOINT_NOT_FOUND
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1370,6 +1370,74 @@ def execute_item(self, child_context, executable):
executor.execute(execution_state, executor_context)


def test_concurrent_executor_resume_checkpoint_failure_propagates():
"""A resume-time checkpoint refresh failure propagates out of execute().

Regression guard: the timer resubmit does a blocking checkpoint refresh.
That refresh only raises when the checkpoint subsystem has failed, which
is terminal. execute() must re-raise it (so the invocation fails and the
backend retries from the last durable checkpoint) rather than leave the
wave PENDING forever - the completion wait has no timeout, so a stranded
PENDING branch would hang the whole map.
"""

class TestExecutor(ConcurrentExecutor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.calls: dict[int, int] = {}
self.long_runner_release = threading.Event()

def execute_item(self, child_context, executable):
task_id = executable.index
self.calls[task_id] = self.calls.get(task_id, 0) + 1
if task_id == 0:
# Long-runner keeps the map alive so task 1 resumes in-process.
self.long_runner_release.wait(timeout=5)
return "result_A"
# Task 1 suspends with a past timestamp -> immediate in-process resume.
msg = "resume-me"
raise TimedSuspendExecution(msg, time.time() - 1)

executables = [Executable(0, lambda: "task_A"), Executable(1, lambda: "task_B")]
completion_config = CompletionConfig(
min_successful=2,
tolerated_failure_count=None,
tolerated_failure_percentage=None,
)

executor = TestExecutor(
executables=executables,
max_concurrency=2,
completion_config=completion_config,
sub_type_top="TOP",
sub_type_iteration="ITER",
name_prefix="test_",
serdes=None,
)

execution_state = Mock()

def checkpoint(*args, **kwargs):
# The resume refresh calls create_checkpoint() with no arguments.
# Fail that call; leave the branches' own checkpoints as no-ops.
if not args and not kwargs:
msg = "resume refresh failed"
raise RuntimeError(msg)

execution_state.create_checkpoint = Mock(side_effect=checkpoint)

executor_context = Mock()
executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa: SLF001
child_context = Mock()
child_context.state.wrap_user_function = lambda func, *args, **kwargs: func
executor_context.create_child_context = lambda *args, **kwargs: child_context

# Must re-raise (not hang): the resume failure surfaces as the original error.
with pytest.raises(RuntimeError, match="resume refresh failed"):
executor.execute(execution_state, executor_context)
executor.long_runner_release.set()


def test_concurrent_executor_with_timed_resubmit_while_other_task_running():
"""Test timed resubmission while other tasks are still running."""

Expand Down Expand Up @@ -3200,7 +3268,9 @@ def test_timer_scheduler_fifo_ordering_with_same_timestamp():
items synchronously, so callback order is deterministic.
"""
results = []
resubmit_callback = Mock(side_effect=lambda exe: results.append(exe.index))
resubmit_callback = Mock(
side_effect=lambda batch: results.extend(exe.index for exe in batch)
)

with TimerScheduler(resubmit_callback) as scheduler:
# Use a past timestamp so they trigger immediately
Expand Down
88 changes: 86 additions & 2 deletions packages/aws-durable-execution-sdk-python/tests/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,7 +1397,7 @@ def test_concurrent_access_to_operations_dictionary():
operation_type=OperationType.STEP,
status=OperationStatus.SUCCEEDED,
)
state.operations["op1"] = operation
state._operations["op1"] = operation

results = []
errors = []
Expand All @@ -1422,7 +1422,7 @@ def writer_thread():
status=OperationStatus.SUCCEEDED,
)
with state._operations_lock:
state.operations[f"op{i}"] = new_op
state._operations[f"op{i}"] = new_op
time.sleep(0.001)
except Exception as e:
errors.append(e)
Expand Down Expand Up @@ -4260,3 +4260,87 @@ def test_plugin_executor_not_called_for_pending_operations():


# endregion Plugin Executor Integration Tests


def _make_execution_state_for_operations(
mock_lambda_client, *, replay_status=ReplayStatus.NEW, operations=None
):
return ExecutionState(
durable_execution_arn="test_arn",
initial_checkpoint_token="token123", # noqa: S106
operations=operations or {},
service_client=mock_lambda_client,
plugin_executor=PluginExecutor(plugins=None),
replay_status=replay_status,
)


def test_operations_property_returns_snapshot_copy():
"""The operations property exposes a copy; mutating it must not affect state."""
mock_lambda_client = Mock(spec=LambdaClient)
op = Operation(
operation_id="op1",
operation_type=OperationType.STEP,
status=OperationStatus.SUCCEEDED,
)
state = _make_execution_state_for_operations(
mock_lambda_client, operations={"op1": op}
)

snapshot = state.operations
assert snapshot == {"op1": op}

snapshot["op2"] = op # mutating the returned copy must not leak into state
assert "op2" not in state.operations
assert len(state.operations) == 1


def test_track_replay_iteration_safe_under_concurrent_update():
"""track_replay must not raise when operations are updated concurrently.

A worker thread iterates operations inside track_replay while the checkpoint
path updates the same map. Without consistent locking this raises
"dictionary changed size during iteration".
"""
mock_lambda_client = Mock(spec=LambdaClient)
state = _make_execution_state_for_operations(
mock_lambda_client, replay_status=ReplayStatus.REPLAY
)
# Seed completed operations so track_replay keeps iterating (stays REPLAY).
for i in range(50):
state._operations[f"seed{i}"] = Operation(
operation_id=f"seed{i}",
operation_type=OperationType.STEP,
status=OperationStatus.SUCCEEDED,
)

errors: list[Exception] = []
stop = threading.Event()

def writer():
i = 0
while not stop.is_set():
with state._operations_lock:
state._operations[f"w{i}"] = Operation(
operation_id=f"w{i}",
operation_type=OperationType.STEP,
status=OperationStatus.SUCCEEDED,
)
i += 1

def reader():
try:
for _ in range(2000):
state.track_replay(operation_id="probe")
except Exception as e: # noqa: BLE001
errors.append(e)

writer_t = threading.Thread(target=writer, daemon=True)
reader_t = threading.Thread(target=reader, daemon=True)
writer_t.start()
reader_t.start()
reader_t.join(timeout=30)
stop.set()
writer_t.join(timeout=5)

assert not errors, f"track_replay raced with concurrent update: {errors}"