diff --git a/doc/code/memory/3_memory_data_types.md b/doc/code/memory/3_memory_data_types.md index 71e73a0bf0..607960bb3c 100644 --- a/doc/code/memory/3_memory_data_types.md +++ b/doc/code/memory/3_memory_data_types.md @@ -150,7 +150,7 @@ Identifiers are content-addressed: the same configuration always produces the sa ### Composite Identifiers -For atomic attacks, `build_atomic_attack_identifier` composes a tree of identifiers: +For atomic attacks, `AtomicAttackIdentifier.build` composes a tree of identifiers: - **`attack_technique`** — the attack strategy and its children (target, converters, scorer, technique seeds) - **`seed_identifiers`** — all seeds from the seed group, for traceability diff --git a/doc/code/targets/0_prompt_targets.md b/doc/code/targets/0_prompt_targets.md index 74c60b737c..058a88b3dd 100644 --- a/doc/code/targets/0_prompt_targets.md +++ b/doc/code/targets/0_prompt_targets.md @@ -64,7 +64,7 @@ A `TargetConfiguration` composes three concerns: Each target class defines defaults; instances can override individual capabilities when they depend on deployment configuration (e.g. `HTTPTarget`, `PlaywrightTarget`). -For well-known underlying models, you can look up a profile with `TargetCapabilities.get_known_capabilities(underlying_model="gpt-4o")`. +For well-known underlying models, you can look up a profile with `get_known_capabilities(underlying_model="gpt-4o")` from `pyrit.prompt_target`. ### How consumers use capabilities diff --git a/doc/code/targets/6_1_target_capabilities.ipynb b/doc/code/targets/6_1_target_capabilities.ipynb index 596971a200..70b6e319d1 100644 --- a/doc/code/targets/6_1_target_capabilities.ipynb +++ b/doc/code/targets/6_1_target_capabilities.ipynb @@ -92,7 +92,7 @@ "\n", "Each target class declares a `_DEFAULT_CONFIGURATION` class attribute. For well-known underlying models,\n", "`get_default_configuration(underlying_model=...)` returns a richer profile from\n", - "`TargetCapabilities.get_known_capabilities` — for example, `gpt-5` gains `supports_json_schema=True`\n", + "`get_known_capabilities` — for example, `gpt-5` gains `supports_json_schema=True`\n", "and other models pick up the right modality combinations automatically. Unknown models fall back to\n", "the class default." ] diff --git a/doc/code/targets/6_1_target_capabilities.py b/doc/code/targets/6_1_target_capabilities.py index 23ab296366..ee2ceec82f 100644 --- a/doc/code/targets/6_1_target_capabilities.py +++ b/doc/code/targets/6_1_target_capabilities.py @@ -55,7 +55,7 @@ # # Each target class declares a `_DEFAULT_CONFIGURATION` class attribute. For well-known underlying models, # `get_default_configuration(underlying_model=...)` returns a richer profile from -# `TargetCapabilities.get_known_capabilities` — for example, `gpt-5` gains `supports_json_schema=True` +# `get_known_capabilities` — for example, `gpt-5` gains `supports_json_schema=True` # and other models pick up the right modality combinations automatically. Unknown models fall back to # the class default. diff --git a/pyrit/backend/mappers/converter_mappers.py b/pyrit/backend/mappers/converter_mappers.py index a71b5aa537..ce5f3190cf 100644 --- a/pyrit/backend/mappers/converter_mappers.py +++ b/pyrit/backend/mappers/converter_mappers.py @@ -3,17 +3,18 @@ """ Converter mappers – domain → DTO translation for converter-related models. + +Identity vs. presentation: +``ConverterIdentifier`` is the typed, lossless +*identity* projection of a converter's ``ComponentIdentifier``; +``ConverterInstance`` is the backend *presentation* view (adds ``converter_id`` +binding, ``display_name``, and ``sub_converter_ids``). """ -from pyrit.backend.models.converters import ConverterInstance +from pyrit.backend.models import ConverterInstance +from pyrit.models import ConverterIdentifier from pyrit.prompt_converter import PromptConverter -# Base keys from PromptConverter._create_identifier that are NOT converter-specific -_BASE_CONVERTER_PARAM_KEYS = { - "supported_input_types", - "supported_output_types", -} - def converter_object_to_instance( converter_id: str, @@ -35,17 +36,19 @@ def converter_object_to_instance( Returns: ConverterInstance DTO with metadata derived from the object. """ - identifier = converter_obj.get_identifier() + converter_identifier = ConverterIdentifier.from_component_identifier(converter_obj.get_identifier()) - supported_input = identifier.params.get("supported_input_types") - supported_output = identifier.params.get("supported_output_types") + supported_input = converter_identifier.supported_input_types + supported_output = converter_identifier.supported_output_types - # Extract converter-specific params by filtering out base keys - converter_specific = {k: v for k, v in identifier.params.items() if k not in _BASE_CONVERTER_PARAM_KEYS} or None + # supported_input/output_types are promoted to typed fields and mirrored into + # params; strip them so only converter-specific params remain. + promoted_param_names = set(ConverterIdentifier._promoted_param_fields()) + converter_specific = {k: v for k, v in converter_identifier.params.items() if k not in promoted_param_names} or None return ConverterInstance( converter_id=converter_id, - converter_type=identifier.class_name, + converter_type=converter_identifier.class_name, display_name=None, supported_input_types=list(supported_input) if supported_input else [], supported_output_types=list(supported_output) if supported_output else [], diff --git a/pyrit/backend/mappers/target_mappers.py b/pyrit/backend/mappers/target_mappers.py index b7f715aca0..07bb05e629 100644 --- a/pyrit/backend/mappers/target_mappers.py +++ b/pyrit/backend/mappers/target_mappers.py @@ -3,9 +3,17 @@ """ Target mappers – domain → DTO translation for target-related models. + +Identity vs. presentation: ``TargetIdentifier`` +is the typed, lossless *identity* projection of a target's +``ComponentIdentifier``. ``TargetInstance`` is the backend *presentation* view — +it adds registry binding (``target_registry_name``), flattened capabilities, and +composite ``inner_targets`` for the frontend. These mappers read typed fields off +``TargetIdentifier`` instead of poking ``identifier.params`` by string key. """ -from pyrit.backend.models.targets import TargetCapabilitiesInfo, TargetInstance +from pyrit.backend.models import TargetCapabilitiesInfo, TargetInstance +from pyrit.models import TargetIdentifier from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.target_capabilities import CapabilityName, TargetCapabilities from pyrit.prompt_target.round_robin_target import RoundRobinTarget @@ -15,7 +23,9 @@ _CAPABILITY_PARAM_NAMES = frozenset(cap.value for cap in CapabilityName) -def _target_capabilities_to_info(capabilities: TargetCapabilities) -> TargetCapabilitiesInfo: +def _target_capabilities_to_info( + capabilities: TargetCapabilities, +) -> TargetCapabilitiesInfo: """ Build a TargetCapabilitiesInfo DTO from a domain TargetCapabilities object. @@ -55,27 +65,27 @@ def target_object_to_instance(target_registry_name: str, target_obj: PromptTarge Returns: TargetInstance DTO with metadata derived from the object. """ - identifier = target_obj.get_identifier() - params = identifier.params - - # Keys that are extracted as top-level TargetInstance fields, are internal-only - # (e.g., target_configuration is the verbose capabilities blob), or duplicate - # capability flags (filtered via _CAPABILITY_PARAM_NAMES) — those are sourced - # solely from target_obj.capabilities and must not leak into target_specific_params. - extracted_keys = { - "endpoint", - "model_name", - "underlying_model_name", - "temperature", - "top_p", - "max_requests_per_minute", - "target_specific_params", - "target_configuration", - } | _CAPABILITY_PARAM_NAMES + target_identifier = TargetIdentifier.from_component_identifier(target_obj.get_identifier()) + + # Promoted params (endpoint, model_name, …) are mirrored into params and also + # exposed as typed fields; strip them so they don't leak into + # target_specific_params. Capabilities are no longer part of the identifier at + # all. The strip set is also defensive: it drops the explicit + # target_specific_params bag (merged in separately) plus any legacy capability / + # configuration keys that might appear in older persisted identifiers. + extracted_keys = ( + { + "target_specific_params", + "target_configuration", + } + | _CAPABILITY_PARAM_NAMES + | set(TargetIdentifier._promoted_param_fields()) + ) # Collect remaining params as target_specific_params so the frontend can display them - explicit_specific = params.get("target_specific_params") or {} - extra = {k: v for k, v in params.items() if k not in extracted_keys and v is not None} + raw_specific = target_identifier.params.get("target_specific_params") + explicit_specific = raw_specific if isinstance(raw_specific, dict) else {} + extra = {k: v for k, v in target_identifier.params.items() if k not in extracted_keys and v is not None} combined_specific = {**extra, **explicit_specific} or None inner_targets = _build_inner_targets(target_obj) @@ -84,8 +94,8 @@ def target_object_to_instance(target_registry_name: str, target_obj: PromptTarge # only when ALL inner targets share the same deployment name. When they differ # (e.g. "gpt-4o-japan-nilfilter" vs "pyrit-github-gpt4"), show "—" for # consistency with how other targets display model_name. - model_name = params.get("model_name") or None - underlying_model_name = params.get("underlying_model_name") or None + model_name = target_identifier.model_name or None + underlying_model_name = target_identifier.underlying_model_name or None if model_name is None and inner_targets: inner_models = {t.model_name for t in inner_targets} model_name = inner_models.pop() if len(inner_models) == 1 else None @@ -95,17 +105,17 @@ def target_object_to_instance(target_registry_name: str, target_obj: PromptTarge return TargetInstance( target_registry_name=target_registry_name, - target_type=identifier.class_name, - endpoint=params.get("endpoint") or None, + target_type=target_identifier.class_name, + endpoint=target_identifier.endpoint or None, model_name=model_name, underlying_model_name=underlying_model_name, - temperature=params.get("temperature"), - top_p=params.get("top_p"), - max_requests_per_minute=params.get("max_requests_per_minute"), + temperature=target_identifier.temperature, + top_p=target_identifier.top_p, + max_requests_per_minute=target_identifier.max_requests_per_minute, capabilities=_target_capabilities_to_info(target_obj.capabilities), target_specific_params=combined_specific, inner_targets=inner_targets, - identifier_hash=identifier.hash, + identifier_hash=target_identifier.hash, ) diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 866a9df364..e2b09149d5 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -52,6 +52,8 @@ from pyrit.backend.services.target_service import get_target_service from pyrit.memory import CentralMemory, data_serializer_factory from pyrit.models import ( + AtomicAttackIdentifier, + AttackIdentifier, AttackOutcome, AttackResult, ComponentIdentifier, @@ -60,7 +62,6 @@ ConversationType, MessagePiece, PromptDataType, - build_atomic_attack_identifier, ) from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer @@ -322,11 +323,11 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt attack_result = AttackResult( conversation_id=conversation_id, objective=request.name or "Manual attack via GUI", - atomic_attack_identifier=build_atomic_attack_identifier( - attack_identifier=ComponentIdentifier( + atomic_attack_identifier=AtomicAttackIdentifier.build( + attack_identifier=AttackIdentifier( class_name=request.name or "ManualAttack", class_module="pyrit.backend", - children={"objective_target": target_identifier} if target_identifier else {}, + objective_target=target_identifier, ), ), outcome=AttackOutcome.UNDETERMINED, diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index 00ccb57644..19929f8888 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -26,13 +26,17 @@ ) from pyrit.memory.central_memory import CentralMemory from pyrit.models import ( + AttackIdentifier, AttackOutcome, AttackResult, ComponentIdentifier, ConversationReference, + ConverterIdentifier, Identifiable, Message, + ScorerIdentifier, SeedPrompt, + TargetIdentifier, ) from pyrit.prompt_target.common.target_requirements import TargetRequirements @@ -458,48 +462,64 @@ def _create_identifier( Returns: ComponentIdentifier: The identifier for this attack strategy. """ - all_children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] = { - "objective_target": self.get_objective_target().get_identifier(), - } - + all_children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] = dict(children) if children else {} merged_params: dict[str, Any] = dict(params) if params else {} + objective_target = TargetIdentifier.from_component_identifier(self.get_objective_target().get_identifier()) + # Add scorer if present + objective_scorer: ScorerIdentifier | None = None scoring_config = self.get_attack_scoring_config() if scoring_config and scoring_config.objective_scorer: - all_children["objective_scorer"] = scoring_config.objective_scorer.get_identifier() + objective_scorer = ScorerIdentifier.from_component_identifier( + scoring_config.objective_scorer.get_identifier() + ) # Add adversarial chat target and its effective prompts if present. The adversarial # target becomes a child (filtered to model params by the eval rule), while the # effective system/seed prompts land on the attack-strategy node so they are included - # in both the full component hash and the eval hash. None-valued params are dropped by - # ComponentIdentifier.of, so strategies that do not use a given prompt simply omit it. + # in both the full component hash and the eval hash. None-valued promoted fields are + # dropped by ComponentIdentifier.of, so strategies that do not use a given prompt + # simply omit it. + adversarial_chat: TargetIdentifier | None = None + adversarial_system_prompt: str | None = None + adversarial_seed_prompt: str | None = None adversarial_config = self.get_attack_adversarial_config() if adversarial_config is not None and getattr(adversarial_config, "target", None) is not None: - all_children["adversarial_chat"] = adversarial_config.target.get_identifier() - merged_params["adversarial_system_prompt"] = self._extract_adversarial_prompt_text( - adversarial_config.system_prompt - ) - merged_params["adversarial_seed_prompt"] = self._extract_adversarial_prompt_text( - adversarial_config.seed_prompt - ) + adversarial_chat = TargetIdentifier.from_component_identifier(adversarial_config.target.get_identifier()) + adversarial_system_prompt = self._extract_adversarial_prompt_text(adversarial_config.system_prompt) + adversarial_seed_prompt = self._extract_adversarial_prompt_text(adversarial_config.seed_prompt) # Add request converter identifiers if present + request_converters: list[ConverterIdentifier] | None = None if self._request_converters: - all_children["request_converters"] = [ - converter.get_identifier() for config in self._request_converters for converter in config.converters + request_converters = [ + ConverterIdentifier.from_component_identifier(converter.get_identifier()) + for config in self._request_converters + for converter in config.converters ] # Add response converter identifiers if present + response_converters: list[ConverterIdentifier] | None = None if self._response_converters: - all_children["response_converters"] = [ - converter.get_identifier() for config in self._response_converters for converter in config.converters + response_converters = [ + ConverterIdentifier.from_component_identifier(converter.get_identifier()) + for config in self._response_converters + for converter in config.converters ] - if children: - all_children.update(children) - - return ComponentIdentifier.of(self, params=merged_params or None, children=all_children) + return AttackIdentifier.of( + self, + params=merged_params or None, + children=all_children or None, + objective_target=objective_target, + adversarial_chat=adversarial_chat, + objective_scorer=objective_scorer, + request_converters=request_converters, + response_converters=response_converters, + adversarial_system_prompt=adversarial_system_prompt, + adversarial_seed_prompt=adversarial_seed_prompt, + ) @staticmethod def _extract_adversarial_prompt_text(value: str | SeedPrompt | None) -> str | None: diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index b7b04ec10e..13f9b7ab9e 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -21,11 +21,11 @@ MultiTurnAttackStrategy, ) from pyrit.models import ( + AtomicAttackIdentifier, AttackOutcome, AttackResult, Message, Score, - build_atomic_attack_identifier, ) from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget @@ -316,7 +316,7 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac return AttackResult( conversation_id=context.session.conversation_id, objective=context.objective, - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=self.get_identifier()), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=self.get_identifier()), last_response=response.get_piece() if response else None, last_score=score, related_conversations=context.related_conversations, diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index e988bfc3d6..c102e247f0 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -37,6 +37,7 @@ from pyrit.memory.central_memory import CentralMemory from pyrit.message_normalizer import ConversationContextNormalizer from pyrit.models import ( + AtomicAttackIdentifier, AttackOutcome, AttackResult, ConversationReference, @@ -44,7 +45,6 @@ Message, Score, SeedPrompt, - build_atomic_attack_identifier, ) from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import CapabilityName, TargetRequirements @@ -427,7 +427,7 @@ async def _perform_async(self, *, context: CrescendoAttackContext) -> CrescendoA # Prepare the result result = CrescendoAttackResult( - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=self.get_identifier()), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=self.get_identifier()), conversation_id=context.session.conversation_id, objective=context.objective, outcome=(AttackOutcome.SUCCESS if achieved_objective else AttackOutcome.FAILURE), diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index cc4c53531d..61e4dd2e55 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -20,12 +20,12 @@ MultiTurnAttackStrategy, ) from pyrit.models import ( + AtomicAttackIdentifier, AttackOutcome, AttackResult, Message, Score, SeedAttackGroup, - build_atomic_attack_identifier, ) from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import CapabilityName, PromptTarget @@ -287,7 +287,7 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac return AttackResult( conversation_id=context.session.conversation_id, objective=context.objective, - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=self.get_identifier()), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=self.get_identifier()), last_response=response.get_piece() if response else None, last_score=score, related_conversations=context.related_conversations, diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index 2bc78036b9..85512d16d7 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -30,6 +30,7 @@ ) from pyrit.memory import CentralMemory from pyrit.models import ( + AtomicAttackIdentifier, AttackOutcome, AttackResult, Conversation, @@ -38,7 +39,6 @@ Message, Score, SeedPrompt, - build_atomic_attack_identifier, ) from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import CapabilityName @@ -354,7 +354,7 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac # Prepare the result return AttackResult( - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=self.get_identifier()), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=self.get_identifier()), conversation_id=context.session.conversation_id, objective=context.objective, outcome=(AttackOutcome.SUCCESS if achieved_objective else AttackOutcome.FAILURE), diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 703278d91d..65f2f60c99 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -41,6 +41,7 @@ from pyrit.executor.attack.multi_turn import MultiTurnAttackContext from pyrit.memory import CentralMemory from pyrit.models import ( + AtomicAttackIdentifier, AttackOutcome, AttackResult, ComponentIdentifier, @@ -51,7 +52,6 @@ MessagePiece, Score, SeedPrompt, - build_atomic_attack_identifier, ) from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer from pyrit.prompt_target import CapabilityName, PromptTarget @@ -2173,7 +2173,7 @@ def _create_attack_result( # Create the result with basic information result = TAPAttackResult( - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=self.get_identifier()), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=self.get_identifier()), conversation_id=context.best_conversation_id or "", objective=context.objective, outcome=outcome, diff --git a/pyrit/executor/attack/single_turn/prompt_sending.py b/pyrit/executor/attack/single_turn/prompt_sending.py index 32e4db677b..6f7954a165 100644 --- a/pyrit/executor/attack/single_turn/prompt_sending.py +++ b/pyrit/executor/attack/single_turn/prompt_sending.py @@ -16,13 +16,13 @@ SingleTurnAttackStrategy, ) from pyrit.models import ( + AtomicAttackIdentifier, AttackOutcome, AttackResult, ConversationReference, ConversationType, Message, Score, - build_atomic_attack_identifier, ) from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget @@ -230,7 +230,7 @@ async def _perform_async(self, *, context: SingleTurnAttackContext[Any]) -> Atta return AttackResult( conversation_id=context.conversation_id, objective=context.objective, - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=self.get_identifier()), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=self.get_identifier()), last_response=response.get_piece() if response else None, last_score=score, related_conversations=context.related_conversations, diff --git a/pyrit/executor/attack/streaming/barge_in.py b/pyrit/executor/attack/streaming/barge_in.py index 2c96fbdb62..717e269d71 100644 --- a/pyrit/executor/attack/streaming/barge_in.py +++ b/pyrit/executor/attack/streaming/barge_in.py @@ -20,7 +20,7 @@ AttackResult, Message, ) -from pyrit.models.identifiers.atomic_attack_identifier import build_atomic_attack_identifier +from pyrit.models.identifiers.atomic_attack_identifier import AtomicAttackIdentifier from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target.common.target_capabilities import CapabilityName from pyrit.prompt_target.common.target_requirements import TargetRequirements @@ -199,7 +199,7 @@ def _build_result( return AttackResult( conversation_id=context.conversation_id, objective=context.objective, - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=self.get_identifier()), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=self.get_identifier()), last_response=(last_response.message_pieces[0] if last_response else None), last_score=None, related_conversations=context.related_conversations, diff --git a/pyrit/executor/benchmark/fairness_bias.py b/pyrit/executor/benchmark/fairness_bias.py index ef4aba1180..7607387e04 100644 --- a/pyrit/executor/benchmark/fairness_bias.py +++ b/pyrit/executor/benchmark/fairness_bias.py @@ -19,11 +19,11 @@ from pyrit.executor.core import Strategy, StrategyContext from pyrit.memory import CentralMemory from pyrit.models import ( + AtomicAttackIdentifier, AttackOutcome, AttackResult, ComponentIdentifier, Message, - build_atomic_attack_identifier, ) from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget @@ -197,7 +197,7 @@ async def _perform_async(self, *, context: FairnessBiasBenchmarkContext) -> Atta conversation_id=str(uuid.UUID(int=0)), objective=context.generated_objective, outcome=AttackOutcome.FAILURE, - atomic_attack_identifier=build_atomic_attack_identifier( + atomic_attack_identifier=AtomicAttackIdentifier.build( attack_identifier=ComponentIdentifier.of(self), ), labels=context.memory_labels, diff --git a/pyrit/identifiers/__init__.py b/pyrit/identifiers/__init__.py index f3353096af..d33a8a1a3b 100644 --- a/pyrit/identifiers/__init__.py +++ b/pyrit/identifiers/__init__.py @@ -32,8 +32,6 @@ ObjectiveTargetEvaluationIdentifier, ScorerEvaluationIdentifier, ScorerIdentifier, - build_atomic_attack_identifier, - build_seed_identifier, class_name_to_snake_case, compute_eval_hash, config_hash, @@ -43,8 +41,6 @@ __all__ = [ "AtomicAttackEvaluationIdentifier", - "build_atomic_attack_identifier", - "build_seed_identifier", "ChildEvalRule", "class_name_to_snake_case", "ComponentIdentifier", @@ -66,17 +62,11 @@ _warned: set[str] = set() -# Names that have an additional deprecation warning at the new pyrit.models.identifiers path — -# for these, skip the shim's path-migration warning and let the deeper module's __getattr__ -# emit the (more informative) name-deprecation warning pointing at the actual replacement -# class. Otherwise users would see two warnings on a single access. -_NAMES_DEPRECATED_AT_NEW_PATH = frozenset({"ScorerIdentifier"}) - def __getattr__(name: str) -> Any: if name not in __all__: raise AttributeError(f"module 'pyrit.identifiers' has no attribute {name!r}") - if name not in _NAMES_DEPRECATED_AT_NEW_PATH and name not in _warned: + if name not in _warned: print_deprecation_message( old_item=f"pyrit.identifiers.{name}", new_item=f"pyrit.models.identifiers.{name}", diff --git a/pyrit/identifiers/atomic_attack_identifier.py b/pyrit/identifiers/atomic_attack_identifier.py deleted file mode 100644 index 18e105dd0f..0000000000 --- a/pyrit/identifiers/atomic_attack_identifier.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -"""Deprecation shim — moved to pyrit.models.identifiers.atomic_attack_identifier in 0.14.""" - -from typing import TYPE_CHECKING, Any - -from pyrit.common.deprecation import print_deprecation_message -from pyrit.models.identifiers import atomic_attack_identifier as _new - -if TYPE_CHECKING: - from pyrit.models.identifiers.atomic_attack_identifier import ( - build_atomic_attack_identifier, - build_seed_identifier, - ) - -__all__ = ["build_atomic_attack_identifier", "build_seed_identifier"] - -_warned: set[str] = set() - - -def __getattr__(name: str) -> Any: - if name not in __all__: - raise AttributeError(f"module 'pyrit.identifiers.atomic_attack_identifier' has no attribute {name!r}") - if name not in _warned: - print_deprecation_message( - old_item=f"pyrit.identifiers.atomic_attack_identifier.{name}", - new_item=f"pyrit.models.identifiers.atomic_attack_identifier.{name}", - removed_in="0.16.0", - ) - _warned.add(name) - return getattr(_new, name) - - -def __dir__() -> list[str]: - return sorted(__all__) diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index b5ffe6be7b..65d6bb6493 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -18,7 +18,7 @@ """ import importlib -from typing import TYPE_CHECKING, Any +from typing import Any from pyrit.common.deprecation import print_deprecation_message from pyrit.models.chat_message import ( @@ -35,16 +35,21 @@ TARGET_EVAL_PARAM_FALLBACKS, TARGET_EVAL_PARAMS, AtomicAttackEvaluationIdentifier, + AtomicAttackIdentifier, + AttackIdentifier, + AttackTechniqueIdentifier, ChildEvalRule, ComponentIdentifier, + ConverterIdentifier, EvaluationIdentifier, Identifiable, IdentifierFilter, IdentifierType, ObjectiveTargetEvaluationIdentifier, ScorerEvaluationIdentifier, - build_atomic_attack_identifier, - build_seed_identifier, + ScorerIdentifier, + SeedIdentifier, + TargetIdentifier, class_name_to_snake_case, compute_eval_hash, config_hash, @@ -101,27 +106,31 @@ SeedUnion, SimulatedTargetSystemPromptPaths, ) +from pyrit.models.target_capabilities import CapabilityName, TargetCapabilities __all__ = [ "ALLOWED_CHAT_MESSAGE_ROLES", "AllowedCategories", "AtomicAttackEvaluationIdentifier", + "AtomicAttackIdentifier", + "AttackIdentifier", + "AttackTechniqueIdentifier", "AttackResult", "AttackResultT", "AttackOutcome", "AudioPathDataTypeSerializer", "AzureBlobStorageIO", "BinaryPathDataTypeSerializer", - "build_atomic_attack_identifier", - "build_seed_identifier", "ChatMessage", "ChatMessagesDataset", "ChatMessageRole", "ChildEvalRule", "class_name_to_snake_case", + "CapabilityName", "ComponentIdentifier", "compute_eval_hash", "config_hash", + "ConverterIdentifier", "Conversation", "ConversationReference", "ConversationStats", @@ -180,6 +189,7 @@ "SeedPrompt", "SeedDataset", "SeedGroup", + "SeedIdentifier", "SeedSimulatedConversation", "SeedType", "SeedUnion", @@ -191,6 +201,8 @@ "StrategyResultT", "TARGET_EVAL_PARAM_FALLBACKS", "TARGET_EVAL_PARAMS", + "TargetCapabilities", + "TargetIdentifier", "TextDataTypeSerializer", "UnvalidatedScore", "validate_registry_name", @@ -198,17 +210,6 @@ "RetryEvent", ] -if TYPE_CHECKING: - # Type-only alias so static checkers can resolve ``from pyrit.models import ScorerIdentifier``. - # At runtime the symbol is served by ``__getattr__`` below so accessing it emits a one-shot - # DeprecationWarning per process. Will be removed in 0.16.0. - ScorerIdentifier = ComponentIdentifier - -# Deprecated rename aliases (pre-#1387 names that were collapsed into ComponentIdentifier). -_DEPRECATED_RENAME_ALIASES: dict[str, Any] = { - "ScorerIdentifier": ComponentIdentifier, -} - # Names that moved to ``pyrit.memory.storage``. Served lazily via importlib so that # importing ``pyrit.models`` stays import-boundary clean and fires no warning until a # moved name is actually accessed. Will be removed in 0.17.0. @@ -231,16 +232,6 @@ def __getattr__(name: str) -> Any: - if name in _DEPRECATED_RENAME_ALIASES: - target = _DEPRECATED_RENAME_ALIASES[name] - if name not in _warned: - print_deprecation_message( - old_item=f"{__name__}.{name}", - new_item=target, - removed_in="0.16.0", - ) - _warned.add(name) - return target if name in _MOVED_TO_MEMORY_STORAGE: target_module = _MOVED_TO_MEMORY_STORAGE[name] if name not in _warned: diff --git a/pyrit/models/identifiers/__init__.py b/pyrit/models/identifiers/__init__.py index 4e260fdf8c..4606829b8c 100644 --- a/pyrit/models/identifiers/__init__.py +++ b/pyrit/models/identifiers/__init__.py @@ -3,13 +3,11 @@ """Identifiers module for PyRIT components.""" -from typing import TYPE_CHECKING, Any - -from pyrit.common.deprecation import print_deprecation_message from pyrit.models.identifiers.atomic_attack_identifier import ( - build_atomic_attack_identifier, - build_seed_identifier, + AtomicAttackIdentifier, ) +from pyrit.models.identifiers.attack_identifier import AttackIdentifier +from pyrit.models.identifiers.attack_technique_identifier import AttackTechniqueIdentifier from pyrit.models.identifiers.class_name_utils import ( REGISTRY_NAME_PATTERN, class_name_to_snake_case, @@ -17,6 +15,7 @@ validate_registry_name, ) from pyrit.models.identifiers.component_identifier import ComponentIdentifier, Identifiable, config_hash +from pyrit.models.identifiers.converter_identifier import ConverterIdentifier from pyrit.models.identifiers.evaluation_identifier import ( TARGET_EVAL_PARAM_FALLBACKS, TARGET_EVAL_PARAMS, @@ -29,56 +28,34 @@ compute_inner_attack_eval_hash, ) from pyrit.models.identifiers.identifier_filters import IdentifierFilter, IdentifierType - -if TYPE_CHECKING: - # Type-only alias so static checkers can resolve ``from pyrit.models.identifiers import - # ScorerIdentifier``. At runtime the symbol is served by ``__getattr__`` below so we can - # emit a one-shot DeprecationWarning per process. - ScorerIdentifier = ComponentIdentifier +from pyrit.models.identifiers.scorer_identifier import ScorerIdentifier +from pyrit.models.identifiers.seed_identifier import SeedIdentifier +from pyrit.models.identifiers.target_identifier import TargetIdentifier __all__ = [ "AtomicAttackEvaluationIdentifier", - "build_atomic_attack_identifier", - "build_seed_identifier", + "AtomicAttackIdentifier", + "AttackIdentifier", + "AttackTechniqueIdentifier", "ChildEvalRule", "class_name_to_snake_case", "ComponentIdentifier", "compute_eval_hash", "compute_inner_attack_eval_hash", + "ConverterIdentifier", "EvaluationIdentifier", "Identifiable", "ObjectiveTargetEvaluationIdentifier", "REGISTRY_NAME_PATTERN", "ScorerEvaluationIdentifier", "ScorerIdentifier", + "SeedIdentifier", "snake_case_to_class_name", "TARGET_EVAL_PARAM_FALLBACKS", "TARGET_EVAL_PARAMS", + "TargetIdentifier", "validate_registry_name", "config_hash", "IdentifierFilter", "IdentifierType", ] - -# Deprecated rename aliases (pre-#1387 names that were collapsed into ComponentIdentifier). -# Served via ``__getattr__`` rather than as static module attributes so accessing them emits -# a one-shot DeprecationWarning per process. Will be removed in 0.16.0. -_DEPRECATED_RENAME_ALIASES: dict[str, Any] = { - "ScorerIdentifier": ComponentIdentifier, -} - -_warned: set[str] = set() - - -def __getattr__(name: str) -> Any: - if name in _DEPRECATED_RENAME_ALIASES: - target = _DEPRECATED_RENAME_ALIASES[name] - if name not in _warned: - print_deprecation_message( - old_item=f"{__name__}.{name}", - new_item=target, - removed_in="0.16.0", - ) - _warned.add(name) - return target - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyrit/models/identifiers/atomic_attack_identifier.py b/pyrit/models/identifiers/atomic_attack_identifier.py index bb4618156c..60d9a47170 100644 --- a/pyrit/models/identifiers/atomic_attack_identifier.py +++ b/pyrit/models/identifiers/atomic_attack_identifier.py @@ -2,13 +2,10 @@ # Licensed under the MIT license. """ -Atomic attack identity builder functions. +Composite identifier for an atomic attack run. -Builds a composite ComponentIdentifier that uniquely identifies an attack run -by combining the attack strategy's identity with the seed identifiers from -the dataset. - -The composite identifier has this shape:: +Combines an attack technique with the seed identifiers from the dataset. The +composite identifier has this shape:: AtomicAttack ├── attack_technique (class_name="AttackTechnique") @@ -17,17 +14,18 @@ └── seed_identifiers (list of ALL seed ComponentIdentifiers, for traceability) """ -import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING + +from pydantic import Field +from pyrit.models.identifiers.attack_identifier import AttackIdentifier +from pyrit.models.identifiers.attack_technique_identifier import AttackTechniqueIdentifier from pyrit.models.identifiers.component_identifier import ComponentIdentifier +from pyrit.models.identifiers.seed_identifier import SeedIdentifier if TYPE_CHECKING: - from pyrit.models.seeds.seed import Seed from pyrit.models.seeds.seed_group import SeedGroup -logger = logging.getLogger(__name__) - # Class metadata for the composite identifier _ATOMIC_ATTACK_CLASS_NAME = "AtomicAttack" _ATOMIC_ATTACK_CLASS_MODULE = "pyrit.scenario.core.atomic_attack" @@ -36,91 +34,76 @@ _ATTACK_TECHNIQUE_CLASS_MODULE = "pyrit.scenario.core.attack_technique" -def build_seed_identifier(seed: "Seed") -> ComponentIdentifier: +class AtomicAttackIdentifier(ComponentIdentifier): """ - Build a ComponentIdentifier from a seed's behavioral properties. + Strongly-typed projection of an atomic attack's ``ComponentIdentifier``. - Captures the seed's content hash, dataset name, and class type so that - different seeds produce different identifiers while the same seed content - always produces the same identifier. - - Args: - seed: The seed to build an identifier for. - - Returns: - An identifier capturing the seed's behavioral properties. - """ - params: dict[str, Any] = { - "value": seed.value, - "value_sha256": seed.value_sha256, - "dataset_name": seed.dataset_name, - "is_general_technique": seed.is_general_technique, - } - - return ComponentIdentifier( - class_name=seed.__class__.__name__, - class_module=seed.__class__.__module__, - params=params, - ) - - -def build_atomic_attack_identifier( - *, - technique_identifier: ComponentIdentifier | None = None, - attack_identifier: ComponentIdentifier | None = None, - seed_group: "SeedGroup | None" = None, -) -> ComponentIdentifier: + Promotes the attack technique (``attack_technique``) and all seed identifiers + from the dataset (``seed_identifiers``). """ - Build a composite ComponentIdentifier for an atomic attack. - - The identifier places the attack technique in ``children["attack_technique"]`` - and all seeds from the seed group in ``children["seed_identifiers"]`` for traceability. - - Callers that have an ``AttackTechnique`` object should pass - ``technique_identifier=attack_technique.get_identifier()``. - Callers that only have a raw attack strategy identifier (e.g. legacy - backward-compat paths) can pass ``attack_identifier`` instead, which is - wrapped in a minimal technique node automatically. - - Args: - technique_identifier: Pre-built technique identifier from - ``AttackTechnique.get_identifier()``. Mutually exclusive with - ``attack_identifier``. - attack_identifier: Raw attack strategy identifier. Used when no - ``AttackTechnique`` instance is available. Mutually exclusive - with ``technique_identifier``. - seed_group: The seed group to extract all seeds from. - - Returns: - A composite ComponentIdentifier with class_name="AtomicAttack". - - Raises: - ValueError: If both or neither of ``technique_identifier`` and - ``attack_identifier`` are provided. - """ - if technique_identifier is not None and attack_identifier is not None: - raise ValueError("Provide technique_identifier or attack_identifier, not both") - - if technique_identifier is None: - if attack_identifier is None: - raise ValueError("Either technique_identifier or attack_identifier must be provided") - technique_identifier = ComponentIdentifier( - class_name=_ATTACK_TECHNIQUE_CLASS_NAME, - class_module=_ATTACK_TECHNIQUE_CLASS_MODULE, - children={"attack": attack_identifier}, - ) - - seed_identifiers: list[ComponentIdentifier] = [] - if seed_group is not None: - seed_identifiers.extend(build_seed_identifier(seed) for seed in seed_group.seeds) - children: dict[str, Any] = { - "attack_technique": technique_identifier, - "seed_identifiers": seed_identifiers, - } - - return ComponentIdentifier( - class_name=_ATOMIC_ATTACK_CLASS_NAME, - class_module=_ATOMIC_ATTACK_CLASS_MODULE, - children=children, - ) + #: The attack technique executed. + attack_technique: AttackTechniqueIdentifier | None = None + #: All seed identifiers from the dataset, for traceability. + seed_identifiers: list[SeedIdentifier] = Field(default_factory=list) + + @classmethod + def build( + cls, + *, + technique_identifier: ComponentIdentifier | None = None, + attack_identifier: ComponentIdentifier | None = None, + seed_group: "SeedGroup | None" = None, + ) -> "AtomicAttackIdentifier": + """ + Build a composite AtomicAttackIdentifier for an atomic attack. + + The identifier places the attack technique in ``children["attack_technique"]`` + and all seeds from the seed group in ``children["seed_identifiers"]`` for traceability. + + Callers that have an ``AttackTechnique`` object should pass + ``technique_identifier=attack_technique.get_identifier()``. + Callers that only have a raw attack strategy identifier (e.g. legacy + backward-compat paths) can pass ``attack_identifier`` instead, which is + wrapped in a minimal technique node automatically. + + Args: + technique_identifier: Pre-built technique identifier from + ``AttackTechnique.get_identifier()``. Mutually exclusive with + ``attack_identifier``. + attack_identifier: Raw attack strategy identifier. Used when no + ``AttackTechnique`` instance is available. Mutually exclusive + with ``technique_identifier``. + seed_group: The seed group to extract all seeds from. + + Returns: + A composite AtomicAttackIdentifier with class_name="AtomicAttack". + + Raises: + ValueError: If both or neither of ``technique_identifier`` and + ``attack_identifier`` are provided. + """ + if technique_identifier is not None and attack_identifier is not None: + raise ValueError("Provide technique_identifier or attack_identifier, not both") + + if technique_identifier is None: + if attack_identifier is None: + raise ValueError("Either technique_identifier or attack_identifier must be provided") + technique_identifier = AttackTechniqueIdentifier( + class_name=_ATTACK_TECHNIQUE_CLASS_NAME, + class_module=_ATTACK_TECHNIQUE_CLASS_MODULE, + attack=AttackIdentifier.from_component_identifier(attack_identifier), + ) + + technique = AttackTechniqueIdentifier.from_component_identifier(technique_identifier) + + seed_identifiers: list[SeedIdentifier] = [] + if seed_group is not None: + seed_identifiers.extend(SeedIdentifier.from_seed(seed) for seed in seed_group.seeds) + + return cls( + class_name=_ATOMIC_ATTACK_CLASS_NAME, + class_module=_ATOMIC_ATTACK_CLASS_MODULE, + attack_technique=technique, + seed_identifiers=seed_identifiers, + ) diff --git a/pyrit/models/identifiers/attack_identifier.py b/pyrit/models/identifiers/attack_identifier.py new file mode 100644 index 0000000000..db26405625 --- /dev/null +++ b/pyrit/models/identifiers/attack_identifier.py @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Strongly-typed projection of an attack strategy's identifier.""" + +from __future__ import annotations + +from pydantic import Field + +from pyrit.models.identifiers.component_identifier import ComponentIdentifier +from pyrit.models.identifiers.converter_identifier import ( # noqa: TC001 + ConverterIdentifier, # runtime-required by Pydantic field annotations +) +from pyrit.models.identifiers.scorer_identifier import ( # noqa: TC001 + ScorerIdentifier, # runtime-required by Pydantic field annotations +) +from pyrit.models.identifiers.target_identifier import ( # noqa: TC001 + TargetIdentifier, # runtime-required by Pydantic field annotations +) + + +class AttackIdentifier(ComponentIdentifier): + """ + Strongly-typed projection of an ``AttackStrategy``'s ``ComponentIdentifier``. + + Promotes the effective adversarial system/seed prompts and the attack's own + child slots — objective target, adversarial chat target, objective scorer, and + the request/response converter pipelines. + """ + + #: Effective adversarial system prompt text, if the strategy uses one. + adversarial_system_prompt: str | None = None + #: Effective adversarial seed prompt text, if the strategy uses one. + adversarial_seed_prompt: str | None = None + #: The objective target the attack drives. + objective_target: TargetIdentifier | None = None + #: The adversarial chat target, if the strategy uses one. + adversarial_chat: TargetIdentifier | None = None + #: The objective scorer, if the strategy uses one. + objective_scorer: ScorerIdentifier | None = None + #: Request-side converter pipeline. + request_converters: list[ConverterIdentifier] = Field(default_factory=list) + #: Response-side converter pipeline. + response_converters: list[ConverterIdentifier] = Field(default_factory=list) diff --git a/pyrit/models/identifiers/attack_technique_identifier.py b/pyrit/models/identifiers/attack_technique_identifier.py new file mode 100644 index 0000000000..412c00dbed --- /dev/null +++ b/pyrit/models/identifiers/attack_technique_identifier.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Strongly-typed projection of an attack technique's identifier.""" + +from __future__ import annotations + +from pydantic import Field + +from pyrit.models.identifiers.attack_identifier import ( # noqa: TC001 + AttackIdentifier, # runtime-required by Pydantic field annotations +) +from pyrit.models.identifiers.component_identifier import ComponentIdentifier +from pyrit.models.identifiers.seed_identifier import ( # noqa: TC001 + SeedIdentifier, # runtime-required by Pydantic field annotations +) + + +class AttackTechniqueIdentifier(ComponentIdentifier): + """ + Strongly-typed projection of an ``AttackTechnique``'s ``ComponentIdentifier``. + + Promotes the attack strategy child (``attack``) and the optional technique + seeds (``technique_seeds``). + """ + + #: The attack strategy that defines the technique. + attack: AttackIdentifier | None = None + #: Optional seeds that specialize the technique. + technique_seeds: list[SeedIdentifier] = Field(default_factory=list) diff --git a/pyrit/models/identifiers/component_identifier.py b/pyrit/models/identifiers/component_identifier.py index b8e7ad1895..19834e3450 100644 --- a/pyrit/models/identifiers/component_identifier.py +++ b/pyrit/models/identifiers/component_identifier.py @@ -20,13 +20,25 @@ import json import logging from abc import ABC, abstractmethod -from typing import Any, ClassVar +from typing import Any, ClassVar, get_args, get_origin from pydantic import BaseModel, ConfigDict, Field, SerializationInfo, model_serializer, model_validator +from typing_extensions import Self, TypeAliasType import pyrit from pyrit.common.deprecation import print_deprecation_message +#: The set of value types allowed inside ``ComponentIdentifier.params``. Params +#: must be JSON-serializable scalars (``str`` / ``int`` / ``float`` / ``bool`` / +#: ``None``) or arbitrarily nested ``list`` / ``dict`` containers of those. This +#: mirrors exactly what ``config_hash``'s ``json.dumps`` can serialize, so an +#: identifier that validates is guaranteed to hash. Composite identity belongs in +#: ``children`` (typed as ``ComponentIdentifier``), never in ``params``. +JSONValue = TypeAliasType( + "JSONValue", + "str | int | float | bool | None | list[JSONValue] | dict[str, JSONValue]", +) + #: Param names that collide with reserved top-level keys in the flat storage #: shape. Forbidden inside ``ComponentIdentifier.params`` so storage / REST #: round-trips stay lossless. @@ -72,8 +84,8 @@ def _build_hash_dict( *, class_name: str, class_module: str, - params: dict[str, Any], - children: dict[str, Any], + params: dict[str, JSONValue], + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]], ) -> dict[str, Any]: """ Build the canonical dictionary used for hash computation. @@ -85,8 +97,9 @@ def _build_hash_dict( Args: class_name (str): The component's class name. class_module (str): The component's module path. - params (dict[str, Any]): Behavioral parameters (non-None values only). - children (dict[str, Any]): Child name to ComponentIdentifier or list of ComponentIdentifier. + params (dict[str, JSONValue]): Behavioral parameters (non-None values only). + children (dict[str, ComponentIdentifier | list[ComponentIdentifier]]): Child name to + ComponentIdentifier or list of ComponentIdentifier. Returns: dict[str, Any]: The canonical dictionary for hashing. @@ -114,6 +127,32 @@ def _build_hash_dict( return hash_dict +def _dump_child_identifiers_to_dict(value: Any) -> Any: + """ + Replace ``ComponentIdentifier`` instances in a child value with their flat dict form. + + A promoted child field is typed as a specific ``ComponentIdentifier`` + subclass (e.g. ``TargetIdentifier``). Build sites and DB loads may supply a + base ``ComponentIdentifier`` (or a different subclass) for that slot, which + Pydantic's strict model validation would reject. Dumping such instances to + their flat ``model_dump()`` dict lets validation re-parse them into the + declared subclass; the stored ``hash`` rides along, so identity is preserved. + + Args: + value (Any): The raw child value (an identifier instance, a dict, a list + of either, or ``None``). + + Returns: + Any: The value with any ``ComponentIdentifier`` instances replaced by + their flat dict form. + """ + if isinstance(value, ComponentIdentifier): + return value.model_dump() + if isinstance(value, list): + return [_dump_child_identifiers_to_dict(item) for item in value] + return value + + class ComponentIdentifier(BaseModel): """ Immutable snapshot of a component's behavioral configuration. @@ -126,22 +165,26 @@ class ComponentIdentifier(BaseModel): params, and children produce the same hash. This enables deterministic metrics lookup, DB deduplication, and registry keying. - Serialization - ------------- - ``model_dump()`` returns a **flat** dict where reserved keys - (``class_name``, ``class_module``, ``hash``, ``pyrit_version``, - ``eval_hash``, ``children``) sit at the top level alongside the inlined - param values. This shape is also the storage / REST format. Pass - ``context={"max_value_length": N}`` to truncate long string param values. - ``model_validate()`` accepts the same flat shape (plus a structured form - with an explicit ``params`` dict). - - Mutability - ---------- - The model is frozen, but ``params`` and ``children`` are dicts whose + Typed projections: subclasses (``TargetIdentifier``, ``ConverterIdentifier``, …) + may promote well-known params and children to ordinary typed fields. Promotion is + automatic and keyed off the field's annotation: a scalar field maps to a ``params`` + entry; a field annotated as a ``ComponentIdentifier`` subclass (or a ``list`` + thereof) maps to a ``children`` slot of the same name. The promoted value is + mirrored back into ``params`` / ``children`` before hashing, so a typed subclass + serializes and hashes identically to a plain ``ComponentIdentifier`` built with the + same params/children. Non-promoted members simply stay in ``params`` / ``children``. + + Serialization: ``model_dump()`` returns a flat dict where reserved keys + (``class_name``, ``class_module``, ``hash``, ``pyrit_version``, ``eval_hash``, + ``children``) sit at the top level alongside the inlined param values. This shape is + also the storage / REST format. Pass ``context={"max_value_length": N}`` to truncate + long string param values. ``model_validate()`` accepts the same flat shape (plus a + structured form with an explicit ``params`` dict). + + Mutability: the model is frozen, but ``params`` and ``children`` are dicts whose contents are not deep-frozen — mutating them after construction creates an - identifier whose stored ``hash`` no longer matches its content. Treat - every identifier as a fully immutable value. + identifier whose stored ``hash`` no longer matches its content. Treat every + identifier as a fully immutable value. """ model_config = ConfigDict(frozen=True, extra="forbid") @@ -159,8 +202,10 @@ class ComponentIdentifier(BaseModel): class_name: str #: Full module path (e.g., "pyrit.score.self_ask_scale_scorer"). class_module: str - #: Behavioral parameters that affect output. - params: dict[str, Any] = Field(default_factory=dict) + #: Behavioral parameters that affect output. Values must be JSON-serializable + #: scalars or nested ``list`` / ``dict`` containers of them (see ``JSONValue``); + #: composite identity belongs in ``children`` instead. + params: dict[str, JSONValue] = Field(default_factory=dict) #: Named child identifiers for compositional identity (e.g., a scorer's target). children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] = Field(default_factory=dict) #: Content-addressed SHA256 hash. Computed automatically when ``None``; @@ -173,6 +218,64 @@ class ComponentIdentifier(BaseModel): #: to the identifier so it survives DB round-trips with truncated params. eval_hash: str | None = None + # ------------------------------------------------------------------ + # Promotion (typed projection — derived from the subclass's own fields) + # ------------------------------------------------------------------ + + @staticmethod + def _is_child_field(annotation: Any) -> bool: + """ + Return whether a field annotation denotes a child identifier. + + Args: + annotation (Any): The resolved field annotation (from + ``model_fields[name].annotation``). + + Returns: + bool: ``True`` if the annotation is a ``ComponentIdentifier`` subclass + or a ``list`` thereof (optionally wrapped in ``| None``); ``False`` for + scalar (param) fields. + """ + if get_origin(annotation) is list: + args = get_args(annotation) + inner = args[0] if args else None + return isinstance(inner, type) and issubclass(inner, ComponentIdentifier) + + candidates: tuple[Any, ...] = get_args(annotation) or (annotation,) + return any(isinstance(c, type) and issubclass(c, ComponentIdentifier) for c in candidates) + + @classmethod + def _promoted_fields(cls) -> tuple[str, ...]: + """ + Return the subclass's own fields (everything beyond the base structural fields). + + Returns: + tuple[str, ...]: Field names declared by ``cls`` but not by the base + ``ComponentIdentifier``, in field-definition order. + """ + base_fields = set(ComponentIdentifier.model_fields) + return tuple(name for name in cls.model_fields if name not in base_fields) + + @classmethod + def _promoted_param_fields(cls) -> tuple[str, ...]: + """ + Return the subclass's own scalar fields, which map to ``params`` entries. + + Returns: + tuple[str, ...]: Promoted param field names, in field-definition order. + """ + return tuple(n for n in cls._promoted_fields() if not cls._is_child_field(cls.model_fields[n].annotation)) + + @classmethod + def _promoted_child_fields(cls) -> tuple[str, ...]: + """ + Return the subclass's own identifier-typed fields, which map to ``children`` slots. + + Returns: + tuple[str, ...]: Promoted child field names, in field-definition order. + """ + return tuple(n for n in cls._promoted_fields() if cls._is_child_field(cls.model_fields[n].annotation)) + # ------------------------------------------------------------------ # Validators # ------------------------------------------------------------------ @@ -228,6 +331,7 @@ def _normalize_input(cls, data: Any) -> Any: data.setdefault(cls.KEY_CLASS_NAME, "Unknown") data.setdefault(cls.KEY_CLASS_MODULE, "unknown") + promoted_fields = cls._promoted_fields() reserved_top = { cls.KEY_CLASS_NAME, cls.KEY_CLASS_MODULE, @@ -235,6 +339,7 @@ def _normalize_input(cls, data: Any) -> Any: cls.KEY_PYRIT_VERSION, cls.KEY_EVAL_HASH, cls.KEY_CHILDREN, + *promoted_fields, } if "params" in data: @@ -257,20 +362,66 @@ def _normalize_input(cls, data: Any) -> Any: if collisions: raise ValueError(f"ComponentIdentifier params must not use reserved names: {sorted(collisions)}") + # Promotion: lift any promoted value that arrived inside the flat + # ``params`` / ``children`` buckets (e.g. the storage shape) up to its + # matching top-level field so Pydantic validates it into the typed field. + # Build-site construction already passes promoted values top-level, so + # those are left untouched here. + if promoted_fields: + params_bucket = params_dict if isinstance(params_dict, dict) else {} + children_value = data.get(cls.KEY_CHILDREN) + children_bucket = children_value if isinstance(children_value, dict) else {} + for name in promoted_fields: + if name in data: + continue + if name in params_bucket: + data[name] = params_bucket[name] + elif name in children_bucket: + data[name] = children_bucket[name] + + # Promoted child values may arrive as ComponentIdentifier instances + # (possibly a base ComponentIdentifier or a different subclass than + # the typed field declares). Dump them to their flat dict form so + # Pydantic re-parses them into the declared identifier subclass. + # Round-tripping through model_dump preserves the stored hash. + for name in cls._promoted_child_fields(): + if name in data: + data[name] = _dump_child_identifiers_to_dict(data[name]) + return data @model_validator(mode="after") - def _compute_hash_if_missing(self) -> ComponentIdentifier: + def _promote_and_compute_hash(self) -> ComponentIdentifier: """ - Compute the content-addressed hash if it was not provided. + Mirror promoted typed fields into ``params`` / ``children`` and hash. - Preserves any pre-set hash (e.g. one reconstructed from a truncated - DB row, where recomputing from the truncated params would produce a - wrong identity). + Promoted scalar fields are written into ``params`` and promoted + identifier fields into ``children`` (``None`` / empty list dropped), so a + typed subclass serializes and hashes identically to a plain + ``ComponentIdentifier`` with the same values. The content-addressed hash + is then computed if it was not provided — a pre-set hash (e.g. one + reconstructed from a truncated DB row) is preserved. Returns: - ``self`` (mutated in-place via ``object.__setattr__``). - """ + ``self`` (mutated in-place). + """ + for name in self._promoted_param_fields(): + value = getattr(self, name) + if value is not None: + self.params[name] = value + for name in self._promoted_child_fields(): + value = getattr(self, name) + if value is None: + continue + if isinstance(value, list): + # Store non-empty lists always; store an empty list only when it + # was set explicitly (preserves hashes for builders that include + # an empty child slot, while a defaulted empty list stays absent). + if value or name in self.model_fields_set: + self.children[name] = value + else: + self.children[name] = value + if self.hash is None: hash_dict = _build_hash_dict( class_name=self.class_name, @@ -437,7 +588,8 @@ def of( *, params: dict[str, Any] | None = None, children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, - ) -> ComponentIdentifier: + **promoted: Any, + ) -> Self: """ Build a ComponentIdentifier from a live object instance. @@ -450,20 +602,45 @@ def of( identifier. params: Optional behavioral params. children: Optional child identifiers. + **promoted: Optional promoted typed fields (for subclasses). Passed + by name; ``None`` values are dropped. These are mirrored back + into ``params`` / ``children`` automatically. Returns: A new ComponentIdentifier describing ``obj``. """ clean_params = {k: v for k, v in (params or {}).items() if v is not None} clean_children = {k: v for k, v in (children or {}).items() if v is not None} + clean_promoted = {k: v for k, v in promoted.items() if v is not None} return cls( class_name=obj.__class__.__name__, class_module=obj.__class__.__module__, params=clean_params, children=clean_children, + **clean_promoted, ) + @classmethod + def from_component_identifier(cls, identifier: ComponentIdentifier) -> Self: + """ + Return ``identifier`` as an instance of this typed subclass. + + Pass-through when ``identifier`` is already an instance of ``cls``; + otherwise revalidate its flat dump into ``cls`` (e.g. a base identifier + loaded from the DB), rehydrating promoted typed fields. The hash is + preserved across the round-trip. + + Args: + identifier: A ``ComponentIdentifier`` (possibly the base type). + + Returns: + An instance of ``cls`` describing the same identity. + """ + if isinstance(identifier, cls): + return identifier + return cls.model_validate(identifier.model_dump()) + def get_child(self, key: str) -> ComponentIdentifier | None: """ Get a single child by key. diff --git a/pyrit/models/identifiers/converter_identifier.py b/pyrit/models/identifiers/converter_identifier.py new file mode 100644 index 0000000000..555147d699 --- /dev/null +++ b/pyrit/models/identifiers/converter_identifier.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Strongly-typed projection of a converter's identifier.""" + +from __future__ import annotations + +from pydantic import Field + +from pyrit.models.identifiers.component_identifier import ComponentIdentifier +from pyrit.models.identifiers.target_identifier import ( # noqa: TC001 + TargetIdentifier, # runtime-required by Pydantic field annotations +) +from pyrit.models.literals import PromptDataType # noqa: TC001 (runtime-required by Pydantic field annotations) + + +class ConverterIdentifier(ComponentIdentifier): + """ + Strongly-typed projection of a ``PromptConverter``'s ``ComponentIdentifier``. + + Promotes the supported input/output data types; any converter-specific params + stay in ``params``. The converter's own child slots — ``converter_target`` + (an LLM target) and ``sub_converters`` (nested converters) — are promoted to + typed fields. + """ + + #: Input data types supported by this converter. + supported_input_types: list[PromptDataType] | None = None + #: Output data types produced by this converter. + supported_output_types: list[PromptDataType] | None = None + #: Target an LLM-backed converter calls (e.g., ``LLMGenericTextConverter``). + converter_target: TargetIdentifier | None = None + #: Nested converters a composite wraps (e.g., ``SelectiveTextConverter``), + #: typed recursively. + sub_converters: list[ConverterIdentifier] = Field(default_factory=list) diff --git a/pyrit/models/identifiers/evaluation_identifier.py b/pyrit/models/identifiers/evaluation_identifier.py index fafea5f57c..03c6e713d3 100644 --- a/pyrit/models/identifiers/evaluation_identifier.py +++ b/pyrit/models/identifiers/evaluation_identifier.py @@ -364,7 +364,7 @@ class ObjectiveTargetEvaluationIdentifier(EvaluationIdentifier): ) -def compute_inner_attack_eval_hash(*, attack: AttackStrategy) -> str: +def compute_inner_attack_eval_hash(*, attack: AttackStrategy[Any, Any]) -> str: """ Predict the eval hash the executor will stamp on persisted child rows for this attack. @@ -380,7 +380,7 @@ def compute_inner_attack_eval_hash(*, attack: AttackStrategy) -> str: str: The eval hash that will appear on persisted child rows. """ # Local import avoids a circular dependency inside the identifiers package. - from pyrit.models.identifiers.atomic_attack_identifier import build_atomic_attack_identifier + from pyrit.models.identifiers.atomic_attack_identifier import AtomicAttackIdentifier - composite = build_atomic_attack_identifier(attack_identifier=attack.get_identifier()) + composite = AtomicAttackIdentifier.build(attack_identifier=attack.get_identifier()) return AtomicAttackEvaluationIdentifier(composite).eval_hash diff --git a/pyrit/models/identifiers/scorer_identifier.py b/pyrit/models/identifiers/scorer_identifier.py new file mode 100644 index 0000000000..c1d4a50b93 --- /dev/null +++ b/pyrit/models/identifiers/scorer_identifier.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Strongly-typed projection of a scorer's identifier.""" + +from __future__ import annotations + +from pydantic import Field + +from pyrit.models.identifiers.component_identifier import ComponentIdentifier +from pyrit.models.identifiers.target_identifier import ( # noqa: TC001 + TargetIdentifier, # runtime-required by Pydantic field annotations +) + + +class ScorerIdentifier(ComponentIdentifier): + """ + Strongly-typed projection of a ``Scorer``'s ``ComponentIdentifier``. + + Promotes the ``scorer_type`` discriminator, the ``score_aggregator`` name, and + the scorer's own child slots — ``prompt_target`` (an LLM target) and + ``sub_scorers`` (nested scorers). + """ + + #: The scorer category (e.g., ``"true_false"`` or ``"float_scale"``). + scorer_type: str | None = None + #: Name of the aggregator function combining sub-scores (e.g., ``"AND_"``). + score_aggregator: str | None = None + #: Target an LLM-backed scorer calls (e.g., ``SelfAskScaleScorer``). + prompt_target: TargetIdentifier | None = None + #: Nested scorers a composite wraps, typed recursively. + sub_scorers: list[ScorerIdentifier] = Field(default_factory=list) diff --git a/pyrit/models/identifiers/seed_identifier.py b/pyrit/models/identifiers/seed_identifier.py new file mode 100644 index 0000000000..e3e11af728 --- /dev/null +++ b/pyrit/models/identifiers/seed_identifier.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Strongly-typed projection of a seed's identifier.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pyrit.models.identifiers.component_identifier import ComponentIdentifier +from pyrit.models.literals import PromptDataType # noqa: TC001 (runtime-required by Pydantic field annotations) + +if TYPE_CHECKING: + from pyrit.models.seeds.seed import Seed + + +class SeedIdentifier(ComponentIdentifier): + """ + Strongly-typed projection of a ``Seed``'s ``ComponentIdentifier``. + + Promotes the seed properties that define its identity: the raw value, its + SHA256, the originating dataset, the data type, and whether it is a general + technique. + """ + + #: The seed's raw value. + value: str | None = None + #: SHA256 of the seed value. + value_sha256: str | None = None + #: The seed's data type (e.g. ``"text"``, ``"image_path"``). + data_type: PromptDataType | None = None + #: Name of the dataset the seed came from. + dataset_name: str | None = None + #: Whether the seed represents a general (non-objective-specific) technique. + is_general_technique: bool | None = None + + @classmethod + def from_seed(cls, seed: Seed) -> SeedIdentifier: + """ + Build a SeedIdentifier from a seed's behavioral properties. + + Captures the seed's content hash, dataset name, and class type so that + different seeds produce different identifiers while the same seed content + always produces the same identifier. + + Args: + seed: The seed to build an identifier for. + + Returns: + An identifier capturing the seed's behavioral properties. + """ + return cls.of( + seed, + value=seed.value, + value_sha256=seed.value_sha256, + data_type=seed.data_type, + dataset_name=seed.dataset_name, + is_general_technique=seed.is_general_technique, + ) diff --git a/pyrit/models/identifiers/target_identifier.py b/pyrit/models/identifiers/target_identifier.py new file mode 100644 index 0000000000..970a561b81 --- /dev/null +++ b/pyrit/models/identifiers/target_identifier.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Strongly-typed projection of a target's identifier.""" + +from __future__ import annotations + +from pydantic import Field + +from pyrit.models.identifiers.component_identifier import ComponentIdentifier + + +class TargetIdentifier(ComponentIdentifier): + """ + Strongly-typed projection of a ``PromptTarget``'s ``ComponentIdentifier``. + + Promotes the common target params to typed fields; any other params stay in + ``params``. Capabilities are not part of identity and are not surfaced here. + + Promotes the one child slot a target owns in its own constructor: + ``targets`` (inner targets of a multi-target like ``RoundRobinTarget``), + typed recursively as ``TargetIdentifier``. + """ + + #: Target endpoint URL. + endpoint: str | None = None + #: Model or deployment name used in API calls. + model_name: str | None = None + #: Underlying model name if different (e.g., "gpt-4o"). + underlying_model_name: str | None = None + #: Temperature parameter for generation. + temperature: float | None = None + #: Top-p parameter for generation. + top_p: float | None = None + #: Maximum requests per minute. + max_requests_per_minute: int | None = None + #: Inner targets of a multi-target (e.g., ``RoundRobinTarget``), typed recursively. + targets: list[TargetIdentifier] = Field(default_factory=list) diff --git a/pyrit/models/target_capabilities.py b/pyrit/models/target_capabilities.py new file mode 100644 index 0000000000..2f0269292e --- /dev/null +++ b/pyrit/models/target_capabilities.py @@ -0,0 +1,116 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +``TargetCapabilities`` — a declarative description of what a target supports. + +This is canonical *data* (what modalities and behaviors a target natively +handles), so it lives in ``pyrit.models`` next to the other core models rather +than in the ``pyrit.prompt_target`` package. Handling concerns that depend on +the message-normalization machinery (``CapabilityHandlingPolicy``, +``UnsupportedCapabilityBehavior``) and the known-model capability profiles +(``get_known_capabilities``) stay in +``pyrit.prompt_target.common.target_capabilities``. + +Capabilities describe a target but are deliberately **not** part of identity: +they are not modeled on the typed identifier projections in +``pyrit.models.identifiers``. +""" + +from __future__ import annotations + +from enum import Enum +from typing import cast + +from pydantic import BaseModel, ConfigDict, Field + +from pyrit.models.literals import PromptDataType # noqa: TC001 (runtime-required by Pydantic field annotations) + +# Immutable text-only default shared by input/output modality fields. ``cast`` is used because +# ty infers the literal ``"text"`` as ``str`` (widening), which is not assignable to the invariant +# ``frozenset[frozenset[PromptDataType]]`` element type. +_DEFAULT_TEXT_MODALITIES: frozenset[frozenset[PromptDataType]] = cast( + "frozenset[frozenset[PromptDataType]]", frozenset({frozenset({"text"})}) +) + + +class CapabilityName(str, Enum): + """ + Canonical identifiers for target capabilities. + + This keeps capability identity in one place so policy, requirements, and + normalization code do not duplicate string field names. + """ + + MULTI_TURN = "supports_multi_turn" + MULTI_MESSAGE_PIECES = "supports_multi_message_pieces" + JSON_SCHEMA = "supports_json_schema" + JSON_OUTPUT = "supports_json_output" + EDITABLE_HISTORY = "supports_editable_history" + SYSTEM_PROMPT = "supports_system_prompt" + STREAMING_AUDIO = "supports_streaming_audio" + + +class TargetCapabilities(BaseModel): + """ + Describes the capabilities of a PromptTarget so that attacks + and other components can adapt their behavior accordingly. + + Each target class defines default capabilities via the _DEFAULT_CONFIGURATION + class attribute. Users can override individual capabilities per instance + through constructor parameters, which is useful for targets whose + capabilities depend on deployment configuration (e.g., Playwright, HTTP). + + Immutable (``frozen``) so a single capabilities object can be safely shared + across targets and reused as a known-model profile. + """ + + model_config = ConfigDict(frozen=True) + + #: Whether the target natively supports multi-turn conversations + #: (i.e., it accepts and uses conversation history or maintains state + #: across turns via external mechanisms like WebSocket connections). + supports_multi_turn: bool = False + + #: Whether the target natively supports multiple message pieces in a single request. + supports_multi_message_pieces: bool = False + + #: Whether the target natively supports constraining output to a provided JSON schema. + supports_json_schema: bool = False + + #: Whether the target natively supports JSON output (e.g., via a "json" response + #: format), which ensures the output is valid JSON. + supports_json_output: bool = False + + #: Whether the target allows the attack history to be modified. Implies that the + #: target supports multi-turn interactions and that the attack history is not + #: immutable once set. + supports_editable_history: bool = False + + #: Whether the target natively supports system prompts. + supports_system_prompt: bool = False + + #: Whether the target supports the streaming audio API: opening a long-lived + #: streaming session via ``open_streaming_session`` that pushes user audio chunks, + #: delivers VAD-committed audio to the attack for converter work, swaps committed + #: items in place, and drives manual ``response.create`` turns. Required by + #: ``BargeInAttack``. + supports_streaming_audio: bool = False + + #: The input modalities supported by the target (e.g., "text", "image"). + input_modalities: frozenset[frozenset[PromptDataType]] = Field(default=_DEFAULT_TEXT_MODALITIES) + + #: The output modalities supported by the target (e.g., "text", "image"). + output_modalities: frozenset[frozenset[PromptDataType]] = Field(default=_DEFAULT_TEXT_MODALITIES) + + def includes(self, *, capability: CapabilityName) -> bool: + """ + Return whether this target supports the given capability. + + Args: + capability: The capability to check. + + Returns: + bool: True if supported, otherwise False. + """ + return bool(getattr(self, capability.value)) diff --git a/pyrit/prompt_converter/image_prompt_style_converter.py b/pyrit/prompt_converter/image_prompt_style_converter.py index 25201d009c..6a6e0c0fef 100644 --- a/pyrit/prompt_converter/image_prompt_style_converter.py +++ b/pyrit/prompt_converter/image_prompt_style_converter.py @@ -140,7 +140,7 @@ def _build_identifier(self) -> ComponentIdentifier: "filter_name": self._filter_name, "variation": self._variation, }, - children={"converter_target": self._converter_target.get_identifier()}, + converter_target=self._converter_target.get_identifier(), ) async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: diff --git a/pyrit/prompt_converter/llm_generic_text_converter.py b/pyrit/prompt_converter/llm_generic_text_converter.py index 210e6a5ab9..1adee84967 100644 --- a/pyrit/prompt_converter/llm_generic_text_converter.py +++ b/pyrit/prompt_converter/llm_generic_text_converter.py @@ -129,7 +129,7 @@ def _build_identifier(self) -> ComponentIdentifier: "system_prompt_template_hash": system_prompt_hash, "user_prompt_template_hash": user_prompt_hash, }, - children={"converter_target": self._converter_target.get_identifier()}, + converter_target=self._converter_target.get_identifier(), ) async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: diff --git a/pyrit/prompt_converter/noise_converter.py b/pyrit/prompt_converter/noise_converter.py index 77e418e731..08715e3da3 100644 --- a/pyrit/prompt_converter/noise_converter.py +++ b/pyrit/prompt_converter/noise_converter.py @@ -74,5 +74,5 @@ def _build_identifier(self) -> ComponentIdentifier: "noise": self._noise, "number_errors": self._number_errors, }, - children={"converter_target": self._converter_target.get_identifier()}, + converter_target=self._converter_target.get_identifier(), ) diff --git a/pyrit/prompt_converter/persuasion_converter.py b/pyrit/prompt_converter/persuasion_converter.py index 8695bf32f1..1d52538124 100644 --- a/pyrit/prompt_converter/persuasion_converter.py +++ b/pyrit/prompt_converter/persuasion_converter.py @@ -94,7 +94,7 @@ def _build_identifier(self) -> ComponentIdentifier: params={ "persuasion_technique": self._persuasion_technique, }, - children={"converter_target": self.converter_target.get_identifier()}, + converter_target=self.converter_target.get_identifier(), ) def _process_response(self, response_text: str) -> str: diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index 282f2b01a3..4dc10901ad 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Optional, get_args from pyrit import prompt_converter -from pyrit.models import ComponentIdentifier, Identifiable, PromptDataType +from pyrit.models import ComponentIdentifier, ConverterIdentifier, Identifiable, PromptDataType from pyrit.prompt_target.common.target_requirements import TargetRequirements if TYPE_CHECKING: @@ -202,14 +202,17 @@ def _create_identifier( self, *, params: dict[str, Any] | None = None, - children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, + converter_target: ComponentIdentifier | None = None, + sub_converters: list[ComponentIdentifier] | None = None, ) -> ComponentIdentifier: """ Construct and return the converter identifier. - Builds a ComponentIdentifier with the base converter parameters - (supported_input_types, supported_output_types) and merges in any - additional params or children provided by subclasses. + Builds a ``ConverterIdentifier`` with the base converter params + (supported_input_types, supported_output_types) and the converter's promoted + child slots. The child slots are exposed as explicit named parameters + (mirroring ``ConverterIdentifier``'s promoted fields) so they cannot drift + into untyped ``children`` dicts. Subclasses should call this method in their _build_identifier() implementation to set the identifier with their specific parameters. @@ -217,20 +220,22 @@ def _create_identifier( Args: params (dict[str, Any] | None): Additional behavioral parameters from the subclass (e.g., font, encoding_func). Merged into the base params. - children (dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None): - Named child component identifiers (e.g., sub-converters, converter targets). + converter_target (ComponentIdentifier | None): The target an LLM-backed + converter calls, promoted to ``ConverterIdentifier.converter_target``. + sub_converters (list[ComponentIdentifier] | None): Nested converters a + composite wraps, promoted to ``ConverterIdentifier.sub_converters``. Returns: ComponentIdentifier: The identifier for this converter. """ - all_params: dict[str, Any] = { - "supported_input_types": self.SUPPORTED_INPUT_TYPES, - "supported_output_types": self.SUPPORTED_OUTPUT_TYPES, - } - if params: - all_params.update(params) - - return ComponentIdentifier.of(self, params=all_params, children=children) + return ConverterIdentifier.of( + self, + params=params, + supported_input_types=self.SUPPORTED_INPUT_TYPES, + supported_output_types=self.SUPPORTED_OUTPUT_TYPES, + converter_target=converter_target, + sub_converters=sub_converters, + ) @property def supported_input_types(self) -> list[PromptDataType]: diff --git a/pyrit/prompt_converter/scientific_translation_converter.py b/pyrit/prompt_converter/scientific_translation_converter.py index 88f688f166..3ac567e792 100644 --- a/pyrit/prompt_converter/scientific_translation_converter.py +++ b/pyrit/prompt_converter/scientific_translation_converter.py @@ -102,5 +102,5 @@ def _build_identifier(self) -> ComponentIdentifier: params={ "mode": self._mode, }, - children={"converter_target": self._converter_target.get_identifier()}, + converter_target=self._converter_target.get_identifier(), ) diff --git a/pyrit/prompt_converter/selective_text_converter.py b/pyrit/prompt_converter/selective_text_converter.py index 3899500566..e056fb9663 100644 --- a/pyrit/prompt_converter/selective_text_converter.py +++ b/pyrit/prompt_converter/selective_text_converter.py @@ -103,9 +103,7 @@ def _build_identifier(self) -> ComponentIdentifier: "start_token": self._start_token, "end_token": self._end_token, }, - children={ - "sub_converters": [self._converter.get_identifier()], - }, + sub_converters=[self._converter.get_identifier()], ) def _validate_converter( diff --git a/pyrit/prompt_converter/tense_converter.py b/pyrit/prompt_converter/tense_converter.py index 8f0852b2c2..c29268c1f0 100644 --- a/pyrit/prompt_converter/tense_converter.py +++ b/pyrit/prompt_converter/tense_converter.py @@ -62,5 +62,5 @@ def _build_identifier(self) -> ComponentIdentifier: params={ "tense": self._tense, }, - children={"converter_target": self._converter_target.get_identifier()}, + converter_target=self._converter_target.get_identifier(), ) diff --git a/pyrit/prompt_converter/tone_converter.py b/pyrit/prompt_converter/tone_converter.py index 562a4ee6af..9372f20657 100644 --- a/pyrit/prompt_converter/tone_converter.py +++ b/pyrit/prompt_converter/tone_converter.py @@ -65,5 +65,5 @@ def _build_identifier(self) -> ComponentIdentifier: params={ "tone": self._tone, }, - children={"converter_target": self._converter_target.get_identifier()}, + converter_target=self._converter_target.get_identifier(), ) diff --git a/pyrit/prompt_converter/translation_converter.py b/pyrit/prompt_converter/translation_converter.py index ec40f01e35..1da38f2c47 100644 --- a/pyrit/prompt_converter/translation_converter.py +++ b/pyrit/prompt_converter/translation_converter.py @@ -83,7 +83,7 @@ def _build_identifier(self) -> ComponentIdentifier: params={ "language": self.language, }, - children={"converter_target": self.converter_target.get_identifier()}, + converter_target=self.converter_target.get_identifier(), ) def _process_response(self, response_text: str) -> str: diff --git a/pyrit/prompt_converter/variation_converter.py b/pyrit/prompt_converter/variation_converter.py index 4e05f962b9..9f31130cb9 100644 --- a/pyrit/prompt_converter/variation_converter.py +++ b/pyrit/prompt_converter/variation_converter.py @@ -80,7 +80,7 @@ def _build_identifier(self) -> ComponentIdentifier: ComponentIdentifier: The identifier for this converter. """ return self._create_identifier( - children={"converter_target": self.converter_target.get_identifier()}, + converter_target=self.converter_target.get_identifier(), ) def _process_response(self, response_text: str) -> str: diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index 366b3f2f7f..e27cecebe1 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -25,6 +25,7 @@ CapabilityName, TargetCapabilities, UnsupportedCapabilityBehavior, + get_known_capabilities, ) from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.target_requirements import CHAT_TARGET_REQUIREMENTS, TargetRequirements @@ -109,5 +110,6 @@ def __getattr__(name: str) -> object: "UnsupportedCapabilityBehavior", "TextTarget", "discover_target_capabilities_async", + "get_known_capabilities", "WebSocketCopilotTarget", ] diff --git a/pyrit/prompt_target/common/discover_target_capabilities.py b/pyrit/prompt_target/common/discover_target_capabilities.py index 3712f97000..b47a90a3c7 100644 --- a/pyrit/prompt_target/common/discover_target_capabilities.py +++ b/pyrit/prompt_target/common/discover_target_capabilities.py @@ -41,7 +41,6 @@ import uuid from collections.abc import Awaitable, Callable, Iterable, Iterator from contextlib import contextmanager -from dataclasses import replace from pathlib import Path from pyrit.common.path import DATASETS_PATH @@ -147,8 +146,7 @@ def _permissive_configuration( merged_modalities = original.capabilities.input_modalities | _TEXT_MODALITY if extra_input_modalities is not None: merged_modalities = frozenset(merged_modalities | frozenset(extra_input_modalities)) - permissive_caps = replace( - original.capabilities, + permissive_caps = TargetCapabilities( supports_multi_turn=True, supports_multi_message_pieces=True, supports_json_schema=True, @@ -157,6 +155,7 @@ def _permissive_configuration( supports_system_prompt=True, supports_streaming_audio=True, input_modalities=merged_modalities, + output_modalities=original.capabilities.output_modalities, ) # Rebuild a fresh configuration from the instance's native capabilities so # probes bypass preflight validation without inheriting ADAPT policy or diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 4bef9e3a26..65000a6139 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -7,9 +7,13 @@ from pyrit.common.deprecation import print_deprecation_message from pyrit.memory import CentralMemory, MemoryInterface -from pyrit.models import ComponentIdentifier, Conversation, Identifiable, Message, MessagePiece +from pyrit.models import ComponentIdentifier, Conversation, Identifiable, Message, MessagePiece, TargetIdentifier from pyrit.prompt_target.common.json_response_config import _JsonResponseConfig -from pyrit.prompt_target.common.target_capabilities import CapabilityName, TargetCapabilities +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityName, + TargetCapabilities, + get_known_capabilities, +) from pyrit.prompt_target.common.target_configuration import TargetConfiguration logger = logging.getLogger(__name__) @@ -365,14 +369,16 @@ def _create_identifier( self, *, params: dict[str, Any] | None = None, - children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, + targets: list[ComponentIdentifier] | None = None, ) -> ComponentIdentifier: """ Construct the target identifier. - Builds a ComponentIdentifier with the base target parameters (endpoint, - model_name, max_requests_per_minute) and merges in any additional params - or children provided by subclasses. + Builds a ``TargetIdentifier`` with the base target params (endpoint, + model_name, max_requests_per_minute) and the target's promoted child slot. + The child slot is exposed as an explicit named parameter (mirroring + ``TargetIdentifier``'s promoted field) so it cannot drift into an untyped + ``children`` dict. Subclasses should call this method in their _build_identifier() implementation to set the identifier with their specific parameters. @@ -380,23 +386,22 @@ def _create_identifier( Args: params (dict[str, Any] | None): Additional behavioral parameters from the subclass (e.g., temperature, top_p). Merged into the base params. - children (dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None): - Named child component identifiers. + targets (list[ComponentIdentifier] | None): Inner targets of a + multi-target (e.g., ``RoundRobinTarget``), promoted to + ``TargetIdentifier.targets``. Returns: ComponentIdentifier: The identifier for this prompt target. """ - all_params: dict[str, Any] = { - "endpoint": self._endpoint, - "model_name": self._model_name or "", - "underlying_model_name": self._underlying_model or "", - "max_requests_per_minute": self._max_requests_per_minute, - "target_configuration": self.configuration.as_identifier_params(), - } - if params: - all_params.update(params) - - return ComponentIdentifier.of(self, params=all_params, children=children) + return TargetIdentifier.of( + self, + params=params, + endpoint=self._endpoint, + model_name=self._model_name or "", + underlying_model_name=self._underlying_model or "", + max_requests_per_minute=self._max_requests_per_minute, + targets=targets, + ) @property def configuration(self) -> TargetConfiguration: @@ -463,7 +468,7 @@ def get_default_configuration(cls, underlying_model: str | None = None) -> Targe ``_DEFAULT_CONFIGURATION`` if the model is unrecognized or not provided. """ if underlying_model: - known = TargetCapabilities.get_known_capabilities(underlying_model) + known = get_known_capabilities(underlying_model) if known is not None: return TargetConfiguration(capabilities=known) logger.info( diff --git a/pyrit/prompt_target/common/target_capabilities.py b/pyrit/prompt_target/common/target_capabilities.py index c030b8b053..10c1ec0ce2 100644 --- a/pyrit/prompt_target/common/target_capabilities.py +++ b/pyrit/prompt_target/common/target_capabilities.py @@ -7,24 +7,16 @@ from types import MappingProxyType from typing import NoReturn, cast -from pyrit.models import PromptDataType +from pyrit.models.literals import PromptDataType +from pyrit.models.target_capabilities import CapabilityName, TargetCapabilities - -class CapabilityName(str, Enum): - """ - Canonical identifiers for target capabilities. - - This keeps capability identity in one place so policy, requirements, and - normalization code do not duplicate string field names. - """ - - MULTI_TURN = "supports_multi_turn" - MULTI_MESSAGE_PIECES = "supports_multi_message_pieces" - JSON_SCHEMA = "supports_json_schema" - JSON_OUTPUT = "supports_json_output" - EDITABLE_HISTORY = "supports_editable_history" - SYSTEM_PROMPT = "supports_system_prompt" - STREAMING_AUDIO = "supports_streaming_audio" +__all__ = [ + "CapabilityHandlingPolicy", + "CapabilityName", + "TargetCapabilities", + "UnsupportedCapabilityBehavior", + "get_known_capabilities", +] class UnsupportedCapabilityBehavior(str, Enum): @@ -106,80 +98,6 @@ def __post_init__(self) -> None: object.__setattr__(self, "behaviors", MappingProxyType(dict(self.behaviors))) -@dataclass(frozen=True) -class TargetCapabilities: - """ - Describes the capabilities of a PromptTarget so that attacks - and other components can adapt their behavior accordingly. - - Each target class defines default capabilities via the _DEFAULT_CONFIGURATION - class attribute. Users can override individual capabilities per instance - through constructor parameters, which is useful for targets whose - capabilities depend on deployment configuration (e.g., Playwright, HTTP). - """ - - # Whether the target natively supports multi-turn conversations - # (i.e., it accepts and uses conversation history or maintains state - # across turns via external mechanisms like WebSocket connections). - supports_multi_turn: bool = False - - # Whether the target natively supports multiple message pieces in a single request. - supports_multi_message_pieces: bool = False - - # Whether the target natively supports constraining output to a provided JSON schema. - supports_json_schema: bool = False - - # Whether the target natively supports JSON output (e.g., via a "json" response format), which ensures the output - # is valid JSON. - supports_json_output: bool = False - - # Whether the target allows the attack history to be modified. Implies that the target supports - # multi-turn interactions and that the attack history is not immutable once set. - supports_editable_history: bool = False - - # Whether the target natively supports system prompts. - supports_system_prompt: bool = False - - # Whether the target supports the streaming audio API: opening a long-lived - # streaming session via ``open_streaming_session`` that pushes user audio chunks, - # delivers VAD-committed audio to the attack for converter work, swaps committed - # items in place, and drives manual ``response.create`` turns. Required by - # ``BargeInAttack``. - supports_streaming_audio: bool = False - - # The input modalities supported by the target (e.g., "text", "image"). - input_modalities: frozenset[frozenset[PromptDataType]] = frozenset({frozenset(["text"])}) - - # The output modalities supported by the target (e.g., "text", "image"). - output_modalities: frozenset[frozenset[PromptDataType]] = frozenset({frozenset(["text"])}) - - def includes(self, *, capability: CapabilityName) -> bool: - """ - Return whether this target supports the given capability. - - Args: - capability: The capability to check. - - Returns: - bool: True if supported, otherwise False. - """ - return bool(getattr(self, capability.value)) - - @staticmethod - def get_known_capabilities(underlying_model: str) -> "TargetCapabilities | None": - """ - Return the known capabilities for a specific underlying model, or None if unrecognized. - - Args: - underlying_model (str): The underlying model name (e.g., "gpt-4o"). - - Returns: - TargetCapabilities | None: The known capabilities for the model, or None if the model - is not recognized. - """ - return _KNOWN_CAPABILITIES.get(underlying_model) - - # --------------------------------------------------------------------------- # Known capability profiles — add new models here. # Shared profiles are defined once and referenced by multiple model names. @@ -259,3 +177,17 @@ def get_known_capabilities(underlying_model: str) -> "TargetCapabilities | None" "tts": _TTS, "sora-2": _SORA_2, } + + +def get_known_capabilities(underlying_model: str) -> TargetCapabilities | None: + """ + Return the known capabilities for a specific underlying model, or None if unrecognized. + + Args: + underlying_model (str): The underlying model name (e.g., "gpt-4o"). + + Returns: + TargetCapabilities | None: The known capabilities for the model, or None if the model + is not recognized. + """ + return _KNOWN_CAPABILITIES.get(underlying_model) diff --git a/pyrit/prompt_target/common/target_configuration.py b/pyrit/prompt_target/common/target_configuration.py index 7e11a04673..6613b9dbb0 100644 --- a/pyrit/prompt_target/common/target_configuration.py +++ b/pyrit/prompt_target/common/target_configuration.py @@ -3,7 +3,6 @@ import logging from collections.abc import Mapping -from dataclasses import fields from typing import Any from pyrit.message_normalizer import MessageListNormalizer @@ -177,7 +176,7 @@ def _capabilities_to_identifier_params(capabilities: TargetCapabilities) -> dict Project a ``TargetCapabilities`` instance into a deterministic dict suitable for inclusion in a ``ComponentIdentifier``. - Fields are discovered dynamically via ``dataclasses.fields`` so new + Fields are discovered dynamically via the pydantic model fields so new capability fields are picked up automatically. Set-valued fields (e.g., the modality frozensets) are detected by type and normalized to sorted lists of sorted lists; all other fields are passed through as-is. @@ -189,15 +188,15 @@ def _capabilities_to_identifier_params(capabilities: TargetCapabilities) -> dict dict[str, Any]: Field-name to serialized-value mapping. """ params: dict[str, Any] = {} - for dataclass_field in fields(capabilities): - value = getattr(capabilities, dataclass_field.name) + for field_name in type(capabilities).model_fields: + value = getattr(capabilities, field_name) # Normalize set-valued fields (e.g., modality frozensets) to a # deterministic representation. Handles both frozenset[frozenset[...]] # (modality combinations) and plain frozensets. if isinstance(value, (frozenset, set)): - params[dataclass_field.name] = sorted( + params[field_name] = sorted( sorted(item) if isinstance(item, (frozenset, set)) else item for item in value ) else: - params[dataclass_field.name] = value + params[field_name] = value return params diff --git a/pyrit/prompt_target/round_robin_target.py b/pyrit/prompt_target/round_robin_target.py index 584e963c41..3086601d72 100644 --- a/pyrit/prompt_target/round_robin_target.py +++ b/pyrit/prompt_target/round_robin_target.py @@ -190,7 +190,7 @@ def _build_identifier(self) -> ComponentIdentifier: """ return self._create_identifier( params={"weights": self._weights}, - children={"targets": [t.get_identifier() for t in self._targets]}, + targets=[t.get_identifier() for t in self._targets], ) diff --git a/pyrit/scenario/core/atomic_attack.py b/pyrit/scenario/core/atomic_attack.py index 6a48ec69ef..c35c521da6 100644 --- a/pyrit/scenario/core/atomic_attack.py +++ b/pyrit/scenario/core/atomic_attack.py @@ -23,7 +23,7 @@ from pyrit.executor.attack.core.attack_result_attribution import AttackResultAttribution from pyrit.memory import CentralMemory from pyrit.memory.memory_models import MAX_IDENTIFIER_VALUE_LENGTH -from pyrit.models import AtomicAttackEvaluationIdentifier, AttackResult, SeedAttackGroup, build_atomic_attack_identifier +from pyrit.models import AtomicAttackEvaluationIdentifier, AtomicAttackIdentifier, AttackResult, SeedAttackGroup from pyrit.scenario.core.attack_technique import AttackTechnique if TYPE_CHECKING: @@ -203,7 +203,7 @@ def technique_eval_hash(self) -> str: which is what makes it usable as the resume disambiguator alongside ``atomic_attack_name``. """ - composite = build_atomic_attack_identifier( + composite = AtomicAttackIdentifier.build( technique_identifier=self._attack_technique.get_identifier(), seed_group=None, ) @@ -433,7 +433,7 @@ def _enrich_atomic_attack_identifiers(self, *, results: AttackExecutorResult[Att for result, idx in zip(results.completed_results, results.input_indices, strict=True): if idx < len(self._seed_groups): - identifier = build_atomic_attack_identifier( + identifier = AtomicAttackIdentifier.build( technique_identifier=self._attack_technique.get_identifier(), seed_group=self._seed_groups[idx], ) diff --git a/pyrit/scenario/core/attack_technique.py b/pyrit/scenario/core/attack_technique.py index 019335a5da..58525242dc 100644 --- a/pyrit/scenario/core/attack_technique.py +++ b/pyrit/scenario/core/attack_technique.py @@ -11,7 +11,12 @@ from typing import TYPE_CHECKING, Any -from pyrit.models import ComponentIdentifier, Identifiable, build_seed_identifier +from pyrit.models import ( + AttackTechniqueIdentifier, + ComponentIdentifier, + Identifiable, + SeedIdentifier, +) if TYPE_CHECKING: from pyrit.executor.attack import AttackStrategy @@ -59,13 +64,14 @@ def _build_identifier(self) -> ComponentIdentifier: Returns: ComponentIdentifier: The frozen identity snapshot. """ - children: dict[str, Any] = { - "attack": self._attack.get_identifier(), - } - + technique_seeds: list[SeedIdentifier] | None = None if self._seed_technique is not None: - technique_seed_ids = [build_seed_identifier(seed) for seed in self._seed_technique.seeds] + technique_seed_ids = [SeedIdentifier.from_seed(seed) for seed in self._seed_technique.seeds] if technique_seed_ids: - children["technique_seeds"] = technique_seed_ids + technique_seeds = list(technique_seed_ids) - return ComponentIdentifier.of(self, children=children) + return AttackTechniqueIdentifier.of( + self, + attack=self._attack.get_identifier(), + technique_seeds=technique_seeds, + ) diff --git a/pyrit/scenario/core/attack_technique_factory.py b/pyrit/scenario/core/attack_technique_factory.py index 000e296e89..cd9da5a3c8 100644 --- a/pyrit/scenario/core/attack_technique_factory.py +++ b/pyrit/scenario/core/attack_technique_factory.py @@ -37,9 +37,9 @@ ComponentIdentifier, Identifiable, SeedAttackTechniqueGroup, + SeedIdentifier, SeedPrompt, SeedSimulatedConversation, - build_seed_identifier, ) from pyrit.models.seeds.seed_simulated_conversation import NextMessageSystemPromptPaths from pyrit.scenario.core.attack_technique import AttackTechnique @@ -758,7 +758,7 @@ def _build_identifier(self) -> ComponentIdentifier: children: dict[str, Any] = {} if self._seed_technique is not None: - technique_seed_ids = [build_seed_identifier(seed) for seed in self._seed_technique.seeds] + technique_seed_ids = [SeedIdentifier.from_seed(seed) for seed in self._seed_technique.seeds] if technique_seed_ids: children["technique_seeds"] = technique_seed_ids diff --git a/pyrit/score/conversation_scorer.py b/pyrit/score/conversation_scorer.py index e22b563096..0e377acd4f 100644 --- a/pyrit/score/conversation_scorer.py +++ b/pyrit/score/conversation_scorer.py @@ -219,9 +219,7 @@ def _build_identifier(self) -> ComponentIdentifier: ComponentIdentifier: The identifier for this scorer. """ return self._create_identifier( - children={ - "sub_scorers": [self._wrapped_scorer.get_identifier()], - }, + sub_scorers=[self._wrapped_scorer.get_identifier()], ) return DynamicConversationScorer() diff --git a/pyrit/score/float_scale/audio_float_scale_scorer.py b/pyrit/score/float_scale/audio_float_scale_scorer.py index d2c216050e..183c1379b0 100644 --- a/pyrit/score/float_scale/audio_float_scale_scorer.py +++ b/pyrit/score/float_scale/audio_float_scale_scorer.py @@ -48,9 +48,7 @@ def _build_identifier(self) -> ComponentIdentifier: ComponentIdentifier: The identifier for this scorer. """ return self._create_identifier( - children={ - "sub_scorers": [self._audio_helper.text_scorer.get_identifier()], - }, + sub_scorers=[self._audio_helper.text_scorer.get_identifier()], ) async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: diff --git a/pyrit/score/float_scale/insecure_code_scorer.py b/pyrit/score/float_scale/insecure_code_scorer.py index f46b635110..5340be8fd2 100644 --- a/pyrit/score/float_scale/insecure_code_scorer.py +++ b/pyrit/score/float_scale/insecure_code_scorer.py @@ -66,9 +66,7 @@ def _build_identifier(self) -> ComponentIdentifier: params={ "system_prompt_template": self._system_prompt, }, - children={ - "prompt_target": self._prompt_target.get_identifier(), - }, + prompt_target=self._prompt_target.get_identifier(), ) async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: diff --git a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py index 9631f944a9..f99d48c0f7 100644 --- a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py @@ -107,9 +107,7 @@ def _build_identifier(self) -> ComponentIdentifier: "min_value": self._min_value, "max_value": self._max_value, }, - children={ - "prompt_target": self._prompt_target.get_identifier(), - }, + prompt_target=self._prompt_target.get_identifier(), ) async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: diff --git a/pyrit/score/float_scale/self_ask_likert_scorer.py b/pyrit/score/float_scale/self_ask_likert_scorer.py index 750a86e7c6..a82157c7ab 100644 --- a/pyrit/score/float_scale/self_ask_likert_scorer.py +++ b/pyrit/score/float_scale/self_ask_likert_scorer.py @@ -244,9 +244,7 @@ def _build_identifier(self) -> ComponentIdentifier: params={ "system_prompt_template": self._system_prompt, }, - children={ - "prompt_target": self._prompt_target.get_identifier(), - }, + prompt_target=self._prompt_target.get_identifier(), ) def _set_likert_scale_system_prompt(self, likert_scale_path: Path) -> None: diff --git a/pyrit/score/float_scale/self_ask_scale_scorer.py b/pyrit/score/float_scale/self_ask_scale_scorer.py index 87e8e73b51..c1e2305656 100644 --- a/pyrit/score/float_scale/self_ask_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_scale_scorer.py @@ -96,9 +96,7 @@ def _build_identifier(self) -> ComponentIdentifier: "system_prompt_template": self._system_prompt, "user_prompt_template": "objective: {objective}\nresponse: {response}", }, - children={ - "prompt_target": self._prompt_target.get_identifier(), - }, + prompt_target=self._prompt_target.get_identifier(), ) async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: diff --git a/pyrit/score/float_scale/video_float_scale_scorer.py b/pyrit/score/float_scale/video_float_scale_scorer.py index 8e32bd9064..b2a2c62d83 100644 --- a/pyrit/score/float_scale/video_float_scale_scorer.py +++ b/pyrit/score/float_scale/video_float_scale_scorer.py @@ -105,15 +105,13 @@ def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier( params={ - "score_aggregator": self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] "num_sampled_frames": self._video_helper.num_sampled_frames, "has_audio_scorer": self.audio_scorer is not None, "image_objective_template": self._video_helper.image_objective_template, "audio_objective_template": self._video_helper.audio_objective_template, }, - children={ - "sub_scorers": sub_scorer_ids, - }, + score_aggregator=self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] + sub_scorers=sub_scorer_ids, ) async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index ee72a2970d..ee0dc4a9fc 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -34,6 +34,7 @@ PromptDataType, Score, ScorerEvaluationIdentifier, + ScorerIdentifier, ScoreType, UnvalidatedScore, ) @@ -166,34 +167,43 @@ def _create_identifier( self, *, params: dict[str, Any] | None = None, - children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, + score_aggregator: str | None = None, + prompt_target: ComponentIdentifier | None = None, + sub_scorers: list[ComponentIdentifier] | None = None, ) -> ComponentIdentifier: """ Construct the scorer identifier. - Builds a ComponentIdentifier with the base scorer parameters (scorer_type) - and merges in any additional params or children provided by subclasses. + Builds a ``ScorerIdentifier`` with the base scorer ``scorer_type`` and + the scorer's promoted params/child slots. The promoted fields are exposed + as explicit named parameters (mirroring ``ScorerIdentifier``'s fields) so + they cannot drift into untyped ``params`` / ``children`` dicts. Subclasses should call this method in their _build_identifier() implementation to set the identifier with their specific parameters. Args: params (dict[str, Any] | None): Additional behavioral parameters from - the subclass (e.g., system_prompt_template, score_aggregator). Merged - into the base params. - children (dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None): - Named child component identifiers (e.g., prompt_target, sub_scorers). + the subclass (e.g., system_prompt_template, threshold). Merged into + the base params. + score_aggregator (str | None): Name of the aggregator function that + combines sub-scores, promoted to ``ScorerIdentifier.score_aggregator``. + prompt_target (ComponentIdentifier | None): The target an LLM-backed + scorer calls, promoted to ``ScorerIdentifier.prompt_target``. + sub_scorers (list[ComponentIdentifier] | None): Nested scorers a + composite wraps, promoted to ``ScorerIdentifier.sub_scorers``. Returns: ComponentIdentifier: The identifier for this scorer. """ - all_params: dict[str, Any] = { - "scorer_type": self.scorer_type, - } - if params: - all_params.update(params) - - return ComponentIdentifier.of(self, params=all_params, children=children) + return ScorerIdentifier.of( + self, + params=params, + scorer_type=self.scorer_type, + score_aggregator=score_aggregator, + prompt_target=prompt_target, + sub_scorers=sub_scorers, + ) async def score_async( self, diff --git a/pyrit/score/true_false/audio_true_false_scorer.py b/pyrit/score/true_false/audio_true_false_scorer.py index c9caa023f3..eb291e6597 100644 --- a/pyrit/score/true_false/audio_true_false_scorer.py +++ b/pyrit/score/true_false/audio_true_false_scorer.py @@ -48,9 +48,7 @@ def _build_identifier(self) -> ComponentIdentifier: ComponentIdentifier: The identifier for this scorer. """ return self._create_identifier( - children={ - "sub_scorers": [self._audio_helper.text_scorer.get_identifier()], - }, + sub_scorers=[self._audio_helper.text_scorer.get_identifier()], ) async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: diff --git a/pyrit/score/true_false/decoding_scorer.py b/pyrit/score/true_false/decoding_scorer.py index ec17af03fc..d6d936c19d 100644 --- a/pyrit/score/true_false/decoding_scorer.py +++ b/pyrit/score/true_false/decoding_scorer.py @@ -59,9 +59,9 @@ def _build_identifier(self) -> ComponentIdentifier: """ return self._create_identifier( params={ - "score_aggregator": self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] "text_matcher": self._text_matcher.__class__.__name__, }, + score_aggregator=self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] ) async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: diff --git a/pyrit/score/true_false/float_scale_threshold_scorer.py b/pyrit/score/true_false/float_scale_threshold_scorer.py index 828b98a9dd..c329c80688 100644 --- a/pyrit/score/true_false/float_scale_threshold_scorer.py +++ b/pyrit/score/true_false/float_scale_threshold_scorer.py @@ -66,13 +66,11 @@ def _build_identifier(self) -> ComponentIdentifier: """ return self._create_identifier( params={ - "score_aggregator": self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] "threshold": self._threshold, "float_scale_aggregator": self._float_scale_aggregator.__name__, # type: ignore[ty:unresolved-attribute] }, - children={ - "sub_scorers": [self._scorer.get_identifier()], - }, + score_aggregator=self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] + sub_scorers=[self._scorer.get_identifier()], ) def get_chat_target(self) -> Optional["PromptTarget"]: diff --git a/pyrit/score/true_false/gandalf_scorer.py b/pyrit/score/true_false/gandalf_scorer.py index 1064aaa28e..bb1340839e 100644 --- a/pyrit/score/true_false/gandalf_scorer.py +++ b/pyrit/score/true_false/gandalf_scorer.py @@ -67,12 +67,8 @@ def _build_identifier(self) -> ComponentIdentifier: ComponentIdentifier: The identifier for this scorer. """ return self._create_identifier( - params={ - "score_aggregator": self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] - }, - children={ - "prompt_target": self._prompt_target.get_identifier(), - }, + score_aggregator=self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] + prompt_target=self._prompt_target.get_identifier(), ) @pyrit_target_retry diff --git a/pyrit/score/true_false/markdown_injection.py b/pyrit/score/true_false/markdown_injection.py index 33a678469f..3bf773cf15 100644 --- a/pyrit/score/true_false/markdown_injection.py +++ b/pyrit/score/true_false/markdown_injection.py @@ -49,9 +49,7 @@ def _build_identifier(self) -> ComponentIdentifier: ComponentIdentifier: The identifier for this scorer. """ return self._create_identifier( - params={ - "score_aggregator": self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] - }, + score_aggregator=self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] ) async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: diff --git a/pyrit/score/true_false/prompt_shield_scorer.py b/pyrit/score/true_false/prompt_shield_scorer.py index d6ac555610..85bc32c183 100644 --- a/pyrit/score/true_false/prompt_shield_scorer.py +++ b/pyrit/score/true_false/prompt_shield_scorer.py @@ -56,12 +56,8 @@ def _build_identifier(self) -> ComponentIdentifier: ComponentIdentifier: The identifier for this scorer. """ return self._create_identifier( - params={ - "score_aggregator": self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] - }, - children={ - "prompt_target": self._prompt_target.get_identifier(), - }, + score_aggregator=self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] + prompt_target=self._prompt_target.get_identifier(), ) async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: diff --git a/pyrit/score/true_false/question_answer_scorer.py b/pyrit/score/true_false/question_answer_scorer.py index 46e7b94527..1f2604ee35 100644 --- a/pyrit/score/true_false/question_answer_scorer.py +++ b/pyrit/score/true_false/question_answer_scorer.py @@ -62,9 +62,9 @@ def _build_identifier(self) -> ComponentIdentifier: """ return self._create_identifier( params={ - "score_aggregator": self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] "correct_answer_matching_patterns": self._correct_answer_matching_patterns, }, + score_aggregator=self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] ) async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: diff --git a/pyrit/score/true_false/regex/regex_scorer.py b/pyrit/score/true_false/regex/regex_scorer.py index f0f6e47f98..f01ce2296d 100644 --- a/pyrit/score/true_false/regex/regex_scorer.py +++ b/pyrit/score/true_false/regex/regex_scorer.py @@ -47,7 +47,9 @@ def __init__( raise ValueError("patterns must be a non-empty dict") self._patterns = dict(patterns) - self._compiled: dict[str, re.Pattern] = {name: re.compile(pattern) for name, pattern in self._patterns.items()} + self._compiled: dict[str, re.Pattern[str]] = { + name: re.compile(pattern) for name, pattern in self._patterns.items() + } self._score_categories = categories or [] super().__init__(validator=validator or self._DEFAULT_VALIDATOR, score_aggregator=score_aggregator) @@ -61,9 +63,9 @@ def _build_identifier(self) -> ComponentIdentifier: """ return self._create_identifier( params={ - "score_aggregator": self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] "pattern_count": len(self._patterns), }, + score_aggregator=self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] ) async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: diff --git a/pyrit/score/true_false/self_ask_category_scorer.py b/pyrit/score/true_false/self_ask_category_scorer.py index 9c526deff4..e84b7d3c00 100644 --- a/pyrit/score/true_false/self_ask_category_scorer.py +++ b/pyrit/score/true_false/self_ask_category_scorer.py @@ -91,11 +91,9 @@ def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier( params={ "system_prompt_template": self._system_prompt, - "score_aggregator": self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] - }, - children={ - "prompt_target": self._prompt_target.get_identifier(), }, + score_aggregator=self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] + prompt_target=self._prompt_target.get_identifier(), ) def _content_classifier_to_string(self, categories: list[dict[str, str]]) -> str: diff --git a/pyrit/score/true_false/self_ask_general_true_false_scorer.py b/pyrit/score/true_false/self_ask_general_true_false_scorer.py index f706efbcbe..8a1ff56d40 100644 --- a/pyrit/score/true_false/self_ask_general_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_general_true_false_scorer.py @@ -105,11 +105,9 @@ def _build_identifier(self) -> ComponentIdentifier: params={ "system_prompt_template": self._system_prompt_format_string, "user_prompt_template": self._prompt_format_string, - "score_aggregator": self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] - }, - children={ - "prompt_target": self._prompt_target.get_identifier(), }, + score_aggregator=self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] + prompt_target=self._prompt_target.get_identifier(), ) async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: diff --git a/pyrit/score/true_false/self_ask_refusal_scorer.py b/pyrit/score/true_false/self_ask_refusal_scorer.py index 04cc555320..d447b0275b 100644 --- a/pyrit/score/true_false/self_ask_refusal_scorer.py +++ b/pyrit/score/true_false/self_ask_refusal_scorer.py @@ -136,11 +136,9 @@ def _build_identifier(self) -> ComponentIdentifier: "system_prompt_template": self._system_prompt, "user_prompt_template": self._prompt_format_string, "response_json_schema": self._response_json_schema, - "score_aggregator": self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] - }, - children={ - "prompt_target": self._prompt_target.get_identifier(), }, + score_aggregator=self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] + prompt_target=self._prompt_target.get_identifier(), ) async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: diff --git a/pyrit/score/true_false/self_ask_true_false_scorer.py b/pyrit/score/true_false/self_ask_true_false_scorer.py index e526e84f94..27228ae69c 100644 --- a/pyrit/score/true_false/self_ask_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_true_false_scorer.py @@ -188,11 +188,9 @@ def _build_identifier(self) -> ComponentIdentifier: params={ "system_prompt_template": self._system_prompt, "user_prompt_template": "objective: {objective}\nresponse: {response}", - "score_aggregator": self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] - }, - children={ - "prompt_target": self._prompt_target.get_identifier(), }, + score_aggregator=self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] + prompt_target=self._prompt_target.get_identifier(), ) async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: diff --git a/pyrit/score/true_false/substring_scorer.py b/pyrit/score/true_false/substring_scorer.py index 4429930e50..4d733bff05 100644 --- a/pyrit/score/true_false/substring_scorer.py +++ b/pyrit/score/true_false/substring_scorer.py @@ -58,10 +58,10 @@ def _build_identifier(self) -> ComponentIdentifier: """ return self._create_identifier( params={ - "score_aggregator": self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] "substring": self._substring, "text_matcher": self._text_matcher.__class__.__name__, }, + score_aggregator=self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] ) async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: diff --git a/pyrit/score/true_false/true_false_composite_scorer.py b/pyrit/score/true_false/true_false_composite_scorer.py index 0fece73d64..fd3b6fa53e 100644 --- a/pyrit/score/true_false/true_false_composite_scorer.py +++ b/pyrit/score/true_false/true_false_composite_scorer.py @@ -63,12 +63,8 @@ def _build_identifier(self) -> ComponentIdentifier: ComponentIdentifier: The identifier for this scorer. """ return self._create_identifier( - params={ - "score_aggregator": self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] - }, - children={ - "sub_scorers": [s.get_identifier() for s in self._scorers], - }, + score_aggregator=self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] + sub_scorers=[s.get_identifier() for s in self._scorers], ) def get_chat_target(self) -> Optional["PromptTarget"]: diff --git a/pyrit/score/true_false/true_false_inverter_scorer.py b/pyrit/score/true_false/true_false_inverter_scorer.py index c3b894edda..e98b8791fb 100644 --- a/pyrit/score/true_false/true_false_inverter_scorer.py +++ b/pyrit/score/true_false/true_false_inverter_scorer.py @@ -41,12 +41,8 @@ def _build_identifier(self) -> ComponentIdentifier: ComponentIdentifier: The identifier for this scorer. """ return self._create_identifier( - params={ - "score_aggregator": self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] - }, - children={ - "sub_scorers": [self._scorer.get_identifier()], - }, + score_aggregator=self._score_aggregator.__name__, # type: ignore[ty:unresolved-attribute] + sub_scorers=[self._scorer.get_identifier()], ) def get_chat_target(self) -> Optional["PromptTarget"]: diff --git a/pyrit/score/true_false/video_true_false_scorer.py b/pyrit/score/true_false/video_true_false_scorer.py index 5c45eae477..752532a984 100644 --- a/pyrit/score/true_false/video_true_false_scorer.py +++ b/pyrit/score/true_false/video_true_false_scorer.py @@ -88,9 +88,7 @@ def _build_identifier(self) -> ComponentIdentifier: "image_objective_template": self._video_helper.image_objective_template, "audio_objective_template": self._video_helper.audio_objective_template, }, - children={ - "sub_scorers": sub_scorer_ids, - }, + sub_scorers=sub_scorer_ids, ) async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: diff --git a/tests/integration/memory/test_azure_sql_memory_integration.py b/tests/integration/memory/test_azure_sql_memory_integration.py index 7a35452e98..c80b601f40 100644 --- a/tests/integration/memory/test_azure_sql_memory_integration.py +++ b/tests/integration/memory/test_azure_sql_memory_integration.py @@ -16,6 +16,7 @@ ScenarioResultEntry, ) from pyrit.models import ( + AtomicAttackIdentifier, AttackOutcome, AttackResult, ComponentIdentifier, @@ -23,7 +24,6 @@ ScenarioIdentifier, ScenarioResult, SeedPrompt, - build_atomic_attack_identifier, ) @@ -48,7 +48,7 @@ def get_test_atomic_attack_identifier() -> ComponentIdentifier: class_name="TestAttack", class_module="tests.integration.memory.test_azure_sql_memory_integration", ) - return build_atomic_attack_identifier(attack_identifier=attack_strategy) + return AtomicAttackIdentifier.build(attack_identifier=attack_strategy) def get_test_scorer_identifier(**kwargs) -> ComponentIdentifier: diff --git a/tests/unit/analytics/test_result_analysis.py b/tests/unit/analytics/test_result_analysis.py index 6af17d5c8c..b07b84f567 100644 --- a/tests/unit/analytics/test_result_analysis.py +++ b/tests/unit/analytics/test_result_analysis.py @@ -14,13 +14,13 @@ ) from pyrit.memory.memory_interface import MemoryInterface from pyrit.models import ( + AtomicAttackIdentifier, AttackOutcome, AttackResult, ComponentIdentifier, IdentifierFilter, IdentifierType, ObjectiveTargetEvaluationIdentifier, - build_atomic_attack_identifier, ) @@ -36,7 +36,7 @@ def make_attack( atomic_attack_identifier: ComponentIdentifier | None = None if attack_type is not None: attack_identifier = ComponentIdentifier(class_name=attack_type, class_module="tests.unit.analytics") - atomic_attack_identifier = build_atomic_attack_identifier(attack_identifier=attack_identifier) + atomic_attack_identifier = AtomicAttackIdentifier.build(attack_identifier=attack_identifier) return AttackResult( conversation_id=conversation_id, diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index 2749c6fd67..c2310582fe 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -27,12 +27,12 @@ get_attack_service, ) from pyrit.models import ( + AtomicAttackIdentifier, AttackOutcome, AttackResult, ComponentIdentifier, Message, MessagePiece, - build_atomic_attack_identifier, ) from pyrit.models.conversation_stats import ConversationStats @@ -91,7 +91,7 @@ def make_attack_result( return AttackResult( conversation_id=conversation_id, objective=objective, - atomic_attack_identifier=build_atomic_attack_identifier( + atomic_attack_identifier=AtomicAttackIdentifier.build( attack_identifier=ComponentIdentifier( class_name=name, class_module="pyrit.backend", @@ -286,7 +286,7 @@ async def test_list_attacks_forwards_has_converters_false(self, attack_service, async def test_list_attacks_filters_by_converter_types_and_logic(self, attack_service, mock_memory) -> None: """Test that list_attacks passes converter_types to memory layer.""" ar1 = make_attack_result(conversation_id="attack-1", name="Attack One") - ar1.atomic_attack_identifier = build_atomic_attack_identifier( + ar1.atomic_attack_identifier = AtomicAttackIdentifier.build( attack_identifier=ComponentIdentifier( class_name="Attack One", class_module="pyrit.backend", @@ -2226,7 +2226,7 @@ async def test_add_message_merges_converter_identifiers_without_duplicates(self, ar = make_attack_result(conversation_id="attack-1") # Rebuild the atomic_attack_identifier to include an existing converter child strategy = ar.get_attack_strategy_identifier() - ar.atomic_attack_identifier = build_atomic_attack_identifier( + ar.atomic_attack_identifier = AtomicAttackIdentifier.build( attack_identifier=ComponentIdentifier( class_name="ManualAttack", class_module="pyrit.backend", diff --git a/tests/unit/backend/test_converter_service.py b/tests/unit/backend/test_converter_service.py index 9033941dec..c2eea9d26b 100644 --- a/tests/unit/backend/test_converter_service.py +++ b/tests/unit/backend/test_converter_service.py @@ -530,13 +530,13 @@ def _try_instantiate_converter(converter_name: str): ): mock_target = MagicMock(spec=PromptTarget) mock_target.__class__.__name__ = "MockChatTarget" - # Configure get_identifier() to return a proper identifier-like object - # so that _create_identifier can extract class_name, model_name, etc. - mock_id = MagicMock() - mock_id.class_name = "MockChatTarget" - mock_id.model_name = "test-model" - mock_id.temperature = None - mock_id.top_p = None + # Configure get_identifier() to return a real identifier so that + # _create_identifier can promote it into the typed child slot. + mock_id = ComponentIdentifier( + class_name="MockChatTarget", + class_module="mock", + params={"model_name": "test-model"}, + ) mock_target.get_identifier.return_value = mock_id kwargs[pname] = mock_target # PromptConverter — use a real simple converter to avoid JSON serialization issues diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index 9170f031df..c0886d4676 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -30,13 +30,13 @@ from pyrit.backend.models._media import build_filename, infer_mime_type from pyrit.backend.models.attacks import ScoreView from pyrit.models import ( + AtomicAttackIdentifier, AttackOutcome, AttackResult, ComponentIdentifier, Message, MessagePiece, Score, - build_atomic_attack_identifier, ) from pyrit.models.conversation_stats import ConversationStats from pyrit.prompt_target import PromptTarget, TargetCapabilities @@ -73,7 +73,7 @@ def _make_attack_result( conversation_id=conversation_id, objective="test", attack_result_id=str(uuid.uuid4()), - atomic_attack_identifier=build_atomic_attack_identifier( + atomic_attack_identifier=AtomicAttackIdentifier.build( attack_identifier=ComponentIdentifier( class_name=name, class_module="pyrit.backend", @@ -276,7 +276,7 @@ async def test_converters_extracted_from_identifier(self) -> None: conversation_id="attack-conv", objective="test", attack_result_id=str(uuid.uuid4()), - atomic_attack_identifier=build_atomic_attack_identifier( + atomic_attack_identifier=AtomicAttackIdentifier.build( attack_identifier=ComponentIdentifier( class_name="TestAttack", class_module="pyrit.backend", diff --git a/tests/unit/backend/test_response_contracts.py b/tests/unit/backend/test_response_contracts.py index bb7f403df3..a86fe754ae 100644 --- a/tests/unit/backend/test_response_contracts.py +++ b/tests/unit/backend/test_response_contracts.py @@ -22,12 +22,12 @@ ScoreView, ) from pyrit.models import ( + AtomicAttackIdentifier, AttackResult, ComponentIdentifier, MessagePiece, RetryEvent, Score, - build_atomic_attack_identifier, ) from pyrit.models.conversation_reference import ConversationReference, ConversationType @@ -60,7 +60,7 @@ def _make_attack_result(*, name: str = "CrescendoAttack") -> AttackResult: conversation_id="attack-1", objective="test objective", attack_result_id="ar-attack-1", - atomic_attack_identifier=build_atomic_attack_identifier( + atomic_attack_identifier=AtomicAttackIdentifier.build( attack_identifier=ComponentIdentifier( class_name=name, class_module="pyrit.attacks", diff --git a/tests/unit/executor/attack/core/test_attack_strategy.py b/tests/unit/executor/attack/core/test_attack_strategy.py index a0eb00b63e..3f0847892e 100644 --- a/tests/unit/executor/attack/core/test_attack_strategy.py +++ b/tests/unit/executor/attack/core/test_attack_strategy.py @@ -25,7 +25,7 @@ ) from pyrit.models.identifiers import ( AtomicAttackEvaluationIdentifier, - build_atomic_attack_identifier, + AtomicAttackIdentifier, ) from pyrit.models.retry_event import RetryEvent from pyrit.prompt_target import PromptTarget @@ -896,7 +896,7 @@ def get_attack_adversarial_config(self): def _eval_hash(attack_identifier: ComponentIdentifier) -> str: - composite = build_atomic_attack_identifier(attack_identifier=attack_identifier) + composite = AtomicAttackIdentifier.build(attack_identifier=attack_identifier) return AtomicAttackEvaluationIdentifier(composite).eval_hash diff --git a/tests/unit/identifiers/test_deprecation_shim.py b/tests/unit/identifiers/test_deprecation_shim.py index a1ca39ea9e..af5336f11d 100644 --- a/tests/unit/identifiers/test_deprecation_shim.py +++ b/tests/unit/identifiers/test_deprecation_shim.py @@ -22,14 +22,12 @@ import pytest import pyrit.identifiers as shim -import pyrit.identifiers.atomic_attack_identifier as shim_atomic import pyrit.identifiers.class_name_utils as shim_class_name import pyrit.identifiers.component_identifier as shim_component import pyrit.identifiers.evaluation_identifier as shim_eval import pyrit.identifiers.identifier_filters as shim_filters import pyrit.models as models_pkg import pyrit.models.identifiers as new -import pyrit.models.identifiers.atomic_attack_identifier as new_atomic import pyrit.models.identifiers.class_name_utils as new_class_name import pyrit.models.identifiers.component_identifier as new_component import pyrit.models.identifiers.evaluation_identifier as new_eval @@ -37,27 +35,21 @@ SUBMODULE_PAIRS = [ (shim_component, new_component, "component_identifier"), - (shim_atomic, new_atomic, "atomic_attack_identifier"), (shim_eval, new_eval, "evaluation_identifier"), (shim_class_name, new_class_name, "class_name_utils"), (shim_filters, new_filters, "identifier_filters"), ] -# Names that are deprecated at BOTH the pyrit.identifiers shim path AND the new -# pyrit.models.identifiers canonical path (because the underlying class was itself -# renamed). The shim's __getattr__ suppresses its standard path-migration warning -# for these names so a single access produces a single, more informative warning -# pointing at the actual replacement class. Tested separately in -# ``test_scorer_identifier_*`` below. -NAMES_DEPRECATED_AT_NEW_PATH = {"ScorerIdentifier"} -FORWARD_ONLY_NAMES = [n for n in shim.__all__ if n not in NAMES_DEPRECATED_AT_NEW_PATH] +# Every public name on the shim forwards to its canonical ``pyrit.models.identifiers`` +# location and emits the standard one-shot path-migration warning. +FORWARD_ONLY_NAMES = list(shim.__all__) @pytest.fixture(autouse=True) def _reset_warning_caches(): """Reset every shim's per-process `_warned` set so each test starts clean.""" saved = {} - modules = [shim, new, models_pkg] + [m for m, _, _ in SUBMODULE_PAIRS] + modules = [shim, models_pkg] + [m for m, _, _ in SUBMODULE_PAIRS] for mod in modules: saved[mod] = set(mod._warned) mod._warned.clear() @@ -94,60 +86,6 @@ def test_top_level_shim_emits_one_warning_per_name(name): assert "0.16.0" in message -def test_scorer_identifier_via_shim_emits_single_rename_warning(): - """`from pyrit.identifiers import ScorerIdentifier` produces ONE warning that points at the - actual replacement (ComponentIdentifier), not at the deprecated pyrit.models.identifiers path. - - The shim's standard path-migration warning is suppressed for this name so the partner sees a - single actionable signal in one step. - """ - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always", DeprecationWarning) - result = shim.ScorerIdentifier - _ = shim.ScorerIdentifier - _ = shim.ScorerIdentifier - - dep = [w for w in caught if issubclass(w.category, DeprecationWarning)] - assert len(dep) == 1, f"Expected 1 DeprecationWarning, got {len(dep)}: {[str(w.message) for w in dep]}" - message = str(dep[0].message) - assert "pyrit.models.identifiers.ScorerIdentifier" in message - assert "ComponentIdentifier" in message - assert "0.16.0" in message - assert result is new.ComponentIdentifier - - -def test_scorer_identifier_via_canonical_path_emits_single_warning(): - """`from pyrit.models.identifiers import ScorerIdentifier` warns once per process.""" - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always", DeprecationWarning) - result = new.ScorerIdentifier - _ = new.ScorerIdentifier - - dep = [w for w in caught if issubclass(w.category, DeprecationWarning)] - assert len(dep) == 1, f"Expected 1 DeprecationWarning, got {len(dep)}" - message = str(dep[0].message) - assert "pyrit.models.identifiers.ScorerIdentifier" in message - assert "ComponentIdentifier" in message - assert "0.16.0" in message - assert result is new.ComponentIdentifier - - -def test_scorer_identifier_via_models_package_emits_single_warning(): - """`from pyrit.models import ScorerIdentifier` warns once per process.""" - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always", DeprecationWarning) - result = models_pkg.ScorerIdentifier - _ = models_pkg.ScorerIdentifier - - dep = [w for w in caught if issubclass(w.category, DeprecationWarning)] - assert len(dep) == 1, f"Expected 1 DeprecationWarning, got {len(dep)}" - message = str(dep[0].message) - assert "pyrit.models.ScorerIdentifier" in message - assert "ComponentIdentifier" in message - assert "0.16.0" in message - assert result is models_pkg.ComponentIdentifier - - def test_top_level_shim_attribute_error_for_unknown_name(): with pytest.raises(AttributeError, match="has no attribute 'definitely_not_a_real_name'"): _ = shim.definitely_not_a_real_name diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index debe035316..d57b0ec776 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -11,6 +11,7 @@ from pyrit.memory import MemoryInterface from pyrit.memory.memory_models import AttackResultEntry from pyrit.models import ( + AtomicAttackIdentifier, AttackOutcome, AttackResult, ComponentIdentifier, @@ -20,7 +21,6 @@ IdentifierType, MessagePiece, Score, - build_atomic_attack_identifier, ) if TYPE_CHECKING: @@ -478,7 +478,7 @@ def test_attack_result_all_outcomes(sqlite_instance: MemoryInterface): attack_result = AttackResult( conversation_id=f"conv_{i}", objective=f"Test objective {i}", - atomic_attack_identifier=build_atomic_attack_identifier( + atomic_attack_identifier=AtomicAttackIdentifier.build( attack_identifier=ComponentIdentifier(class_name=f"TestAttack{i}", class_module="test.module"), ), executed_turns=i + 1, @@ -1236,7 +1236,7 @@ def _make_attack_result_with_identifier( return AttackResult( conversation_id=conversation_id, objective=f"Objective for {conversation_id}", - atomic_attack_identifier=build_atomic_attack_identifier( + atomic_attack_identifier=AtomicAttackIdentifier.build( attack_identifier=ComponentIdentifier( class_name=class_name, class_module="pyrit.attacks", @@ -1314,7 +1314,7 @@ def _eval_hash_for(class_name: str) -> str: from pyrit.models.identifiers.evaluation_identifier import AtomicAttackEvaluationIdentifier return AtomicAttackEvaluationIdentifier( - build_atomic_attack_identifier( + AtomicAttackIdentifier.build( attack_identifier=ComponentIdentifier( class_name=class_name, class_module="pyrit.attacks", diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index 04c4ead899..34862fa0e1 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -14,6 +14,7 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.memory import MemoryInterface, PromptMemoryEntry from pyrit.models import ( + AtomicAttackIdentifier, AttackResult, ComponentIdentifier, Conversation, @@ -23,7 +24,6 @@ MessagePiece, Score, SeedPrompt, - build_atomic_attack_identifier, ) @@ -1015,12 +1015,12 @@ def test_get_message_pieces_attack(sqlite_instance: MemoryInterface): AttackResult( conversation_id="c1", objective="objective 1", - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=attack1.get_identifier()), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=attack1.get_identifier()), ), AttackResult( conversation_id="c2", objective="objective 2", - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=attack2.get_identifier()), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=attack2.get_identifier()), ), ] ) diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index ed12d309c8..a9b109da06 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -11,6 +11,7 @@ from pyrit.memory import MemoryInterface, PromptMemoryEntry from pyrit.models import ( + AtomicAttackIdentifier, AttackResult, ComponentIdentifier, IdentifierFilter, @@ -18,7 +19,6 @@ MessagePiece, Score, SeedPrompt, - build_atomic_attack_identifier, ) @@ -49,7 +49,7 @@ def test_get_scores_by_attack_id_and_label( AttackResult( conversation_id=sample_conversations[0].conversation_id, objective="test objective", - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=attack_strategy_id), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=attack_strategy_id), ) ] ) diff --git a/tests/unit/memory/test_memory_models.py b/tests/unit/memory/test_memory_models.py index fe63b1546b..852b6de9d8 100644 --- a/tests/unit/memory/test_memory_models.py +++ b/tests/unit/memory/test_memory_models.py @@ -22,6 +22,7 @@ _load_identifier, ) from pyrit.models import ( + AtomicAttackIdentifier, AttackOutcome, AttackResult, ComponentIdentifier, @@ -34,7 +35,6 @@ SeedObjective, SeedPrompt, SeedSimulatedConversation, - build_atomic_attack_identifier, ) # --------------------------------------------------------------------------- @@ -528,7 +528,7 @@ def test_filter_json_serializable_metadata_mixed(self): def test_get_attack_result_prefers_atomic_over_stale_attack_identifier(self): """When atomic_attack_identifier and attack_identifier disagree, atomic wins.""" correct_attack_id = ComponentIdentifier(class_name="CorrectAttack", class_module="pyrit.backend") - atomic_id = build_atomic_attack_identifier(attack_identifier=correct_attack_id) + atomic_id = AtomicAttackIdentifier.build(attack_identifier=correct_attack_id) ar = _make_attack_result(atomic_attack_identifier=atomic_id) entry = AttackResultEntry(entry=ar) diff --git a/tests/unit/models/identifiers/test_atomic_attack_identifier.py b/tests/unit/models/identifiers/test_atomic_attack_identifier.py index 933c233207..12d8f499ae 100644 --- a/tests/unit/models/identifiers/test_atomic_attack_identifier.py +++ b/tests/unit/models/identifiers/test_atomic_attack_identifier.py @@ -4,9 +4,9 @@ from pyrit.models.identifiers import ( AtomicAttackEvaluationIdentifier, + AtomicAttackIdentifier, ComponentIdentifier, - build_atomic_attack_identifier, - build_seed_identifier, + SeedIdentifier, compute_eval_hash, ) from pyrit.models.seeds.seed_prompt import SeedPrompt @@ -48,99 +48,112 @@ def _make_attack( # ========================================================================= -# build_seed_identifier +# SeedIdentifier.from_seed # ========================================================================= class TestBuildSeedIdentifier: - """Tests for build_seed_identifier.""" + """Tests for SeedIdentifier.from_seed.""" def test_returns_component_identifier(self): seed = SeedPrompt(value="hello", value_sha256="abc123", dataset_name="test_ds", name="seed1") - result = build_seed_identifier(seed) + result = SeedIdentifier.from_seed(seed) assert isinstance(result, ComponentIdentifier) def test_captures_class_name(self): seed = SeedPrompt(value="hello", value_sha256="abc123") - assert build_seed_identifier(seed).class_name == "SeedPrompt" + assert SeedIdentifier.from_seed(seed).class_name == "SeedPrompt" def test_includes_value_and_sha256_and_dataset(self): seed = SeedPrompt(value="hello", value_sha256="abc", dataset_name="my_dataset") - result = build_seed_identifier(seed) + result = SeedIdentifier.from_seed(seed) assert result.params["value"] == "hello" assert result.params["value_sha256"] == "abc" assert result.params["dataset_name"] == "my_dataset" + def test_includes_data_type(self): + seed = SeedPrompt(value="hello", value_sha256="abc") + result = SeedIdentifier.from_seed(seed) + assert result.data_type == "text" + assert result.params["data_type"] == "text" + + def test_data_type_distinguishes_identity(self): + text_seed = SeedPrompt(value="payload", value_sha256="abc", data_type="text") + image_seed = SeedPrompt(value="payload", value_sha256="abc", data_type="image_path") + assert SeedIdentifier.from_seed(text_seed).hash != SeedIdentifier.from_seed(image_seed).hash + def test_includes_is_general_technique_true(self): seed = SeedPrompt(value="hello", value_sha256="abc", is_general_technique=True) - result = build_seed_identifier(seed) + result = SeedIdentifier.from_seed(seed) assert result.params["is_general_technique"] is True def test_includes_is_general_technique_false(self): seed = SeedPrompt(value="hello", value_sha256="abc", is_general_technique=False) - result = build_seed_identifier(seed) + result = SeedIdentifier.from_seed(seed) assert result.params["is_general_technique"] is False - def test_none_values_present_in_params(self): + def test_none_values_dropped_from_params(self): + # Promoted None-valued fields are dropped from params (consistent with + # ComponentIdentifier.of semantics). The hash excludes None either way, + # so identity is unchanged. seed = SeedPrompt(value="hello") seed.value_sha256 = None seed.dataset_name = None - result = build_seed_identifier(seed) - assert "value_sha256" in result.params - assert result.params["value_sha256"] is None - assert "dataset_name" in result.params - assert result.params["dataset_name"] is None + result = SeedIdentifier.from_seed(seed) + assert "value_sha256" not in result.params + assert "dataset_name" not in result.params + assert result.params["value"] == "hello" def test_deterministic_hash(self): seed1 = SeedPrompt(value="hello", value_sha256="abc123", dataset_name="ds") seed2 = SeedPrompt(value="hello", value_sha256="abc123", dataset_name="ds") - assert build_seed_identifier(seed1).hash == build_seed_identifier(seed2).hash + assert SeedIdentifier.from_seed(seed1).hash == SeedIdentifier.from_seed(seed2).hash def test_different_content_different_hash(self): seed1 = SeedPrompt(value="hello", value_sha256="abc123") seed2 = SeedPrompt(value="world", value_sha256="def456") - assert build_seed_identifier(seed1).hash != build_seed_identifier(seed2).hash + assert SeedIdentifier.from_seed(seed1).hash != SeedIdentifier.from_seed(seed2).hash # ========================================================================= -# build_atomic_attack_identifier +# AtomicAttackIdentifier.build # ========================================================================= class TestBuildAtomicAttackIdentifier: - """Tests for build_atomic_attack_identifier.""" + """Tests for AtomicAttackIdentifier.build.""" def test_returns_component_identifier(self): - result = build_atomic_attack_identifier(attack_identifier=_make_attack()) + result = AtomicAttackIdentifier.build(attack_identifier=_make_attack()) assert isinstance(result, ComponentIdentifier) def test_class_name_is_atomic_attack(self): - result = build_atomic_attack_identifier(attack_identifier=_make_attack()) + result = AtomicAttackIdentifier.build(attack_identifier=_make_attack()) assert result.class_name == "AtomicAttack" def test_class_module_is_correct(self): - result = build_atomic_attack_identifier(attack_identifier=_make_attack()) + result = AtomicAttackIdentifier.build(attack_identifier=_make_attack()) assert result.class_module == "pyrit.scenario.core.atomic_attack" def test_attack_technique_child_is_present(self): attack_id = _make_attack() - result = build_atomic_attack_identifier(attack_identifier=attack_id) + result = AtomicAttackIdentifier.build(attack_identifier=attack_id) technique = result.children["attack_technique"] assert technique.class_name == "AttackTechnique" assert technique.children["attack"] == attack_id def test_no_seed_group_empty_seed_identifiers(self): - result = build_atomic_attack_identifier(attack_identifier=_make_attack()) + result = AtomicAttackIdentifier.build(attack_identifier=_make_attack()) assert result.children["seed_identifiers"] == [] def test_empty_seed_group_empty_seed_identifiers(self): - result = build_atomic_attack_identifier(attack_identifier=_make_attack(), seed_group=_FakeSeedGroup(seeds=[])) + result = AtomicAttackIdentifier.build(attack_identifier=_make_attack(), seed_group=_FakeSeedGroup(seeds=[])) assert result.children["seed_identifiers"] == [] def test_includes_all_seeds(self): general_seed = SeedPrompt(value="technique", value_sha256="abc", is_general_technique=True) non_general_seed = SeedPrompt(value="objective", value_sha256="def", is_general_technique=False) - result = build_atomic_attack_identifier( + result = AtomicAttackIdentifier.build( attack_identifier=_make_attack(), seed_group=_FakeSeedGroup(seeds=[general_seed, non_general_seed]), ) @@ -154,7 +167,7 @@ def test_includes_all_seeds(self): def test_multiple_seeds(self): seed1 = SeedPrompt(value="tech1", value_sha256="aaa", is_general_technique=True) seed2 = SeedPrompt(value="tech2", value_sha256="bbb", is_general_technique=True) - result = build_atomic_attack_identifier( + result = AtomicAttackIdentifier.build( attack_identifier=_make_attack(), seed_group=_FakeSeedGroup(seeds=[seed1, seed2]), ) @@ -163,26 +176,26 @@ def test_multiple_seeds(self): def test_deterministic_hash(self): attack_id = _make_attack() seed = SeedPrompt(value="technique", value_sha256="abc", is_general_technique=True) - r1 = build_atomic_attack_identifier(attack_identifier=attack_id, seed_group=_FakeSeedGroup(seeds=[seed])) - r2 = build_atomic_attack_identifier(attack_identifier=attack_id, seed_group=_FakeSeedGroup(seeds=[seed])) + r1 = AtomicAttackIdentifier.build(attack_identifier=attack_id, seed_group=_FakeSeedGroup(seeds=[seed])) + r2 = AtomicAttackIdentifier.build(attack_identifier=attack_id, seed_group=_FakeSeedGroup(seeds=[seed])) assert r1.hash == r2.hash def test_different_seeds_different_hash(self): attack_id = _make_attack() seed1 = SeedPrompt(value="tech1", value_sha256="aaa", is_general_technique=True) seed2 = SeedPrompt(value="tech2", value_sha256="bbb", is_general_technique=True) - r1 = build_atomic_attack_identifier(attack_identifier=attack_id, seed_group=_FakeSeedGroup(seeds=[seed1])) - r2 = build_atomic_attack_identifier(attack_identifier=attack_id, seed_group=_FakeSeedGroup(seeds=[seed2])) + r1 = AtomicAttackIdentifier.build(attack_identifier=attack_id, seed_group=_FakeSeedGroup(seeds=[seed1])) + r2 = AtomicAttackIdentifier.build(attack_identifier=attack_id, seed_group=_FakeSeedGroup(seeds=[seed2])) assert r1.hash != r2.hash def test_different_attacks_different_hash(self): - r1 = build_atomic_attack_identifier(attack_identifier=_make_attack(class_name="PromptSendingAttack")) - r2 = build_atomic_attack_identifier(attack_identifier=_make_attack(class_name="CrescendoAttack")) + r1 = AtomicAttackIdentifier.build(attack_identifier=_make_attack(class_name="PromptSendingAttack")) + r2 = AtomicAttackIdentifier.build(attack_identifier=_make_attack(class_name="CrescendoAttack")) assert r1.hash != r2.hash def test_serialization_round_trip(self): seed = SeedPrompt(value="technique", value_sha256="abc", is_general_technique=True, dataset_name="ds") - original = build_atomic_attack_identifier( + original = AtomicAttackIdentifier.build( attack_identifier=_make_attack(), seed_group=_FakeSeedGroup(seeds=[seed]), ) @@ -228,19 +241,19 @@ def test_seed_identifiers_rule(self): # -- Basic properties -------------------------------------------------- def test_identifier_property_returns_original(self): - composite = build_atomic_attack_identifier(attack_identifier=_make_attack()) + composite = AtomicAttackIdentifier.build(attack_identifier=_make_attack()) identity = AtomicAttackEvaluationIdentifier(composite) assert identity.identifier is composite def test_eval_hash_is_64_char_hex(self): - composite = build_atomic_attack_identifier(attack_identifier=_make_attack()) + composite = AtomicAttackIdentifier.build(attack_identifier=_make_attack()) identity = AtomicAttackEvaluationIdentifier(composite) assert isinstance(identity.eval_hash, str) and len(identity.eval_hash) == 64 # -- Consistency with free functions ----------------------------------- def test_eval_hash_matches_compute_eval_hash_with_rules(self): - composite = build_atomic_attack_identifier( + composite = AtomicAttackIdentifier.build( attack_identifier=_make_attack(children={"objective_target": _make_target(params={"temperature": 0.5})}) ) identity = AtomicAttackEvaluationIdentifier(composite) @@ -256,15 +269,15 @@ def test_objective_target_operational_params_ignored(self): """Same temperature, different endpoint/model -> same eval hash.""" t1 = _make_target(params={"model_name": "gpt-4o", "endpoint": "https://a.com", "temperature": 0.7}) t2 = _make_target(params={"model_name": "gpt-3.5", "endpoint": "https://b.com", "temperature": 0.7}) - c1 = build_atomic_attack_identifier(attack_identifier=_make_attack(children={"objective_target": t1})) - c2 = build_atomic_attack_identifier(attack_identifier=_make_attack(children={"objective_target": t2})) + c1 = AtomicAttackIdentifier.build(attack_identifier=_make_attack(children={"objective_target": t1})) + c2 = AtomicAttackIdentifier.build(attack_identifier=_make_attack(children={"objective_target": t2})) assert AtomicAttackEvaluationIdentifier(c1).eval_hash == AtomicAttackEvaluationIdentifier(c2).eval_hash def test_objective_target_different_temperature_different_hash(self): t1 = _make_target(params={"temperature": 0.7}) t2 = _make_target(params={"temperature": 0.0}) - c1 = build_atomic_attack_identifier(attack_identifier=_make_attack(children={"objective_target": t1})) - c2 = build_atomic_attack_identifier(attack_identifier=_make_attack(children={"objective_target": t2})) + c1 = AtomicAttackIdentifier.build(attack_identifier=_make_attack(children={"objective_target": t1})) + c2 = AtomicAttackIdentifier.build(attack_identifier=_make_attack(children={"objective_target": t2})) assert AtomicAttackEvaluationIdentifier(c1).eval_hash != AtomicAttackEvaluationIdentifier(c2).eval_hash # -- adversarial_chat filtering ---------------------------------------- @@ -275,8 +288,8 @@ def test_adversarial_chat_model_name_affects_hash(self): chat2 = ComponentIdentifier(class_name="Chat", class_module="m", params={"model_name": "gpt-3.5"}) a1 = _make_attack(children={"adversarial_chat": chat1}) a2 = _make_attack(children={"adversarial_chat": chat2}) - c1 = build_atomic_attack_identifier(attack_identifier=a1) - c2 = build_atomic_attack_identifier(attack_identifier=a2) + c1 = AtomicAttackIdentifier.build(attack_identifier=a1) + c2 = AtomicAttackIdentifier.build(attack_identifier=a2) assert AtomicAttackEvaluationIdentifier(c1).eval_hash != AtomicAttackEvaluationIdentifier(c2).eval_hash def test_adversarial_chat_endpoint_ignored(self): @@ -293,8 +306,8 @@ def test_adversarial_chat_endpoint_ignored(self): ) a1 = _make_attack(children={"adversarial_chat": chat1}) a2 = _make_attack(children={"adversarial_chat": chat2}) - c1 = build_atomic_attack_identifier(attack_identifier=a1) - c2 = build_atomic_attack_identifier(attack_identifier=a2) + c1 = AtomicAttackIdentifier.build(attack_identifier=a1) + c2 = AtomicAttackIdentifier.build(attack_identifier=a2) assert AtomicAttackEvaluationIdentifier(c1).eval_hash == AtomicAttackEvaluationIdentifier(c2).eval_hash def test_adversarial_chat_wrapper_unwrapped_via_inner_child_name(self): @@ -310,8 +323,8 @@ def test_adversarial_chat_wrapper_unwrapped_via_inner_child_name(self): ) a_bare = _make_attack(children={"adversarial_chat": bare}) a_wrapped = _make_attack(children={"adversarial_chat": wrapper}) - c_bare = build_atomic_attack_identifier(attack_identifier=a_bare) - c_wrapped = build_atomic_attack_identifier(attack_identifier=a_wrapped) + c_bare = AtomicAttackIdentifier.build(attack_identifier=a_bare) + c_wrapped = AtomicAttackIdentifier.build(attack_identifier=a_wrapped) assert ( AtomicAttackEvaluationIdentifier(c_bare).eval_hash == AtomicAttackEvaluationIdentifier(c_wrapped).eval_hash ) @@ -328,8 +341,8 @@ def test_objective_scorer_excluded_from_eval_hash(self): ) a1 = _make_attack(children={"objective_scorer": scorer1}) a2 = _make_attack(children={"objective_scorer": scorer2}) - c1 = build_atomic_attack_identifier(attack_identifier=a1) - c2 = build_atomic_attack_identifier(attack_identifier=a2) + c1 = AtomicAttackIdentifier.build(attack_identifier=a1) + c2 = AtomicAttackIdentifier.build(attack_identifier=a2) assert AtomicAttackEvaluationIdentifier(c1).eval_hash == AtomicAttackEvaluationIdentifier(c2).eval_hash def test_objective_scorer_presence_vs_absence_same_hash(self): @@ -339,8 +352,8 @@ def test_objective_scorer_presence_vs_absence_same_hash(self): ) a_with = _make_attack(children={"objective_scorer": scorer}) a_without = _make_attack() - c1 = build_atomic_attack_identifier(attack_identifier=a_with) - c2 = build_atomic_attack_identifier(attack_identifier=a_without) + c1 = AtomicAttackIdentifier.build(attack_identifier=a_with) + c2 = AtomicAttackIdentifier.build(attack_identifier=a_without) assert AtomicAttackEvaluationIdentifier(c1).eval_hash == AtomicAttackEvaluationIdentifier(c2).eval_hash # -- Converters (non-target, fully included) --------------------------- @@ -350,16 +363,16 @@ def test_different_request_converters_different_hash(self): conv2 = ComponentIdentifier(class_name="ROT13Converter", class_module="pyrit.prompt_converter") a1 = _make_attack(children={"request_converters": [conv1]}) a2 = _make_attack(children={"request_converters": [conv2]}) - c1 = build_atomic_attack_identifier(attack_identifier=a1) - c2 = build_atomic_attack_identifier(attack_identifier=a2) + c1 = AtomicAttackIdentifier.build(attack_identifier=a1) + c2 = AtomicAttackIdentifier.build(attack_identifier=a2) assert AtomicAttackEvaluationIdentifier(c1).eval_hash != AtomicAttackEvaluationIdentifier(c2).eval_hash def test_same_request_converters_same_hash(self): conv = ComponentIdentifier(class_name="Base64Converter", class_module="pyrit.prompt_converter") a1 = _make_attack(children={"request_converters": [conv]}) a2 = _make_attack(children={"request_converters": [conv]}) - c1 = build_atomic_attack_identifier(attack_identifier=a1) - c2 = build_atomic_attack_identifier(attack_identifier=a2) + c1 = AtomicAttackIdentifier.build(attack_identifier=a1) + c2 = AtomicAttackIdentifier.build(attack_identifier=a2) assert AtomicAttackEvaluationIdentifier(c1).eval_hash == AtomicAttackEvaluationIdentifier(c2).eval_hash def test_response_converters_contribute(self): @@ -367,8 +380,8 @@ def test_response_converters_contribute(self): conv2 = ComponentIdentifier(class_name="ROT13Converter", class_module="pyrit.prompt_converter") a1 = _make_attack(children={"response_converters": [conv1]}) a2 = _make_attack(children={"response_converters": [conv2]}) - c1 = build_atomic_attack_identifier(attack_identifier=a1) - c2 = build_atomic_attack_identifier(attack_identifier=a2) + c1 = AtomicAttackIdentifier.build(attack_identifier=a1) + c2 = AtomicAttackIdentifier.build(attack_identifier=a2) assert AtomicAttackEvaluationIdentifier(c1).eval_hash != AtomicAttackEvaluationIdentifier(c2).eval_hash def test_converters_contribute_while_target_endpoint_ignored(self): @@ -378,8 +391,8 @@ def test_converters_contribute_while_target_endpoint_ignored(self): conv = ComponentIdentifier(class_name="Base64Converter", class_module="pyrit.prompt_converter") a1 = _make_attack(children={"objective_target": t1, "request_converters": [conv]}) a2 = _make_attack(children={"objective_target": t2, "request_converters": [conv]}) - c1 = build_atomic_attack_identifier(attack_identifier=a1) - c2 = build_atomic_attack_identifier(attack_identifier=a2) + c1 = AtomicAttackIdentifier.build(attack_identifier=a1) + c2 = AtomicAttackIdentifier.build(attack_identifier=a2) assert AtomicAttackEvaluationIdentifier(c1).eval_hash == AtomicAttackEvaluationIdentifier(c2).eval_hash # -- Seeds and technique_seeds (eval hash uses technique_seeds, excludes seeds) --- @@ -391,15 +404,15 @@ def test_different_technique_seeds_different_eval_hash(self): technique1 = ComponentIdentifier( class_name="AttackTechnique", class_module="pyrit.scenario.core.attack_technique", - children={"attack": attack_id, "technique_seeds": [build_seed_identifier(seed1)]}, + children={"attack": attack_id, "technique_seeds": [SeedIdentifier.from_seed(seed1)]}, ) technique2 = ComponentIdentifier( class_name="AttackTechnique", class_module="pyrit.scenario.core.attack_technique", - children={"attack": attack_id, "technique_seeds": [build_seed_identifier(seed2)]}, + children={"attack": attack_id, "technique_seeds": [SeedIdentifier.from_seed(seed2)]}, ) - c1 = build_atomic_attack_identifier(technique_identifier=technique1) - c2 = build_atomic_attack_identifier(technique_identifier=technique2) + c1 = AtomicAttackIdentifier.build(technique_identifier=technique1) + c2 = AtomicAttackIdentifier.build(technique_identifier=technique2) assert AtomicAttackEvaluationIdentifier(c1).eval_hash != AtomicAttackEvaluationIdentifier(c2).eval_hash def test_seeds_in_seed_group_ignored_in_eval_hash(self): @@ -407,11 +420,11 @@ def test_seeds_in_seed_group_ignored_in_eval_hash(self): attack_id = _make_attack() non_general_1 = SeedPrompt(value="obj1", value_sha256="xxx", is_general_technique=False) non_general_2 = SeedPrompt(value="obj2", value_sha256="yyy", is_general_technique=False) - c1 = build_atomic_attack_identifier( + c1 = AtomicAttackIdentifier.build( attack_identifier=attack_id, seed_group=_FakeSeedGroup(seeds=[non_general_1]), ) - c2 = build_atomic_attack_identifier( + c2 = AtomicAttackIdentifier.build( attack_identifier=attack_id, seed_group=_FakeSeedGroup(seeds=[non_general_2]), ) @@ -421,11 +434,11 @@ def test_general_technique_seeds_in_seed_group_ignored_in_eval_hash(self): """Even general technique seeds in seed_group are excluded from eval hash.""" attack_id = _make_attack() general_seed = SeedPrompt(value="technique", value_sha256="abc", is_general_technique=True) - c_with = build_atomic_attack_identifier( + c_with = AtomicAttackIdentifier.build( attack_identifier=attack_id, seed_group=_FakeSeedGroup(seeds=[general_seed]), ) - c_without = build_atomic_attack_identifier( + c_without = AtomicAttackIdentifier.build( attack_identifier=attack_id, ) assert ( @@ -437,11 +450,11 @@ def test_identifier_hash_differs_with_different_seeds(self): attack_id = _make_attack() non_general_1 = SeedPrompt(value="obj1", value_sha256="xxx", is_general_technique=False) non_general_2 = SeedPrompt(value="obj2", value_sha256="yyy", is_general_technique=False) - c1 = build_atomic_attack_identifier( + c1 = AtomicAttackIdentifier.build( attack_identifier=attack_id, seed_group=_FakeSeedGroup(seeds=[non_general_1]), ) - c2 = build_atomic_attack_identifier( + c2 = AtomicAttackIdentifier.build( attack_identifier=attack_id, seed_group=_FakeSeedGroup(seeds=[non_general_2]), ) @@ -474,7 +487,7 @@ def test_full_composite_eval_hash(self): "request_converters": [converter], } ) - composite = build_atomic_attack_identifier( + composite = AtomicAttackIdentifier.build( attack_identifier=attack_id, seed_group=_FakeSeedGroup(seeds=[seed]), ) @@ -491,7 +504,7 @@ def test_full_composite_eval_hash(self): "request_converters": [converter], } ) - composite2 = build_atomic_attack_identifier( + composite2 = AtomicAttackIdentifier.build( attack_identifier=attack_id2, seed_group=_FakeSeedGroup(seeds=[seed]), ) @@ -509,7 +522,7 @@ def test_full_composite_eval_hash(self): "request_converters": [converter], } ) - composite3 = build_atomic_attack_identifier( + composite3 = AtomicAttackIdentifier.build( attack_identifier=attack_id3, seed_group=_FakeSeedGroup(seeds=[seed]), ) diff --git a/tests/unit/models/identifiers/test_component_identifier.py b/tests/unit/models/identifiers/test_component_identifier.py index a7c35a0c54..ec51120619 100644 --- a/tests/unit/models/identifiers/test_component_identifier.py +++ b/tests/unit/models/identifiers/test_component_identifier.py @@ -1427,6 +1427,42 @@ def test_ambiguous_flat_and_params_shape_rejected(self): ) +class TestComponentIdentifierParamsTyping: + """Params must be JSON-serializable scalars / nested list / dict containers.""" + + def test_accepts_json_scalars_and_nested_containers(self): + identifier = ComponentIdentifier( + class_name="Foo", + class_module="m", + params={ + "s": "text", + "i": 3, + "f": 1.5, + "b": True, + "n": None, + "lst": [1, "two", [3, 4]], + "nested": {"a": {"b": [1, 2]}}, + }, + ) + assert identifier.params["nested"] == {"a": {"b": [1, 2]}} + + def test_tuple_param_coerced_to_list(self): + """Tuples coerce to lists (JSON has no tuple), keeping the hash stable.""" + identifier = ComponentIdentifier(class_name="Foo", class_module="m", params={"t": (1, 2, 3)}) + assert identifier.params["t"] == [1, 2, 3] + assert isinstance(identifier.params["t"], list) + list_form = ComponentIdentifier(class_name="Foo", class_module="m", params={"t": [1, 2, 3]}) + assert identifier.hash == list_form.hash + + def test_non_json_object_value_rejected(self): + with pytest.raises(ValidationError): + ComponentIdentifier(class_name="Foo", class_module="m", params={"bad": object()}) + + def test_non_json_nested_value_rejected(self): + with pytest.raises(ValidationError): + ComponentIdentifier(class_name="Foo", class_module="m", params={"bad": [1, object()]}) + + class TestComponentIdentifierDeprecationWarnings: def test_to_dict_warns(self): ident = ComponentIdentifier(class_name="Foo", class_module="m", params={"a": 1}) diff --git a/tests/unit/models/identifiers/test_evaluation_identifier.py b/tests/unit/models/identifiers/test_evaluation_identifier.py index 46a6898fed..e4db3da13b 100644 --- a/tests/unit/models/identifiers/test_evaluation_identifier.py +++ b/tests/unit/models/identifiers/test_evaluation_identifier.py @@ -597,10 +597,10 @@ def _attack_with_identifier(self, identifier: ComponentIdentifier): return attack def test_matches_manual_two_step_composition(self): - """Helper equals the executor recipe (build_atomic_attack_identifier + AtomicAttackEvaluationIdentifier).""" + """Helper equals the executor recipe (AtomicAttackIdentifier.build + AtomicAttackEvaluationIdentifier).""" from pyrit.models.identifiers import ( AtomicAttackEvaluationIdentifier, - build_atomic_attack_identifier, + AtomicAttackIdentifier, compute_inner_attack_eval_hash, ) @@ -611,7 +611,7 @@ def test_matches_manual_two_step_composition(self): attack = self._attack_with_identifier(inner_id) expected = AtomicAttackEvaluationIdentifier( - build_atomic_attack_identifier(attack_identifier=inner_id), + AtomicAttackIdentifier.build(attack_identifier=inner_id), ).eval_hash assert compute_inner_attack_eval_hash(attack=attack) == expected @@ -639,7 +639,7 @@ def test_matches_persisted_row_eval_hash(self): identifier must yield an entry with the same eval_hash.""" from pyrit.memory.memory_models import AttackResultEntry from pyrit.models import AttackResult - from pyrit.models.identifiers import build_atomic_attack_identifier, compute_inner_attack_eval_hash + from pyrit.models.identifiers import AtomicAttackIdentifier, compute_inner_attack_eval_hash inner_id = ComponentIdentifier( class_name="MyAttack", @@ -651,7 +651,7 @@ def test_matches_persisted_row_eval_hash(self): result = AttackResult( conversation_id="conv_1", objective="o", - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=inner_id), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=inner_id), ) entry = AttackResultEntry(entry=result) assert entry.atomic_attack_identifier["eval_hash"] == predicted diff --git a/tests/unit/models/identifiers/test_typed_identifier.py b/tests/unit/models/identifiers/test_typed_identifier.py new file mode 100644 index 0000000000..4922e4d4d4 --- /dev/null +++ b/tests/unit/models/identifiers/test_typed_identifier.py @@ -0,0 +1,330 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for the strongly-typed identifier projections (ComponentIdentifier subclasses).""" + +import pytest + +from pyrit.models.identifiers import ( + AtomicAttackIdentifier, + AttackIdentifier, + AttackTechniqueIdentifier, + ComponentIdentifier, + ConverterIdentifier, + ScorerIdentifier, + SeedIdentifier, + TargetIdentifier, +) + + +def _target_identifier() -> ComponentIdentifier: + """A representative target ComponentIdentifier.""" + return ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + params={ + "endpoint": "https://example.openai.azure.com", + "model_name": "gpt-4o", + "underlying_model_name": "gpt-4o", + "temperature": 0.7, + "top_p": 0.9, + "max_requests_per_minute": 60, + "custom_thing": "keep-me", + }, + ) + + +def _converter_identifier() -> ComponentIdentifier: + """A representative converter ComponentIdentifier.""" + return ComponentIdentifier( + class_name="Base64Converter", + class_module="pyrit.prompt_converter.base64_converter", + params={ + "supported_input_types": ["text"], + "supported_output_types": ["text"], + "some_option": 3, + }, + ) + + +def _round_robin_identifier() -> ComponentIdentifier: + """A composite target identifier with list-valued children.""" + inner_a = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + params={"endpoint": "https://a", "model_name": "gpt-4o"}, + ) + inner_b = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + params={"endpoint": "https://b", "model_name": "gpt-4o-mini"}, + ) + return ComponentIdentifier( + class_name="RoundRobinTarget", + class_module="pyrit.prompt_target.round_robin_target", + params={"weights": [1, 1]}, + children={"targets": [inner_a, inner_b]}, + ) + + +def _scorer_with_child_identifier() -> ComponentIdentifier: + """A scorer-shaped identifier with a single (non-list) child.""" + child = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + params={"endpoint": "https://c", "model_name": "gpt-4o", "temperature": 0.0}, + ) + return ComponentIdentifier( + class_name="SelfAskScaleScorer", + class_module="pyrit.score.self_ask_scale_scorer", + params={"scorer_type": "float_scale", "some_param": "v"}, + children={"prompt_target": child}, + ) + + +class TestPromoteInvariant: + """The core round-trip guarantee: from_component_identifier() preserves identity and serialization.""" + + @pytest.mark.parametrize( + ("typed_cls", "factory"), + [ + (TargetIdentifier, _target_identifier), + (ConverterIdentifier, _converter_identifier), + (TargetIdentifier, _round_robin_identifier), + (ScorerIdentifier, _scorer_with_child_identifier), + ], + ) + def test_hash_preserved(self, typed_cls, factory): + """promote(ci).hash == ci.hash.""" + ci = factory() + assert typed_cls.from_component_identifier(ci).hash == ci.hash + + @pytest.mark.parametrize( + ("typed_cls", "factory"), + [ + (TargetIdentifier, _target_identifier), + (ConverterIdentifier, _converter_identifier), + (TargetIdentifier, _round_robin_identifier), + (ScorerIdentifier, _scorer_with_child_identifier), + ], + ) + def test_full_structure_preserved(self, typed_cls, factory): + """The full flat serialization round-trips byte-for-byte.""" + ci = factory() + assert typed_cls.from_component_identifier(ci).model_dump() == ci.model_dump() + + @pytest.mark.parametrize( + ("typed_cls", "factory"), + [ + (TargetIdentifier, _target_identifier), + (ConverterIdentifier, _converter_identifier), + (TargetIdentifier, _round_robin_identifier), + (ScorerIdentifier, _scorer_with_child_identifier), + ], + ) + def test_hash_recomputed_from_scratch_matches(self, typed_cls, factory): + """ + Recomputing the hash from the projected params/children (not forwarding the + stored hash) still equals the original — proves params are losslessly captured. + """ + ci = factory() + rebuilt = typed_cls.from_component_identifier(ci) + recomputed = ComponentIdentifier( + class_name=rebuilt.class_name, + class_module=rebuilt.class_module, + params=rebuilt.params, + children=rebuilt.children, + ) + assert recomputed.hash == ci.hash + + def test_promote_is_pass_through_for_same_type(self): + td = TargetIdentifier.from_component_identifier(_target_identifier()) + assert TargetIdentifier.from_component_identifier(td) is td + + +class TestTargetIdentifier: + """Promotion of well-known target params; capabilities are intentionally not projected.""" + + def test_promoted_fields(self): + td = TargetIdentifier.from_component_identifier(_target_identifier()) + assert td.endpoint == "https://example.openai.azure.com" + assert td.model_name == "gpt-4o" + assert td.underlying_model_name == "gpt-4o" + assert td.temperature == 0.7 + assert td.top_p == 0.9 + assert td.max_requests_per_minute == 60 + + def test_unknown_params_stay_in_params(self): + td = TargetIdentifier.from_component_identifier(_target_identifier()) + assert td.params["custom_thing"] == "keep-me" + # Promoted params are mirrored into params (so hashing/serialization is identical). + assert td.params["endpoint"] == "https://example.openai.azure.com" + + def test_capabilities_not_projected_as_typed_field(self): + # Capabilities describe a target but are deliberately not part of its + # typed identity, so there is no ``capabilities`` field. + assert "capabilities" not in TargetIdentifier.model_fields + + def test_inner_targets_typed_as_targets(self): + td = TargetIdentifier.from_component_identifier(_round_robin_identifier()) + inner = td.targets + assert isinstance(inner, list) + assert all(isinstance(child, TargetIdentifier) for child in inner) + assert inner[0].endpoint == "https://a" + + +class TestConverterIdentifier: + """Promotion of converter input/output types.""" + + def test_promoted_fields(self): + cd = ConverterIdentifier.from_component_identifier(_converter_identifier()) + assert cd.supported_input_types == ["text"] + assert cd.supported_output_types == ["text"] + assert cd.params["some_option"] == 3 + + def test_promoted_children_typed_per_field(self): + target_child = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + params={"endpoint": "https://obj", "model_name": "gpt-4o"}, + ) + sub_converter_child = ComponentIdentifier( + class_name="Base64Converter", + class_module="pyrit.prompt_converter.base64_converter", + params={"supported_input_types": ["text"]}, + ) + ci = ComponentIdentifier( + class_name="LLMGenericTextConverter", + class_module="pyrit.prompt_converter.llm_generic_text_converter", + params={}, + children={ + "converter_target": target_child, + "sub_converters": [sub_converter_child], + }, + ) + cd = ConverterIdentifier.from_component_identifier(ci) + assert isinstance(cd.converter_target, TargetIdentifier) + assert cd.converter_target.endpoint == "https://obj" + assert isinstance(cd.sub_converters, list) + assert all(isinstance(c, ConverterIdentifier) for c in cd.sub_converters) + assert cd.hash == ci.hash + + +class TestScorerIdentifier: + """Promotion of the scorer type discriminator and child target.""" + + def test_promoted_fields(self): + sd = ScorerIdentifier.from_component_identifier(_scorer_with_child_identifier()) + assert sd.scorer_type == "float_scale" + assert isinstance(sd.prompt_target, TargetIdentifier) + assert sd.prompt_target.endpoint == "https://c" + + +class TestDirectConstruction: + """Building a typed identifier by hand yields a valid ComponentIdentifier.""" + + def test_hand_built_target(self): + td = TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + endpoint="https://hand", + model_name="gpt-4o", + ) + assert td.class_name == "OpenAIChatTarget" + assert td.params["endpoint"] == "https://hand" + assert td.params["model_name"] == "gpt-4o" + assert td.hash is not None + # A hand-built typed identifier hashes identically to the plain projection. + plain = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + params={"endpoint": "https://hand", "model_name": "gpt-4o"}, + ) + assert td.hash == plain.hash + + def test_none_promoted_fields_are_dropped(self): + td = TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + endpoint="https://hand", + ) + assert "temperature" not in td.params + assert "top_p" not in td.params + + def test_of_with_promoted_kwargs(self): + class _Obj: + pass + + sd = ScorerIdentifier.of(_Obj(), scorer_type="true_false") + assert sd.scorer_type == "true_false" + assert sd.params["scorer_type"] == "true_false" + + +class TestCompositeIdentifiers: + """Attack / technique / atomic identifiers compose typed children.""" + + def test_attack_identifier_children_typed(self): + objective_target = TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + endpoint="https://obj", + ) + scorer = ScorerIdentifier( + class_name="SelfAskScaleScorer", + class_module="pyrit.score.self_ask_scale_scorer", + scorer_type="float_scale", + ) + attack = AttackIdentifier( + class_name="RedTeamingAttack", + class_module="pyrit.executor.attack.red_teaming", + objective_target=objective_target, + objective_scorer=scorer, + ) + assert attack.children["objective_target"].hash == objective_target.hash + assert attack.children["objective_scorer"].hash == scorer.hash + + rebuilt = AttackIdentifier.from_component_identifier(ComponentIdentifier.model_validate(attack.model_dump())) + assert isinstance(rebuilt.objective_target, TargetIdentifier) + assert isinstance(rebuilt.objective_scorer, ScorerIdentifier) + assert rebuilt.hash == attack.hash + + def test_atomic_identifier_empty_seed_list_preserved(self): + attack = AttackIdentifier( + class_name="RedTeamingAttack", + class_module="pyrit.executor.attack.red_teaming", + ) + technique = AttackTechniqueIdentifier( + class_name="AttackTechnique", + class_module="pyrit.scenario.core.attack_technique", + attack=attack, + ) + atomic = AtomicAttackIdentifier( + class_name="AtomicAttack", + class_module="pyrit.scenario.core.atomic_attack", + attack_technique=technique, + seed_identifiers=[], + ) + # An explicitly-set empty list is preserved in children (hash-affecting). + assert atomic.children["seed_identifiers"] == [] + plain = ComponentIdentifier( + class_name="AtomicAttack", + class_module="pyrit.scenario.core.atomic_attack", + children={"attack_technique": technique, "seed_identifiers": []}, + ) + assert atomic.hash == plain.hash + + def test_seed_identifier_promoted_fields(self): + sid = SeedIdentifier( + class_name="Seed", + class_module="pyrit.models.seeds.seed", + value="hello", + value_sha256="abc", + dataset_name="ds", + is_general_technique=False, + ) + assert sid.params == { + "value": "hello", + "value_sha256": "abc", + "dataset_name": "ds", + "is_general_technique": False, + } diff --git a/tests/unit/output/attack_result/test_markdown.py b/tests/unit/output/attack_result/test_markdown.py index b61081867a..ee8ffc4b4d 100644 --- a/tests/unit/output/attack_result/test_markdown.py +++ b/tests/unit/output/attack_result/test_markdown.py @@ -8,6 +8,7 @@ from pyrit.memory import MemoryInterface from pyrit.models import ( + AtomicAttackIdentifier, AttackOutcome, AttackResult, ComponentIdentifier, @@ -15,7 +16,6 @@ Message, MessagePiece, Score, - build_atomic_attack_identifier, ) from pyrit.models.conversation_reference import ConversationReference from pyrit.output.attack_result.markdown import MarkdownAttackResultMemoryPrinter @@ -57,7 +57,7 @@ def printer(patch_central_database): def attack_result(): return AttackResult( objective="Test objective", - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=_attack_id()), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=_attack_id()), conversation_id="conv-main", executed_turns=3, execution_time_ms=1500, diff --git a/tests/unit/output/attack_result/test_pretty.py b/tests/unit/output/attack_result/test_pretty.py index dd7d02c6d7..c493c27fdc 100644 --- a/tests/unit/output/attack_result/test_pretty.py +++ b/tests/unit/output/attack_result/test_pretty.py @@ -7,6 +7,7 @@ from pyrit.memory import MemoryInterface from pyrit.models import ( + AtomicAttackIdentifier, AttackOutcome, AttackResult, ComponentIdentifier, @@ -14,7 +15,6 @@ Message, MessagePiece, Score, - build_atomic_attack_identifier, ) from pyrit.models.conversation_reference import ConversationReference from pyrit.output.attack_result.pretty import PrettyAttackResultMemoryPrinter @@ -57,7 +57,7 @@ def printer(patch_central_database): def attack_result(): return AttackResult( objective="Test objective", - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=_attack_id()), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=_attack_id()), conversation_id="conv-main", executed_turns=3, execution_time_ms=1500, diff --git a/tests/unit/prompt_converter/test_add_image_text_converter.py b/tests/unit/prompt_converter/test_add_image_text_converter.py index a372fdb9b7..410a246500 100644 --- a/tests/unit/prompt_converter/test_add_image_text_converter.py +++ b/tests/unit/prompt_converter/test_add_image_text_converter.py @@ -216,7 +216,7 @@ def test_add_image_text_converter_bounding_box_identifier(large_sample_image): ) identifier = converter.get_identifier() params = identifier.params - assert params["bounding_box"] == (100, 100, 400, 300) + assert params["bounding_box"] == [100, 100, 400, 300] assert params["rotation"] == 10.0 assert params["center_text"] is True assert params["font_size_min"] == 8 diff --git a/tests/unit/prompt_target/target/test_prompt_target.py b/tests/unit/prompt_target/target/test_prompt_target.py index 651fe5c770..77168f101c 100644 --- a/tests/unit/prompt_target/target/test_prompt_target.py +++ b/tests/unit/prompt_target/target/test_prompt_target.py @@ -557,7 +557,7 @@ async def test_no_warning_when_message_count_unchanged(): # --------------------------------------------------------------------------- -# _create_identifier — target configuration in the identifier +# _create_identifier — capabilities are NOT part of the identifier # --------------------------------------------------------------------------- @@ -581,7 +581,7 @@ def _make_identifier_target( @pytest.mark.usefixtures("patch_central_database") -def test_identifier_includes_capability_params(): +def test_identifier_excludes_capability_params(): target = _make_identifier_target( capabilities=TargetCapabilities( supports_multi_turn=True, @@ -594,36 +594,25 @@ def test_identifier_includes_capability_params(): ) params = target.get_identifier().params - target_config = params["target_configuration"] - capabilities = target_config["capabilities"] - # Config-derived fields are nested under ``target_configuration``, not - # spread at the top level — guards against accidental re-flattening. + # Capabilities can change with deployment configuration, so they are + # deliberately not part of a target's identity. + assert "target_configuration" not in params assert "supports_multi_turn" not in params - assert set(target_config.keys()) == {"capabilities", "capability_policy", "normalization_pipeline"} - - assert capabilities["supports_multi_turn"] is True - assert capabilities["supports_multi_message_pieces"] is True - assert capabilities["supports_json_schema"] is True - assert capabilities["supports_json_output"] is True - assert capabilities["supports_editable_history"] is False - assert capabilities["supports_system_prompt"] is True - assert capabilities["input_modalities"] == [["text"]] - assert capabilities["output_modalities"] == [["text"]] - assert isinstance(target_config["capability_policy"], dict) - assert isinstance(target_config["normalization_pipeline"], list) @pytest.mark.usefixtures("patch_central_database") -def test_identifier_differs_when_capabilities_differ(): +def test_identifier_same_when_capabilities_differ(): a = _make_identifier_target(capabilities=TargetCapabilities(supports_json_schema=False)) b = _make_identifier_target(capabilities=TargetCapabilities(supports_json_schema=True)) - assert a.get_identifier().hash != b.get_identifier().hash + # Capabilities are not part of identity, so differing capabilities alone + # must not change the identifier hash. + assert a.get_identifier().hash == b.get_identifier().hash @pytest.mark.usefixtures("patch_central_database") -def test_identifier_differs_when_policy_differs(): +def test_identifier_same_when_policy_differs(): capabilities = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) a = _make_identifier_target( capabilities=capabilities, @@ -644,7 +633,8 @@ def test_identifier_differs_when_policy_differs(): ), ) - assert a.get_identifier().hash != b.get_identifier().hash + # Handling policy is part of the (non-identity) configuration, not identity. + assert a.get_identifier().hash == b.get_identifier().hash @pytest.mark.usefixtures("patch_central_database") @@ -663,7 +653,7 @@ def test_identifier_is_deterministic_across_instances(): @pytest.mark.usefixtures("patch_central_database") -def test_identifier_differs_when_normalizer_overrides_differ(): +def test_identifier_same_when_normalizer_overrides_differ(): from pyrit.message_normalizer import GenericSystemSquashNormalizer, MessageListNormalizer from pyrit.models import Message from pyrit.prompt_target.common.target_capabilities import CapabilityName @@ -699,7 +689,8 @@ async def normalize_async(self, messages): # pragma: no cover - not exercised custom_configuration=custom_cfg, ) - assert a.get_identifier().hash != b.get_identifier().hash + # The resolved normalization pipeline is configuration, not identity. + assert a.get_identifier().hash == b.get_identifier().hash def test_apply_capabilities_replaces_capabilities_and_preserves_policy(patch_central_database): diff --git a/tests/unit/prompt_target/target/test_target_capabilities.py b/tests/unit/prompt_target/target/test_target_capabilities.py index 3ea7cf3984..30c47ce1ca 100644 --- a/tests/unit/prompt_target/target/test_target_capabilities.py +++ b/tests/unit/prompt_target/target/test_target_capabilities.py @@ -11,6 +11,7 @@ CapabilityName, TargetCapabilities, UnsupportedCapabilityBehavior, + get_known_capabilities, ) from pyrit.prompt_target.common.target_configuration import TargetConfiguration @@ -397,36 +398,36 @@ def test_custom_configuration_override_modalities(self): class TestGetKnownCapabilities: - """Test TargetCapabilities.get_known_capabilities for every recognized model.""" + """Test get_known_capabilities for every recognized model.""" def test_gpt_4o_supports_multi_turn_and_json_output(self): - caps = TargetCapabilities.get_known_capabilities("gpt-4o") + caps = get_known_capabilities("gpt-4o") assert caps is not None assert caps.supports_multi_turn is True assert caps.supports_multi_message_pieces is True assert caps.supports_json_output is True def test_gpt_4o_does_not_set_json_schema_or_editable_history(self): - caps = TargetCapabilities.get_known_capabilities("gpt-4o") + caps = get_known_capabilities("gpt-4o") assert caps is not None assert caps.supports_json_schema is False assert caps.supports_editable_history is True def test_gpt_4o_input_modalities_include_text_image_and_combined(self): - caps = TargetCapabilities.get_known_capabilities("gpt-4o") + caps = get_known_capabilities("gpt-4o") assert caps is not None assert frozenset({"text"}) in caps.input_modalities assert frozenset({"image_path"}) in caps.input_modalities assert frozenset({"text", "image_path"}) in caps.input_modalities def test_gpt_4o_output_modalities_are_text_only(self): - caps = TargetCapabilities.get_known_capabilities("gpt-4o") + caps = get_known_capabilities("gpt-4o") assert caps is not None assert caps.output_modalities == frozenset({frozenset({"text"})}) def test_gpt_5_returns_json_schema_and_json_output(self): for model in ["gpt-5", "gpt-5.1", "gpt-5.4"]: - caps = TargetCapabilities.get_known_capabilities(model) + caps = get_known_capabilities(model) assert caps is not None, f"Expected caps for {model}" assert caps.supports_multi_turn is True assert caps.supports_multi_message_pieces is True @@ -435,7 +436,7 @@ def test_gpt_5_returns_json_schema_and_json_output(self): def test_gpt_5_input_modalities_include_text_image_path_and_combined(self): for model in ["gpt-5", "gpt-5.1", "gpt-5.4"]: - caps = TargetCapabilities.get_known_capabilities(model) + caps = get_known_capabilities(model) assert caps is not None assert frozenset({"text"}) in caps.input_modalities assert frozenset({"image_path"}) in caps.input_modalities @@ -443,12 +444,12 @@ def test_gpt_5_input_modalities_include_text_image_path_and_combined(self): def test_gpt_5_output_modalities_are_text_only(self): for model in ["gpt-5", "gpt-5.1", "gpt-5.4"]: - caps = TargetCapabilities.get_known_capabilities(model) + caps = get_known_capabilities(model) assert caps is not None assert caps.output_modalities == frozenset({frozenset({"text"})}) def test_gpt_realtime_1_5_returns_multi_turn_text_defaults(self): - caps = TargetCapabilities.get_known_capabilities("gpt-realtime-1.5") + caps = get_known_capabilities("gpt-realtime-1.5") assert caps is not None assert caps.supports_multi_turn is True assert caps.supports_multi_message_pieces is True @@ -459,13 +460,13 @@ def test_gpt_realtime_1_5_returns_multi_turn_text_defaults(self): assert frozenset({"audio_path"}) in caps.output_modalities def test_tts_returns_text_input_audio_output(self): - caps = TargetCapabilities.get_known_capabilities("tts") + caps = get_known_capabilities("tts") assert caps is not None assert caps.input_modalities == frozenset({frozenset(["text"])}) assert caps.output_modalities == frozenset({frozenset({"audio_path"})}) def test_sora_2_input_modalities_include_text_image_path_and_combined(self): - caps = TargetCapabilities.get_known_capabilities("sora-2") + caps = get_known_capabilities("sora-2") assert caps is not None assert caps.supports_multi_turn is True assert caps.supports_multi_message_pieces is True @@ -474,16 +475,16 @@ def test_sora_2_input_modalities_include_text_image_path_and_combined(self): assert frozenset({"text", "image_path"}) in caps.input_modalities def test_sora_2_output_modalities_include_video_and_audio(self): - caps = TargetCapabilities.get_known_capabilities("sora-2") + caps = get_known_capabilities("sora-2") assert caps is not None assert frozenset({"video_path"}) in caps.output_modalities assert frozenset({"audio_path", "video_path"}) in caps.output_modalities def test_unknown_model_returns_none(self): - assert TargetCapabilities.get_known_capabilities("unknown-model-xyz") is None + assert get_known_capabilities("unknown-model-xyz") is None def test_empty_string_returns_none(self): - assert TargetCapabilities.get_known_capabilities("") is None + assert get_known_capabilities("") is None @pytest.mark.usefixtures("patch_central_database") @@ -513,7 +514,7 @@ def test_returns_known_config_when_model_is_recognized(self): custom_config = TargetConfiguration(capabilities=TargetCapabilities()) cls = self._make_target_class(default_config=custom_config) result = cls.get_default_configuration("gpt-4o") - expected = TargetCapabilities.get_known_capabilities("gpt-4o") + expected = get_known_capabilities("gpt-4o") assert result.capabilities == expected def test_returns_class_default_and_warns_when_model_is_unrecognized(self): diff --git a/tests/unit/prompt_target/target/test_target_configuration.py b/tests/unit/prompt_target/target/test_target_configuration.py index 9e1efe4caa..36f0c62fc5 100644 --- a/tests/unit/prompt_target/target/test_target_configuration.py +++ b/tests/unit/prompt_target/target/test_target_configuration.py @@ -1,8 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from dataclasses import fields - import pytest from pyrit.message_normalizer import ( @@ -303,8 +301,8 @@ def test_capabilities_to_identifier_params_includes_all_fields(): params = TargetConfiguration._capabilities_to_identifier_params(caps) - # Every dataclass field on TargetCapabilities must appear in the params. - assert set(params.keys()) == {f.name for f in fields(caps)} + # Every model field on TargetCapabilities must appear in the params. + assert set(params.keys()) == set(type(caps).model_fields) def test_capabilities_to_identifier_params_scalar_fields_passthrough(): diff --git a/tests/unit/registry/test_converter_registry.py b/tests/unit/registry/test_converter_registry.py index 7fa2de4599..a4b6262254 100644 --- a/tests/unit/registry/test_converter_registry.py +++ b/tests/unit/registry/test_converter_registry.py @@ -209,7 +209,7 @@ def test_build_metadata_includes_supported_input_types(self): self.registry.register_instance(converter, name="text_converter") metadata = self.registry.list_metadata() - assert metadata[0].params["supported_input_types"] == ("text",) + assert metadata[0].params["supported_input_types"] == ["text"] def test_build_metadata_includes_supported_output_types(self): """Test that metadata includes supported_output_types in params.""" @@ -217,7 +217,7 @@ def test_build_metadata_includes_supported_output_types(self): self.registry.register_instance(converter, name="text_converter") metadata = self.registry.list_metadata() - assert metadata[0].params["supported_output_types"] == ("text",) + assert metadata[0].params["supported_output_types"] == ["text"] def test_build_metadata_is_component_identifier(self): """Test that metadata is the converter's ComponentIdentifier.""" @@ -234,8 +234,8 @@ def test_build_metadata_different_modalities(self): self.registry.register_instance(converter, name="image_converter") metadata = self.registry.list_metadata() - assert metadata[0].params["supported_input_types"] == ("image_path",) - assert metadata[0].params["supported_output_types"] == ("text",) + assert metadata[0].params["supported_input_types"] == ["image_path"] + assert metadata[0].params["supported_output_types"] == ["text"] assert metadata[0].class_name == "MockImageConverter" diff --git a/tests/unit/scenario/benchmark/test_adversarial.py b/tests/unit/scenario/benchmark/test_adversarial.py index d0693c691e..bebe095b11 100644 --- a/tests/unit/scenario/benchmark/test_adversarial.py +++ b/tests/unit/scenario/benchmark/test_adversarial.py @@ -131,7 +131,11 @@ def _register_mock_factory(*, name: str, tags: list[str] | None = None, seed_tec factory.uses_adversarial = True factory.strategy_tags = tags if tags is not None else ["core", "light"] factory.seed_technique = seed_technique - factory.create.return_value = MagicMock(name="AttackTechnique") + technique_instance = MagicMock(name="AttackTechnique") + technique_instance.get_identifier.return_value = ComponentIdentifier( + class_name="MockTechnique", class_module="pyrit.test" + ) + factory.create.return_value = technique_instance factory.attack_class = MagicMock(__name__=name) AttackTechniqueRegistry.get_registry_singleton().register_from_factories([factory]) return factory diff --git a/tests/unit/scenario/core/test_atomic_attack.py b/tests/unit/scenario/core/test_atomic_attack.py index af28362fc0..5b0242b783 100644 --- a/tests/unit/scenario/core/test_atomic_attack.py +++ b/tests/unit/scenario/core/test_atomic_attack.py @@ -12,6 +12,7 @@ from pyrit.executor.attack import AttackExecutor, AttackStrategy from pyrit.executor.attack.core import AttackExecutorResult from pyrit.models import ( + AtomicAttackIdentifier, AttackOutcome, AttackResult, ComponentIdentifier, @@ -19,7 +20,6 @@ SeedGroup, SeedObjective, SeedPrompt, - build_atomic_attack_identifier, ) from pyrit.scenario import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique @@ -28,7 +28,9 @@ @pytest.fixture def mock_attack(): """Create a mock AttackStrategy for testing.""" - return MagicMock(spec=AttackStrategy) + attack = MagicMock(spec=AttackStrategy) + attack.get_identifier.return_value = ComponentIdentifier(class_name="MockAttack", class_module="pyrit.test") + return attack @pytest.fixture @@ -789,7 +791,7 @@ async def test_enrichment_populates_atomic_attack_identifier(self, mock_attack): objective="obj1", outcome=AttackOutcome.SUCCESS, executed_turns=1, - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=attack_id), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=attack_id), ) atomic = AtomicAttack( @@ -844,7 +846,7 @@ async def test_enrichment_skips_out_of_range_index(self, mock_attack): objective="obj1", outcome=AttackOutcome.SUCCESS, executed_turns=1, - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=attack_id), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=attack_id), ) atomic = AtomicAttack( @@ -884,7 +886,7 @@ async def test_enrichment_includes_all_seeds(self, mock_attack): objective="obj1", outcome=AttackOutcome.SUCCESS, executed_turns=1, - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=attack_id), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=attack_id), ) atomic = AtomicAttack( @@ -927,14 +929,14 @@ async def test_enrichment_maps_multiple_results_to_correct_seed_groups(self, moc objective="obj1", outcome=AttackOutcome.SUCCESS, executed_turns=1, - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=attack_id), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=attack_id), ), AttackResult( conversation_id="c2", objective="obj2", outcome=AttackOutcome.SUCCESS, executed_turns=1, - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=attack_id), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=attack_id), ), ] @@ -973,7 +975,7 @@ async def test_enrichment_persists_to_db(self, mock_attack): outcome=AttackOutcome.SUCCESS, executed_turns=1, attack_result_id="00000000-0000-0000-0000-000000000001", - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=attack_id), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=attack_id), ) atomic = AtomicAttack( @@ -1014,7 +1016,7 @@ async def test_enrichment_skips_db_update_when_no_attack_result_id(self, mock_at outcome=AttackOutcome.SUCCESS, executed_turns=1, attack_result_id="", - atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=attack_id), + atomic_attack_identifier=AtomicAttackIdentifier.build(attack_identifier=attack_id), ) atomic = AtomicAttack( diff --git a/tests/unit/scenario/scenarios/adaptive/test_dispatcher.py b/tests/unit/scenario/scenarios/adaptive/test_dispatcher.py index 0ab722f224..e8be54e66e 100644 --- a/tests/unit/scenario/scenarios/adaptive/test_dispatcher.py +++ b/tests/unit/scenario/scenarios/adaptive/test_dispatcher.py @@ -261,5 +261,5 @@ async def test_predicted_hash_matches_persisted_row(self, sqlite_instance): assert stamped_hash == predicted_hash, ( f"Selector-side eval_hash ({predicted_hash}) drifted from executor-stamped " f"eval_hash ({stamped_hash}) on persisted row {row.id}. " - f"compute_inner_attack_eval_hash and build_atomic_attack_identifier must agree." + f"compute_inner_attack_eval_hash and AtomicAttackIdentifier.build must agree." ) diff --git a/tests/unit/score/test_insecure_code_scorer.py b/tests/unit/score/test_insecure_code_scorer.py index 40b1d0dc36..44637c25c0 100644 --- a/tests/unit/score/test_insecure_code_scorer.py +++ b/tests/unit/score/test_insecure_code_scorer.py @@ -13,7 +13,9 @@ @pytest.fixture def mock_chat_target(patch_central_database): - return MagicMock(spec=PromptTarget) + target = MagicMock(spec=PromptTarget) + target.get_identifier.return_value = ComponentIdentifier(class_name="MockChatTarget", class_module="mock") + return target async def test_insecure_code_scorer_valid_response(mock_chat_target): diff --git a/tests/unit/score/test_self_ask_question_answer_scorer.py b/tests/unit/score/test_self_ask_question_answer_scorer.py index ba57bf9dcc..d5cfb2ad24 100644 --- a/tests/unit/score/test_self_ask_question_answer_scorer.py +++ b/tests/unit/score/test_self_ask_question_answer_scorer.py @@ -12,7 +12,9 @@ @pytest.fixture def mock_chat_target(patch_central_database): - return MagicMock(spec=PromptTarget) + target = MagicMock(spec=PromptTarget) + target.get_identifier.return_value = ComponentIdentifier(class_name="MockChatTarget", class_module="mock") + return target async def test_score_async_returns_score_from_unvalidated(mock_chat_target): diff --git a/tests/unit/setup/test_load_default_datasets.py b/tests/unit/setup/test_load_default_datasets.py index 84c4b10a26..f3ecd53d5f 100644 --- a/tests/unit/setup/test_load_default_datasets.py +++ b/tests/unit/setup/test_load_default_datasets.py @@ -139,9 +139,14 @@ async def test_all_required_datasets_available_in_seed_provider(self, populated_ # Patch OpenAIChatTarget at the fallback construction site so registry # introspection does not depend on OPENAI_CHAT_MODEL or other env vars. + from pyrit.models.identifiers import ComponentIdentifier from pyrit.score import TrueFalseScorer fallback_target = MagicMock() + fallback_target.get_identifier.return_value = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + ) fallback_scorer = MagicMock(spec=TrueFalseScorer) with ( patch("pyrit.scenario.core.scenario_target_defaults.OpenAIChatTarget", return_value=fallback_target), diff --git a/tests/unit/setup/test_scorer_initializer.py b/tests/unit/setup/test_scorer_initializer.py index 2aeb6883eb..67884d284a 100644 --- a/tests/unit/setup/test_scorer_initializer.py +++ b/tests/unit/setup/test_scorer_initializer.py @@ -557,9 +557,14 @@ def _register_mock_target(self, *, name: str, underlying_model: str = "gpt-4o") def _register_mock_rr_target(self, *, name: str) -> MagicMock: """Register a mock RoundRobinTarget under the given name.""" + from pyrit.models.identifiers import ComponentIdentifier from pyrit.prompt_target import RoundRobinTarget rr_mock = MagicMock(spec=RoundRobinTarget) + rr_mock.get_identifier.return_value = ComponentIdentifier( + class_name="RoundRobinTarget", + class_module="pyrit.prompt_target.round_robin_target", + ) registry = TargetRegistry.get_registry_singleton() registry.register_instance(rr_mock, name=name) return rr_mock