Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/a2a/server/agent_execution/active_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,43 @@ async def cancel(self, call_context: ServerCallContext) -> Task:
raise RuntimeError('Task should have been created')
return task

async def aclose(self) -> None:
"""Force-closes the task's queues and drains its background tasks.

Provides a bounded, public teardown for the producer and consumer
``asyncio.Task``s spawned in ``start()``. Without it, a producer
wedged in its ``finally`` closing an abandoned subscriber sink can
survive until event-loop shutdown and surface as
``Task was destroyed but it is pending!``.

Always forces: the queues are closed with ``immediate=True`` and the
background tasks are cancelled, so teardown is bounded even when a
subscriber sink was never drained. It is safe to call multiple times.
"""
await self._event_queue_agent.close(immediate=True)
await self._event_queue_subscribers.close(immediate=True)
Comment on lines +745 to +746

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If start() was never called or failed early, the background tasks are never spawned, meaning their finally blocks (which shut down self._request_queue) will not run. Explicitly shutting down self._request_queue in aclose() ensures that all queues are properly cleaned up and any pending operations on the request queue are unblocked.

Suggested change
await self._event_queue_agent.close(immediate=True)
await self._event_queue_subscribers.close(immediate=True)
await self._event_queue_agent.close(immediate=True)
await self._event_queue_subscribers.close(immediate=True)
self._request_queue.shutdown(immediate=True)

# Set `_is_finished` and collect the background tasks under `_lock` so
# this is mutually exclusive with `start()`, which refuses to spawn
# once `_is_finished` is set. The lock is released before awaiting the
# tasks, because their teardown re-acquires it.
async with self._lock:
self._is_finished.set()
background_tasks = [
task
for task in (self._producer_task, self._consumer_task)
if task is not None
]
for task in background_tasks:
task.cancel()
if background_tasks:
results = await asyncio.gather(
*background_tasks, return_exceptions=True
)
for result in results:
# CancelledError is a BaseException, so it is excluded here.
if isinstance(result, Exception):
logger.error('Error during aclose', exc_info=result)

async def _maybe_cleanup(self) -> None:
"""Triggers cleanup if task is finished and has no subscribers.

Expand Down
39 changes: 39 additions & 0 deletions src/a2a/server/agent_execution/active_task_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
self._active_tasks: dict[str, ActiveTask] = {}
self._lock = asyncio.Lock()
self._cleanup_tasks: set[asyncio.Task[None]] = set()
self._closed = False

async def get_or_create(
self,
Expand All @@ -44,6 +45,8 @@ async def get_or_create(
) -> ActiveTask:
"""Retrieves an existing ActiveTask or creates a new one."""
async with self._lock:
if self._closed:
raise RuntimeError('ActiveTaskRegistry is closed')
if task_id in self._active_tasks:
return self._active_tasks[task_id]

Expand Down Expand Up @@ -86,3 +89,39 @@ async def get(self, task_id: str) -> ActiveTask | None:
"""Retrieves an existing task."""
async with self._lock:
return self._active_tasks.get(task_id)

async def aclose(self) -> None:
"""Closes the registry and drains all active tasks.

Marks the registry closed so ``get_or_create`` refuses new work, then
force-closes every registered ``ActiveTask`` and awaits the in-flight
``_remove_task`` cleanup tasks they schedule, so no SDK-owned
``asyncio.Task`` is left pending at event-loop shutdown. Safe to call
multiple times.

The close flag is set and the active-task snapshot is taken under
``_lock``, and the lock is then released before awaiting, because
``_remove_task`` re-acquires ``_lock``; holding it while draining
would deadlock. Marking closed under the same lock prevents a
concurrent ``get_or_create`` from registering a task that the drain
would miss.
"""
async with self._lock:
self._closed = True
active_tasks = list(self._active_tasks.values())

if active_tasks:
results = await asyncio.gather(
*(task.aclose() for task in active_tasks),
return_exceptions=True,
)
for result in results:
if isinstance(result, Exception):
logger.error('Error draining active task', exc_info=result)

cleanup_tasks = list(self._cleanup_tasks)
if cleanup_tasks:
await asyncio.gather(*cleanup_tasks, return_exceptions=True)

async with self._lock:
self._active_tasks.clear()
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ def __init__( # noqa: PLR0913
)
self._background_tasks = set()

async def aclose(self) -> None:
"""Shuts down the handler, draining all active tasks.

Drains the ``ActiveTaskRegistry`` so a server shutdown leaves no
pending ``asyncio.Task``. Intended to be wired into an ASGI
``lifespan`` / ``on_shutdown`` hook. Safe to call multiple times.
"""
await self._active_task_registry.aclose()

@validate_request_params
async def on_get_task( # noqa: D102
self,
Expand Down
87 changes: 87 additions & 0 deletions tests/server/agent_execution/test_active_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,3 +895,90 @@ async def execute_mock(req, q):
assert len(events) == 0

await active_task.cancel(request_context)


@pytest.mark.timeout(5)
@pytest.mark.asyncio
async def test_active_task_aclose_reaps_background_tasks():
"""aclose() drains a live producer and consumer."""
agent_executor = Mock()
task_manager = Mock()
request_context = Mock(spec=RequestContext)

active_task = ActiveTask(
agent_executor=agent_executor,
task_id='test-task-id',
task_manager=task_manager,
push_sender=Mock(),
)

async def slow_execute(req, q):
await asyncio.sleep(10)

agent_executor.execute = AsyncMock(side_effect=slow_execute)
task_manager.get_task = AsyncMock(
return_value=Task(
id='test-task-id',
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
)
)

await active_task.enqueue_request(request_context)
await active_task.start(
call_context=ServerCallContext(), create_task_if_missing=True
)

await active_task.aclose()

assert active_task._producer_task is not None
assert active_task._producer_task.done()
assert active_task._consumer_task is not None
assert active_task._consumer_task.done()
assert active_task._is_finished.is_set()


@pytest.mark.timeout(5)
@pytest.mark.asyncio
async def test_active_task_aclose_force_closes_undrained_subscriber():
"""aclose() unblocks past an undrained subscriber sink.

Reproduces issue #1101: a graceful close(immediate=False) would block
forever on the leaked sink's join().
"""
agent_executor = Mock()
task_manager = Mock()
request_context = Mock(spec=RequestContext)

active_task = ActiveTask(
agent_executor=agent_executor,
task_id='test-task-id',
task_manager=task_manager,
push_sender=Mock(),
)

async def slow_execute(req, q):
await asyncio.sleep(10)

agent_executor.execute = AsyncMock(side_effect=slow_execute)
task_manager.get_task = AsyncMock(
return_value=Task(
id='test-task-id',
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
)
)

await active_task.enqueue_request(request_context)
await active_task.start(
call_context=ServerCallContext(), create_task_if_missing=True
)

# Leak a subscriber sink and push an event into it without draining it.
leaked = await active_task._event_queue_subscribers.tap()
await active_task._event_queue_subscribers.enqueue_event(Message())
await asyncio.sleep(0.05)

await active_task.aclose()

assert active_task._producer_task is not None
assert active_task._producer_task.done()
assert leaked.is_closed()
107 changes: 107 additions & 0 deletions tests/server/agent_execution/test_active_task_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import asyncio
import logging

from unittest.mock import AsyncMock

import pytest

from a2a.server.agent_execution.active_task_registry import ActiveTaskRegistry
from a2a.server.agent_execution.agent_executor import AgentExecutor
from a2a.server.agent_execution.context import RequestContext
from a2a.server.context import ServerCallContext
from a2a.server.events.event_queue_v2 import EventQueue
from a2a.server.tasks import InMemoryTaskStore


class _SlowExecutor(AgentExecutor):
"""An executor whose execute() blocks until cancelled."""

async def execute(
self, context: RequestContext, event_queue: EventQueue
) -> None:
await asyncio.sleep(10)

async def cancel(
self, context: RequestContext, event_queue: EventQueue
) -> None:
return None


def _make_registry() -> ActiveTaskRegistry:
return ActiveTaskRegistry(
agent_executor=_SlowExecutor(),
task_store=InMemoryTaskStore(),
)


@pytest.mark.timeout(5)
@pytest.mark.asyncio
async def test_aclose_reaps_active_tasks_and_empties_registry():
"""aclose() reaps background tasks and removes them."""
registry = _make_registry()
active = await registry.get_or_create(
'task-1',
call_context=ServerCallContext(),
create_task_if_missing=True,
)

await registry.aclose()

assert active._producer_task is not None
assert active._producer_task.done()
assert active._consumer_task is not None
assert active._consumer_task.done()
assert await registry.get('task-1') is None


@pytest.mark.timeout(5)
@pytest.mark.asyncio
async def test_aclose_is_idempotent():
"""Calling aclose() repeatedly is a safe no-op."""
registry = _make_registry()
await registry.get_or_create(
'task-1',
call_context=ServerCallContext(),
create_task_if_missing=True,
)

await registry.aclose()
await registry.aclose()


@pytest.mark.timeout(5)
@pytest.mark.asyncio
async def test_aclose_on_empty_registry():
"""aclose() with no active tasks returns immediately."""
registry = _make_registry()
await registry.aclose()


@pytest.mark.timeout(5)
@pytest.mark.asyncio
async def test_get_or_create_rejected_after_aclose():
"""A closed registry refuses to create new tasks (no orphan race)."""
registry = _make_registry()
await registry.aclose()

with pytest.raises(RuntimeError):
await registry.get_or_create(
'task-1',
call_context=ServerCallContext(),
create_task_if_missing=True,
)


@pytest.mark.timeout(5)
@pytest.mark.asyncio
async def test_aclose_logs_and_swallows_task_errors(caplog):
"""A failing ActiveTask.aclose is logged, not propagated."""
registry = _make_registry()
failing = AsyncMock()
failing.aclose = AsyncMock(side_effect=ValueError('boom'))
registry._active_tasks['bad'] = failing

with caplog.at_level(logging.ERROR):
await registry.aclose()

assert 'Error draining active task' in caplog.text
Original file line number Diff line number Diff line change
Expand Up @@ -1557,3 +1557,37 @@ async def test_on_get_task_push_notification_config_is_owner_scoped():
),
_ctx('bob'),
)


@pytest.mark.timeout(5)
@pytest.mark.asyncio
async def test_aclose_drains_registry():
"""aclose() drains the active-task registry on shutdown."""
handler = DefaultRequestHandlerV2(
agent_executor=MockAgentExecutor(),
task_store=InMemoryTaskStore(),
agent_card=create_default_agent_card(),
)
await handler._active_task_registry.get_or_create(
'task-1',
call_context=ServerCallContext(user=UnauthenticatedUser()),
create_task_if_missing=True,
)

await handler.aclose()

assert await handler._active_task_registry.get('task-1') is None


@pytest.mark.timeout(5)
@pytest.mark.asyncio
async def test_aclose_is_idempotent_and_handles_empty():
"""aclose() is safe with no active tasks and when called twice."""
handler = DefaultRequestHandlerV2(
agent_executor=MockAgentExecutor(),
task_store=InMemoryTaskStore(),
agent_card=create_default_agent_card(),
)

await handler.aclose()
await handler.aclose()
Loading