diff --git a/pyrit/cli/_output.py b/pyrit/cli/_output.py index 3580c8e5b6..507b14629f 100644 --- a/pyrit/cli/_output.py +++ b/pyrit/cli/_output.py @@ -284,7 +284,7 @@ async def print_scenario_result_async(*, result_dict: dict[str, Any]) -> None: Args: result_dict: ``ScenarioResult.to_dict()`` payload from the REST API. """ - from pyrit.models.scenario_result import ScenarioResult + from pyrit.models import ScenarioResult from pyrit.output.scenario_result.pretty import PrettyScenarioResultMemoryPrinter scenario_result = ScenarioResult.from_dict(result_dict) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index ab467c31a7..f3a20abdd4 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -55,7 +55,6 @@ SeedSimulatedConversation, SeedType, ) -from pyrit.models.scenario_result import ScenarioRunState logger = logging.getLogger(__name__) diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index b5ffe6be7b..12da83e36b 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -21,12 +21,6 @@ from typing import TYPE_CHECKING, Any from pyrit.common.deprecation import print_deprecation_message -from pyrit.models.chat_message import ( - ALLOWED_CHAT_MESSAGE_ROLES, - ChatMessage, - ChatMessagesDataset, -) -from pyrit.models.conversation_reference import ConversationReference, ConversationType from pyrit.models.conversation_stats import ConversationStats from pyrit.models.embeddings import EmbeddingData, EmbeddingResponse, EmbeddingSupport, EmbeddingUsageInformation from pyrit.models.harm_definition import HarmDefinition, ScaleDescription, get_all_harm_definitions @@ -79,11 +73,18 @@ group_message_pieces_into_conversations, sort_message_pieces, ) +from pyrit.models.messages.chat_message import ( + ALLOWED_CHAT_MESSAGE_ROLES, + ChatMessage, + ChatMessagesDataset, + ToolCall, +) +from pyrit.models.messages.conversation_reference import ConversationReference, ConversationType from pyrit.models.question_answering import QuestionAnsweringDataset, QuestionAnsweringEntry, QuestionChoice from pyrit.models.results.attack_result import AttackOutcome, AttackResult, AttackResultT +from pyrit.models.results.scenario_result import ScenarioIdentifier, ScenarioResult, ScenarioRunState from pyrit.models.results.strategy_result import StrategyResult, StrategyResultT from pyrit.models.retry_event import RetryEvent -from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult, ScenarioRunState from pyrit.models.score import Score, ScoreType, UnvalidatedScore # Seeds - import from new seeds submodule for forward compatibility @@ -192,6 +193,7 @@ "TARGET_EVAL_PARAM_FALLBACKS", "TARGET_EVAL_PARAMS", "TextDataTypeSerializer", + "ToolCall", "UnvalidatedScore", "validate_registry_name", "VideoPathDataTypeSerializer", diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py deleted file mode 100644 index 19ab95a058..0000000000 --- a/pyrit/models/attack_result.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Backward-compatibility shim. - -``AttackResult`` and ``AttackOutcome`` now live in ``pyrit.models.results``. -Import from there (or from ``pyrit.models``) instead. This module re-exports the -public names so existing ``from pyrit.models.attack_result import ...`` imports -keep working. -""" - -from typing import Any - -from pyrit.models.results import attack_result as _attack_result -from pyrit.models.results.attack_result import AttackOutcome, AttackResult, AttackResultT - - -def __getattr__(name: str) -> Any: - return getattr(_attack_result, name) - - -__all__ = ["AttackOutcome", "AttackResult", "AttackResultT"] diff --git a/pyrit/models/message.py b/pyrit/models/message.py deleted file mode 100644 index 27936dc04a..0000000000 --- a/pyrit/models/message.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Backward-compatibility shim. - -``Message`` and the conversation helpers now live in ``pyrit.models.messages``. -Import from there (or from ``pyrit.models``) instead. This module re-exports the -public names so existing ``from pyrit.models.message import ...`` imports keep -working. -""" - -from typing import Any - -from pyrit.models.messages import message as _message -from pyrit.models.messages.conversations import ( - construct_response_from_request, - flatten_to_message_pieces, - get_all_values, - group_conversation_message_pieces_by_sequence, - group_message_pieces_into_conversations, -) -from pyrit.models.messages.message import Message - - -def __getattr__(name: str) -> Any: - return getattr(_message, name) - - -__all__ = [ - "Message", - "construct_response_from_request", - "flatten_to_message_pieces", - "get_all_values", - "group_conversation_message_pieces_by_sequence", - "group_message_pieces_into_conversations", -] diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py deleted file mode 100644 index b5d92f3036..0000000000 --- a/pyrit/models/message_piece.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Backward-compatibility shim. - -``MessagePiece`` now lives in ``pyrit.models.messages``. Import from there (or -from ``pyrit.models``) instead. This module re-exports the public names so -existing ``from pyrit.models.message_piece import ...`` imports keep working. -""" - -from typing import Any - -from pyrit.models.messages import message_piece as _message_piece -from pyrit.models.messages.message_piece import MessagePiece, sort_message_pieces - - -def __getattr__(name: str) -> Any: - return getattr(_message_piece, name) - - -__all__ = ["MessagePiece", "sort_message_pieces"] diff --git a/pyrit/models/messages/__init__.py b/pyrit/models/messages/__init__.py index 58c9f1a63e..c18b2a0964 100644 --- a/pyrit/models/messages/__init__.py +++ b/pyrit/models/messages/__init__.py @@ -6,11 +6,21 @@ - MessagePiece: A single piece of a message exchanged with a target. - Message: One request/response to a target, made up of one or more pieces. +- ChatMessage: OpenAI-style wire shape consumed/emitted by prompt targets. +- Conversation: Conversation-scoped metadata shared by every piece. +- ConversationReference: Immutable reference to a conversation in an attack. - conversations: Free functions that operate on collections of messages/pieces. """ -from pyrit.models.messages.conversation import Conversation +from pyrit.models.messages.chat_message import ( + ALLOWED_CHAT_MESSAGE_ROLES, + ChatMessage, + ChatMessagesDataset, + ToolCall, +) +from pyrit.models.messages.conversation_reference import ConversationReference, ConversationType from pyrit.models.messages.conversations import ( + Conversation, construct_response_from_request, flatten_to_message_pieces, get_all_values, @@ -21,9 +31,15 @@ from pyrit.models.messages.message_piece import MessagePiece, sort_message_pieces __all__ = [ + "ALLOWED_CHAT_MESSAGE_ROLES", + "ChatMessage", + "ChatMessagesDataset", "Conversation", + "ConversationReference", + "ConversationType", "Message", "MessagePiece", + "ToolCall", "construct_response_from_request", "flatten_to_message_pieces", "get_all_values", diff --git a/pyrit/models/chat_message.py b/pyrit/models/messages/chat_message.py similarity index 65% rename from pyrit/models/chat_message.py rename to pyrit/models/messages/chat_message.py index 99939704a5..57e2b0dbab 100644 --- a/pyrit/models/chat_message.py +++ b/pyrit/models/messages/chat_message.py @@ -1,6 +1,20 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +""" +OpenAI-format chat message types. + +``ChatMessage`` is the OpenAI Chat Completions wire shape — a ``role`` plus a +string-or-multipart ``content``, with the OpenAI ``name`` / ``tool_calls`` / +``tool_call_id`` fields. Prompt targets that speak the OpenAI API (and the many +providers that mirror it) consume and emit these objects directly. + +It is intentionally distinct from the PyRIT domain ``Message`` / ``MessagePiece`` +types in this same package: those model a persisted request/response exchange, +whereas ``ChatMessage`` is the lightweight OpenAI-shaped transport representation +handed to a model API. +""" + from typing import Any from pydantic import BaseModel, ConfigDict @@ -21,9 +35,9 @@ class ToolCall(BaseModel): class ChatMessage(BaseModel): """ - Represents a chat message for API consumption. + Represents a single OpenAI Chat Completions message. - The content field can be: + Mirrors the OpenAI message schema. The content field can be: - A simple string for single-part text messages - A list of dicts for multipart messages (e.g., text + images) """ diff --git a/pyrit/models/messages/conversation.py b/pyrit/models/messages/conversation.py deleted file mode 100644 index f5b8d956fb..0000000000 --- a/pyrit/models/messages/conversation.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from __future__ import annotations - -from pydantic import BaseModel, ConfigDict - -from pyrit.models.score import ( # noqa: TC001 (runtime-required by Pydantic field annotations) - ComponentIdentifierField, -) - - -class Conversation(BaseModel): - """ - Conversation-scoped metadata shared by every piece in a conversation. - - A ``Conversation`` records identifiers that belong to the conversation as a - whole rather than to any individual ``MessagePiece`` -- most importantly the - target the conversation is held with. Persisting these once per conversation - (instead of stamping them onto every piece/row) is what keeps ``MessagePiece`` - small. - """ - - model_config = ConfigDict( - arbitrary_types_allowed=True, - extra="forbid", - validate_assignment=False, - ) - - conversation_id: str - target_identifier: ComponentIdentifierField | None = None diff --git a/pyrit/models/conversation_reference.py b/pyrit/models/messages/conversation_reference.py similarity index 100% rename from pyrit/models/conversation_reference.py rename to pyrit/models/messages/conversation_reference.py diff --git a/pyrit/models/messages/conversations.py b/pyrit/models/messages/conversations.py index 9e829a28c8..d132679bb1 100644 --- a/pyrit/models/messages/conversations.py +++ b/pyrit/models/messages/conversations.py @@ -1,14 +1,19 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -"""Helpers that operate on collections of ``Message`` / ``MessagePiece``.""" +"""``Conversation`` model plus helpers that operate on collections of ``Message`` / ``MessagePiece``.""" from __future__ import annotations from typing import TYPE_CHECKING +from pydantic import BaseModel, ConfigDict + from pyrit.models.messages.message import Message from pyrit.models.messages.message_piece import MessagePiece +from pyrit.models.score import ( # noqa: TC001 (runtime-required by Pydantic field annotations) + ComponentIdentifierField, +) if TYPE_CHECKING: from collections.abc import MutableSequence, Sequence @@ -16,6 +21,27 @@ from pyrit.models.literals import PromptDataType, PromptResponseError +class Conversation(BaseModel): + """ + Conversation-scoped metadata shared by every piece in a conversation. + + A ``Conversation`` records identifiers that belong to the conversation as a + whole rather than to any individual ``MessagePiece`` -- most importantly the + target the conversation is held with. Persisting these once per conversation + (instead of stamping them onto every piece/row) is what keeps ``MessagePiece`` + small. + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + validate_assignment=False, + ) + + conversation_id: str + target_identifier: ComponentIdentifierField | None = None + + def get_all_values(messages: Sequence[Message]) -> list[str]: """ Return all converted values across the provided messages. diff --git a/pyrit/models/results/__init__.py b/pyrit/models/results/__init__.py index 4bcc2f8848..b57cb1ef37 100644 --- a/pyrit/models/results/__init__.py +++ b/pyrit/models/results/__init__.py @@ -2,20 +2,31 @@ # Licensed under the MIT license. """ -Results module - strategy and attack result types for PyRIT. +Results module - strategy, attack, and scenario result types for PyRIT. - StrategyResult: Base class for all strategy results. - AttackResult: Result of an attack execution, with conversation/scoring evidence. - AttackOutcome: Enum of possible attack outcomes. +- ScenarioResult: Aggregate result of a scenario run. +- ScenarioIdentifier: Identifier describing the executed scenario. +- ScenarioRunState: Lifecycle state of a scenario run. """ from pyrit.models.results.attack_result import AttackOutcome, AttackResult, AttackResultT +from pyrit.models.results.scenario_result import ( + ScenarioIdentifier, + ScenarioResult, + ScenarioRunState, +) from pyrit.models.results.strategy_result import StrategyResult, StrategyResultT __all__ = [ "AttackOutcome", "AttackResult", "AttackResultT", + "ScenarioIdentifier", + "ScenarioResult", + "ScenarioRunState", "StrategyResult", "StrategyResultT", ] diff --git a/pyrit/models/results/attack_result.py b/pyrit/models/results/attack_result.py index 648c837214..138d3dd38e 100644 --- a/pyrit/models/results/attack_result.py +++ b/pyrit/models/results/attack_result.py @@ -11,8 +11,8 @@ from pydantic import AwareDatetime, Field from pyrit.common.deprecation import print_deprecation_message -from pyrit.models.conversation_reference import ConversationReference, ConversationType from pyrit.models.identifiers.component_identifier import ComponentIdentifier +from pyrit.models.messages.conversation_reference import ConversationReference, ConversationType from pyrit.models.messages.message_piece import MessagePiece from pyrit.models.results.strategy_result import StrategyResult from pyrit.models.retry_event import RetryEvent diff --git a/pyrit/models/scenario_result.py b/pyrit/models/results/scenario_result.py similarity index 100% rename from pyrit/models/scenario_result.py rename to pyrit/models/results/scenario_result.py diff --git a/pyrit/models/strategy_result.py b/pyrit/models/strategy_result.py deleted file mode 100644 index bd4367695b..0000000000 --- a/pyrit/models/strategy_result.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Backward-compatibility shim. - -``StrategyResult`` now lives in ``pyrit.models.results``. Import from there (or -from ``pyrit.models``) instead. This module re-exports the public names so -existing ``from pyrit.models.strategy_result import ...`` imports keep working. -""" - -from typing import Any - -from pyrit.models.results import strategy_result as _strategy_result -from pyrit.models.results.strategy_result import StrategyResult, StrategyResultT - - -def __getattr__(name: str) -> Any: - return getattr(_strategy_result, name) - - -__all__ = ["StrategyResult", "StrategyResultT"] diff --git a/pyrit/output/helpers.py b/pyrit/output/helpers.py index 87923d5862..c1a4b21b15 100644 --- a/pyrit/output/helpers.py +++ b/pyrit/output/helpers.py @@ -11,8 +11,7 @@ import os -from pyrit.models import AttackResult, ComponentIdentifier, Message, Score -from pyrit.models.scenario_result import ScenarioResult +from pyrit.models import AttackResult, ComponentIdentifier, Message, ScenarioResult, Score from pyrit.output.attack_result.markdown import MarkdownAttackResultMemoryPrinter from pyrit.output.attack_result.pretty import PrettyAttackResultMemoryPrinter from pyrit.output.conversation.pretty import PrettyConversationMemoryPrinter diff --git a/pyrit/output/scenario_result/base.py b/pyrit/output/scenario_result/base.py index 579d480acc..13972d9ac5 100644 --- a/pyrit/output/scenario_result/base.py +++ b/pyrit/output/scenario_result/base.py @@ -4,7 +4,7 @@ from abc import abstractmethod from pyrit.common.deprecation import print_deprecation_message -from pyrit.models.scenario_result import ScenarioResult +from pyrit.models import ScenarioResult from pyrit.output.base import PrinterBase diff --git a/pyrit/output/scenario_result/pretty.py b/pyrit/output/scenario_result/pretty.py index 5abbc807ec..d8654c0bfd 100644 --- a/pyrit/output/scenario_result/pretty.py +++ b/pyrit/output/scenario_result/pretty.py @@ -6,8 +6,7 @@ from colorama import Fore, Style from pyrit.common.deprecation import print_deprecation_message -from pyrit.models import AttackOutcome -from pyrit.models.scenario_result import ScenarioResult +from pyrit.models import AttackOutcome, ScenarioResult from pyrit.output.scenario_result.base import ScenarioResultPrinterBase from pyrit.output.scorer.base import ScorerPrinterBase from pyrit.output.sink import Sink diff --git a/pyrit/scenario/__init__.py b/pyrit/scenario/__init__.py index a93f3098c1..2ef485abdf 100644 --- a/pyrit/scenario/__init__.py +++ b/pyrit/scenario/__init__.py @@ -19,7 +19,7 @@ from types import ModuleType from pyrit.common.parameter import Parameter -from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult +from pyrit.models import ScenarioIdentifier, ScenarioResult from pyrit.scenario.core import ( AtomicAttack, AttackTechnique, diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 4611128eda..dce77c3bbd 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -37,8 +37,14 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.memory import CentralMemory from pyrit.memory.memory_models import ScenarioResultEntry -from pyrit.models import AttackOutcome, AttackResult, SeedAttackGroup -from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult, ScenarioRunState +from pyrit.models import ( + AttackOutcome, + AttackResult, + ScenarioIdentifier, + ScenarioResult, + ScenarioRunState, + SeedAttackGroup, +) from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.registry import ScorerRegistry diff --git a/tests/partner_integration/azure_ai_evaluation/test_foundry_scenario_contract.py b/tests/partner_integration/azure_ai_evaluation/test_foundry_scenario_contract.py index 3e806f03ef..4ff4d77b57 100644 --- a/tests/partner_integration/azure_ai_evaluation/test_foundry_scenario_contract.py +++ b/tests/partner_integration/azure_ai_evaluation/test_foundry_scenario_contract.py @@ -57,7 +57,7 @@ class TestScenarioResultContract: def test_scenario_result_importable(self): """ScenarioOrchestrator reads ScenarioResult.""" - from pyrit.models.scenario_result import ScenarioResult + from pyrit.models import ScenarioResult assert ScenarioResult is not None diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index 2749c6fd67..5f7bee1809 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -1757,7 +1757,7 @@ async def test_conversation_summary_formats_media_preview(self, attack_service, async def test_returns_main_and_related_conversations(self, attack_service, mock_memory): """Should return main and PRUNED conversations sorted by timestamp.""" - from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models import ConversationReference, ConversationType ar = make_attack_result(conversation_id="attack-1") ar.related_conversations.add( @@ -1910,7 +1910,7 @@ async def test_raises_when_conversation_not_part_of_attack(self, attack_service, async def test_swaps_main_conversation(self, attack_service, mock_memory): """Changing the main to a related conversation should swap it with the main.""" - from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models import ConversationReference, ConversationType ar = make_attack_result(conversation_id="attack-1") ar.related_conversations = { @@ -1950,7 +1950,7 @@ class TestAddMessageTargetConversation: async def test_stores_message_in_target_conversation(self, attack_service, mock_memory): """When target_conversation_id is set, messages should go to that conversation.""" from pyrit.backend.models.attacks import AttackSummary, ConversationMessagesResponse - from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models import ConversationReference, ConversationType ar = make_attack_result(conversation_id="attack-1") ar.related_conversations = { @@ -2017,7 +2017,7 @@ class TestConversationCount: async def test_list_attacks_includes_related_conversation_ids(self, attack_service, mock_memory): """Attacks with related conversations should expose them in the summary.""" - from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models import ConversationReference, ConversationType ar = make_attack_result(conversation_id="attack-1") ar.related_conversations = { @@ -2069,7 +2069,7 @@ async def test_create_conversation_increments_count(self, attack_service, mock_m async def test_create_second_conversation_preserves_first(self, attack_service, mock_memory): """Creating a second related conversation should keep the first one.""" from pyrit.backend.models.attacks import CreateConversationRequest - from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models import ConversationReference, ConversationType ar = make_attack_result(conversation_id="attack-1") ar.related_conversations = { @@ -2099,7 +2099,7 @@ class TestConversationSorting: async def test_conversations_sorted_by_created_at_earliest_first(self, attack_service, mock_memory): """Conversations should be sorted by created_at with earliest first.""" - from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models import ConversationReference, ConversationType ar = make_attack_result(conversation_id="attack-1") ar.related_conversations = { @@ -2127,7 +2127,7 @@ async def test_conversations_sorted_by_created_at_earliest_first(self, attack_se async def test_empty_conversations_sorted_last(self, attack_service, mock_memory): """Conversations with no timestamp should appear at the bottom.""" - from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models import ConversationReference, ConversationType ar = make_attack_result(conversation_id="attack-1") ar.related_conversations = { @@ -2153,7 +2153,7 @@ async def test_empty_conversations_sorted_last(self, attack_service, mock_memory async def test_empty_conversations_all_sort_last(self, attack_service, mock_memory): """Multiple empty conversations should all have created_at=None.""" - from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models import ConversationReference, ConversationType ar = make_attack_result(conversation_id="attack-1") ar.related_conversations = { diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index 9170f031df..1aba27d4b5 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -322,7 +322,7 @@ async def test_no_converters_returns_empty_list(self) -> None: async def test_related_conversation_ids_from_related_conversations(self) -> None: """Test that related_conversation_ids includes all related conversation IDs.""" - from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models import ConversationReference, ConversationType ar = _make_attack_result() ar.related_conversations = { diff --git a/tests/unit/backend/test_response_contracts.py b/tests/unit/backend/test_response_contracts.py index bb7f403df3..15ff045b6a 100644 --- a/tests/unit/backend/test_response_contracts.py +++ b/tests/unit/backend/test_response_contracts.py @@ -24,12 +24,13 @@ from pyrit.models import ( AttackResult, ComponentIdentifier, + ConversationReference, + ConversationType, MessagePiece, RetryEvent, Score, build_atomic_attack_identifier, ) -from pyrit.models.conversation_reference import ConversationReference, ConversationType def _make_score() -> Score: diff --git a/tests/unit/cli/test_output.py b/tests/unit/cli/test_output.py index 26f41cccc2..b1c67cbbc6 100644 --- a/tests/unit/cli/test_output.py +++ b/tests/unit/cli/test_output.py @@ -319,7 +319,7 @@ async def test_print_scenario_result_async_uses_pretty_printer(): fake_printer.write_async = AsyncMock() with ( - patch("pyrit.models.scenario_result.ScenarioResult.from_dict", return_value=fake_scenario) as from_dict_mock, + patch("pyrit.models.ScenarioResult.from_dict", return_value=fake_scenario) as from_dict_mock, patch( "pyrit.output.scenario_result.pretty.PrettyScenarioResultMemoryPrinter", return_value=fake_printer ) as printer_cls, @@ -339,8 +339,7 @@ async def test_print_scenario_result_async_roundtrip_with_real_payload(): """ from datetime import datetime, timezone - from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier - from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult + from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier, ScenarioIdentifier, ScenarioResult identifier = ScenarioIdentifier(name="test.scenario", description="A test") target_identifier = ComponentIdentifier.from_dict( diff --git a/tests/unit/models/test_attack_result.py b/tests/unit/models/test_attack_result.py index 43e7522052..5b19316dfb 100644 --- a/tests/unit/models/test_attack_result.py +++ b/tests/unit/models/test_attack_result.py @@ -8,10 +8,8 @@ import pytest from pyrit.memory.memory_models import AttackResultEntry -from pyrit.models import ComponentIdentifier -from pyrit.models.conversation_reference import ConversationReference, ConversationType +from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier, ConversationReference, ConversationType from pyrit.models.messages.message_piece import MessagePiece -from pyrit.models.results.attack_result import AttackOutcome, AttackResult from pyrit.models.retry_event import RetryEvent from pyrit.models.score import Score @@ -393,24 +391,3 @@ def test_duplicate_preserves_subclass_type(self) -> None: assert copy.backtrack_count == 3 copy.backtrack_count = 9 assert original.backtrack_count == 3 - - -class TestAttackResultShim: - """The relocated module must be importable from the legacy path silently.""" - - def test_shim_reexports_same_classes_silently(self) -> None: - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - from pyrit.models.attack_result import AttackOutcome as ShimOutcome - from pyrit.models.attack_result import AttackResult as ShimResult - - assert ShimResult is AttackResult - assert ShimOutcome is AttackOutcome - deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] - assert len(deprecation_warnings) == 0, "Shim import must be silent" - - def test_shim_getattr_reexports_dynamic_names(self) -> None: - """The module __getattr__ falls through to the relocated module.""" - import pyrit.models.attack_result as shim - - assert shim.AttackResultT is not None diff --git a/tests/unit/models/test_chat_message.py b/tests/unit/models/test_chat_message.py index c3475285b7..8fe206cf7e 100644 --- a/tests/unit/models/test_chat_message.py +++ b/tests/unit/models/test_chat_message.py @@ -6,11 +6,7 @@ import pytest from pydantic import ValidationError -from pyrit.models.chat_message import ( - ChatMessage, - ChatMessagesDataset, - ToolCall, -) +from pyrit.models import ChatMessage, ChatMessagesDataset, ToolCall def test_tool_call_init(): diff --git a/tests/unit/models/test_conversation.py b/tests/unit/models/test_conversation.py new file mode 100644 index 0000000000..b2576cd2e9 --- /dev/null +++ b/tests/unit/models/test_conversation.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest +from pydantic import ValidationError + +from pyrit.models import ComponentIdentifier, Conversation + + +def test_init_requires_conversation_id(): + with pytest.raises(ValidationError): + Conversation() # type: ignore[call-arg] + + +def test_init_defaults_target_identifier_to_none(): + conversation = Conversation(conversation_id="conv-1") + assert conversation.conversation_id == "conv-1" + assert conversation.target_identifier is None + + +def test_init_forbids_extra_fields(): + with pytest.raises(ValidationError): + Conversation(conversation_id="conv-1", unexpected="value") # type: ignore[call-arg] + + +def test_init_accepts_component_identifier(): + identifier = ComponentIdentifier(class_name="OpenAIChatTarget", class_module="pyrit.prompt_target") + conversation = Conversation(conversation_id="conv-1", target_identifier=identifier) + assert conversation.target_identifier == identifier + + +def test_target_identifier_accepts_flat_dict(): + identifier = ComponentIdentifier(class_name="OpenAIChatTarget", class_module="pyrit.prompt_target") + conversation = Conversation(conversation_id="conv-1", target_identifier=identifier.model_dump()) + assert isinstance(conversation.target_identifier, ComponentIdentifier) + assert conversation.target_identifier.class_name == "OpenAIChatTarget" + + +def test_model_dump_serializes_target_identifier_to_flat_dict(): + identifier = ComponentIdentifier(class_name="OpenAIChatTarget", class_module="pyrit.prompt_target") + conversation = Conversation(conversation_id="conv-1", target_identifier=identifier) + + dumped = conversation.model_dump() + + assert dumped["conversation_id"] == "conv-1" + assert dumped["target_identifier"]["class_name"] == "OpenAIChatTarget" + assert dumped["target_identifier"]["class_module"] == "pyrit.prompt_target" + + +def test_model_dump_with_no_target_identifier(): + conversation = Conversation(conversation_id="conv-1") + assert conversation.model_dump()["target_identifier"] is None + + +def test_round_trips_through_model_validate(): + identifier = ComponentIdentifier(class_name="OpenAIChatTarget", class_module="pyrit.prompt_target") + conversation = Conversation(conversation_id="conv-1", target_identifier=identifier) + + restored = Conversation.model_validate(conversation.model_dump()) + + assert restored.conversation_id == "conv-1" + assert restored.target_identifier == identifier diff --git a/tests/unit/models/test_conversation_reference.py b/tests/unit/models/test_conversation_reference.py index de6263cd95..b229bc27eb 100644 --- a/tests/unit/models/test_conversation_reference.py +++ b/tests/unit/models/test_conversation_reference.py @@ -4,7 +4,7 @@ import pytest from pydantic import ValidationError -from pyrit.models.conversation_reference import ConversationReference, ConversationType +from pyrit.models import ConversationReference, ConversationType def test_conversation_type_values(): diff --git a/tests/unit/models/test_message.py b/tests/unit/models/test_message.py index 27468734f9..0d7913d4d3 100644 --- a/tests/unit/models/test_message.py +++ b/tests/unit/models/test_message.py @@ -462,12 +462,3 @@ def test_conversation_helpers_live_in_conversations_module(self) -> None: "construct_response_from_request", ): assert getattr(conversations, name) is getattr(messages, name) - - def test_legacy_module_paths_reexport_same_objects(self) -> None: - import pyrit.models.message as legacy_message - import pyrit.models.message_piece as legacy_message_piece - from pyrit.models.messages.message import Message as PackagedMessage - from pyrit.models.messages.message_piece import MessagePiece as PackagedMessagePiece - - assert legacy_message.Message is PackagedMessage - assert legacy_message_piece.MessagePiece is PackagedMessagePiece diff --git a/tests/unit/models/test_scenario_result.py b/tests/unit/models/test_scenario_result.py index e15abc53e7..ab2da79e8c 100644 --- a/tests/unit/models/test_scenario_result.py +++ b/tests/unit/models/test_scenario_result.py @@ -7,11 +7,15 @@ import pytest import pyrit -from pyrit.models import ComponentIdentifier -from pyrit.models.conversation_reference import ConversationReference, ConversationType +from pyrit.models import ( + ComponentIdentifier, + ConversationReference, + ConversationType, + ScenarioIdentifier, + ScenarioResult, +) from pyrit.models.results.attack_result import AttackOutcome, AttackResult from pyrit.models.retry_event import RetryEvent -from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult def _make_scenario_identifier(**kwargs): diff --git a/tests/unit/models/test_strategy_result.py b/tests/unit/models/test_strategy_result.py index 07508abef7..63d5a66c04 100644 --- a/tests/unit/models/test_strategy_result.py +++ b/tests/unit/models/test_strategy_result.py @@ -1,11 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import warnings - import pytest -from pyrit.models.results.strategy_result import StrategyResult +from pyrit.models import StrategyResult class ConcreteResult(StrategyResult): @@ -39,14 +37,3 @@ def test_strategy_result_duplicate_preserves_type(): def test_strategy_result_forbids_extra_fields(): with pytest.raises(ValueError): ConcreteResult(value="hello", count=1, unexpected="boom") - - -def test_strategy_result_shim_reexports_same_class_silently(): - """The old import path must re-export the identical class without warning.""" - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - from pyrit.models.strategy_result import StrategyResult as ShimStrategyResult - - assert ShimStrategyResult is StrategyResult - deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] - assert len(deprecation_warnings) == 0, "Shim import must be silent" diff --git a/tests/unit/output/attack_result/test_markdown.py b/tests/unit/output/attack_result/test_markdown.py index b61081867a..7ce0c8b265 100644 --- a/tests/unit/output/attack_result/test_markdown.py +++ b/tests/unit/output/attack_result/test_markdown.py @@ -11,13 +11,13 @@ AttackOutcome, AttackResult, ComponentIdentifier, + ConversationReference, ConversationType, Message, MessagePiece, Score, build_atomic_attack_identifier, ) -from pyrit.models.conversation_reference import ConversationReference from pyrit.output.attack_result.markdown import MarkdownAttackResultMemoryPrinter diff --git a/tests/unit/output/attack_result/test_pretty.py b/tests/unit/output/attack_result/test_pretty.py index dd7d02c6d7..7b9e79737d 100644 --- a/tests/unit/output/attack_result/test_pretty.py +++ b/tests/unit/output/attack_result/test_pretty.py @@ -10,13 +10,13 @@ AttackOutcome, AttackResult, ComponentIdentifier, + ConversationReference, ConversationType, Message, MessagePiece, Score, build_atomic_attack_identifier, ) -from pyrit.models.conversation_reference import ConversationReference from pyrit.output.attack_result.pretty import PrettyAttackResultMemoryPrinter diff --git a/tests/unit/output/scenario_result/test_base.py b/tests/unit/output/scenario_result/test_base.py index fe64b39b2d..0d3bb06413 100644 --- a/tests/unit/output/scenario_result/test_base.py +++ b/tests/unit/output/scenario_result/test_base.py @@ -5,7 +5,7 @@ import pytest -from pyrit.models.scenario_result import ScenarioResult +from pyrit.models import ScenarioResult from pyrit.output.scenario_result.base import ScenarioResultPrinterBase diff --git a/tests/unit/output/scenario_result/test_pretty.py b/tests/unit/output/scenario_result/test_pretty.py index b2f8cced9c..f1ba89c431 100644 --- a/tests/unit/output/scenario_result/test_pretty.py +++ b/tests/unit/output/scenario_result/test_pretty.py @@ -5,8 +5,7 @@ import pytest -from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier -from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult +from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier, ScenarioIdentifier, ScenarioResult from pyrit.output.scenario_result.pretty import PrettyScenarioResultMemoryPrinter diff --git a/tests/unit/scenario/core/test_scenario_parameters.py b/tests/unit/scenario/core/test_scenario_parameters.py index 4a013f4365..17fee40f67 100644 --- a/tests/unit/scenario/core/test_scenario_parameters.py +++ b/tests/unit/scenario/core/test_scenario_parameters.py @@ -419,7 +419,7 @@ class TestResumeParameterValidation: @staticmethod def _make_stored_result(*, scenario_name: str, version: int, init_data): """Build a minimal ScenarioResult with a controlled identifier for resume tests.""" - from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult + from pyrit.models import ScenarioIdentifier, ScenarioResult identifier = ScenarioIdentifier( name=scenario_name,