diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py index b0154c8d6..db45ecaf3 100644 --- a/src/a2a/server/agent_execution/active_task.py +++ b/src/a2a/server/agent_execution/active_task.py @@ -729,6 +729,43 @@ async def cancel(self, call_context: ServerCallContext) -> Task: raise RuntimeError('Task should have been created') return task + async def aclose(self) -> None: + """Force-closes the task's queues and drains its background tasks. + + Provides a bounded, public teardown for the producer and consumer + ``asyncio.Task``s spawned in ``start()``. Without it, a producer + wedged in its ``finally`` closing an abandoned subscriber sink can + survive until event-loop shutdown and surface as + ``Task was destroyed but it is pending!``. + + Always forces: the queues are closed with ``immediate=True`` and the + background tasks are cancelled, so teardown is bounded even when a + subscriber sink was never drained. It is safe to call multiple times. + """ + await self._event_queue_agent.close(immediate=True) + await self._event_queue_subscribers.close(immediate=True) + # Set `_is_finished` and collect the background tasks under `_lock` so + # this is mutually exclusive with `start()`, which refuses to spawn + # once `_is_finished` is set. The lock is released before awaiting the + # tasks, because their teardown re-acquires it. + async with self._lock: + self._is_finished.set() + background_tasks = [ + task + for task in (self._producer_task, self._consumer_task) + if task is not None + ] + for task in background_tasks: + task.cancel() + if background_tasks: + results = await asyncio.gather( + *background_tasks, return_exceptions=True + ) + for result in results: + # CancelledError is a BaseException, so it is excluded here. + if isinstance(result, Exception): + logger.error('Error during aclose', exc_info=result) + async def _maybe_cleanup(self) -> None: """Triggers cleanup if task is finished and has no subscribers. diff --git a/src/a2a/server/agent_execution/active_task_registry.py b/src/a2a/server/agent_execution/active_task_registry.py index 9c1299ab3..ac37575d9 100644 --- a/src/a2a/server/agent_execution/active_task_registry.py +++ b/src/a2a/server/agent_execution/active_task_registry.py @@ -34,6 +34,7 @@ def __init__( self._active_tasks: dict[str, ActiveTask] = {} self._lock = asyncio.Lock() self._cleanup_tasks: set[asyncio.Task[None]] = set() + self._closed = False async def get_or_create( self, @@ -44,6 +45,8 @@ async def get_or_create( ) -> ActiveTask: """Retrieves an existing ActiveTask or creates a new one.""" async with self._lock: + if self._closed: + raise RuntimeError('ActiveTaskRegistry is closed') if task_id in self._active_tasks: return self._active_tasks[task_id] @@ -86,3 +89,39 @@ async def get(self, task_id: str) -> ActiveTask | None: """Retrieves an existing task.""" async with self._lock: return self._active_tasks.get(task_id) + + async def aclose(self) -> None: + """Closes the registry and drains all active tasks. + + Marks the registry closed so ``get_or_create`` refuses new work, then + force-closes every registered ``ActiveTask`` and awaits the in-flight + ``_remove_task`` cleanup tasks they schedule, so no SDK-owned + ``asyncio.Task`` is left pending at event-loop shutdown. Safe to call + multiple times. + + The close flag is set and the active-task snapshot is taken under + ``_lock``, and the lock is then released before awaiting, because + ``_remove_task`` re-acquires ``_lock``; holding it while draining + would deadlock. Marking closed under the same lock prevents a + concurrent ``get_or_create`` from registering a task that the drain + would miss. + """ + async with self._lock: + self._closed = True + active_tasks = list(self._active_tasks.values()) + + if active_tasks: + results = await asyncio.gather( + *(task.aclose() for task in active_tasks), + return_exceptions=True, + ) + for result in results: + if isinstance(result, Exception): + logger.error('Error draining active task', exc_info=result) + + cleanup_tasks = list(self._cleanup_tasks) + if cleanup_tasks: + await asyncio.gather(*cleanup_tasks, return_exceptions=True) + + async with self._lock: + self._active_tasks.clear() diff --git a/src/a2a/server/request_handlers/default_request_handler_v2.py b/src/a2a/server/request_handlers/default_request_handler_v2.py index 30304609a..00d7dc195 100644 --- a/src/a2a/server/request_handlers/default_request_handler_v2.py +++ b/src/a2a/server/request_handlers/default_request_handler_v2.py @@ -112,6 +112,15 @@ def __init__( # noqa: PLR0913 ) self._background_tasks = set() + async def aclose(self) -> None: + """Shuts down the handler, draining all active tasks. + + Drains the ``ActiveTaskRegistry`` so a server shutdown leaves no + pending ``asyncio.Task``. Intended to be wired into an ASGI + ``lifespan`` / ``on_shutdown`` hook. Safe to call multiple times. + """ + await self._active_task_registry.aclose() + @validate_request_params async def on_get_task( # noqa: D102 self, diff --git a/tests/server/agent_execution/test_active_task.py b/tests/server/agent_execution/test_active_task.py index ce9e2c068..2bc471e9c 100644 --- a/tests/server/agent_execution/test_active_task.py +++ b/tests/server/agent_execution/test_active_task.py @@ -895,3 +895,90 @@ async def execute_mock(req, q): assert len(events) == 0 await active_task.cancel(request_context) + + +@pytest.mark.timeout(5) +@pytest.mark.asyncio +async def test_active_task_aclose_reaps_background_tasks(): + """aclose() drains a live producer and consumer.""" + agent_executor = Mock() + task_manager = Mock() + request_context = Mock(spec=RequestContext) + + active_task = ActiveTask( + agent_executor=agent_executor, + task_id='test-task-id', + task_manager=task_manager, + push_sender=Mock(), + ) + + async def slow_execute(req, q): + await asyncio.sleep(10) + + agent_executor.execute = AsyncMock(side_effect=slow_execute) + task_manager.get_task = AsyncMock( + return_value=Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + await active_task.aclose() + + assert active_task._producer_task is not None + assert active_task._producer_task.done() + assert active_task._consumer_task is not None + assert active_task._consumer_task.done() + assert active_task._is_finished.is_set() + + +@pytest.mark.timeout(5) +@pytest.mark.asyncio +async def test_active_task_aclose_force_closes_undrained_subscriber(): + """aclose() unblocks past an undrained subscriber sink. + + Reproduces issue #1101: a graceful close(immediate=False) would block + forever on the leaked sink's join(). + """ + agent_executor = Mock() + task_manager = Mock() + request_context = Mock(spec=RequestContext) + + active_task = ActiveTask( + agent_executor=agent_executor, + task_id='test-task-id', + task_manager=task_manager, + push_sender=Mock(), + ) + + async def slow_execute(req, q): + await asyncio.sleep(10) + + agent_executor.execute = AsyncMock(side_effect=slow_execute) + task_manager.get_task = AsyncMock( + return_value=Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + # Leak a subscriber sink and push an event into it without draining it. + leaked = await active_task._event_queue_subscribers.tap() + await active_task._event_queue_subscribers.enqueue_event(Message()) + await asyncio.sleep(0.05) + + await active_task.aclose() + + assert active_task._producer_task is not None + assert active_task._producer_task.done() + assert leaked.is_closed() diff --git a/tests/server/agent_execution/test_active_task_registry.py b/tests/server/agent_execution/test_active_task_registry.py new file mode 100644 index 000000000..16d9c8797 --- /dev/null +++ b/tests/server/agent_execution/test_active_task_registry.py @@ -0,0 +1,107 @@ +import asyncio +import logging + +from unittest.mock import AsyncMock + +import pytest + +from a2a.server.agent_execution.active_task_registry import ActiveTaskRegistry +from a2a.server.agent_execution.agent_executor import AgentExecutor +from a2a.server.agent_execution.context import RequestContext +from a2a.server.context import ServerCallContext +from a2a.server.events.event_queue_v2 import EventQueue +from a2a.server.tasks import InMemoryTaskStore + + +class _SlowExecutor(AgentExecutor): + """An executor whose execute() blocks until cancelled.""" + + async def execute( + self, context: RequestContext, event_queue: EventQueue + ) -> None: + await asyncio.sleep(10) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ) -> None: + return None + + +def _make_registry() -> ActiveTaskRegistry: + return ActiveTaskRegistry( + agent_executor=_SlowExecutor(), + task_store=InMemoryTaskStore(), + ) + + +@pytest.mark.timeout(5) +@pytest.mark.asyncio +async def test_aclose_reaps_active_tasks_and_empties_registry(): + """aclose() reaps background tasks and removes them.""" + registry = _make_registry() + active = await registry.get_or_create( + 'task-1', + call_context=ServerCallContext(), + create_task_if_missing=True, + ) + + await registry.aclose() + + assert active._producer_task is not None + assert active._producer_task.done() + assert active._consumer_task is not None + assert active._consumer_task.done() + assert await registry.get('task-1') is None + + +@pytest.mark.timeout(5) +@pytest.mark.asyncio +async def test_aclose_is_idempotent(): + """Calling aclose() repeatedly is a safe no-op.""" + registry = _make_registry() + await registry.get_or_create( + 'task-1', + call_context=ServerCallContext(), + create_task_if_missing=True, + ) + + await registry.aclose() + await registry.aclose() + + +@pytest.mark.timeout(5) +@pytest.mark.asyncio +async def test_aclose_on_empty_registry(): + """aclose() with no active tasks returns immediately.""" + registry = _make_registry() + await registry.aclose() + + +@pytest.mark.timeout(5) +@pytest.mark.asyncio +async def test_get_or_create_rejected_after_aclose(): + """A closed registry refuses to create new tasks (no orphan race).""" + registry = _make_registry() + await registry.aclose() + + with pytest.raises(RuntimeError): + await registry.get_or_create( + 'task-1', + call_context=ServerCallContext(), + create_task_if_missing=True, + ) + + +@pytest.mark.timeout(5) +@pytest.mark.asyncio +async def test_aclose_logs_and_swallows_task_errors(caplog): + """A failing ActiveTask.aclose is logged, not propagated.""" + registry = _make_registry() + failing = AsyncMock() + failing.aclose = AsyncMock(side_effect=ValueError('boom')) + registry._active_tasks['bad'] = failing + + with caplog.at_level(logging.ERROR): + await registry.aclose() + + assert 'Error draining active task' in caplog.text diff --git a/tests/server/request_handlers/test_default_request_handler_v2.py b/tests/server/request_handlers/test_default_request_handler_v2.py index caaa4f88e..a90320473 100644 --- a/tests/server/request_handlers/test_default_request_handler_v2.py +++ b/tests/server/request_handlers/test_default_request_handler_v2.py @@ -1557,3 +1557,37 @@ async def test_on_get_task_push_notification_config_is_owner_scoped(): ), _ctx('bob'), ) + + +@pytest.mark.timeout(5) +@pytest.mark.asyncio +async def test_aclose_drains_registry(): + """aclose() drains the active-task registry on shutdown.""" + handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=InMemoryTaskStore(), + agent_card=create_default_agent_card(), + ) + await handler._active_task_registry.get_or_create( + 'task-1', + call_context=ServerCallContext(user=UnauthenticatedUser()), + create_task_if_missing=True, + ) + + await handler.aclose() + + assert await handler._active_task_registry.get('task-1') is None + + +@pytest.mark.timeout(5) +@pytest.mark.asyncio +async def test_aclose_is_idempotent_and_handles_empty(): + """aclose() is safe with no active tasks and when called twice.""" + handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=InMemoryTaskStore(), + agent_card=create_default_agent_card(), + ) + + await handler.aclose() + await handler.aclose()