From f50ee9b2a656c2b0fd1870c1567fb263a3890b4c Mon Sep 17 00:00:00 2001 From: Asher Fink Date: Wed, 27 May 2026 14:25:05 -0400 Subject: [PATCH] feat(AGX1-274): register tasks in FGAC authz graph on create/delete --- agentex/src/domain/services/task_service.py | 84 +++++- agentex/tests/fixtures/services.py | 27 +- .../fixtures/integration_client.py | 3 + agentex/tests/integration/test_task_stream.py | 28 +- .../tests/integration/use_cases/__init__.py | 0 .../use_cases/test_task_fgac_dual_write.py | 216 ++++++++++++++ .../tests/unit/services/test_task_service.py | 3 + ...p_type_backwards_compatibility_use_case.py | 16 +- .../use_cases/test_agents_acp_use_case.py | 266 +++++++++--------- 9 files changed, 478 insertions(+), 165 deletions(-) create mode 100644 agentex/tests/integration/use_cases/__init__.py create mode 100644 agentex/tests/integration/use_cases/test_task_fgac_dual_write.py diff --git a/agentex/src/domain/services/task_service.py b/agentex/src/domain/services/task_service.py index 013c6903..3cc335d3 100644 --- a/agentex/src/domain/services/task_service.py +++ b/agentex/src/domain/services/task_service.py @@ -3,7 +3,9 @@ from fastapi import Depends +from src.adapters.crud_store.exceptions import ItemDoesNotExist from src.adapters.streams.adapter_redis import DRedisStreamRepository +from src.api.schemas.authorization_types import AgentexResource from src.domain.entities.agents import ACPType, AgentEntity from src.domain.entities.events import EventEntity from src.domain.entities.task_message_updates import TaskMessageUpdateEntity @@ -14,6 +16,7 @@ from src.domain.repositories.task_repository import DTaskRepository from src.domain.repositories.task_state_repository import DTaskStateRepository from src.domain.services.agent_acp_service import DAgentACPService +from src.domain.services.authorization_service import DAuthorizationService from src.utils.ids import orm_id from src.utils.logging import make_logger from src.utils.stream_topics import get_task_event_stream_topic @@ -33,12 +36,14 @@ def __init__( task_repository: DTaskRepository, event_repository: DEventRepository, stream_repository: DRedisStreamRepository, + authorization_service: DAuthorizationService, ): self.acp_client = acp_client self.task_state_repository = task_state_repository self.task_repository = task_repository self.event_repository = event_repository self.stream_repository = stream_repository + self.authorization_service = authorization_service async def create_task( self, @@ -59,19 +64,37 @@ async def create_task( Returns: Task containing the created task info """ - - task_entity = await self.task_repository.create( - agent_id=agent.id, - task=TaskEntity( - id=orm_id(), - name=task_name, - status=TaskStatus.RUNNING, - status_reason="Task created, forwarding to ACP server", - params=task_params, - task_metadata=task_metadata, - ), + # AGX1-274: register the task in the FGAC authorization graph BEFORE + # persisting, mirroring egp-api-backend's create flows (e.g. + # KnowledgeBaseV2UseCase). ``register_resource`` is a no-op unless the + # per-account ``fgac-tasks-dual-write`` flag is enabled in agentex-auth, + # so it is always called here and gated there. Registering first means a + # registration failure aborts the request with no orphaned Postgres row; + # if the persist then fails, a compensating ``deregister_resource`` + # keeps the authz graph from holding a tuple for a task that never + # existed. + task_entity = TaskEntity( + id=orm_id(), + name=task_name, + status=TaskStatus.RUNNING, + status_reason="Task created, forwarding to ACP server", + params=task_params, + task_metadata=task_metadata, + ) + await self.authorization_service.register_resource( + AgentexResource.task(task_entity.id), + parent=AgentexResource.agent(agent.id), ) - return task_entity + try: + return await self.task_repository.create( + agent_id=agent.id, + task=task_entity, + ) + except Exception: + await self.authorization_service.deregister_resource( + AgentexResource.task(task_entity.id), + ) + raise async def create_task_and_forward_to_acp( self, @@ -91,7 +114,9 @@ async def create_task_and_forward_to_acp( Task containing the created task info """ task_entity = await self.create_task( - agent=agent, task_name=task_name, task_params=task_params + agent=agent, + task_name=task_name, + task_params=task_params, ) if agent.acp_type == ACPType.SYNC: @@ -214,8 +239,41 @@ async def delete_task(self, id: str | None = None, name: str | None = None) -> N """ Delete a task from the repository. """ + # AGX1-274: deregister the task from the FGAC authorization graph after + # the Postgres row is gone, mirroring egp-api-backend's delete flows + # (e.g. KnowledgeBaseV2UseCase). Postgres is the source of truth for + # existence, so we delete first and deregister best-effort: a deregister + # failure is logged and swallowed (it leaves an orphan tuple for a row + # that is already gone, invisible to reads) rather than failing a delete + # that already succeeded. ``deregister_resource`` is a no-op unless the + # per-account ``fgac-tasks-dual-write`` flag is enabled in agentex-auth. + # + # Resolve the id before the delete so we can deregister by id; looking + # it up by name afterwards would race. If the name doesn't resolve, + # swallow ItemDoesNotExist and let delete() surface its own native error + # so the missing-task error contract is unchanged. + task_id_for_deregister: str | None = id + if task_id_for_deregister is None and name is not None: + try: + task = await self.task_repository.get(name=name) + task_id_for_deregister = task.id + except ItemDoesNotExist: + task_id_for_deregister = None + await self.task_repository.delete(id=id, name=name) + if task_id_for_deregister is not None: + try: + await self.authorization_service.deregister_resource( + AgentexResource.task(task_id_for_deregister), + ) + except Exception: + logger.exception( + "task FGAC deregister failed for task %s; the Postgres row " + "is already deleted, leaving an orphan authz tuple", + task_id_for_deregister, + ) + async def list_tasks( self, *, diff --git a/agentex/tests/fixtures/services.py b/agentex/tests/fixtures/services.py index c30c06c8..8969f250 100644 --- a/agentex/tests/fixtures/services.py +++ b/agentex/tests/fixtures/services.py @@ -3,7 +3,7 @@ Provides factory functions and specific fixtures for creating services with test repositories. """ -from unittest.mock import MagicMock, Mock +from unittest.mock import AsyncMock, MagicMock, Mock import pytest @@ -12,6 +12,24 @@ # ============================================================================= +def make_noop_authorization_service() -> Mock: + """Shared noop AuthorizationService mock for tests that don't exercise authz. + + ``principal_context`` is ``None``, and + ``grant``/``revoke``/``register_resource``/``deregister_resource`` are async + no-ops returning ``None`` — matching the real service signature. Use this + anywhere a test just needs to construct ``AgentTaskService`` without caring + about authorization behavior. + """ + svc = Mock() + svc.principal_context = None + svc.grant = AsyncMock(return_value=None) + svc.revoke = AsyncMock(return_value=None) + svc.register_resource = AsyncMock(return_value=None) + svc.deregister_resource = AsyncMock(return_value=None) + return svc + + def create_task_message_service(task_message_repository): """Factory function to create TaskMessageService with given repository""" from src.domain.services.task_message_service import TaskMessageService @@ -52,16 +70,21 @@ def create_task_service( event_repository, agent_acp_service, redis_stream_repository, + authorization_service=None, ): - """Factory function to create AgentTaskService with given repositories and services""" + """Factory function to create AgentTaskService with given repositories and services.""" from src.domain.services.task_service import AgentTaskService + if authorization_service is None: + authorization_service = make_noop_authorization_service() + return AgentTaskService( task_repository=task_repository, task_state_repository=task_state_repository, event_repository=event_repository, acp_client=agent_acp_service, stream_repository=redis_stream_repository, + authorization_service=authorization_service, ) diff --git a/agentex/tests/integration/fixtures/integration_client.py b/agentex/tests/integration/fixtures/integration_client.py index d07fc4a1..b7c2c76a 100644 --- a/agentex/tests/integration/fixtures/integration_client.py +++ b/agentex/tests/integration/fixtures/integration_client.py @@ -22,6 +22,8 @@ from src.config.dependencies import GlobalDependencies from src.config.environment_variables import EnvironmentVariables +from tests.fixtures.services import make_noop_authorization_service + @pytest.fixture(scope="session") def event_loop(): @@ -455,6 +457,7 @@ async def send_message(self, *args, **kwargs): task_repository=isolated_repositories["task_repository"], event_repository=isolated_repositories["event_repository"], stream_repository=isolated_repositories["redis_stream_repository"], + authorization_service=make_noop_authorization_service(), ) return TasksUseCase(task_service=task_service) diff --git a/agentex/tests/integration/test_task_stream.py b/agentex/tests/integration/test_task_stream.py index 289010ee..f2a0d762 100644 --- a/agentex/tests/integration/test_task_stream.py +++ b/agentex/tests/integration/test_task_stream.py @@ -7,6 +7,8 @@ from src.domain.use_cases.tasks_use_case import TasksUseCase from src.utils.ids import orm_id +from tests.fixtures.services import make_noop_authorization_service + @pytest.mark.asyncio @pytest.mark.integration @@ -76,6 +78,7 @@ async def send_message(self, *args, **kwargs): task_repository=isolated_repositories["task_repository"], event_repository=isolated_repositories["event_repository"], stream_repository=isolated_repositories["redis_stream_repository"], + authorization_service=make_noop_authorization_service(), ) return TasksUseCase(task_service=task_service) @@ -103,6 +106,7 @@ async def send_message(self, *args, **kwargs): task_repository=isolated_repositories["task_repository"], event_repository=isolated_repositories["event_repository"], stream_repository=isolated_repositories["redis_stream_repository"], + authorization_service=make_noop_authorization_service(), ) environment_variables = EnvironmentVariables.refresh() @@ -194,17 +198,17 @@ async def collect_stream_events(): pass # Then - Verify the stream event was received - assert ( - len(stream_events) >= 1 - ), f"Expected at least 1 stream event, got {len(stream_events)}" + assert len(stream_events) >= 1, ( + f"Expected at least 1 stream event, got {len(stream_events)}" + ) # Find the task_updated event task_updated_events = [ e for e in stream_events if e.get("type") == "task_updated" ] - assert ( - len(task_updated_events) >= 1 - ), f"Expected task_updated event, got events: {[e.get('type') for e in stream_events]}" + assert len(task_updated_events) >= 1, ( + f"Expected task_updated event, got events: {[e.get('type') for e in stream_events]}" + ) task_updated_event = task_updated_events[0] @@ -389,9 +393,9 @@ async def collect_stream_events(): task_updated_events = [ e for e in stream_events if e.get("type") == "task_updated" ] - assert ( - len(task_updated_events) >= 3 - ), f"Expected at least 3 task_updated events, got {len(task_updated_events)}" + assert len(task_updated_events) >= 3, ( + f"Expected at least 3 task_updated events, got {len(task_updated_events)}" + ) # Verify each event has the correct metadata for its update versions = [ @@ -599,8 +603,8 @@ async def collect_stream_data(): pass # Then - Verify we received at least 2 pings - assert ( - ping_count >= 2 - ), f"Expected at least 2 ping messages during idle period, got {ping_count}" + assert ping_count >= 2, ( + f"Expected at least 2 ping messages during idle period, got {ping_count}" + ) print(f"✅ Stream sent {ping_count} keepalive pings during idle period") diff --git a/agentex/tests/integration/use_cases/__init__.py b/agentex/tests/integration/use_cases/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agentex/tests/integration/use_cases/test_task_fgac_dual_write.py b/agentex/tests/integration/use_cases/test_task_fgac_dual_write.py new file mode 100644 index 00000000..0558336b --- /dev/null +++ b/agentex/tests/integration/use_cases/test_task_fgac_dual_write.py @@ -0,0 +1,216 @@ +"""Tests for the AGX1-274 task FGAC dual-write call sites. + +scale-agentex always calls ``register_resource`` / ``deregister_resource``; the +per-account ``fgac-tasks-dual-write`` flag that turns the writes into no-ops +lives in agentex-auth, so there is no flag to toggle here. These assert the +call-site ordering and failure handling: + +- ``create_task`` registers the task (with the agent as ``parent``) *before* + persisting the Postgres row. +- If the persist fails after a successful register, ``create_task`` issues a + compensating ``deregister_resource`` and re-raises. +- ``delete_task`` deregisters *after* the Postgres delete, best-effort: a + deregister failure is swallowed (Postgres is the source of truth for + existence) so a delete that already succeeded does not surface an error. + +The register-before-persist and delete paths run as integration tests because +they touch the real task repository through ``isolated_repositories``; the +compensation test uses a mock repository so the persist can be forced to fail. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, Mock +from uuid import uuid4 + +import pytest +from src.adapters.crud_store.exceptions import ItemDoesNotExist +from src.api.schemas.authorization_types import AgentexResource, AgentexResourceType +from src.domain.entities.agents import ACPType, AgentEntity, AgentStatus +from src.domain.services.task_service import AgentTaskService + + +def _build_service(*, task_repository) -> tuple[AgentTaskService, Mock]: + authorization_service = Mock() + authorization_service.principal_context = None + authorization_service.grant = AsyncMock(return_value={}) + authorization_service.revoke = AsyncMock(return_value=None) + authorization_service.register_resource = AsyncMock(return_value=None) + authorization_service.deregister_resource = AsyncMock(return_value=None) + + service = AgentTaskService( + acp_client=Mock(), + task_state_repository=Mock(), + task_repository=task_repository, + event_repository=Mock(), + stream_repository=Mock(), + authorization_service=authorization_service, + ) + return service, authorization_service + + +def _agent_entity() -> AgentEntity: + return AgentEntity( + id=str(uuid4()), + name=f"dual-write-agent-{uuid4().hex[:8]}", + description="dual-write test agent", + status=AgentStatus.READY, + acp_type=ACPType.SYNC, + acp_url="http://test-acp", + ) + + +async def _persist_agent(agent_repository) -> AgentEntity: + return await agent_repository.create(_agent_entity()) + + +async def _task_exists(task_repository, task_id: str) -> bool: + try: + await task_repository.get(id=task_id) + return True + except ItemDoesNotExist: + return False + + +async def _clear_task_agent_links(task_repository, task_id: str) -> None: + """Delete the task_agents / agent_task_tracker join rows for a task. + + ``create_task`` writes both join rows, and ``task_repository.delete`` + issues a raw ``DELETE FROM tasks`` that the ``task_agents_task_id_fkey`` + FK rejects while those rows exist — this is the established, intentional + contract (see ``test_task_repository.test_delete_task`` and + ``test_task_service.test_delete_task_with_cleanup``). Tests that exercise + the hard-delete deregister path must clear the join rows first, exactly as + a real cascading delete would, otherwise the delete FK-fails before the + deregister dual-write is ever reached. + """ + from sqlalchemy import delete as sql_delete + from src.adapters.orm import AgentTaskTrackerORM, TaskAgentORM + + async with task_repository.start_async_db_session(True) as session: + await session.execute( + sql_delete(AgentTaskTrackerORM).where( + AgentTaskTrackerORM.task_id == task_id + ) + ) + await session.execute( + sql_delete(TaskAgentORM).where(TaskAgentORM.task_id == task_id) + ) + await session.commit() + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestTaskDualWrite: + async def test_create_task_registers_before_persist_with_agent_as_parent( + self, isolated_repositories + ): + task_repo = isolated_repositories["task_repository"] + agent_repo = isolated_repositories["agent_repository"] + agent = await _persist_agent(agent_repo) + service, authorization_service = _build_service(task_repository=task_repo) + + # When register fires, the Postgres row must not exist yet — this is + # what makes a registration failure abort the request cleanly. + observed = {} + + async def _record_existence(resource, parent=None): + observed["row_exists_at_register"] = await _task_exists( + task_repo, resource.selector + ) + + authorization_service.register_resource.side_effect = _record_existence + + task = await service.create_task( + agent=agent, task_name=f"dw-create-{uuid4().hex[:8]}" + ) + + assert observed["row_exists_at_register"] is False + assert await _task_exists(task_repo, task.id) is True + + authorization_service.register_resource.assert_awaited_once() + call = authorization_service.register_resource.call_args + registered_resource: AgentexResource = call.args[0] + assert registered_resource.type == AgentexResourceType.task + assert registered_resource.selector == task.id + + parent: AgentexResource | None = call.kwargs.get("parent") + assert parent is not None + assert parent.type == AgentexResourceType.agent + assert parent.selector == agent.id + + async def test_create_compensates_with_deregister_when_persist_fails(self): + # Register succeeds, then the Postgres persist blows up. create_task + # must deregister the just-registered task (so no orphan authz tuple is + # left for a task that never persisted) and re-raise. + task_repo = Mock() + task_repo.create = AsyncMock(side_effect=RuntimeError("db down")) + service, authorization_service = _build_service(task_repository=task_repo) + + with pytest.raises(RuntimeError): + await service.create_task( + agent=_agent_entity(), task_name=f"dw-fail-{uuid4().hex[:8]}" + ) + + authorization_service.register_resource.assert_awaited_once() + authorization_service.deregister_resource.assert_awaited_once() + registered = authorization_service.register_resource.call_args.args[0] + compensated = authorization_service.deregister_resource.call_args.args[0] + assert compensated.type == AgentexResourceType.task + assert compensated.selector == registered.selector + + async def test_delete_task_deregisters_after_delete(self, isolated_repositories): + task_repo = isolated_repositories["task_repository"] + agent_repo = isolated_repositories["agent_repository"] + agent = await _persist_agent(agent_repo) + service, authorization_service = _build_service(task_repository=task_repo) + task = await service.create_task( + agent=agent, task_name=f"dw-delete-{uuid4().hex[:8]}" + ) + authorization_service.deregister_resource.reset_mock() + await _clear_task_agent_links(task_repo, task.id) + + await service.delete_task(id=task.id) + + assert await _task_exists(task_repo, task.id) is False + authorization_service.deregister_resource.assert_awaited_once() + deregistered: AgentexResource = ( + authorization_service.deregister_resource.call_args.args[0] + ) + assert deregistered.type == AgentexResourceType.task + assert deregistered.selector == task.id + + async def test_delete_task_swallows_deregister_failure(self, isolated_repositories): + # Postgres is the source of truth for existence: a deregister failure + # after a successful delete leaves an orphan tuple (invisible to reads) + # rather than failing a delete that already happened. + task_repo = isolated_repositories["task_repository"] + agent_repo = isolated_repositories["agent_repository"] + agent = await _persist_agent(agent_repo) + service, authorization_service = _build_service(task_repository=task_repo) + task = await service.create_task( + agent=agent, task_name=f"dw-dereg-fail-{uuid4().hex[:8]}" + ) + await _clear_task_agent_links(task_repo, task.id) + authorization_service.deregister_resource.reset_mock() + authorization_service.deregister_resource.side_effect = RuntimeError( + "authz down" + ) + + # Must not raise despite the deregister failure. + await service.delete_task(id=task.id) + + assert await _task_exists(task_repo, task.id) is False + authorization_service.deregister_resource.assert_awaited_once() + + async def test_delete_task_by_missing_name_does_not_deregister( + self, isolated_repositories + ): + # The pre-delete id lookup catches ItemDoesNotExist, so a missing name + # neither deregisters nor changes the underlying delete's error contract. + task_repo = isolated_repositories["task_repository"] + service, authorization_service = _build_service(task_repository=task_repo) + + await service.delete_task(name=f"missing-{uuid4().hex[:8]}") + + authorization_service.deregister_resource.assert_not_awaited() diff --git a/agentex/tests/unit/services/test_task_service.py b/agentex/tests/unit/services/test_task_service.py index eb096eb1..2e1f19c3 100644 --- a/agentex/tests/unit/services/test_task_service.py +++ b/agentex/tests/unit/services/test_task_service.py @@ -19,6 +19,8 @@ from src.domain.repositories.task_state_repository import TaskStateRepository from src.domain.services.task_service import AgentTaskService +from tests.fixtures.services import make_noop_authorization_service + async def create_or_get_agent(agent_repository, agent): """Helper to create agent or get existing one if name already exists""" @@ -84,6 +86,7 @@ def task_service( task_state_repository=task_state_repository, event_repository=event_repository, stream_repository=redis_stream_repository, + authorization_service=make_noop_authorization_service(), ) diff --git a/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py b/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py index 60914adc..e6d13949 100644 --- a/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py +++ b/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py @@ -23,6 +23,7 @@ from src.domain.services.task_service import AgentTaskService from src.domain.use_cases.agents_acp_use_case import AgentsACPUseCase from src.domain.use_cases.agents_use_case import AgentsUseCase +from tests.fixtures.services import make_noop_authorization_service @pytest.mark.unit @@ -35,9 +36,9 @@ async def test_both_agentic_and_async_have_same_allowed_methods(self): agentic_methods = set(ACP_TYPE_TO_ALLOWED_RPC_METHODS[ACPType.AGENTIC]) async_methods = set(ACP_TYPE_TO_ALLOWED_RPC_METHODS[ACPType.ASYNC]) - assert ( - agentic_methods == async_methods - ), "AGENTIC and ASYNC should have identical allowed RPC methods" + assert agentic_methods == async_methods, ( + "AGENTIC and ASYNC should have identical allowed RPC methods" + ) # Verify they include the expected methods expected_methods = { @@ -95,6 +96,7 @@ async def test_agentic_agent_forwards_task_to_acp(self): task_repository=task_repo, event_repository=event_repo, stream_repository=stream_repo, + authorization_service=make_noop_authorization_service(), ) # Create AGENTIC agent @@ -148,6 +150,7 @@ async def test_sync_agent_does_not_forward_task_to_acp(self): task_repository=task_repo, event_repository=event_repo, stream_repository=stream_repo, + authorization_service=make_noop_authorization_service(), ) # Create SYNC agent @@ -195,6 +198,7 @@ async def test_async_agent_forwards_task_to_acp(self): task_repository=task_repo, event_repository=event_repo, stream_repository=stream_repo, + authorization_service=make_noop_authorization_service(), ) # Create ASYNC agent @@ -355,6 +359,6 @@ async def test_agentic_and_async_agents_both_use_not_sync_logic(self): # Both AGENTIC and ASYNC should pass the same conditional checks agentic_is_not_sync = agentic_agent.acp_type != ACPType.SYNC async_is_not_sync = async_agent.acp_type != ACPType.SYNC - assert ( - agentic_is_not_sync == async_is_not_sync - ), "AGENTIC and ASYNC should have identical behavior in != SYNC checks" + assert agentic_is_not_sync == async_is_not_sync, ( + "AGENTIC and ASYNC should have identical behavior in != SYNC checks" + ) diff --git a/agentex/tests/unit/use_cases/test_agents_acp_use_case.py b/agentex/tests/unit/use_cases/test_agents_acp_use_case.py index b48751a4..abb7af0b 100644 --- a/agentex/tests/unit/use_cases/test_agents_acp_use_case.py +++ b/agentex/tests/unit/use_cases/test_agents_acp_use_case.py @@ -36,6 +36,7 @@ from src.domain.services.task_message_service import TaskMessageService from src.domain.services.task_service import AgentTaskService from src.domain.use_cases.agents_acp_use_case import AgentsACPUseCase +from tests.fixtures.services import make_noop_authorization_service # UTC timezone constant UTC = ZoneInfo("UTC") @@ -135,6 +136,7 @@ def task_service( event_repository=event_repository, acp_client=agent_acp_service, stream_repository=redis_stream_repository, + authorization_service=make_noop_authorization_service(), ) @@ -552,9 +554,9 @@ def create_mock_stream(*args, **kwargs): # Verify database interactions - should have created messages final_message_count = len(await task_message_repository.list()) - assert ( - final_message_count > initial_message_count - ), "New messages should have been created in database" + assert final_message_count > initial_message_count, ( + "New messages should have been created in database" + ) # Get all messages from database to verify content all_messages = await task_message_repository.list() @@ -563,33 +565,33 @@ def create_mock_stream(*args, **kwargs): ] # Only the newly created messages # Should have at least 2 new messages (input + response) - assert ( - len(new_messages) >= 2 - ), f"Expected at least 2 new messages (input + response), got {len(new_messages)}" + assert len(new_messages) >= 2, ( + f"Expected at least 2 new messages (input + response), got {len(new_messages)}" + ) # Verify we have both user input and agent response messages content_authors = {msg.content.author.value for msg in new_messages} assert "user" in content_authors, "Should have user input message in database" - assert ( - "agent" in content_authors - ), "Should have agent response message in database" + assert "agent" in content_authors, ( + "Should have agent response message in database" + ) # Find the agent response message and verify its final accumulated content agent_messages = [ msg for msg in new_messages if msg.content.author == MessageAuthor.AGENT ] - assert ( - len(agent_messages) >= 1 - ), "Should have at least one agent response message" + assert len(agent_messages) >= 1, ( + "Should have at least one agent response message" + ) # Verify the final accumulated content includes both deltas response_message = agent_messages[0] # First agent response - assert ( - "Hello" in response_message.content.content - ), f"Expected 'Hello' in final content, got '{response_message.content.content}'" - assert ( - "world!" in response_message.content.content - ), f"Expected 'world!' in final content, got '{response_message.content.content}'" + assert "Hello" in response_message.content.content, ( + f"Expected 'Hello' in final content, got '{response_message.content.content}'" + ) + assert "world!" in response_message.content.content, ( + f"Expected 'world!' in final content, got '{response_message.content.content}'" + ) async def test_handle_message_send_stream_full_message( self, @@ -662,9 +664,9 @@ def create_mock_stream(*args, **kwargs): # Verify database interactions - should have created messages final_message_count = len(await task_message_repository.list()) - assert ( - final_message_count > initial_message_count - ), "New messages should have been created in database" + assert final_message_count > initial_message_count, ( + "New messages should have been created in database" + ) # Get all messages from database to verify content all_messages = await task_message_repository.list() @@ -673,30 +675,30 @@ def create_mock_stream(*args, **kwargs): ] # Only the newly created messages # Should have at least 2 new messages (input + response) - assert ( - len(new_messages) >= 2 - ), f"Expected at least 2 new messages (input + response), got {len(new_messages)}" + assert len(new_messages) >= 2, ( + f"Expected at least 2 new messages (input + response), got {len(new_messages)}" + ) # Verify we have both user input and agent response messages content_authors = {msg.content.author.value for msg in new_messages} assert "user" in content_authors, "Should have user input message in database" - assert ( - "agent" in content_authors - ), "Should have agent response message in database" + assert "agent" in content_authors, ( + "Should have agent response message in database" + ) # Find the agent response message and verify its content agent_messages = [ msg for msg in new_messages if msg.content.author == MessageAuthor.AGENT ] - assert ( - len(agent_messages) >= 1 - ), "Should have at least one agent response message" + assert len(agent_messages) >= 1, ( + "Should have at least one agent response message" + ) # Verify the FULL message content is correctly stored response_message = agent_messages[0] # First agent response - assert ( - response_message.content.content == "Complete message in one chunk" - ), f"Expected complete message content, got '{response_message.content.content}'" + assert response_message.content.content == "Complete message in one chunk", ( + f"Expected complete message content, got '{response_message.content.content}'" + ) async def test_handle_message_send_stream_multiple_indexes( self, @@ -820,9 +822,9 @@ def create_mock_stream(*args, **kwargs): # Verify database interactions - should have created messages final_message_count = len(await task_message_repository.list()) - assert ( - final_message_count > initial_message_count - ), "New messages should have been created in database" + assert final_message_count > initial_message_count, ( + "New messages should have been created in database" + ) # Get all messages from database to verify content all_messages = await task_message_repository.list() @@ -831,30 +833,30 @@ def create_mock_stream(*args, **kwargs): ] # Only the newly created messages # Should have at least 3 new messages (input + 2 response messages for different indexes) - assert ( - len(new_messages) >= 3 - ), f"Expected at least 3 new messages (input + 2 responses), got {len(new_messages)}" + assert len(new_messages) >= 3, ( + f"Expected at least 3 new messages (input + 2 responses), got {len(new_messages)}" + ) # Verify we have both user input and agent response messages content_authors = {msg.content.author.value for msg in new_messages} assert "user" in content_authors, "Should have user input message in database" - assert ( - "agent" in content_authors - ), "Should have agent response messages in database" + assert "agent" in content_authors, ( + "Should have agent response messages in database" + ) # Find the agent response messages - should have multiple for different indexes agent_messages = [ msg for msg in new_messages if msg.content.author == MessageAuthor.AGENT ] - assert ( - len(agent_messages) >= 2 - ), f"Should have at least 2 agent response messages for different indexes, got {len(agent_messages)}" + assert len(agent_messages) >= 2, ( + f"Should have at least 2 agent response messages for different indexes, got {len(agent_messages)}" + ) # Verify the content includes expected text from both indexes agent_content = " ".join([msg.content.content for msg in agent_messages]) - assert ( - "First" in agent_content or "Second" in agent_content - ), f"Expected content from multiple indexes, got '{agent_content}'" + assert "First" in agent_content or "Second" in agent_content, ( + f"Expected content from multiple indexes, got '{agent_content}'" + ) async def test_handle_task_create_error( self, @@ -1160,9 +1162,9 @@ def create_mock_stream(*args, **kwargs): # Verify database interactions - should have created messages final_message_count = len(await task_message_repository.list()) - assert ( - final_message_count > initial_message_count - ), "New messages should have been created in database" + assert final_message_count > initial_message_count, ( + "New messages should have been created in database" + ) # Get all messages from database to verify content all_messages = await task_message_repository.list() @@ -1171,33 +1173,33 @@ def create_mock_stream(*args, **kwargs): ] # Only the newly created messages # Should have at least 2 new messages (input + response) - assert ( - len(new_messages) >= 2 - ), f"Expected at least 2 new messages (input + response), got {len(new_messages)}" + assert len(new_messages) >= 2, ( + f"Expected at least 2 new messages (input + response), got {len(new_messages)}" + ) # Verify we have both user input and agent response messages content_authors = {msg.content.author.value for msg in new_messages} assert "user" in content_authors, "Should have user input message in database" - assert ( - "agent" in content_authors - ), "Should have agent response message in database" + assert "agent" in content_authors, ( + "Should have agent response message in database" + ) # Find the agent response message and verify accumulated content was flushed agent_messages = [ msg for msg in new_messages if msg.content.author == MessageAuthor.AGENT ] - assert ( - len(agent_messages) >= 1 - ), "Should have at least one agent response message" + assert len(agent_messages) >= 1, ( + "Should have at least one agent response message" + ) # Verify the deltas were properly accumulated and flushed to database response_message = agent_messages[0] # First agent response - assert ( - "Incomplete" in response_message.content.content - ), f"Expected 'Incomplete' in flushed content, got '{response_message.content.content}'" - assert ( - "message" in response_message.content.content - ), f"Expected 'message' in flushed content, got '{response_message.content.content}'" + assert "Incomplete" in response_message.content.content, ( + f"Expected 'Incomplete' in flushed content, got '{response_message.content.content}'" + ) + assert "message" in response_message.content.content, ( + f"Expected 'message' in flushed content, got '{response_message.content.content}'" + ) async def test_handle_message_send_stream_complex_mixed_content_types( self, @@ -1432,9 +1434,9 @@ def create_mock_stream(*args, **kwargs): # Verify database interactions - should have created messages final_message_count = len(await task_message_repository.list()) - assert ( - final_message_count > initial_message_count - ), "New messages should have been created in database" + assert final_message_count > initial_message_count, ( + "New messages should have been created in database" + ) # Get all messages from database to verify content all_messages = await task_message_repository.list() @@ -1443,18 +1445,18 @@ def create_mock_stream(*args, **kwargs): ] # Only the newly created messages # Should have at least 3 new messages (one for each index) plus deltas potentially stored - assert ( - len(new_messages) >= 3 - ), f"Expected at least 3 new messages, got {len(new_messages)}" + assert len(new_messages) >= 3, ( + f"Expected at least 3 new messages, got {len(new_messages)}" + ) # Verify we have messages with different content types content_types_found = {msg.content.type.value for msg in new_messages} expected_types = {"tool_request", "tool_response", "text"} # At least some of the expected types should be present (depends on how deltas vs full messages are stored) - assert ( - len(content_types_found.intersection(expected_types)) > 0 - ), f"Expected some of {expected_types}, got {content_types_found}" + assert len(content_types_found.intersection(expected_types)) > 0, ( + f"Expected some of {expected_types}, got {content_types_found}" + ) # Verify index distribution - should have messages for different indexes indexes_found = {getattr(update, "index", None) for update in updates} @@ -1478,18 +1480,18 @@ def create_mock_stream(*args, **kwargs): u for u in updates if isinstance(u, StreamTaskMessageDoneEntity) ] - assert ( - len(start_updates) == 3 - ), f"Expected 3 START updates, got {len(start_updates)}" - assert ( - len(delta_updates) == 6 - ), f"Expected 6 DELTA updates, got {len(delta_updates)}" - assert ( - len(full_updates) == 1 - ), f"Expected 1 FULL update, got {len(full_updates)}" - assert ( - len(done_updates) == 2 - ), f"Expected 2 DONE updates, got {len(done_updates)} (index 0 completed with FULL message)" + assert len(start_updates) == 3, ( + f"Expected 3 START updates, got {len(start_updates)}" + ) + assert len(delta_updates) == 6, ( + f"Expected 6 DELTA updates, got {len(delta_updates)}" + ) + assert len(full_updates) == 1, ( + f"Expected 1 FULL update, got {len(full_updates)}" + ) + assert len(done_updates) == 2, ( + f"Expected 2 DONE updates, got {len(done_updates)} (index 0 completed with FULL message)" + ) # Verify content types in START messages start_content_types = {update.content.type.value for update in start_updates} @@ -1641,27 +1643,27 @@ async def mock_async_call(*args, **kwargs): # Verify database interactions - should have created an event final_event_count = len(await event_repository.list()) - assert ( - final_event_count > initial_event_count - ), "New event should have been created in database" + assert final_event_count > initial_event_count, ( + "New event should have been created in database" + ) # Get all events from database to verify content all_events = await event_repository.list() new_events = all_events[initial_event_count:] # Only the newly created events # Should have exactly 1 new event - assert ( - len(new_events) == 1 - ), f"Expected exactly 1 new event, got {len(new_events)}" + assert len(new_events) == 1, ( + f"Expected exactly 1 new event, got {len(new_events)}" + ) # Verify the event was properly stored created_event = new_events[0] - assert ( - created_event.task_id == created_task.id - ), f"Expected task_id {created_task.id}, got {created_event.task_id}" - assert ( - created_event.content == sample_text_content - ), "Expected event content to match input" + assert created_event.task_id == created_task.id, ( + f"Expected task_id {created_task.id}, got {created_event.task_id}" + ) + assert created_event.content == sample_text_content, ( + "Expected event content to match input" + ) async def test_handle_event_send_with_task_name( self, @@ -1713,27 +1715,27 @@ async def mock_async_call(*args, **kwargs): # Verify database interactions - should have created an event final_event_count = len(await event_repository.list()) - assert ( - final_event_count > initial_event_count - ), "New event should have been created in database" + assert final_event_count > initial_event_count, ( + "New event should have been created in database" + ) # Get all events from database to verify content all_events = await event_repository.list() new_events = all_events[initial_event_count:] # Only the newly created events # Should have exactly 1 new event - assert ( - len(new_events) == 1 - ), f"Expected exactly 1 new event, got {len(new_events)}" + assert len(new_events) == 1, ( + f"Expected exactly 1 new event, got {len(new_events)}" + ) # Verify the event was properly stored created_event = new_events[0] - assert ( - created_event.task_id == created_task.id - ), f"Expected task_id {created_task.id}, got {created_event.task_id}" - assert ( - created_event.content == sample_text_content - ), "Expected event content to match input" + assert created_event.task_id == created_task.id, ( + f"Expected task_id {created_task.id}, got {created_event.task_id}" + ) + assert created_event.content == sample_text_content, ( + "Expected event content to match input" + ) async def test_handle_event_send_with_request_headers( self, @@ -1811,9 +1813,9 @@ async def mock_async_call(*args, **kwargs): # Verify database interactions - should have created an event final_event_count = len(await event_repository.list()) - assert ( - final_event_count > initial_event_count - ), "New event should have been created in database" + assert final_event_count > initial_event_count, ( + "New event should have been created in database" + ) # Verify HTTP call was made (mock_async_call will assert headers) mock_http_gateway.async_call.assert_called_once() @@ -1871,9 +1873,9 @@ async def mock_async_call(*args, **kwargs): # Verify database interactions - should have created an event final_event_count = len(await event_repository.list()) - assert ( - final_event_count > initial_event_count - ), "New event should have been created in database" + assert final_event_count > initial_event_count, ( + "New event should have been created in database" + ) # Verify HTTP call was made (mock_async_call will assert no headers) mock_http_gateway.async_call.assert_called_once() @@ -2056,9 +2058,9 @@ def create_mock_stream(*args, **kwargs): # Verify database interactions - should have created messages final_message_count = len(await task_message_repository.list()) - assert ( - final_message_count > initial_message_count - ), "New messages should have been created in database" + assert final_message_count > initial_message_count, ( + "New messages should have been created in database" + ) # Get all messages from database to verify content all_messages = await task_message_repository.list() @@ -2067,26 +2069,26 @@ def create_mock_stream(*args, **kwargs): # Verify we have both user input and agent response messages content_authors = {msg.content.author.value for msg in new_messages} assert "user" in content_authors, "Should have user input message in database" - assert ( - "agent" in content_authors - ), "Should have agent response message in database" + assert "agent" in content_authors, ( + "Should have agent response message in database" + ) # Find the agent response message and verify its final accumulated content agent_messages = [ msg for msg in new_messages if msg.content.author == MessageAuthor.AGENT ] - assert ( - len(agent_messages) >= 1 - ), "Should have at least one agent response message" + assert len(agent_messages) >= 1, ( + "Should have at least one agent response message" + ) # Verify the final accumulated content includes both deltas response_message = agent_messages[0] - assert ( - "Stream response" in response_message.content.content - ), f"Expected 'Stream response' in final content, got '{response_message.content.content}'" - assert ( - "to named task" in response_message.content.content - ), f"Expected 'to named task' in final content, got '{response_message.content.content}'" + assert "Stream response" in response_message.content.content, ( + f"Expected 'Stream response' in final content, got '{response_message.content.content}'" + ) + assert "to named task" in response_message.content.content, ( + f"Expected 'to named task' in final content, got '{response_message.content.content}'" + ) async def test_handle_message_send_sync_with_task_params( self,