Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions pyrit/backend/routes/scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from pyrit.backend.services.scenario_run_service import get_scenario_run_service
from pyrit.backend.services.scenario_service import get_scenario_service
from pyrit.models.scenario_result import ScenarioResult

router = APIRouter(prefix="/scenarios", tags=["scenarios"])

Expand Down Expand Up @@ -194,22 +195,21 @@ async def cancel_scenario_run(scenario_result_id: str) -> ScenarioRunSummary: #

@router.get(
"/runs/{scenario_result_id}/results",
response_model=ScenarioResult,
responses={
404: {"model": ProblemDetail, "description": "Run not found"},
409: {"model": ProblemDetail, "description": "Run not yet completed"},
},
)
async def get_scenario_run_results(scenario_result_id: str) -> dict: # pyrit-async-suffix-exempt
async def get_scenario_run_results(scenario_result_id: str) -> ScenarioResult: # pyrit-async-suffix-exempt
"""
Get detailed results for a completed scenario run.

Returns the full ScenarioResult serialization.

Args:
scenario_result_id: The scenario_result_id.

Returns:
dict: ScenarioResult.to_dict() payload.
ScenarioResult: Detailed run results. FastAPI handles JSON serialization.
"""
service = get_scenario_run_service()
try:
Expand All @@ -222,4 +222,4 @@ async def get_scenario_run_results(scenario_result_id: str) -> dict: # pyrit-as
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Scenario run '{scenario_result_id}' not found",
)
return result.to_dict()
return result
4 changes: 2 additions & 2 deletions pyrit/cli/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,12 @@ async def print_scenario_result_async(*, result_dict: dict[str, Any]) -> None:
Print detailed scenario results using the output module.

Args:
result_dict: ``ScenarioResult.to_dict()`` payload from the REST API.
result_dict: ``ScenarioResult.model_dump(mode="json", by_alias=True)`` payload from the REST API.
"""
from pyrit.models.scenario_result import ScenarioResult
from pyrit.output.scenario_result.pretty import PrettyScenarioResultMemoryPrinter

scenario_result = ScenarioResult.from_dict(result_dict)
scenario_result = ScenarioResult.model_validate(result_dict)
printer = PrettyScenarioResultMemoryPrinter()
await printer.write_async(scenario_result)

Expand Down
2 changes: 1 addition & 1 deletion pyrit/cli/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ async def get_scenario_run_results_async(self, *, scenario_result_id: str) -> di
Get detailed results for a completed scenario run.
Returns:
dict: ``ScenarioResult.to_dict()`` payload.
dict: ``ScenarioResult.model_dump(mode="json", by_alias=True)`` payload.
"""
return await self._get_json_async(path=f"/api/scenarios/runs/{scenario_result_id}/results")

Expand Down
40 changes: 22 additions & 18 deletions tests/unit/backend/test_scenario_run_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,33 +231,37 @@ class TestGetScenarioRunResultsRoute:

def test_get_results_returns_200(self, client: TestClient) -> None:
"""Test that getting results of a completed run returns 200."""
mock_scenario_result = MagicMock()
mock_scenario_result.to_dict.return_value = {
"id": "result-uuid",
"scenario_identifier": {"name": "foundry.red_team_agent", "version": 1},
"scenario_run_state": "COMPLETED",
"attack_results": {
"base64_attack": [
{
"attack_result_id": "ar-1",
"conversation_id": "conv-1",
"objective": "Extract sensitive info",
"outcome": "success",
}
]
},
}
from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier
from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult

attack = AttackResult(
conversation_id="conv-1",
objective="Extract sensitive info",
outcome=AttackOutcome.SUCCESS,
executed_turns=1,
execution_time_ms=100,
timestamp=datetime(2025, 1, 1, tzinfo=timezone.utc),
)
scenario_result = ScenarioResult(
scenario_identifier=ScenarioIdentifier(name="foundry.red_team_agent", description="Foundry red-team agent"),
objective_target_identifier=ComponentIdentifier.from_dict(
{"__type__": "FakeTarget", "__module__": "test.mod", "params": {}}
),
objective_scorer_identifier=None,
attack_results={"base64_attack": [attack]},
scenario_run_state="COMPLETED",
)

with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get:
mock_service = MagicMock()
mock_service.get_run_results.return_value = mock_scenario_result
mock_service.get_run_results.return_value = scenario_result
mock_get.return_value = mock_service

response = client.get("/api/scenarios/runs/test-run-id/results")

assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["id"] == "result-uuid"
assert data["scenario_identifier"]["name"] == "foundry.red_team_agent"
assert "base64_attack" in data["attack_results"]

def test_get_results_not_found_returns_404(self, client: TestClient) -> None:
Expand Down
19 changes: 11 additions & 8 deletions tests/unit/cli/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,23 +319,26 @@ async def test_print_scenario_result_async_uses_pretty_printer():
fake_printer.write_async = AsyncMock()

with (
patch("pyrit.models.scenario_result.ScenarioResult.from_dict", return_value=fake_scenario) as from_dict_mock,
patch(
"pyrit.models.scenario_result.ScenarioResult.model_validate", return_value=fake_scenario
) as model_validate_mock,
patch(
"pyrit.output.scenario_result.pretty.PrettyScenarioResultMemoryPrinter", return_value=fake_printer
) as printer_cls,
):
await _output.print_scenario_result_async(result_dict=result_dict)

from_dict_mock.assert_called_once_with(result_dict)
model_validate_mock.assert_called_once_with(result_dict)
printer_cls.assert_called_once_with()
fake_printer.write_async.assert_awaited_once_with(fake_scenario)


async def test_print_scenario_result_async_roundtrip_with_real_payload():
"""
Integration smoke test: a real ScenarioResult.to_dict() payload must flow
through ScenarioResult.from_dict() inside print_scenario_result_async
without raising. Locks the REST contract used by the CLI thin client.
Integration smoke test: a real ``ScenarioResult.model_dump(mode="json", by_alias=True)``
payload must flow through ``ScenarioResult.model_validate(...)`` inside
``print_scenario_result_async`` without raising. Locks the REST contract used by the CLI
thin client.
"""
from datetime import datetime, timezone

Expand All @@ -361,10 +364,10 @@ async def test_print_scenario_result_async_roundtrip_with_real_payload():
attack_results={"strat_a": [attack]},
scenario_run_state="COMPLETED",
)
payload = original.to_dict()
payload = original.model_dump(mode="json", by_alias=True)

# Drive print_scenario_result_async through the real from_dict path; only
# stub the printer to keep the test fast.
# Drive print_scenario_result_async through the real model_validate path;
# only stub the printer to keep the test fast.
fake_printer = MagicMock()
fake_printer.write_async = AsyncMock()
with patch(
Expand Down
Loading