diff --git a/pyoaev/apis/signature.py b/pyoaev/apis/signature.py index bed3e9b..2450c2c 100644 --- a/pyoaev/apis/signature.py +++ b/pyoaev/apis/signature.py @@ -10,7 +10,11 @@ from pyoaev import exceptions as exc from pyoaev.base import RESTManager, RESTObject from pyoaev.exceptions import SignatureTransmissionError -from pyoaev.signatures.models import SignatureCallbackPayload +from pyoaev.signatures.models import ( + ExecutionDetails, + SignatureCallbackPayload, + SignatureOutputStructure, +) class Signature(RESTObject): @@ -22,20 +26,16 @@ class Signature(RESTObject): class SignatureApiManager(RESTManager): """Manage signature callback transport to the OpenAEV backend. - Handles payload validation, auto-chunking, and retry with exponential backoff. + Handles payload validation and retry with exponential backoff. """ _path = "/injects" _obj_cls = Signature - DEFAULT_MAX_PAYLOAD_SIZE = 1_048_576 # 1 MiB + DEFAULT_MAX_PAYLOAD_SIZE = 5_242_880 # 5 MiB MAX_RETRIES = 3 RETRY_DELAYS = (1, 2, 4) - _CHUNK_METADATA_RESERVE = len( - ',"chunk_index":99999,"total_chunks":99999,"phase":"execution_complete_extended"' - ) - def __init__(self, openaev: "Any", parent: "Any" = None) -> None: """Initialize the signature API manager. @@ -68,8 +68,10 @@ def logger(self, value: logging.Logger) -> None: def send_signatures( self, inject_id: str, - phase: str, - signatures: dict[str, Any], + signatures: SignatureOutputStructure, + execution_details: ExecutionDetails, + max_payload_size: int | None = None, + logger: logging.Logger | None = None, ) -> None: """Send compiled signatures to the inject callback endpoint. @@ -77,38 +79,55 @@ def send_signatures( Args: inject_id: Inject UUID. - phase: Execution phase (e.g. 'execution_complete'). signatures: Full signatures dict (canonical or flat, grouped on the fly). + execution_details: Raises: SignatureTransmissionError: Validation failed, 4xx hit, or retries exhausted. """ - self._logger.debug("send_signatures inject_id=%s phase=%s", inject_id, phase) - signatures = self._normalize_signature_payload(signatures) - payload = self._build_callback_payload(signatures, phase=phase) + effective_max_size = ( + max_payload_size if max_payload_size is not None else self._max_payload_size + ) + effective_logger = logger if logger is not None else self._logger - serialized = json.dumps(payload, separators=(",", ":")).encode() + effective_logger.debug( + "send_signatures inject_id=%s, execution_status=%s, execution_action=%s", + inject_id, + execution_details.execution_status, + execution_details.execution_action, + ) + signatures.normalize_signature_payload() + payload = self._build_callback_payload( + signatures=signatures, execution_details=execution_details + ) + payload_size = len(json.dumps(payload).encode("utf-8")) - if len(serialized) <= self._max_payload_size: - self._send_with_retry(inject_id, payload) - else: - self._send_chunked(inject_id, payload["expectation_signature"], phase=phase) + if payload_size <= effective_max_size: + self._send_with_retry(inject_id, payload, logger=effective_logger) + return + + sig_data = json.loads(payload["execution_output_structured"]) + targets = sig_data["signatures"]["targets"] + envelopes = self._split_into_envelopes( + payload, + sig_data, + targets, + max_payload_size=effective_max_size, + logger=effective_logger, + ) + for envelope in envelopes: + self._send_with_retry(inject_id, envelope, logger=effective_logger) def _build_callback_payload( self, - signatures: dict[str, Any], - *, - phase: str | None = None, - chunk_index: int | None = None, - total_chunks: int | None = None, + signatures: SignatureOutputStructure, + execution_details: ExecutionDetails, ) -> dict[str, Any]: """Validate and wrap signatures in the strict callback envelope. Args: signatures: The inner signatures body, already normalised. - phase: Execution phase string (e.g. 'execution_complete'). - chunk_index: 0-based index when chunking, None for single POSTs. - total_chunks: Chunk count when chunking, None for single POSTs. + execution_details: The execution metadata to be stored next to the signatures in the payload. Returns: The validated dict ready for wire transmission. @@ -117,150 +136,76 @@ def _build_callback_payload( SignatureTransmissionError: Envelope failed Pydantic validation. """ try: - envelope = SignatureCallbackPayload.model_validate( - { - "expectation_signature": signatures, - "phase": phase, - "chunk_index": chunk_index, - "total_chunks": total_chunks, - } + envelope = SignatureCallbackPayload.build_from_models( + signatures, execution_details ) except ValidationError as ve: raise SignatureTransmissionError( error_message=f"Invalid signatures payload: {ve}", ) from ve - return envelope.model_dump(mode="json", exclude_none=True) - - def _normalize_signature_payload( - self, signatures: dict[str, Any] - ) -> dict[str, Any]: - """Regroup signature_values by expectation_type within each target. + envelope_dict = envelope.model_dump(mode="json", exclude_none=True) + SignatureCallbackPayload.model_validate(envelope_dict) + return envelope_dict - Accepts flat or pre-grouped input and returns canonical grouped form. + def _split_into_envelopes( + self, + base_payload: dict[str, Any], + sig_data: dict[str, Any], + targets: list[dict[str, Any]], + max_payload_size: int | None = None, + logger: logging.Logger | None = None, + ) -> list[dict[str, Any]]: + effective_max = ( + max_payload_size if max_payload_size is not None else self._max_payload_size + ) + effective_logger = logger if logger is not None else self._logger - Args: - signatures: Raw signatures dict with any mix of flat and grouped entries. + envelopes: list[dict[str, Any]] = [] + current_targets: list[dict[str, Any]] = [] - Returns: - New dict where every signature_values list is in canonical grouped form. - """ - targets = signatures.get("targets") - if not targets: - return signatures - - normalized_targets: list[dict[str, Any]] = [] for target in targets: - sig_values = target.get("signature_values") - if not sig_values: - normalized_targets.append(target) - continue - - grouped: dict[str, list[dict[str, Any]]] = {} - order: list[str] = [] - - for entry in sig_values: - etype = entry.get("expectation_type") - if etype not in grouped: - grouped[etype] = [] - order.append(etype) - - if "values" in entry and isinstance(entry["values"], list): - grouped[etype].extend(entry["values"]) + trial_targets = current_targets + [target] + trial_envelope = self._build_envelope(base_payload, sig_data, trial_targets) + trial_size = len(json.dumps(trial_envelope).encode("utf-8")) + + if trial_size > effective_max: + if current_targets: + envelopes.append( + self._build_envelope(base_payload, sig_data, current_targets) + ) + current_targets = [target] else: - grouped[etype].append( - {k: v for k, v in entry.items() if k != "expectation_type"} + effective_logger.warning( + "Single target exceeds max_payload_size (%d bytes > %d limit). Sending oversized envelope.", + trial_size, + effective_max, ) + envelopes.append(trial_envelope) + current_targets = [] + else: + current_targets = trial_targets + + if current_targets: + envelopes.append( + self._build_envelope(base_payload, sig_data, current_targets) + ) - normalized_target = dict(target) - normalized_target["signature_values"] = [ - {"expectation_type": etype, "values": grouped[etype]} for etype in order - ] - normalized_targets.append(normalized_target) - - normalized = dict(signatures) - normalized["targets"] = normalized_targets - return normalized - - def _send_chunked( - self, inject_id: str, signatures: dict[str, Any], phase: str | None = None - ) -> None: - """Split targets across sequential POSTs, each tagged with chunk metadata. - - Args: - inject_id: Inject UUID for the callback path. - signatures: Normalised inner signatures body to partition. - phase: Execution phase forwarded to each chunk envelope. - - Raises: - SignatureTransmissionError: A single target alone exceeds max_payload_size. - """ - targets = signatures.get("targets", []) - if not targets: - payload = self._build_callback_payload(signatures, phase=phase) - size = len(json.dumps(payload, separators=(",", ":")).encode()) - if size > self._max_payload_size: - self._logger.warning( - "Payload of %d bytes exceeds max_payload_size %d but has no " - "'targets' key to chunk on; sending unchunked", - size, - self._max_payload_size, - ) - self._send_with_retry(inject_id, payload) - return - - budget = max(self._max_payload_size - self._CHUNK_METADATA_RESERVE, 0) - chunks: list[list[Any]] = [] - current_chunk: list[Any] = [] + return envelopes - for target in targets: - candidate = current_chunk + [target] - size = len( - json.dumps( - {"expectation_signature": {"targets": candidate}}, - separators=(",", ":"), - ).encode() - ) + def _build_envelope( + self, + base_payload: dict[str, Any], + sig_data: dict[str, Any], + targets_subset: list[dict[str, Any]], + ) -> dict[str, Any]: + subset_sig = dict(sig_data) + subset_sig["signatures"] = dict(sig_data["signatures"]) + subset_sig["signatures"]["targets"] = targets_subset - if size <= budget: - current_chunk.append(target) - continue - - if not current_chunk: - raise SignatureTransmissionError( - error_message=( - f"Single target payload of {size} bytes exceeds " - f"max_payload_size {self._max_payload_size}; cannot chunk further" - ), - ) - - chunks.append(current_chunk) - current_chunk = [target] - solo_size = len( - json.dumps( - {"expectation_signature": {"targets": [target]}}, - separators=(",", ":"), - ).encode() - ) - if solo_size > budget: - raise SignatureTransmissionError( - error_message=( - f"Single target payload of {solo_size} bytes exceeds " - f"max_payload_size {self._max_payload_size}; cannot chunk further" - ), - ) - - if current_chunk: - chunks.append(current_chunk) - - total_chunks = len(chunks) - for idx, chunk_targets in enumerate(chunks): - chunk_payload = self._build_callback_payload( - {"targets": chunk_targets}, - phase=phase, - chunk_index=idx, - total_chunks=total_chunks, - ) - self._send_with_retry(inject_id, chunk_payload) + envelope = dict(base_payload) + envelope["execution_output_structured"] = json.dumps(subset_sig) + SignatureCallbackPayload.model_validate(envelope) + return envelope @exc.on_http_error(exc.OpenAEVUpdateError) def callback( @@ -276,12 +221,15 @@ def callback( Returns: The parsed response from the backend. """ - path = f"{self.path}/{inject_id}/callback" + path = f"{self.path}/execution/callback/{inject_id}" result = self.openaev.http_post(path, post_data=data, **kwargs) return result def _send_with_retry( - self, inject_id: str, payload: dict[str, Any] + self, + inject_id: str, + payload: dict[str, Any], + logger: logging.Logger | None = None, ) -> dict[str, Any]: """Retry callback() with exponential backoff on 5xx, immediate raise on 4xx. @@ -297,6 +245,7 @@ def _send_with_retry( """ from pyoaev.exceptions import OpenAEVError + effective_logger = logger if logger is not None else self._logger last_error: Exception | None = None for attempt in range(self.MAX_RETRIES + 1): @@ -308,7 +257,7 @@ def _send_with_retry( body_str = "" if ex.response_body: body_str = ex.response_body.decode(errors="replace") - self._logger.error( + effective_logger.error( "Client error %d sending signatures: %s", status, body_str or ex.error_message, @@ -322,7 +271,7 @@ def _send_with_retry( last_error = ex if attempt < self.MAX_RETRIES: delay = self.RETRY_DELAYS[attempt] - self._logger.warning( + effective_logger.warning( "Retry %d/%d after %ds (HTTP %s): %s", attempt + 1, self.MAX_RETRIES, diff --git a/pyoaev/signatures/__init__.py b/pyoaev/signatures/__init__.py index d8a60cd..dc2e061 100644 --- a/pyoaev/signatures/__init__.py +++ b/pyoaev/signatures/__init__.py @@ -1,31 +1,37 @@ from pyoaev.signatures.models import ( CloudInjectorConfig, ExpectationSignatureGroup, - ExternalInjectorConfig, ExtraSignatureData, InjectorConfig, NetworkInjectorConfig, SignatureCallbackPayload, SignaturePayload, + SignatureTarget, SignatureValue, TargetSignatures, build_network_configs, ) from pyoaev.signatures.signature_manager import SignatureManager -from pyoaev.signatures.types import ExpectationType, MatchTypes, SignatureTypes +from pyoaev.signatures.types import ( + ExpectationType, + InjectExecutionActions, + MatchTypes, + SignatureTypes, +) __all__ = [ "CloudInjectorConfig", "ExpectationSignatureGroup", - "ExternalInjectorConfig", "ExpectationType", "ExtraSignatureData", "InjectorConfig", + "InjectExecutionActions", "MatchTypes", "NetworkInjectorConfig", "SignatureCallbackPayload", "SignatureManager", "SignaturePayload", + "SignatureTarget", "SignatureTypes", "SignatureValue", "TargetSignatures", diff --git a/pyoaev/signatures/models.py b/pyoaev/signatures/models.py index 553c7e2..b8a2c25 100644 --- a/pyoaev/signatures/models.py +++ b/pyoaev/signatures/models.py @@ -1,6 +1,9 @@ """Pydantic schemas pinning every shape SignatureManager touches.""" import ipaddress +import math +from collections import defaultdict +from datetime import datetime, timezone from typing import Any from pydantic import ( @@ -8,11 +11,40 @@ ConfigDict, Field, JsonValue, + TypeAdapter, + computed_field, field_validator, model_validator, ) -from pyoaev.signatures.types import ExpectationType +from pyoaev.signatures.types import ExpectationType, InjectExecutionActions + + +class ToolErrorInfo(BaseModel): + """Crash report. Non-zero exit code and a timestamp if the tool left one behind.""" + + model_config = ConfigDict(extra="allow") + + exit_code: int = 0 + crash_timestamp: str | None = None + + +class ToolTimeoutInfo(BaseModel): + """Timeout report. Whatever partial loot was rescued before the kill signal.""" + + model_config = ConfigDict(extra="allow") + + partial_results: list[str] = Field(default_factory=list) + + +class ToolOutput(BaseModel): + """Whatever the tool spat out: status, error info, timeout info, or injector extras.""" + + model_config = ConfigDict(extra="allow") + + status: str | None = None + error_info: ToolErrorInfo | None = None + timeout_info: ToolTimeoutInfo | None = None class SignatureValue(BaseModel): @@ -42,11 +74,11 @@ def is_expectation_type(cls, value: str) -> str: class ExtraSignatureData(BaseModel): """Format for extra signatures added to the default signatures""" - detection: dict[str, JsonValue] | None = Field(default_factory=dict) - prevention: dict[str, JsonValue] | None = Field(default_factory=dict) - vulnerability: dict[str, JsonValue] | None = Field(default_factory=dict) + detection: dict[str, JsonValue] = Field(default_factory=dict) + prevention: dict[str, JsonValue] = Field(default_factory=dict) + vulnerability: dict[str, JsonValue] = Field(default_factory=dict) - def get_extra(self, expectation_type: str): + def get_extra(self, expectation_type: str) -> dict[str, JsonValue]: if expectation_type.lower() == "detection": return self.detection if expectation_type.lower() == "prevention": @@ -58,11 +90,22 @@ def get_extra(self, expectation_type: str): ) +class SignatureTarget(BaseModel): + """Target identity on the wire.""" + + model_config = ConfigDict(extra="forbid") + + agent: str | None = None + asset: str | None = None + asset_group: str | None = None + + class TargetSignatures(BaseModel): """A target plus everything observed about it, grouped by expectation.""" model_config = ConfigDict(extra="allow") + signature_target: SignatureTarget signature_values: list[ExpectationSignatureGroup] @@ -74,24 +117,147 @@ class SignaturePayload(BaseModel): targets: list[TargetSignatures] +class SignatureOutputStructure(BaseModel): + """Structured output to be serialized as a str in the callback payload yet data has to follow model.""" + + model_config = ConfigDict(populate_by_name=True, extra="forbid") + + signatures: SignaturePayload + + def normalize_signature_payload(self) -> None: + """ + Regroup signature_values by expectation_type within each target. + """ + normalized_targets: list[TargetSignatures] = [] + + for target in self.signatures.targets: + if not target.signature_values: + normalized_targets.append(target) + continue + + grouped: dict[str, list[dict[str, Any]]] = defaultdict(list) + order: list[str] = [] + + for entry in target.signature_values: + if entry.expectation_type not in order: + order.append(entry.expectation_type) + grouped[entry.expectation_type].extend(entry.values) + + normalized_target = TargetSignatures( + signature_target=target.signature_target, + signature_values=[ + ExpectationSignatureGroup( + expectation_type=expectation_type, + values=grouped[expectation_type], + ) + for expectation_type in order + ], + ) + + normalized_targets.append(normalized_target) + + self.signatures.targets = normalized_targets + + +class ExecutionDetails(BaseModel): + """Helper to wrap the execution-related details for the callback payload""" + + model_config = ConfigDict(extra="forbid") + + start_time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + end_time: datetime | None = None + + execution_status: str = "unknown" + execution_action: InjectExecutionActions | None = None + + @computed_field + @property + def execution_message(self) -> str: + action = self.execution_action.value if self.execution_action else "unknown" + return f"Current action: {action} - Current status: {self.execution_status}" + + @computed_field + @property + def execution_duration(self) -> float: + try: + return (self.end_time - self.start_time).total_seconds() + except: + return 0.0 + + def post_execution_update(self, tool_output: ToolOutput, now: datetime) -> None: + """ + Update execution-related metadata according to tool output and now timestamp + """ + self.end_time = now + + if tool_output.error_info and tool_output.error_info.exit_code != 0: + self.execution_status = "failed" + if tool_output.error_info.crash_timestamp: + self.end_time = datetime.strptime( + tool_output.error_info.crash_timestamp, "%Y-%m-%dT%H:%M:%SZ" + ) + elif tool_output.timeout_info: + self.execution_status = "timeout" + elif tool_output.status == "partial": + self.execution_status = "partial" + else: + self.execution_status = "success" + + self.execution_action = InjectExecutionActions("complete") + + class SignatureCallbackPayload(BaseModel): - """Outer POST envelope. Pure ``{signatures}`` when unchunked, plus chunk fields when split.""" + """Outer POST envelope validated by ``SignatureApiManager`` before wire transmission.""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - expectation_signature: SignaturePayload - phase: str | None = None - chunk_index: int | None = None - total_chunks: int | None = None + execution_message: str + execution_output_structured: str | None = None + execution_status: str + execution_duration: int | None = None + execution_action: InjectExecutionActions | None = None + + @field_validator("execution_output_structured", mode="after") + @classmethod + def is_proper_signature_output_structure(cls, value: str) -> str | None: + if value is None: + return None + TypeAdapter(SignatureOutputStructure).validate_json(value) + return value + + @classmethod + def build_from_models( + cls, signatures: SignatureOutputStructure, execution_details: ExecutionDetails + ): + """Producing a SignatureCallbackPayload from the data of a SignatureOutputStructure and of a ExecutionDetails.""" + return cls( + execution_message=execution_details.execution_message, + execution_output_structured=signatures.model_dump_json(exclude_none=True), + execution_status=execution_details.execution_status, + execution_duration=( + math.ceil(execution_details.execution_duration) + if execution_details.execution_duration is not None + else None + ), + execution_action=execution_details.execution_action, + ) -class PreExecutionSignature(BaseModel): - """Pre-execution data dump. Field set varies by category: network, cloud, external.""" +class ExecutionSignature(BaseModel): + """ + Execution signature data. Field set varies by category: network, cloud. Plus outcome, end_time, and any partial results. + """ model_config = ConfigDict(extra="allow") # Timing always emitted at call time. - start_time: str | None = None + start_time: str = Field( + default_factory=lambda: datetime.now(timezone.utc).strftime( + "%Y-%m-%dT%H:%M:%SZ" + ) + ) + end_time: str | None = None + partial_results: list[str] | None = None # Network identity source_ipv4: str | None = None @@ -106,43 +272,15 @@ class PreExecutionSignature(BaseModel): cloud_region: str | None = None target_service: str | None = None - # External - query: str | None = None - - -class PostExecutionSignature(PreExecutionSignature): - """Post-execution view: pre-execution fields plus outcome, end_time, and any partial results.""" - - end_time: str | None = None - execution_status: str | None = None - partial_results: list[str] | None = None - - -class ToolErrorInfo(BaseModel): - """Crash report. Non-zero exit code and a timestamp if the tool left one behind.""" - - model_config = ConfigDict(extra="allow") - - exit_code: int = 0 - crash_timestamp: str | None = None - - -class ToolTimeoutInfo(BaseModel): - """Timeout report. Whatever partial loot was rescued before the kill signal.""" - - model_config = ConfigDict(extra="allow") - - partial_results: list[str] = [] + def post_execution_update(self, tool_output: ToolOutput, now: datetime) -> None: + """ """ + self.end_time = now.strftime("%Y-%m-%dT%H:%M:%SZ") + if tool_output.error_info and tool_output.error_info.crash_timestamp: + self.end_time = tool_output.error_info.crash_timestamp -class ToolOutput(BaseModel): - """Whatever the tool spat out: status, error info, timeout info, or injector extras.""" - - model_config = ConfigDict(extra="allow") - - status: str | None = None - error_info: ToolErrorInfo | None = None - timeout_info: ToolTimeoutInfo | None = None + if tool_output.timeout_info and tool_output.timeout_info.partial_results: + self.partial_results = tool_output.timeout_info.partial_results class NetworkInjectorConfig(BaseModel): @@ -159,7 +297,7 @@ class NetworkInjectorConfig(BaseModel): def check_one(cls, data): assert ( sum( - value != None + value is not None for key, value in data.items() if key in ["target_ipv4", "target_ipv6", "target_hostname"] ) @@ -179,17 +317,7 @@ class CloudInjectorConfig(BaseModel): target_service: str | None = None -class ExternalInjectorConfig(BaseModel): - """A single external scan target (e.g. Shodan): a query against an asset.""" - - model_config = ConfigDict(extra="forbid") - - query: str - target_ipv4: str | None = None - target_hostname: str | None = None - - -InjectorConfig = NetworkInjectorConfig | CloudInjectorConfig | ExternalInjectorConfig +InjectorConfig = NetworkInjectorConfig | CloudInjectorConfig # --------------------------------------------------------------------------- @@ -245,17 +373,16 @@ def build_network_configs( __all__ = [ "SignatureValue", "ExpectationSignatureGroup", + "SignatureTarget", "TargetSignatures", "SignaturePayload", "SignatureCallbackPayload", - "PreExecutionSignature", - "PostExecutionSignature", + "ExecutionSignature", "ToolErrorInfo", "ToolTimeoutInfo", "ToolOutput", "NetworkInjectorConfig", "CloudInjectorConfig", - "ExternalInjectorConfig", "InjectorConfig", "build_network_configs", ] diff --git a/pyoaev/signatures/signature_manager.py b/pyoaev/signatures/signature_manager.py index aaf101d..81db6ef 100644 --- a/pyoaev/signatures/signature_manager.py +++ b/pyoaev/signatures/signature_manager.py @@ -12,14 +12,15 @@ from pyoaev.exceptions import OpenAEVError from pyoaev.signatures.models import ( CloudInjectorConfig, + ExecutionDetails, + ExecutionSignature, ExpectationSignatureGroup, - ExternalInjectorConfig, ExtraSignatureData, InjectorConfig, NetworkInjectorConfig, - PostExecutionSignature, - PreExecutionSignature, + SignatureOutputStructure, SignaturePayload, + SignatureTarget, SignatureValue, TargetSignatures, ToolOutput, @@ -32,7 +33,7 @@ class SignatureManager: """End-to-end signature pipeline: compile, merge, transmit. One class, three jobs.""" - DEFAULT_MAX_PAYLOAD_SIZE = 1_048_576 # 1 MiB + DEFAULT_MAX_PAYLOAD_SIZE = 5_242_880 # 5 MiB def __init__( self, @@ -50,23 +51,23 @@ def _utcnow(self) -> datetime: """Current UTC time. Carved out so tests can pin the clock.""" return datetime.now(timezone.utc) - def compile_pre_execution_signatures( + def build_execution_signatures( self, config: InjectorConfig | list[InjectorConfig], - ) -> dict[str, Any] | list[dict[str, Any]]: + ) -> ExecutionSignature | list[ExecutionSignature]: """Build pre-execution signature dicts from one or more typed injector configs. The category is carried by the config type itself - (:class:`NetworkInjectorConfig`, :class:`CloudInjectorConfig`, - :class:`ExternalInjectorConfig`), so no separate ``category`` flag is needed. + (:class:`NetworkInjectorConfig`, :class:`CloudInjectorConfig`), + so no separate ``category`` flag is needed. Args: config: A single injector config or a homogeneous list of them. Multi-target injects must be expressed as a list. Returns: - One dict when a single config is given, otherwise a list of dicts in - input order. + One ExecutionSignature object when a single config is given, + otherwise a list of ExecutionSignature in input order. Raises: ValueError: Empty list, or mixed config types in a single call. @@ -74,107 +75,82 @@ def compile_pre_execution_signatures( """ configs = list(config) if isinstance(config, list) else [config] if not configs: - raise ValueError( - "compile_pre_execution_signatures requires at least one config" - ) + raise ValueError("build_execution_signatures requires at least one config") first_type = type(configs[0]) for c in configs: if not isinstance(c, first_type): raise ValueError( - "compile_pre_execution_signatures does not mix injector config types; " + "build_execution_signatures does not mix injector config types; " f"got {sorted({type(c).__name__ for c in configs})}" ) start_time = self._utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") - results = [self._compile_one(cfg, start_time) for cfg in configs] + results = [self._build_one(cfg, start_time) for cfg in configs] return results[0] if len(results) == 1 else results - def _compile_one(self, config: InjectorConfig, start_time: str) -> dict[str, Any]: + def _build_one(self, config: InjectorConfig, start_time: str) -> ExecutionSignature: """Project a single injector config into a flat pre-execution signature dict. Common pipeline for every category: 1. Seed the base dict with ``start_time`` and category-specific context - (network gets resolved source IPs; cloud/external add nothing). + (network gets resolved source IPs; cloud add nothing). 2. Layer the config's own fields on top. - 3. Run it through :class:`PreExecutionSignature` for validation + 3. Run it through :class:`ExecutionSignature` for validation and emit JSON-ready output stripped of ``None``\\ s. """ base: dict[str, Any] = {"start_time": start_time} base.update(self._source_context(config)) base.update(config.model_dump(exclude_none=True)) - return PreExecutionSignature(**base).model_dump(mode="json", exclude_none=True) + return ExecutionSignature(**base) def _source_context(self, config: InjectorConfig) -> dict[str, Any]: """Return the source identity bits injected for the config's category. Only network signatures need the running container's source IPs; - cloud and external rows have no source identity to carry. + cloud rows have no source identity to carry. """ if isinstance(config, NetworkInjectorConfig): return { "source_ipv4": self.resolve_container_ip(), "source_ipv6": self._cached_ipv6, } - if isinstance(config, (CloudInjectorConfig, ExternalInjectorConfig)): + if isinstance(config, CloudInjectorConfig): return {} raise TypeError(f"unsupported injector config type: {type(config).__name__}") - def compile_post_execution_signatures( + def post_execution_updates( self, - pre_signatures: dict[str, Any] | list[dict[str, Any]], + execution_details: ExecutionDetails, + execution_signatures: ExecutionSignature | list[ExecutionSignature], tool_output: dict[str, Any], - ) -> dict[str, Any] | list[dict[str, Any]]: - """Merge pre-execution dicts with the tool's verdict into post-execution dicts. - - Args: - pre_signatures: One pre-execution dict or a list of them. - tool_output: Tool result with optional `error_info` / `timeout_info` / `status`. - - Returns: - Same shape as `pre_signatures`, now carrying `end_time` and `execution_status`. + ) -> None: + """ + Update both execution details and execution signatures according to tool output """ - if isinstance(pre_signatures, list): - return [self._merge_post(sig, tool_output) for sig in pre_signatures] - return self._merge_post(pre_signatures, tool_output) - - def _merge_post( - self, pre_sig: dict[str, Any], tool_output: dict[str, Any] - ) -> dict[str, Any]: try: - tool = ToolOutput.model_validate(tool_output or {}) + tool_output = ToolOutput.model_validate(tool_output or {}) except ValidationError as exc: raise OpenAEVError( error_message=f"Invalid tool_output: {exc}", ) from exc - post = PostExecutionSignature.model_validate(pre_sig) - now = self._utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") - - if tool.error_info and tool.error_info.exit_code != 0: - post.execution_status = "failed" - post.end_time = tool.error_info.crash_timestamp or now - elif tool.timeout_info: - post.execution_status = "timeout" - post.end_time = now - if tool.timeout_info.partial_results: - post.partial_results = tool.timeout_info.partial_results - elif tool.status == "partial": - post.execution_status = "partial" - post.end_time = now - else: - post.execution_status = "success" - post.end_time = now + now = self._utcnow() - merged = post.model_dump(mode="json", exclude_none=True) + execution_details.post_execution_update(tool_output, now) - return merged + if isinstance(execution_signatures, list): + for exec_sig in execution_signatures: + exec_sig.post_execution_update(tool_output, now) + else: + execution_signatures.post_execution_update(tool_output, now) @staticmethod def build_payload( - post_signatures: dict[str, Any] | list[dict[str, Any]], + execution_signatures: ExecutionSignature | list[ExecutionSignature], + targets_meta: Any | list[Any], expectation_types: list[str], - extra_signatures: ExtraSignatureData | None = None, + extra_signatures: ExtraSignatureData = ExtraSignatureData(), ) -> dict[str, Any]: """Build the nested wire payload from flat post-execution signatures. @@ -182,7 +158,8 @@ def build_payload( and send_signatures input (nested wire format). Args: - post_signatures: A single post-execution dict or a list (multi-targets). + execution_signatures: A single post-execution ExecutionSignature or a list (multi-targets). + targets_meta: Target metadata dict(s) with keys like agent, asset, asset_group. expectation_types: The 1+ expectation type labels (e.g. ['DETECTION', 'PREVENTION']). extra_signatures: Optional mapping of expectation types to additional signature fields that will be merged into the base post_signatures. @@ -190,25 +167,21 @@ def build_payload( Returns: A payload dict ready for send_signatures. """ - if isinstance(post_signatures, dict): - post_signatures = [post_signatures] + if not isinstance(execution_signatures, list): + execution_signatures = [execution_signatures] + if isinstance(targets_meta, dict): + targets_meta = [targets_meta] * len(execution_signatures) targets = [] - for signature in post_signatures: + for signature, target in zip(execution_signatures, targets_meta): signature_values = [] - for expectation_type in expectation_types: - signature_data = signature.copy() + signature_data = signature.model_dump(exclude_none=True) signature_data.update(extra_signatures.get_extra(expectation_type)) - values = [ - SignatureValue( - signature_type=key, - signature_value=value, - ) + SignatureValue(signature_type=key, signature_value=value) for key, value in signature_data.items() ] - signature_values.append( ExpectationSignatureGroup( expectation_type=expectation_type, @@ -217,6 +190,11 @@ def build_payload( ) targets.append( TargetSignatures( + signature_target=SignatureTarget( + agent=target.agent_id, + asset=target.asset_id, + asset_group=target.asset_group_id, + ), signature_values=signature_values, ) ) @@ -226,32 +204,40 @@ def build_payload( def send_signatures( self, inject_id: str, - phase: str, + execution_details: ExecutionDetails, signatures: dict[str, Any], ) -> None: """Ship signatures to the callback endpoint via the Signature API manager. - Delegates transport (retry, chunking, validation) to ``client.signature``. + Constructs typed ``SignatureOutputStructure`` and ``ExecutionDetails`` + models, then delegates transport (retry, envelope splitting, validation) + to ``client.signature``. Args: inject_id: Inject UUID. - phase: Execution phase. - signatures: Full signatures dict, canonical or flat, both grouped on the fly. + execution_details: execution-related metadata as an ExecutionDetails object + signatures: Full signatures dict with a ``targets`` list. Raises: SignatureTransmissionError: Validation failed, 4xx hit, or retries exhausted. """ - self.client.signature.max_payload_size = self.max_payload_size - self.client.signature.logger = self.logger - self.client.signature.send_signatures(inject_id, phase, signatures) + sig_output = SignatureOutputStructure(signatures=SignaturePayload(**signatures)) + + self.client.signature.send_signatures( + inject_id, + sig_output, + execution_details, + max_payload_size=self.max_payload_size, + logger=self.logger, + ) def resolve_container_ip(self) -> str: """Sniff the container's primary IPv4. Env var, hostname, then ``hostname -i``. Returns: - The IPv4 string, or ``'unknown'`` with a single warning when all strategies fail. + The IPv4 string, or ``'unknown'`` with a warning when all strategies fail. """ - if self._cached_ipv4: + if self._cached_ipv4 and self._cached_ipv4 != "unknown": return self._cached_ipv4 env_ip = os.environ.get("CONTAINER_IP") diff --git a/pyoaev/signatures/signature_type.py b/pyoaev/signatures/signature_type.py index e0b0500..bab72db 100644 --- a/pyoaev/signatures/signature_type.py +++ b/pyoaev/signatures/signature_type.py @@ -19,7 +19,7 @@ def __init__( self, label: SignatureTypes, match_type: MatchTypes = MatchTypes.MATCH_TYPE_SIMPLE, - match_score: int = None, + match_score: int | None = None, ): self.label = label self.match_policy = SignatureMatch(match_type, match_score) diff --git a/pyoaev/signatures/types.py b/pyoaev/signatures/types.py index ca2737d..0c4f78b 100644 --- a/pyoaev/signatures/types.py +++ b/pyoaev/signatures/types.py @@ -7,6 +7,17 @@ class ExpectationType(str, Enum): VULNERABILITY = "VULNERABILITY" +class InjectExecutionActions(str, Enum): + PREREQUISITE_CHECK = "prerequisite_check" + PREREQUISITE_EXECUTION = "prerequisite_execution" + CLEANUP_EXECUTION = "cleanup_execution" + COMMAND_EXECUTION = "command_execution" + DNS_RESOLUTION = "dns_resolution" + FILE_EXECUTION = "file_execution" + FILE_DROP = "file_drop" + COMPLETE = "complete" + + class MatchTypes(str, Enum): MATCH_TYPE_FUZZY = "fuzzy" MATCH_TYPE_SIMPLE = "simple" diff --git a/test/signatures/constraints/signature_manager_post_execution_constraints.feature b/test/signatures/constraints/signature_manager_post_execution_constraints.feature index 6d163bf..16d7b47 100644 --- a/test/signatures/constraints/signature_manager_post_execution_constraints.feature +++ b/test/signatures/constraints/signature_manager_post_execution_constraints.feature @@ -5,29 +5,35 @@ Feature: SignatureManager post-execution constraints Background: Given a SignatureManager initialised with constructor SignatureManager(client, logger) - And a pre_signatures dict containing: + And a execution_signatures object containing: | key | value | | source_ipv4 | 172.17.0.2 | | target_ipv4 | 10.0.0.1 | | target_hostname | host-a.internal | | start_time | 2024-06-26T06:00:00Z | + And a execution_details object containing: + | key | value | + | start_time | 2024-06-26T06:00:00Z | Scenario: Tool crash sets execution_status to failed and uses crash timestamp as end_time Given a tool_output containing error_info with exit_code=1 and crash_timestamp="2024-06-26T06:05:00Z" - When I call compile_post_execution_signatures with the pre_signatures dict and tool_output + When I call post_execution_updates with the execution_details, execution_signatures and tool_output Then execution_status equals "failed" And end_time equals "2024-06-26T06:05:00Z" - And all pre-execution fields from pre_signatures are present and unchanged in the returned dict + And the execution signature model contains every previous parameter unchanged + And the execution details model contain every previous parameter pair unchanged Scenario: Timeout sets execution_status to timeout and includes available partial results Given a tool_output containing timeout_info with partial_results=["result-A", "result-B"] - When I call compile_post_execution_signatures with the pre_signatures dict and tool_output + When I call post_execution_updates with the execution_details, execution_signatures and tool_output Then execution_status equals "timeout" And the returned dict contains the partial results ["result-A", "result-B"] from timeout_info - And all pre-execution fields from pre_signatures are present and unchanged in the returned dict + And the execution signature model contains every previous parameter unchanged + And the execution details model contain every previous parameter pair unchanged Scenario: Timeout with no partial results still sets execution_status to timeout Given a tool_output containing timeout_info with no partial results available - When I call compile_post_execution_signatures with the pre_signatures dict and tool_output + When I call post_execution_updates with the execution_details, execution_signatures and tool_output Then execution_status equals "timeout" - And all pre-execution fields from pre_signatures are present and unchanged in the returned dict + And the execution signature model contains every previous parameter unchanged + And the execution details model contain every previous parameter pair unchanged diff --git a/test/signatures/constraints/signature_manager_pre_execution_constraints.feature b/test/signatures/constraints/signature_manager_pre_execution_constraints.feature index e37a1e5..f6b25ca 100644 --- a/test/signatures/constraints/signature_manager_pre_execution_constraints.feature +++ b/test/signatures/constraints/signature_manager_pre_execution_constraints.feature @@ -10,6 +10,6 @@ Feature: SignatureManager pre-execution constraints Given a SignatureManager that was instantiated at timestamp T0 And 5 seconds elapse after instantiation And a NetworkInjectorConfig with target_ipv4="192.168.1.10" - When I call compile_pre_execution_signatures with the config at timestamp T1 + When I call build_execution_signatures with the config at timestamp T1 Then the start_time in the returned dict equals T1 within 1 second tolerance And start_time does not equal T0 diff --git a/test/signatures/constraints/signature_manager_transmission_constraints.feature b/test/signatures/constraints/signature_manager_transmission_constraints.feature index 0fcf4f1..956256d 100644 --- a/test/signatures/constraints/signature_manager_transmission_constraints.feature +++ b/test/signatures/constraints/signature_manager_transmission_constraints.feature @@ -6,22 +6,23 @@ Feature: SignatureManager transmission constraints Background: Given a SignatureManager initialised with constructor SignatureManager(client, logger) - Scenario: Payload exceeding MAX_PAYLOAD_SIZE is auto-chunked with chunk metadata + Scenario: Payload exceeding MAX_PAYLOAD_SIZE is split into multiple sequential envelopes Given a compiled payload whose serialised size exceeds MAX_PAYLOAD_SIZE by at least a factor of 2 + And an updated post-execution execution details object And the backend responds with HTTP 200 - When I call send_signatures for inject_id "inject-abc-001" with phase "execution_complete" - Then the payload is sent as multiple sequential POST requests to /injects/inject-abc-001/callback - And each POST request body contains chunk_index as a 0-based integer - And each POST request body contains total_chunks as a positive integer matching the total number of chunks sent - And each POST request body contains only "signatures", "chunk_index" and "total_chunks" at the top level + When I call send_signatures for inject_id "inject-abc-001" + Then the payload is sent as multiple sequential POST requests to /injects/execution/callback/inject-abc-001 + And each POST request body is a valid self-contained envelope with the same structure as a single-send payload + And no POST request body contains chunk_index or total_chunks keys And the union of targets across all POST requests equals the original target set - And no individual POST request body exceeds MAX_PAYLOAD_SIZE bytes + And no individual POST request body exceeds MAX_PAYLOAD_SIZE bytes without warning Scenario: HTTP 5xx response triggers exponential backoff retry for up to 3 additional attempts Given a compiled post-execution payload for inject_id "inject-abc-001" + And an updated post-execution execution details object And the backend responds with HTTP 503 on every attempt - When I call send_signatures for inject_id "inject-abc-001" with phase "execution_complete" - Then send_signatures sends a total of 4 POST requests to /injects/inject-abc-001/callback + When I call send_signatures for inject_id "inject-abc-001" + Then send_signatures sends a total of 4 POST requests to /injects/execution/callback/inject-abc-001 And a WARNING log message containing the retry attempt number is emitted before each of the 3 retry attempts And the wait before attempt 2 is 1 second And the wait before attempt 3 is 2 seconds @@ -30,9 +31,10 @@ Feature: SignatureManager transmission constraints Scenario: HTTP 4xx response raises an exception immediately with no retries and no sleep Given a compiled post-execution payload for inject_id "inject-abc-001" + And an updated post-execution execution details object And the backend responds with HTTP 400 and body '{"error": "bad request"}' - When I call send_signatures for inject_id "inject-abc-001" with phase "execution_complete" - Then only 1 POST request is sent to /injects/inject-abc-001/callback + When I call send_signatures for inject_id "inject-abc-001" + Then only 1 POST request is sent to /injects/execution/callback/inject-abc-001 And an ERROR log message containing status code 400 and the response body is emitted And an exception is raised immediately And no sleep or wait occurs before the exception is raised diff --git a/test/signatures/features/signature_manager_post_execution.feature b/test/signatures/features/signature_manager_post_execution.feature index 856b1dd..ffa78a4 100644 --- a/test/signatures/features/signature_manager_post_execution.feature +++ b/test/signatures/features/signature_manager_post_execution.feature @@ -1,31 +1,28 @@ -Feature: SignatureManager post-execution signature compilation +Feature: SignatureManager post-execution execution elements update As an injector using the OpenAEV client - I want to merge execution results into pre-execution signatures - So that each inject has a complete signature record including outcome and timing + I want to update both execution signatures and execution details with execution results + So that each inject has a complete execution record including outcome and timing Background: Given a SignatureManager initialised with constructor SignatureManager(client, logger) - And a pre_signatures dict containing: + And a execution_signatures object containing: | key | value | | source_ipv4 | 172.17.0.2 | | target_ipv4 | 10.0.0.1 | | target_hostname | host-a.internal | | start_time | 2024-06-26T06:00:00Z | + And a execution_details object containing: + | key | value | + | start_time | 2024-06-26T06:00:00Z | - Scenario: Successful execution merges end_time and execution_status into pre-execution fields + Scenario: Successful execution updates end_time and execution_status in execution signatures and execution details Given a tool_output indicating successful completion with no errors and no timeout - When I call compile_post_execution_signatures with the pre_signatures dict and tool_output - Then the returned dict contains every key-value pair from pre_signatures unchanged - And the returned dict contains end_time as a UTC ISO 8601 string - And end_time is chronologically greater than or equal to start_time "2024-06-26T06:00:00Z" - And the returned dict contains execution_status equal to "success" - - - Scenario: Multi-target pre-signatures merge into a list of post-signatures - Given the pre_signatures is replaced by a list of 3 dicts each with a distinct target_ipv4 - And a tool_output indicating successful completion with no errors and no timeout - When I call compile_post_execution_signatures with the pre_signatures dict and tool_output - Then the returned value is a list of exactly 3 dicts - And every dict in the returned list contains execution_status equal to "success" - And every dict in the returned list contains end_time as a UTC ISO 8601 string - And every dict in the returned list preserves its original target_ipv4 and source_ipv4 fields + When I call post_execution_updates with the execution_details, execution_signatures and tool_output + Then the execution signature model contains every previous parameter unchanged + And the end_time parameter in the execution signature model is a UTC ISO 8601 string + And this end_time is chronologically greater than or equal to start_time "2024-06-26T06:00:00Z" + And the execution details model contain every previous parameter pair unchanged + And the end_time parameter in the execution details model is a datetime object + And this end_time is chronologically greater than or equal to start_time "2024-06-26T06:00:00Z" + And the execution_status parameter in the execution details model is equal to "success" + And the execution_action parameter in the execution details model is equal to "complete" diff --git a/test/signatures/features/signature_manager_pre_execution.feature b/test/signatures/features/signature_manager_pre_execution.feature index 1ba3ed5..cd52045 100644 --- a/test/signatures/features/signature_manager_pre_execution.feature +++ b/test/signatures/features/signature_manager_pre_execution.feature @@ -9,7 +9,7 @@ Feature: SignatureManager pre-execution signature compilation Scenario: Network category returns required IP and timing fields and no cloud or query fields Given a NetworkInjectorConfig with target_ipv4="192.168.1.10" And the running container has a resolvable IPv4 address - When I call compile_pre_execution_signatures with the config + When I call build_execution_signatures with the config Then the returned dict contains source_ipv4 as a non-empty valid IPv4 address string And the returned dict contains start_time as a UTC ISO 8601 string And the returned dict contains target_ipv4 equal to "192.168.1.10" @@ -22,14 +22,14 @@ Feature: SignatureManager pre-execution signature compilation Scenario: Network hostname target returns hostname and no target_ipv4 Given a NetworkInjectorConfig with target_hostname="target.example.com" And the running container has a resolvable IPv4 address - When I call compile_pre_execution_signatures with the config + When I call build_execution_signatures with the config Then the returned dict contains target_hostname equal to "target.example.com" And the returned dict contains source_ipv4 as a non-empty valid IPv4 address string But the returned dict does not contain target_ipv4 Scenario: Cloud category returns required cloud identity fields and no IP fields Given a CloudInjectorConfig with cloud_provider="aws", cloud_account_id="123456789012", cloud_region="eu-west-1", and target_service="ec2" - When I call compile_pre_execution_signatures with the config + When I call build_execution_signatures with the config Then the returned dict contains cloud_provider equal to "aws" And the returned dict contains cloud_account_id equal to "123456789012" And the returned dict contains cloud_region equal to "eu-west-1" @@ -40,18 +40,10 @@ Feature: SignatureManager pre-execution signature compilation And the returned dict does not contain target_ipv4 And the returned dict does not contain target_ipv6 - Scenario: External category returns scan target fields and no source IP - Given an ExternalInjectorConfig with target_ipv4="203.0.113.5" and query="port:22 os:linux" - When I call compile_pre_execution_signatures with the config - Then the returned dict contains target_ipv4 equal to "203.0.113.5" - And the returned dict contains query equal to "port:22 os:linux" - And the returned dict contains start_time as a UTC ISO 8601 string - But the returned dict does not contain source_ipv4 - Scenario Outline: Network multi-target returns one dict per target with a shared source IP Given a list of 3 NetworkInjectorConfig with target_ipv4 "10.0.0.1", "10.0.0.2", "10.0.0.3" And the running container has a resolvable IPv4 address "172.17.0.2" - When I call compile_pre_execution_signatures with the config list + When I call build_execution_signatures with the config list Then the return value is a list of exactly 3 dicts And the dict at position contains target_ipv4 equal to "" And the dict at position contains source_ipv4 equal to "172.17.0.2" @@ -65,13 +57,13 @@ Feature: SignatureManager pre-execution signature compilation Scenario: All network multi-target dicts share the same source_ipv4 Given a list of 3 NetworkInjectorConfig built from default IPv4 targets And the running container has a resolvable IPv4 address - When I call compile_pre_execution_signatures with the config list + When I call build_execution_signatures with the config list Then the return value is a list of 3 dicts And all 3 dicts contain the same source_ipv4 value Scenario Outline: Cloud multi-region returns one dict per region with a shared account ID Given a list of 3 CloudInjectorConfig with cloud_account_id="123456789012" and regions "us-east-1", "eu-west-1", "ap-southeast-1" - When I call compile_pre_execution_signatures with the config list + When I call build_execution_signatures with the config list Then the return value is a list of exactly 3 dicts And the dict at position contains cloud_region equal to "" And the dict at position contains cloud_account_id equal to "123456789012" diff --git a/test/signatures/features/signature_manager_transmission.feature b/test/signatures/features/signature_manager_transmission.feature index 94e34bd..12d9612 100644 --- a/test/signatures/features/signature_manager_transmission.feature +++ b/test/signatures/features/signature_manager_transmission.feature @@ -8,8 +8,9 @@ Feature: SignatureManager signature transmission and container IP resolution Scenario Outline: HTTP 2xx response is treated as successful transmission Given a compiled post-execution payload for inject_id "inject-abc-001" + And an updated post-execution execution details object And the backend responds with HTTP - When I call send_signatures for inject_id "inject-abc-001" with phase "execution_complete" + When I call send_signatures for inject_id "inject-abc-001" Then send_signatures completes without raising an exception Examples: @@ -19,9 +20,10 @@ Feature: SignatureManager signature transmission and container IP resolution Scenario: send_signatures posts to the inject callback with the agreed nested schema Given a compiled payload with 1 target, expectation_type "DETECTION", signature_type "public_ip", signature_value "203.0.113.5" + And an updated post-execution execution details object And the backend responds with HTTP 200 - When I call send_signatures for inject_id "inject-abc-001" with phase "execution_complete" - Then a POST request is sent to /injects/inject-abc-001/callback + When I call send_signatures for inject_id "inject-abc-001" + Then a POST request is sent to /injects/execution/callback/inject-abc-001 And the POST request body contains signatures.targets as a list And signatures.targets[0].signature_values[0].expectation_type equals "DETECTION" And signatures.targets[0].signature_values[0].values[0].signature_type equals "public_ip" @@ -43,8 +45,9 @@ Feature: SignatureManager signature transmission and container IP resolution Scenario: Payload schema groups signature values by expectation_type within each target Given a compiled payload for 1 target with signatures of expectation_type "DETECTION" and expectation_type "PREVENTION" + And an updated post-execution execution details object And the backend responds with HTTP 200 - When I call send_signatures for inject_id "inject-abc-001" with phase "execution_complete" + When I call send_signatures for inject_id "inject-abc-001" Then the POST request body nests signature values under separate expectation_type entries within signatures.targets[0].signature_values And the entry with expectation_type "DETECTION" contains only DETECTION signature values And the entry with expectation_type "PREVENTION" contains only PREVENTION signature values diff --git a/test/signatures/test_signature_manager_post_execution.py b/test/signatures/test_signature_manager_post_execution.py index cabaf14..356959b 100644 --- a/test/signatures/test_signature_manager_post_execution.py +++ b/test/signatures/test_signature_manager_post_execution.py @@ -4,12 +4,13 @@ import pytest from pytest_bdd import given, parsers, scenario, then, when +from pyoaev.signatures.models import ExecutionDetails, ExecutionSignature from pyoaev.signatures.signature_manager import SignatureManager @scenario( "features/signature_manager_post_execution.feature", - "Successful execution merges end_time and execution_status into pre-execution fields", + "Successful execution updates end_time and execution_status in execution signatures and execution details", ) def test_successful_execution_merges_post_execution_fields(): pass @@ -39,14 +40,6 @@ def test_timeout_without_partial_results_still_sets_timeout_status(): pass -@scenario( - "features/signature_manager_post_execution.feature", - "Multi-target pre-signatures merge into a list of post-signatures", -) -def test_multi_target_pre_signatures_merge_into_a_list_of_post_signatures(): - pass - - @pytest.fixture def context(): return {} @@ -64,16 +57,27 @@ def signature_manager(context): @given( - "a pre_signatures dict containing:", - target_fixture="pre_signatures", + "a execution_signatures object containing:", + target_fixture="execution_signatures", +) +def execution_signatures(): + return ExecutionSignature( + source_ipv4="172.17.0.2", + target_ipv4="10.0.0.1", + target_hostname="host-a.internal", + start_time="2024-06-26T06:00:00Z", + ) + + +@given( + "a execution_details object containing:", + target_fixture="execution_details", ) -def pre_signatures(): - return { - "source_ipv4": "172.17.0.2", - "target_ipv4": "10.0.0.1", - "target_hostname": "host-a.internal", - "start_time": "2024-06-26T06:00:00Z", - } +def execution_details(): + return ExecutionDetails( + execution_status="", + start_time=datetime.strptime("2024-06-26T06:00:00Z", "%Y-%m-%dT%H:%M:%SZ"), + ) @given( @@ -114,121 +118,99 @@ def timeout_tool_output_with_no_partial_results(): @when( - "I call compile_post_execution_signatures with the pre_signatures dict and tool_output" + "I call post_execution_updates with the execution_details, execution_signatures and tool_output" ) -def compile_post_execution_signatures(context, pre_signatures, tool_output): - context["result"] = context["signature_manager"].compile_post_execution_signatures( - pre_signatures, tool_output +def post_execution_update( + context, execution_details, execution_signatures, tool_output +): + context["signature_manager"].post_execution_updates( + execution_details, execution_signatures, tool_output ) + context["execution_details_result"] = execution_details + context["execution_signatures_result"] = execution_signatures -@then("the returned dict contains every key-value pair from pre_signatures unchanged") -@then( - "all pre-execution fields from pre_signatures are present and unchanged in the returned dict" -) -def pre_signatures_unchanged(context, pre_signatures): - result = context["result"] - for key, value in pre_signatures.items(): - assert key in result - assert result[key] == value +@then("the execution signature model contains every previous parameter unchanged") +def execution_signatures_unchanged(context, execution_signatures): + basic_exec_sig = execution_signatures + exec_sig_result = context["execution_signatures_result"] + assert exec_sig_result.source_ipv4 == basic_exec_sig.source_ipv4 + assert exec_sig_result.source_ipv6 == basic_exec_sig.source_ipv6 + assert exec_sig_result.target_ipv4 == basic_exec_sig.target_ipv4 + assert exec_sig_result.target_ipv6 == basic_exec_sig.target_ipv6 + assert exec_sig_result.target_hostname == basic_exec_sig.target_hostname + assert exec_sig_result.cloud_provider == basic_exec_sig.cloud_provider + assert exec_sig_result.cloud_account_id == basic_exec_sig.cloud_account_id + assert exec_sig_result.cloud_region == basic_exec_sig.cloud_region + assert exec_sig_result.target_service == basic_exec_sig.target_service -@then("the returned dict contains end_time as a UTC ISO 8601 string") + +@then( + "the end_time parameter in the execution signature model is a UTC ISO 8601 string" +) def result_contains_iso8601_end_time(context): - end_time = context["result"]["end_time"] + end_time = context["execution_signatures_result"].end_time assert isinstance(end_time, str) _parse_iso8601_utc(end_time) @then( parsers.parse( - 'end_time is chronologically greater than or equal to start_time "{start_time}"' + 'this end_time is chronologically greater than or equal to start_time "{start_time}"' ) ) def end_time_at_or_after_start_time(context, start_time): - end_time_dt = _parse_iso8601_utc(context["result"]["end_time"]) + end_time_dt = _parse_iso8601_utc(context["execution_signatures_result"].end_time) start_time_dt = _parse_iso8601_utc(start_time) assert end_time_dt >= start_time_dt -@then(parsers.parse('the returned dict contains execution_status equal to "{status}"')) -@then(parsers.parse('execution_status equals "{status}"')) -def execution_status_equals(context, status): - assert context["result"]["execution_status"] == status - - @then(parsers.parse('end_time equals "{expected_end_time}"')) def end_time_equals(context, expected_end_time): - assert context["result"]["end_time"] == expected_end_time + assert context["execution_signatures_result"].end_time == expected_end_time @then( 'the returned dict contains the partial results ["result-A", "result-B"] from timeout_info' ) def contains_timeout_partial_results(context): - assert context["result"]["partial_results"] == ["result-A", "result-B"] - + assert context["execution_signatures_result"].partial_results == [ + "result-A", + "result-B", + ] -# -------------------------------------------------- -# Multi-target post-execution scenario -# -------------------------------------------------- +@then("the execution details model contain every previous parameter pair unchanged") +def execution_details_unchanged(context, execution_details): + basic_exec_details = execution_details + exec_details_result = context["execution_details_result"] -@given( - "the pre_signatures is replaced by a list of 3 dicts each with a distinct target_ipv4", - target_fixture="pre_signatures", -) -def pre_signatures_multi_target_list(): - return [ - { - "source_ipv4": "172.17.0.2", - "target_ipv4": "10.0.0.1", - "start_time": "2024-06-26T06:00:00Z", - }, - { - "source_ipv4": "172.17.0.2", - "target_ipv4": "10.0.0.2", - "start_time": "2024-06-26T06:00:00Z", - }, - { - "source_ipv4": "172.17.0.2", - "target_ipv4": "10.0.0.3", - "start_time": "2024-06-26T06:00:00Z", - }, - ] + assert exec_details_result.start_time == basic_exec_details.start_time + assert exec_details_result.execution_message == basic_exec_details.execution_message -@then("the returned value is a list of exactly 3 dicts") -def result_is_list_of_three_dicts(context): - result = context["result"] - assert isinstance(result, list) - assert len(result) == 3 - assert all(isinstance(item, dict) for item in result) +@then("the end_time parameter in the execution details model is a datetime object") +def result_contains_datetime_end_time(context): + end_time = context["execution_details_result"].end_time + assert isinstance(end_time, datetime) @then( parsers.parse( - 'every dict in the returned list contains execution_status equal to "{status}"' + 'the execution_status parameter in the execution details model is equal to "{status}"' ) ) -def every_dict_has_execution_status(context, status): - for item in context["result"]: - assert item["execution_status"] == status - - -@then("every dict in the returned list contains end_time as a UTC ISO 8601 string") -def every_dict_has_iso8601_end_time(context): - for item in context["result"]: - assert isinstance(item["end_time"], str) - _parse_iso8601_utc(item["end_time"]) +@then(parsers.parse('execution_status equals "{status}"')) +def execution_status_equals(context, status): + assert context["execution_details_result"].execution_status == status @then( - "every dict in the returned list preserves its original target_ipv4 and source_ipv4 fields" -) -def every_dict_preserves_pre_execution_fields(context, pre_signatures): - result = context["result"] - assert len(result) == len(pre_signatures) - for original, merged in zip(pre_signatures, result): - assert merged["target_ipv4"] == original["target_ipv4"] - assert merged["source_ipv4"] == original["source_ipv4"] + parsers.parse( + 'the execution_action parameter in the execution details model is equal to "{action}"' + ) +) +@then(parsers.parse('execution_action equals "{action}"')) +def execution_action_equals(context, action): + assert context["execution_details_result"].execution_action == action diff --git a/test/signatures/test_signature_manager_pre_execution.py b/test/signatures/test_signature_manager_pre_execution.py index f7581dc..c98bf70 100644 --- a/test/signatures/test_signature_manager_pre_execution.py +++ b/test/signatures/test_signature_manager_pre_execution.py @@ -7,7 +7,7 @@ from pyoaev.signatures.models import ( CloudInjectorConfig, - ExternalInjectorConfig, + ExecutionSignature, NetworkInjectorConfig, build_network_configs, ) @@ -42,14 +42,6 @@ def test_cloud_category_required_fields(): pass -@scenario( - "features/signature_manager_pre_execution.feature", - "External category returns scan target fields and no source IP", -) -def test_external_category_fields(): - pass - - @scenario( "features/signature_manager_pre_execution.feature", "Network multi-target returns one dict per target with a shared source IP", @@ -161,16 +153,6 @@ def cloud_config_single( ) -@given( - parsers.parse( - 'an ExternalInjectorConfig with target_ipv4="{target_ipv4}" and query="{query}"' - ), - target_fixture="config", -) -def external_config_single(target_ipv4, query): - return ExternalInjectorConfig(target_ipv4=target_ipv4, query=query) - - @given( parsers.parse( "a list of 3 NetworkInjectorConfig with target_ipv4 " @@ -282,28 +264,28 @@ def elapsed_5_seconds(context): @when( - "I call compile_pre_execution_signatures with the config", + "I call build_execution_signatures with the config", target_fixture="result", ) -def call_compile_with_config(signature_manager, config): - return signature_manager.compile_pre_execution_signatures(config=config) +def call_build_with_config(signature_manager, config): + return signature_manager.build_execution_signatures(config=config) @when( - "I call compile_pre_execution_signatures with the config list", + "I call build_execution_signatures with the config list", target_fixture="result", ) -def call_compile_with_config_list(signature_manager, config): - return signature_manager.compile_pre_execution_signatures(config=config) +def call_build_with_config_list(signature_manager, config): + return signature_manager.build_execution_signatures(config=config) @when( - "I call compile_pre_execution_signatures with the config at timestamp T1", + "I call build_execution_signatures with the config at timestamp T1", target_fixture="result", ) -def call_compile_at_t1(signature_manager, config, t1): +def call_build_at_t1(signature_manager, config, t1): with patch.object(signature_manager, "_utcnow", return_value=t1): - return signature_manager.compile_pre_execution_signatures(config=config) + return signature_manager.build_execution_signatures(config=config) @when( @@ -321,51 +303,51 @@ def build_configs_from_raw(raw_targets): @then("the returned dict contains source_ipv4 as a non-empty valid IPv4 address string") def source_ipv4_is_valid(result): - source_ipv4 = result["source_ipv4"] + source_ipv4 = result.source_ipv4 assert source_ipv4 ipaddress.IPv4Address(source_ipv4) @then("the returned dict contains start_time as a UTC ISO 8601 string") def start_time_is_utc_iso8601(result): - start_time = result["start_time"] + start_time = result.start_time parsed = parse_utc_iso8601(start_time) assert parsed.tzinfo is not None @then(parsers.parse('the returned dict contains target_ipv4 equal to "{value}"')) def returned_dict_target_ipv4(result, value): - assert result["target_ipv4"] == value + assert result.target_ipv4 == value @then(parsers.parse('the returned dict contains target_hostname equal to "{value}"')) def returned_dict_target_hostname(result, value): - assert result["target_hostname"] == value + assert result.target_hostname == value @then(parsers.parse('the returned dict contains cloud_provider equal to "{value}"')) def returned_dict_cloud_provider(result, value): - assert result["cloud_provider"] == value + assert result.cloud_provider == value @then(parsers.parse('the returned dict contains cloud_account_id equal to "{value}"')) def returned_dict_cloud_account_id(result, value): - assert result["cloud_account_id"] == value + assert result.cloud_account_id == value @then(parsers.parse('the returned dict contains cloud_region equal to "{value}"')) def returned_dict_cloud_region(result, value): - assert result["cloud_region"] == value + assert result.cloud_region == value @then(parsers.parse('the returned dict contains target_service equal to "{value}"')) def returned_dict_target_service(result, value): - assert result["target_service"] == value + assert result.target_service == value @then(parsers.parse('the returned dict contains query equal to "{value}"')) def returned_dict_query(result, value): - assert result["query"] == value + assert result.query == value @then(parsers.parse("the returned dict does not contain {field}")) @@ -377,14 +359,14 @@ def returned_dict_does_not_contain_field(result, field): def return_value_is_list_of_three_dicts(result): assert isinstance(result, list) assert len(result) == 3 - assert all(isinstance(item, dict) for item in result) + assert all(isinstance(item, ExecutionSignature) for item in result) @then(parsers.parse("the return value is a list of {count:d} dicts")) def return_value_is_list_of_n_dicts(result, count): assert isinstance(result, list) assert len(result) == count - assert all(isinstance(item, dict) for item in result) + assert all(isinstance(item, ExecutionSignature) for item in result) @then( @@ -393,7 +375,7 @@ def return_value_is_list_of_n_dicts(result, count): ) ) def list_dict_contains_target_ipv4_at_position(result, index, target_ip): - assert result[index]["target_ipv4"] == target_ip + assert result[index].target_ipv4 == target_ip @then( @@ -406,7 +388,7 @@ def list_dict_contains_source_ipv4_at_position( index, source_ipv4, ): - assert result[index]["source_ipv4"] == source_ipv4 + assert result[index].source_ipv4 == source_ipv4 @then( @@ -415,7 +397,7 @@ def list_dict_contains_source_ipv4_at_position( ) ) def list_dict_contains_cloud_region_at_position(result, index, region): - assert result[index]["cloud_region"] == region + assert result[index].cloud_region == region @then( @@ -424,28 +406,28 @@ def list_dict_contains_cloud_region_at_position(result, index, region): ) ) def list_dict_contains_cloud_account_id_at_position(result, index, account_id): - assert result[index]["cloud_account_id"] == account_id + assert result[index].cloud_account_id == account_id @then("all 3 dicts contain the same source_ipv4 value") def all_dicts_share_same_source_ipv4(result): assert isinstance(result, list) assert len(result) == 3 - source_values = {item["source_ipv4"] for item in result} + source_values = {item.source_ipv4 for item in result} assert len(source_values) == 1 ipaddress.IPv4Address(next(iter(source_values))) @then("the start_time in the returned dict equals T1 within 1 second tolerance") def start_time_equals_t1_with_tolerance(result, t1): - start_time = parse_utc_iso8601(result["start_time"]) + start_time = parse_utc_iso8601(result.start_time) delta_seconds = abs((start_time - t1).total_seconds()) assert delta_seconds <= 1 @then("start_time does not equal T0") def start_time_not_equal_t0(result, signature_manager): - start_time = parse_utc_iso8601(result["start_time"]) + start_time = parse_utc_iso8601(result.start_time) assert start_time != signature_manager._test_t0 diff --git a/test/signatures/test_signature_manager_transmission.py b/test/signatures/test_signature_manager_transmission.py index 6519bd6..e414983 100644 --- a/test/signatures/test_signature_manager_transmission.py +++ b/test/signatures/test_signature_manager_transmission.py @@ -1,5 +1,6 @@ import ipaddress import json +from datetime import timedelta from types import SimpleNamespace from unittest.mock import MagicMock, call @@ -8,6 +9,7 @@ from pyoaev.apis.signature import SignatureApiManager from pyoaev.exceptions import OpenAEVUpdateError, SignatureTransmissionError +from pyoaev.signatures.models import ExecutionDetails from pyoaev.signatures.signature_manager import SignatureManager @@ -29,9 +31,9 @@ def test_send_signatures_posts_with_agreed_nested_schema(): @scenario( "constraints/signature_manager_transmission_constraints.feature", - "Payload exceeding MAX_PAYLOAD_SIZE is auto-chunked with chunk metadata", + "Payload exceeding MAX_PAYLOAD_SIZE is split into multiple sequential envelopes", ) -def test_payload_exceeding_max_payload_size_is_split_into_sequential_chunks(): +def test_payload_exceeding_max_payload_size_is_split_into_sequential_envelopes(): pass @@ -87,6 +89,12 @@ def context(): } +def _extract_targets(body: dict) -> list[dict]: + """Parse targets from the SignatureCallbackPayload wire format.""" + sig_data = json.loads(body["execution_output_structured"]) + return sig_data["signatures"]["targets"] + + def _build_signature_payload( signature_value="203.0.113.5", expectation_types=None, @@ -164,7 +172,6 @@ def _http_post(*args, **kwargs): context["status_plan"] = [200] context["error_body"] = "" context["inject_id"] = "inject-abc-001" - context["phase"] = "execution_complete" context["signatures"] = _build_signature_payload() context["signature_manager"] = SignatureManager(mock_client, logger=logger) @@ -175,6 +182,16 @@ def compiled_post_execution_payload(context, inject_id): context["signatures"] = _build_signature_payload() +@given(parsers.parse("an updated post-execution execution details object")) +def updated_post_execution_execution_details(context): + execution_details = ExecutionDetails( + execution_status="success", + execution_action="complete", + ) + execution_details.end_time = execution_details.start_time + timedelta(0.1) + context["execution_details"] = execution_details + + @given( parsers.parse( 'a compiled payload with 1 target, expectation_type "{expectation_type}", signature_type "{signature_type}", signature_value "{signature_value}"' @@ -233,7 +250,7 @@ def compiled_large_payload(context): }, { "signature_type": "hostname", - "signature_value": f"host-{index}." + ("a" * 140), + "signature_value": f"host-{index}." + ("a" * 94), }, ], } @@ -325,18 +342,25 @@ def compiled_payload_grouped_by_expectation( "signature_values": [ { "expectation_type": expectation_a, - "signature_type": "public_ip", - "signature_value": "203.0.113.5", + "values": [ + { + "signature_type": "public_ip", + "signature_value": "203.0.113.5", + }, + { + "signature_type": "hostname", + "signature_value": "host-a.internal", + }, + ], }, { "expectation_type": expectation_b, - "signature_type": "public_ip", - "signature_value": "198.51.100.10", - }, - { - "expectation_type": expectation_a, - "signature_type": "hostname", - "signature_value": "host-a.internal", + "values": [ + { + "signature_type": "public_ip", + "signature_value": "198.51.100.10", + }, + ], }, ], } @@ -344,19 +368,14 @@ def compiled_payload_grouped_by_expectation( } -@when( - parsers.parse( - 'I call send_signatures for inject_id "{inject_id}" with phase "{phase}"' - ) -) -def call_send_signatures(context, inject_id, phase): +@when(parsers.parse('I call send_signatures for inject_id "{inject_id}"')) +def call_send_signatures(context, inject_id): context["inject_id"] = inject_id - context["phase"] = phase context["send_exception"] = None try: context["signature_manager"].send_signatures( inject_id, - phase, + context["execution_details"], context["signatures"], ) except Exception as exc: @@ -379,18 +398,22 @@ def send_signatures_completes_without_exception(context): @then( parsers.parse( - "a POST request is sent to /injects/{inject_id}/callback", + "a POST request is sent to /injects/execution/callback/{inject_id}", ) ) def assert_post_request_sent_to_callback(context, inject_id): assert context["captured_calls"] - assert context["captured_calls"][-1]["path"] == f"/injects/{inject_id}/callback" + assert ( + context["captured_calls"][-1]["path"] + == f"/injects/execution/callback/{inject_id}" + ) @then("the POST request body contains signatures.targets as a list") def assert_targets_is_list(context): body = context["captured_calls"][-1]["post_data"] - assert isinstance(body["expectation_signature"]["targets"], list) + targets = _extract_targets(body) + assert isinstance(targets, list) @then( @@ -400,9 +423,8 @@ def assert_targets_is_list(context): ) def assert_expectation_type(context, expected_value): body = context["captured_calls"][-1]["post_data"] - assert body["expectation_signature"]["targets"][0]["signature_values"][0][ - "expectation_type" - ] == (expected_value) + targets = _extract_targets(body) + assert targets[0]["signature_values"][0]["expectation_type"] == expected_value @then( @@ -412,10 +434,9 @@ def assert_expectation_type(context, expected_value): ) def assert_signature_type(context, expected_value): body = context["captured_calls"][-1]["post_data"] + targets = _extract_targets(body) assert ( - body["expectation_signature"]["targets"][0]["signature_values"][0]["values"][0][ - "signature_type" - ] + targets[0]["signature_values"][0]["values"][0]["signature_type"] == expected_value ) @@ -427,10 +448,9 @@ def assert_signature_type(context, expected_value): ) def assert_signature_value(context, expected_value): body = context["captured_calls"][-1]["post_data"] + targets = _extract_targets(body) assert ( - body["expectation_signature"]["targets"][0]["signature_values"][0]["values"][0][ - "signature_value" - ] + targets[0]["signature_values"][0]["values"][0]["signature_value"] == expected_value ) @@ -438,54 +458,42 @@ def assert_signature_value(context, expected_value): @then("signatures.targets[0] contains a signature_target key") def assert_signature_target_key(context): body = context["captured_calls"][-1]["post_data"] - assert "signature_target" in body["expectation_signature"]["targets"][0] + targets = _extract_targets(body) + assert "signature_target" in targets[0] @then( parsers.parse( - "the payload is sent as multiple sequential POST requests to /injects/{inject_id}/callback", + "the payload is sent as multiple sequential POST requests to /injects/execution/callback/{inject_id}", ) ) def assert_payload_sent_as_multiple_chunks(context, inject_id): assert context["send_exception"] is None assert len(context["captured_calls"]) > 1 assert all( - call_item["path"] == f"/injects/{inject_id}/callback" + call_item["path"] == f"/injects/execution/callback/{inject_id}" for call_item in context["captured_calls"] ) -@then("each POST request body contains chunk_index as a 0-based integer") -def assert_chunk_index_present(context): - for index, call_item in enumerate(context["captured_calls"]): - post_data = call_item["post_data"] - assert isinstance(post_data["chunk_index"], int) - assert post_data["chunk_index"] == index - - @then( - "each POST request body contains total_chunks as a positive integer matching the total number of chunks sent" + "each POST request body is a valid self-contained envelope with the same structure as a single-send payload" ) -def assert_total_chunks_present(context): - total_chunks = len(context["captured_calls"]) +def assert_each_envelope_is_self_contained(context): for call_item in context["captured_calls"]: post_data = call_item["post_data"] - assert isinstance(post_data["total_chunks"], int) - assert post_data["total_chunks"] > 0 - assert post_data["total_chunks"] == total_chunks + assert "execution_output_structured" in post_data + targets = _extract_targets(post_data) + assert isinstance(targets, list) + assert len(targets) > 0 -@then( - 'each POST request body contains only "signatures", "chunk_index" and "total_chunks" at the top level' -) -def assert_chunked_envelope_is_strict(context): - expected_keys = {"expectation_signature", "chunk_index", "total_chunks", "phase"} +@then("no POST request body contains chunk_index or total_chunks keys") +def assert_no_chunk_metadata(context): for call_item in context["captured_calls"]: post_data = call_item["post_data"] - assert set(post_data.keys()) == expected_keys, ( - f"Chunked envelope must contain exactly {expected_keys}, " - f"got {set(post_data.keys())}" - ) + assert "chunk_index" not in post_data + assert "total_chunks" not in post_data @then("the union of targets across all POST requests equals the original target set") @@ -494,17 +502,17 @@ def assert_targets_union_matches_original(context): sent_targets = [ target for call_item in context["captured_calls"] - for target in call_item["post_data"]["expectation_signature"]["targets"] + for target in _extract_targets(call_item["post_data"]) ] assert len(sent_targets) == len(original_targets), ( - f"Expected {len(original_targets)} targets across all chunks, " + f"Expected {len(original_targets)} targets across all envelopes, " f"got {len(sent_targets)}" ) for original, sent in zip(original_targets, sent_targets): assert sent["signature_target"] == original["signature_target"] -@then("no individual POST request body exceeds MAX_PAYLOAD_SIZE bytes") +@then("no individual POST request body exceeds MAX_PAYLOAD_SIZE bytes without warning") def assert_payload_size_per_chunk(context): max_payload_size = context["signature_manager"].max_payload_size for call_item in context["captured_calls"]: @@ -515,13 +523,13 @@ def assert_payload_size_per_chunk(context): @then( parsers.parse( - "send_signatures sends a total of {total_requests:d} POST requests to /injects/{inject_id}/callback" + "send_signatures sends a total of {total_requests:d} POST requests to /injects/execution/callback/{inject_id}" ) ) def assert_total_post_requests(context, total_requests, inject_id): assert len(context["captured_calls"]) == total_requests assert all( - call_item["path"] == f"/injects/{inject_id}/callback" + call_item["path"] == f"/injects/execution/callback/{inject_id}" for call_item in context["captured_calls"] ) @@ -553,12 +561,15 @@ def assert_signature_transmission_error_after_retries(context): @then( parsers.parse( - "only {request_count:d} POST request is sent to /injects/{inject_id}/callback" + "only {request_count:d} POST request is sent to /injects/execution/callback/{inject_id}" ) ) def assert_single_post_request(context, request_count, inject_id): assert len(context["captured_calls"]) == request_count - assert context["captured_calls"][0]["path"] == f"/injects/{inject_id}/callback" + assert ( + context["captured_calls"][0]["path"] + == f"/injects/execution/callback/{inject_id}" + ) @then( @@ -616,7 +627,8 @@ def assert_no_exception_from_resolve_container_ip(context): ) def assert_signature_values_nested_by_expectation_type(context): body = context["captured_calls"][-1]["post_data"] - entries = body["expectation_signature"]["targets"][0]["signature_values"] + targets = _extract_targets(body) + entries = targets[0]["signature_values"] expectation_types = {entry["expectation_type"] for entry in entries} assert expectation_types == {"DETECTION", "PREVENTION"} @@ -626,7 +638,8 @@ def assert_signature_values_nested_by_expectation_type(context): ) def assert_detection_values_grouped_correctly(context): body = context["captured_calls"][-1]["post_data"] - entries = body["expectation_signature"]["targets"][0]["signature_values"] + targets = _extract_targets(body) + entries = targets[0]["signature_values"] detection_entry = next( entry for entry in entries if entry["expectation_type"] == "DETECTION" ) @@ -640,7 +653,8 @@ def assert_detection_values_grouped_correctly(context): ) def assert_prevention_values_grouped_correctly(context): body = context["captured_calls"][-1]["post_data"] - entries = body["expectation_signature"]["targets"][0]["signature_values"] + targets = _extract_targets(body) + entries = targets[0]["signature_values"] prevention_entry = next( entry for entry in entries if entry["expectation_type"] == "PREVENTION" )