diff --git a/pyrit/backend/routes/scenarios.py b/pyrit/backend/routes/scenarios.py index 941d8021fb..a83d1e9ac6 100644 --- a/pyrit/backend/routes/scenarios.py +++ b/pyrit/backend/routes/scenarios.py @@ -24,6 +24,7 @@ ) from pyrit.backend.services.scenario_run_service import get_scenario_run_service from pyrit.backend.services.scenario_service import get_scenario_service +from pyrit.models.scenario_result import ScenarioResult router = APIRouter(prefix="/scenarios", tags=["scenarios"]) @@ -194,22 +195,21 @@ async def cancel_scenario_run(scenario_result_id: str) -> ScenarioRunSummary: # @router.get( "/runs/{scenario_result_id}/results", + response_model=ScenarioResult, responses={ 404: {"model": ProblemDetail, "description": "Run not found"}, 409: {"model": ProblemDetail, "description": "Run not yet completed"}, }, ) -async def get_scenario_run_results(scenario_result_id: str) -> dict: # pyrit-async-suffix-exempt +async def get_scenario_run_results(scenario_result_id: str) -> ScenarioResult: # pyrit-async-suffix-exempt """ Get detailed results for a completed scenario run. - Returns the full ScenarioResult serialization. - Args: scenario_result_id: The scenario_result_id. Returns: - dict: ScenarioResult.to_dict() payload. + ScenarioResult: Detailed run results. FastAPI handles JSON serialization. """ service = get_scenario_run_service() try: @@ -222,4 +222,4 @@ async def get_scenario_run_results(scenario_result_id: str) -> dict: # pyrit-as status_code=status.HTTP_404_NOT_FOUND, detail=f"Scenario run '{scenario_result_id}' not found", ) - return result.to_dict() + return result diff --git a/pyrit/cli/_output.py b/pyrit/cli/_output.py index 3580c8e5b6..596fcbd682 100644 --- a/pyrit/cli/_output.py +++ b/pyrit/cli/_output.py @@ -282,12 +282,12 @@ async def print_scenario_result_async(*, result_dict: dict[str, Any]) -> None: Print detailed scenario results using the output module. Args: - result_dict: ``ScenarioResult.to_dict()`` payload from the REST API. + result_dict: ``ScenarioResult.model_dump(mode="json", by_alias=True)`` payload from the REST API. """ from pyrit.models.scenario_result import ScenarioResult from pyrit.output.scenario_result.pretty import PrettyScenarioResultMemoryPrinter - scenario_result = ScenarioResult.from_dict(result_dict) + scenario_result = ScenarioResult.model_validate(result_dict) printer = PrettyScenarioResultMemoryPrinter() await printer.write_async(scenario_result) diff --git a/pyrit/cli/api_client.py b/pyrit/cli/api_client.py index bfd75ca420..937dfad9a4 100644 --- a/pyrit/cli/api_client.py +++ b/pyrit/cli/api_client.py @@ -228,7 +228,7 @@ async def get_scenario_run_results_async(self, *, scenario_result_id: str) -> di Get detailed results for a completed scenario run. Returns: - dict: ``ScenarioResult.to_dict()`` payload. + dict: ``ScenarioResult.model_dump(mode="json", by_alias=True)`` payload. """ return await self._get_json_async(path=f"/api/scenarios/runs/{scenario_result_id}/results") diff --git a/tests/unit/backend/test_scenario_run_routes.py b/tests/unit/backend/test_scenario_run_routes.py index faf40a5b8f..846e91300a 100644 --- a/tests/unit/backend/test_scenario_run_routes.py +++ b/tests/unit/backend/test_scenario_run_routes.py @@ -231,33 +231,37 @@ class TestGetScenarioRunResultsRoute: def test_get_results_returns_200(self, client: TestClient) -> None: """Test that getting results of a completed run returns 200.""" - mock_scenario_result = MagicMock() - mock_scenario_result.to_dict.return_value = { - "id": "result-uuid", - "scenario_identifier": {"name": "foundry.red_team_agent", "version": 1}, - "scenario_run_state": "COMPLETED", - "attack_results": { - "base64_attack": [ - { - "attack_result_id": "ar-1", - "conversation_id": "conv-1", - "objective": "Extract sensitive info", - "outcome": "success", - } - ] - }, - } + from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier + from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult + + attack = AttackResult( + conversation_id="conv-1", + objective="Extract sensitive info", + outcome=AttackOutcome.SUCCESS, + executed_turns=1, + execution_time_ms=100, + timestamp=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + scenario_result = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="foundry.red_team_agent", description="Foundry red-team agent"), + objective_target_identifier=ComponentIdentifier.from_dict( + {"__type__": "FakeTarget", "__module__": "test.mod", "params": {}} + ), + objective_scorer_identifier=None, + attack_results={"base64_attack": [attack]}, + scenario_run_state="COMPLETED", + ) with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: mock_service = MagicMock() - mock_service.get_run_results.return_value = mock_scenario_result + mock_service.get_run_results.return_value = scenario_result mock_get.return_value = mock_service response = client.get("/api/scenarios/runs/test-run-id/results") assert response.status_code == status.HTTP_200_OK data = response.json() - assert data["id"] == "result-uuid" + assert data["scenario_identifier"]["name"] == "foundry.red_team_agent" assert "base64_attack" in data["attack_results"] def test_get_results_not_found_returns_404(self, client: TestClient) -> None: diff --git a/tests/unit/cli/test_output.py b/tests/unit/cli/test_output.py index 26f41cccc2..e288480fd0 100644 --- a/tests/unit/cli/test_output.py +++ b/tests/unit/cli/test_output.py @@ -319,23 +319,26 @@ async def test_print_scenario_result_async_uses_pretty_printer(): fake_printer.write_async = AsyncMock() with ( - patch("pyrit.models.scenario_result.ScenarioResult.from_dict", return_value=fake_scenario) as from_dict_mock, + patch( + "pyrit.models.scenario_result.ScenarioResult.model_validate", return_value=fake_scenario + ) as model_validate_mock, patch( "pyrit.output.scenario_result.pretty.PrettyScenarioResultMemoryPrinter", return_value=fake_printer ) as printer_cls, ): await _output.print_scenario_result_async(result_dict=result_dict) - from_dict_mock.assert_called_once_with(result_dict) + model_validate_mock.assert_called_once_with(result_dict) printer_cls.assert_called_once_with() fake_printer.write_async.assert_awaited_once_with(fake_scenario) async def test_print_scenario_result_async_roundtrip_with_real_payload(): """ - Integration smoke test: a real ScenarioResult.to_dict() payload must flow - through ScenarioResult.from_dict() inside print_scenario_result_async - without raising. Locks the REST contract used by the CLI thin client. + Integration smoke test: a real ``ScenarioResult.model_dump(mode="json", by_alias=True)`` + payload must flow through ``ScenarioResult.model_validate(...)`` inside + ``print_scenario_result_async`` without raising. Locks the REST contract used by the CLI + thin client. """ from datetime import datetime, timezone @@ -361,10 +364,10 @@ async def test_print_scenario_result_async_roundtrip_with_real_payload(): attack_results={"strat_a": [attack]}, scenario_run_state="COMPLETED", ) - payload = original.to_dict() + payload = original.model_dump(mode="json", by_alias=True) - # Drive print_scenario_result_async through the real from_dict path; only - # stub the printer to keep the test fast. + # Drive print_scenario_result_async through the real model_validate path; + # only stub the printer to keep the test fast. fake_printer = MagicMock() fake_printer.write_async = AsyncMock() with patch(