From 8004874bc117d526d343fc0a514ef23a31019eae Mon Sep 17 00:00:00 2001 From: Jim Blomo Date: Wed, 10 Jun 2026 12:42:27 -0700 Subject: [PATCH 1/9] [bedrock_sigv4_auth] Add AWS-native Bedrock authentication --- README.md | 17 +- pyproject.toml | 5 + requirements-dev.lock | 7 + requirements.lock | 8 + src/openai/lib/bedrock.py | 478 +++++++++++++++++++++++++++++++++++--- tests/lib/test_bedrock.py | 212 ++++++++++++++++- 6 files changed, 692 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 6f87246470..c89273612e 100644 --- a/README.md +++ b/README.md @@ -944,7 +944,7 @@ response = client.responses.create( print(response.output_text) ``` -`BedrockOpenAI` configures AWS bearer auth and the Bedrock Mantle endpoint, then uses the normal SDK resources. AWS controls which endpoints and features are supported; unsupported calls surface the provider's normal HTTP errors through the SDK. +`BedrockOpenAI` configures AWS authentication and the Bedrock Mantle endpoint, then uses the normal SDK resources. AWS controls which endpoints and features are supported; unsupported calls surface the provider's normal HTTP errors through the SDK. Pass `base_url` or set `AWS_BEDROCK_BASE_URL` to override the derived `https://bedrock-mantle..api.aws/openai/v1` endpoint. The legacy module client supports `openai.api_type = "amazon-bedrock"` or `OPENAI_API_TYPE=amazon-bedrock`. @@ -957,6 +957,21 @@ client = BedrockOpenAI( ) ``` +To use the standard AWS credential chain and SigV4 authentication, install the Bedrock extra and omit bearer-token configuration: + +```sh +pip install 'openai[bedrock]' +``` + +```py +client = BedrockOpenAI( + aws_region="us-west-2", + aws_profile="my-profile", # optional; otherwise uses the default AWS credential chain +) +``` + +You can also pass explicit temporary credentials or an `aws_credentials_provider` that returns botocore-compatible credentials. Explicit bearer and AWS credential options are mutually exclusive. Without explicit authentication, `AWS_BEARER_TOKEN_BEDROCK` takes precedence over the default AWS credential chain. + ## Versioning This package generally follows [SemVer](https://semver.org/spec/v2.0.0.html) conventions, though certain backwards-incompatible changes may be released as minor versions: diff --git a/pyproject.toml b/pyproject.toml index 75d0d5e246..d84ebfb1b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,10 @@ aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.9"] realtime = ["websockets >= 13, < 16"] datalib = ["numpy >= 1", "pandas >= 1.2.3", "pandas-stubs >= 1.1.0.11"] voice_helpers = ["sounddevice>=0.5.1", "numpy>=2.0.2"] +bedrock = [ + "botocore[crt]>=1.42.0,<1.43; python_version < '3.10'", + "botocore[crt]>=1.42.0,<2; python_version >= '3.10'", +] [tool.rye] managed = true @@ -65,6 +69,7 @@ dev-dependencies = [ "rich>=13.7.1", "inline-snapshot>=0.28.0", "azure-identity >=1.14.1", + "botocore==1.42.97", "types-tqdm > 4", "types-pyaudio > 0", "trio >=0.22.2", diff --git a/requirements-dev.lock b/requirements-dev.lock index 73312bcee1..7d6764ed7a 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -33,11 +33,15 @@ attrs==25.4.0 # via nox # via outcome # via trio +awscrt==0.31.2 + # via botocore azure-core==1.36.0 # via azure-identity azure-identity==1.25.1 backports-asyncio-runner==1.2.0 # via pytest-asyncio +botocore==1.42.97 + # via openai certifi==2026.1.4 # via httpcore # via httpx @@ -100,6 +104,8 @@ iniconfig==2.1.0 inline-snapshot==0.31.1 jiter==0.12.0 # via openai +jmespath==1.1.0 + # via botocore markdown-it-py==3.0.0 # via rich mdurl==0.1.2 @@ -218,6 +224,7 @@ typing-inspection==0.4.2 tzdata==2025.2 # via pandas urllib3==2.5.0 + # via botocore # via requests # via types-requests virtualenv==20.35.4 diff --git a/requirements.lock b/requirements.lock index af6e4f99e8..5f5158bd4f 100644 --- a/requirements.lock +++ b/requirements.lock @@ -26,6 +26,10 @@ async-timeout==5.0.1 # via aiohttp attrs==25.4.0 # via aiohttp +awscrt==0.31.2 + # via botocore +botocore==1.42.97 + # via openai certifi==2026.1.4 # via httpcore # via httpx @@ -53,6 +57,8 @@ idna==3.11 # via yarl jiter==0.12.0 # via openai +jmespath==1.1.0 + # via botocore multidict==6.7.0 # via aiohttp # via yarl @@ -100,6 +106,8 @@ typing-inspection==0.4.2 # via pydantic tzdata==2025.2 # via pandas +urllib3==1.26.20 + # via botocore websockets==15.0.1 # via openai yarl==1.22.0 diff --git a/src/openai/lib/bedrock.py b/src/openai/lib/bedrock.py index 266a2e9358..a7dfae7be7 100644 --- a/src/openai/lib/bedrock.py +++ b/src/openai/lib/bedrock.py @@ -3,13 +3,14 @@ import os import re import inspect +import importlib from typing import Any, Mapping, Callable, Awaitable, cast from typing_extensions import Self, override import httpx from ..auth import WorkloadIdentity -from .._types import NOT_GIVEN, Timeout, NotGiven +from .._types import NOT_GIVEN, Headers, Timeout, NotGiven from .._utils import is_given from .._client import OpenAI, AsyncOpenAI from .._models import SecurityOptions, FinalRequestOptions @@ -18,6 +19,181 @@ BedrockTokenProvider = Callable[[], str] AsyncBedrockTokenProvider = Callable[[], "str | Awaitable[str]"] +AwsCredentialsProvider = Callable[[], object] + + +class _BedrockAwsBearerAuth: + def __init__(self) -> None: + try: + auth_module = importlib.import_module("botocore.auth") + awsrequest_module = importlib.import_module("botocore.awsrequest") + tokens_module = importlib.import_module("botocore.tokens") + except ImportError: + self._bearer_auth_cls = None + self._aws_request_cls = None + self._frozen_auth_token_cls = None + return + + self._bearer_auth_cls = auth_module.BearerAuth + self._aws_request_cls = awsrequest_module.AWSRequest + self._frozen_auth_token_cls = tokens_module.FrozenAuthToken + + def sign(self, request: httpx.Request, token: str) -> None: + if self._bearer_auth_cls is None or self._aws_request_cls is None or self._frozen_auth_token_cls is None: + return + + headers = dict(request.headers) + headers.pop("authorization", None) + aws_request = self._aws_request_cls( + method=request.method, + url=str(request.url), + data=request.read(), + headers=headers, + ) + self._bearer_auth_cls(self._frozen_auth_token_cls(token)).add_auth(aws_request) + request.headers.clear() + request.headers.update(dict(aws_request.headers.items())) + + +class _BedrockAwsAuth: + def __init__( + self, + *, + region: str, + profile: str | None, + access_key_id: str | None, + secret_access_key: str | None, + session_token: str | None, + credentials_provider: AwsCredentialsProvider | None, + ) -> None: + try: + auth_module = importlib.import_module("botocore.auth") + session_module = importlib.import_module("botocore.session") + awsrequest_module = importlib.import_module("botocore.awsrequest") + credentials_module = importlib.import_module("botocore.credentials") + except ImportError as exc: + raise OpenAIError( + "AWS credential authentication requires botocore. Install it with `pip install openai[bedrock]`." + ) from exc + + session = session_module.Session(profile=profile) + service_model = session.get_service_model("bedrock-runtime") + auth_options = cast("list[str]", service_model.metadata.get("auth", [])) + if auth_module.resolve_auth_scheme_preference(["sigv4"], auth_options) != "v4": + raise OpenAIError("The installed botocore version does not support Bedrock SigV4 authentication.") + + self._region = region + self._session = session + self._credentials_provider = credentials_provider + self._explicit_credentials = ( + credentials_module.Credentials(access_key_id, secret_access_key, session_token) + if access_key_id is not None and secret_access_key is not None + else None + ) + self._aws_request_cls = awsrequest_module.AWSRequest + self._sigv4_auth_cls = auth_module.SigV4Auth + + def sign(self, request: httpx.Request) -> None: + credentials = ( + self._credentials_provider() + if self._credentials_provider is not None + else self._explicit_credentials or self._session.get_credentials() + ) + if credentials is None: + raise OpenAIError( + "Could not resolve AWS credentials. Configure the standard AWS credential chain or pass explicit " + "AWS credentials to the Bedrock client." + ) + + get_frozen_credentials = getattr(credentials, "get_frozen_credentials", None) + if callable(get_frozen_credentials): + credentials = get_frozen_credentials() + + headers = dict(request.headers) + headers.pop("authorization", None) + aws_request = self._aws_request_cls( + method=request.method, + url=str(request.url), + data=request.read(), + headers=headers, + ) + self._sigv4_auth_cls(credentials, "bedrock-mantle", self._region).add_auth(aws_request) + request.headers.clear() + request.headers.update(dict(aws_request.headers.items())) + + +def _has_explicit_aws_auth( + *, + aws_profile: str | None, + aws_access_key_id: str | None, + aws_secret_access_key: str | None, + aws_session_token: str | None, + aws_credentials_provider: AwsCredentialsProvider | None, +) -> bool: + return any( + value is not None + for value in ( + aws_profile, + aws_access_key_id, + aws_secret_access_key, + aws_session_token, + aws_credentials_provider, + ) + ) + + +def _validate_explicit_aws_auth( + *, + aws_profile: str | None, + aws_access_key_id: str | None, + aws_secret_access_key: str | None, + aws_session_token: str | None, + aws_credentials_provider: AwsCredentialsProvider | None, +) -> None: + if (aws_access_key_id is None) != (aws_secret_access_key is None): + raise OpenAIError("The `aws_access_key_id` and `aws_secret_access_key` arguments must be provided together.") + + credential_sources = sum( + ( + aws_profile is not None, + aws_access_key_id is not None, + aws_credentials_provider is not None, + ) + ) + if credential_sources > 1: + raise OpenAIError( + "The `aws_profile`, explicit AWS credentials, and `aws_credentials_provider` arguments are mutually exclusive." + ) + + if aws_session_token is not None and aws_access_key_id is None: + raise OpenAIError("The `aws_session_token` argument requires explicit AWS access key credentials.") + + +def _resolve_aws_region(aws_region: str | None) -> str: + region = aws_region or os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") + if region is None or not region.strip(): + raise OpenAIError("AWS credential authentication requires `aws_region`, `AWS_REGION`, or `AWS_DEFAULT_REGION`.") + return region.strip() + + +def _resolve_bedrock_env_token() -> str | None: + if "AWS_BEARER_TOKEN_BEDROCK" not in os.environ: + return None + + try: + session_module = importlib.import_module("botocore.session") + except ImportError: + return os.environ.get("AWS_BEARER_TOKEN_BEDROCK") or None + + auth_token = session_module.Session().get_auth_token(signing_name="bedrock") + if auth_token is None: + return None + + get_frozen_token = getattr(auth_token, "get_frozen_token", None) + if callable(get_frozen_token): + auth_token = get_frozen_token() + token = cast(str, auth_token.token) + return token or None def _normalize_bedrock_base_url(base_url: str | httpx.URL) -> httpx.URL: @@ -98,7 +274,14 @@ class BedrockOpenAI(OpenAI): """API client for Amazon Bedrock's OpenAI-compatible endpoint.""" _bedrock_token_provider: BedrockTokenProvider | None + _bedrock_aws_bearer_auth: _BedrockAwsBearerAuth | None + _bedrock_aws_auth: _BedrockAwsAuth | None _uses_region_derived_base_url: bool + _aws_profile: str | None + _aws_access_key_id: str | None + _aws_secret_access_key: str | None + _aws_session_token: str | None + _aws_credentials_provider: AwsCredentialsProvider | None aws_region: str | None def __init__( @@ -107,6 +290,11 @@ def __init__( api_key: str | None = None, bedrock_token_provider: BedrockTokenProvider | None = None, aws_region: str | None = None, + aws_profile: str | None = None, + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + aws_session_token: str | None = None, + aws_credentials_provider: AwsCredentialsProvider | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -123,30 +311,68 @@ def __init__( """Construct a new synchronous Amazon Bedrock client instance. This automatically infers the following arguments from their corresponding environment variables if they are not provided: - - `api_key` from `AWS_BEARER_TOKEN_BEDROCK` + - bearer authentication from `AWS_BEARER_TOKEN_BEDROCK` - `aws_region` from `AWS_REGION` or `AWS_DEFAULT_REGION` when `base_url` and `AWS_BEDROCK_BASE_URL` are not set - `base_url` from `AWS_BEDROCK_BASE_URL` - `bedrock_token_provider` is invoked before each request when provided. + `bedrock_token_provider` is invoked before each request when provided. When no bearer token is configured, + the client uses the standard AWS credential chain and SigV4 authentication. """ - if api_key is None and bedrock_token_provider is None: - api_key = os.environ.get("AWS_BEARER_TOKEN_BEDROCK") - if callable(cast(object, api_key)): raise OpenAIError("Pass refreshable Bedrock credentials via `bedrock_token_provider`, not `api_key`.") + if api_key == "": + raise OpenAIError("The `api_key` argument must not be empty.") + if api_key is not None and bedrock_token_provider is not None: raise OpenAIError("The `api_key` and `bedrock_token_provider` arguments are mutually exclusive.") - if _enforce_credentials and not api_key and bedrock_token_provider is None: - raise OpenAIError( - "Missing credentials. Please pass an `api_key` or `bedrock_token_provider`, or set the " - "`AWS_BEARER_TOKEN_BEDROCK` environment variable." - ) + explicit_bearer_auth = api_key is not None or bedrock_token_provider is not None + explicit_aws_auth = _has_explicit_aws_auth( + aws_profile=aws_profile, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_credentials_provider=aws_credentials_provider, + ) + if explicit_bearer_auth and explicit_aws_auth: + raise OpenAIError("Bearer token and AWS credential authentication arguments are mutually exclusive.") + + _validate_explicit_aws_auth( + aws_profile=aws_profile, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_credentials_provider=aws_credentials_provider, + ) + + if not explicit_bearer_auth and not explicit_aws_auth: + api_key = _resolve_bedrock_env_token() + + use_aws_auth = api_key is None and bedrock_token_provider is None + resolved_region = _resolve_aws_region(aws_region) if use_aws_auth else aws_region self._bedrock_token_provider = bedrock_token_provider + self._bedrock_aws_bearer_auth = _BedrockAwsBearerAuth() if not use_aws_auth else None + self._bedrock_aws_auth = ( + _BedrockAwsAuth( + region=cast(str, resolved_region), + profile=aws_profile, + access_key_id=aws_access_key_id, + secret_access_key=aws_secret_access_key, + session_token=aws_session_token, + credentials_provider=aws_credentials_provider, + ) + if use_aws_auth and _enforce_credentials + else None + ) self._uses_region_derived_base_url = _uses_region_derived_bedrock_base_url(base_url) - self.aws_region = aws_region + self._aws_profile = aws_profile + self._aws_access_key_id = aws_access_key_id + self._aws_secret_access_key = aws_secret_access_key + self._aws_session_token = aws_session_token + self._aws_credentials_provider = aws_credentials_provider + self.aws_region = resolved_region super().__init__( api_key=_bedrock_token_provider(bedrock_token_provider) @@ -156,7 +382,7 @@ def __init__( organization=organization, project=project, webhook_secret=webhook_secret, - base_url=_resolve_bedrock_base_url(base_url, aws_region), + base_url=_resolve_bedrock_base_url(base_url, resolved_region), websocket_base_url=websocket_base_url, timeout=timeout, max_retries=max_retries, @@ -169,11 +395,21 @@ def __init__( @override def _auth_headers(self, security: SecurityOptions) -> dict[str, str]: + if self._bedrock_aws_auth is not None: + return {} + if security.get("bearer_auth", False) or security.get("admin_api_key_auth", False): return self._bearer_auth return {} + @override + def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None: + if self._bedrock_aws_auth is not None: + return + + super()._validate_headers(headers, custom_headers) + @override def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: if ( @@ -185,6 +421,13 @@ def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: return super()._prepare_options(options) + @override + def _prepare_request(self, request: httpx.Request) -> None: + if self._bedrock_aws_auth is not None: + self._bedrock_aws_auth.sign(request) + elif self._bedrock_aws_bearer_auth is not None: + self._bedrock_aws_bearer_auth.sign(request, self.api_key) + @override def copy( self, @@ -194,6 +437,11 @@ def copy( workload_identity: WorkloadIdentity | None = None, bedrock_token_provider: BedrockTokenProvider | None = None, aws_region: str | None = None, + aws_profile: str | None = None, + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + aws_session_token: str | None = None, + aws_credentials_provider: AwsCredentialsProvider | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -219,7 +467,7 @@ def copy( raise OpenAIError("Pass refreshable Bedrock credentials via `bedrock_token_provider`, not `api_key`.") if admin_api_key is not None or workload_identity is not None: - raise OpenAIError("BedrockOpenAI only supports Bedrock bearer token authentication.") + raise OpenAIError("BedrockOpenAI only supports Bedrock bearer token or AWS credential authentication.") if api_key is not None and bedrock_token_provider is not None: raise OpenAIError("The `api_key` and `bedrock_token_provider` arguments are mutually exclusive.") @@ -236,14 +484,33 @@ def copy( elif set_default_query is not None: params = set_default_query - if api_key is not None: + aws_auth_override = _has_explicit_aws_auth( + aws_profile=aws_profile, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_credentials_provider=aws_credentials_provider, + ) + if api_key is not None or aws_auth_override: next_token_provider = None elif bedrock_token_provider is not None: next_token_provider = bedrock_token_provider else: next_token_provider = self._bedrock_token_provider - next_api_key = api_key if api_key is not None else (None if next_token_provider is not None else self.api_key) + preserve_aws_auth = ( + self._bedrock_aws_auth is not None + and not aws_auth_override + and api_key is None + and next_token_provider is None + ) + next_api_key = ( + api_key + if api_key is not None + else None + if next_token_provider is not None or preserve_aws_auth or aws_auth_override + else self.api_key + ) next_base_url = base_url if next_base_url is None and not (aws_region is not None and self._uses_region_derived_base_url): next_base_url = self.base_url @@ -252,6 +519,35 @@ def copy( api_key=next_api_key, bedrock_token_provider=next_token_provider, aws_region=aws_region if aws_region is not None else self.aws_region, + aws_profile=aws_profile if aws_profile is not None else self._aws_profile if preserve_aws_auth else None, + aws_access_key_id=( + aws_access_key_id + if aws_access_key_id is not None + else self._aws_access_key_id + if preserve_aws_auth + else None + ), + aws_secret_access_key=( + aws_secret_access_key + if aws_secret_access_key is not None + else self._aws_secret_access_key + if preserve_aws_auth + else None + ), + aws_session_token=( + aws_session_token + if aws_session_token is not None + else self._aws_session_token + if preserve_aws_auth + else None + ), + aws_credentials_provider=( + aws_credentials_provider + if aws_credentials_provider is not None + else self._aws_credentials_provider + if preserve_aws_auth + else None + ), organization=organization if organization is not None else self.organization, project=project if project is not None else self.project, webhook_secret=webhook_secret if webhook_secret is not None else self.webhook_secret, @@ -273,7 +569,14 @@ class AsyncBedrockOpenAI(AsyncOpenAI): """Async API client for Amazon Bedrock's OpenAI-compatible endpoint.""" _bedrock_token_provider: AsyncBedrockTokenProvider | None + _bedrock_aws_bearer_auth: _BedrockAwsBearerAuth | None + _bedrock_aws_auth: _BedrockAwsAuth | None _uses_region_derived_base_url: bool + _aws_profile: str | None + _aws_access_key_id: str | None + _aws_secret_access_key: str | None + _aws_session_token: str | None + _aws_credentials_provider: AwsCredentialsProvider | None aws_region: str | None def __init__( @@ -282,6 +585,11 @@ def __init__( api_key: str | None = None, bedrock_token_provider: AsyncBedrockTokenProvider | None = None, aws_region: str | None = None, + aws_profile: str | None = None, + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + aws_session_token: str | None = None, + aws_credentials_provider: AwsCredentialsProvider | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -298,30 +606,68 @@ def __init__( """Construct a new asynchronous Amazon Bedrock client instance. This automatically infers the following arguments from their corresponding environment variables if they are not provided: - - `api_key` from `AWS_BEARER_TOKEN_BEDROCK` + - bearer authentication from `AWS_BEARER_TOKEN_BEDROCK` - `aws_region` from `AWS_REGION` or `AWS_DEFAULT_REGION` when `base_url` and `AWS_BEDROCK_BASE_URL` are not set - `base_url` from `AWS_BEDROCK_BASE_URL` - `bedrock_token_provider` is invoked before each request when provided. + `bedrock_token_provider` is invoked before each request when provided. When no bearer token is configured, + the client uses the standard AWS credential chain and SigV4 authentication. """ - if api_key is None and bedrock_token_provider is None: - api_key = os.environ.get("AWS_BEARER_TOKEN_BEDROCK") - if callable(cast(object, api_key)): raise OpenAIError("Pass refreshable Bedrock credentials via `bedrock_token_provider`, not `api_key`.") + if api_key == "": + raise OpenAIError("The `api_key` argument must not be empty.") + if api_key is not None and bedrock_token_provider is not None: raise OpenAIError("The `api_key` and `bedrock_token_provider` arguments are mutually exclusive.") - if _enforce_credentials and not api_key and bedrock_token_provider is None: - raise OpenAIError( - "Missing credentials. Please pass an `api_key` or `bedrock_token_provider`, or set the " - "`AWS_BEARER_TOKEN_BEDROCK` environment variable." - ) + explicit_bearer_auth = api_key is not None or bedrock_token_provider is not None + explicit_aws_auth = _has_explicit_aws_auth( + aws_profile=aws_profile, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_credentials_provider=aws_credentials_provider, + ) + if explicit_bearer_auth and explicit_aws_auth: + raise OpenAIError("Bearer token and AWS credential authentication arguments are mutually exclusive.") + + _validate_explicit_aws_auth( + aws_profile=aws_profile, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_credentials_provider=aws_credentials_provider, + ) + + if not explicit_bearer_auth and not explicit_aws_auth: + api_key = _resolve_bedrock_env_token() + + use_aws_auth = api_key is None and bedrock_token_provider is None + resolved_region = _resolve_aws_region(aws_region) if use_aws_auth else aws_region self._bedrock_token_provider = bedrock_token_provider + self._bedrock_aws_bearer_auth = _BedrockAwsBearerAuth() if not use_aws_auth else None + self._bedrock_aws_auth = ( + _BedrockAwsAuth( + region=cast(str, resolved_region), + profile=aws_profile, + access_key_id=aws_access_key_id, + secret_access_key=aws_secret_access_key, + session_token=aws_session_token, + credentials_provider=aws_credentials_provider, + ) + if use_aws_auth and _enforce_credentials + else None + ) self._uses_region_derived_base_url = _uses_region_derived_bedrock_base_url(base_url) - self.aws_region = aws_region + self._aws_profile = aws_profile + self._aws_access_key_id = aws_access_key_id + self._aws_secret_access_key = aws_secret_access_key + self._aws_session_token = aws_session_token + self._aws_credentials_provider = aws_credentials_provider + self.aws_region = resolved_region super().__init__( api_key=( @@ -333,7 +679,7 @@ def __init__( organization=organization, project=project, webhook_secret=webhook_secret, - base_url=_resolve_bedrock_base_url(base_url, aws_region), + base_url=_resolve_bedrock_base_url(base_url, resolved_region), websocket_base_url=websocket_base_url, timeout=timeout, max_retries=max_retries, @@ -346,11 +692,21 @@ def __init__( @override def _auth_headers(self, security: SecurityOptions) -> dict[str, str]: + if self._bedrock_aws_auth is not None: + return {} + if security.get("bearer_auth", False) or security.get("admin_api_key_auth", False): return self._bearer_auth return {} + @override + def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None: + if self._bedrock_aws_auth is not None: + return + + super()._validate_headers(headers, custom_headers) + @override async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: if ( @@ -362,6 +718,13 @@ async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOp return await super()._prepare_options(options) + @override + async def _prepare_request(self, request: httpx.Request) -> None: + if self._bedrock_aws_auth is not None: + self._bedrock_aws_auth.sign(request) + elif self._bedrock_aws_bearer_auth is not None: + self._bedrock_aws_bearer_auth.sign(request, self.api_key) + @override def copy( self, @@ -371,6 +734,11 @@ def copy( workload_identity: WorkloadIdentity | None = None, bedrock_token_provider: AsyncBedrockTokenProvider | None = None, aws_region: str | None = None, + aws_profile: str | None = None, + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + aws_session_token: str | None = None, + aws_credentials_provider: AwsCredentialsProvider | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -396,7 +764,7 @@ def copy( raise OpenAIError("Pass refreshable Bedrock credentials via `bedrock_token_provider`, not `api_key`.") if admin_api_key is not None or workload_identity is not None: - raise OpenAIError("AsyncBedrockOpenAI only supports Bedrock bearer token authentication.") + raise OpenAIError("AsyncBedrockOpenAI only supports Bedrock bearer token or AWS credential authentication.") if api_key is not None and bedrock_token_provider is not None: raise OpenAIError("The `api_key` and `bedrock_token_provider` arguments are mutually exclusive.") @@ -413,14 +781,33 @@ def copy( elif set_default_query is not None: params = set_default_query - if api_key is not None: + aws_auth_override = _has_explicit_aws_auth( + aws_profile=aws_profile, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_credentials_provider=aws_credentials_provider, + ) + if api_key is not None or aws_auth_override: next_token_provider = None elif bedrock_token_provider is not None: next_token_provider = bedrock_token_provider else: next_token_provider = self._bedrock_token_provider - next_api_key = api_key if api_key is not None else (None if next_token_provider is not None else self.api_key) + preserve_aws_auth = ( + self._bedrock_aws_auth is not None + and not aws_auth_override + and api_key is None + and next_token_provider is None + ) + next_api_key = ( + api_key + if api_key is not None + else None + if next_token_provider is not None or preserve_aws_auth or aws_auth_override + else self.api_key + ) next_base_url = base_url if next_base_url is None and not (aws_region is not None and self._uses_region_derived_base_url): next_base_url = self.base_url @@ -429,6 +816,35 @@ def copy( api_key=next_api_key, bedrock_token_provider=next_token_provider, aws_region=aws_region if aws_region is not None else self.aws_region, + aws_profile=aws_profile if aws_profile is not None else self._aws_profile if preserve_aws_auth else None, + aws_access_key_id=( + aws_access_key_id + if aws_access_key_id is not None + else self._aws_access_key_id + if preserve_aws_auth + else None + ), + aws_secret_access_key=( + aws_secret_access_key + if aws_secret_access_key is not None + else self._aws_secret_access_key + if preserve_aws_auth + else None + ), + aws_session_token=( + aws_session_token + if aws_session_token is not None + else self._aws_session_token + if preserve_aws_auth + else None + ), + aws_credentials_provider=( + aws_credentials_provider + if aws_credentials_provider is not None + else self._aws_credentials_provider + if preserve_aws_auth + else None + ), organization=organization if organization is not None else self.organization, project=project if project is not None else self.project, webhook_secret=webhook_secret if webhook_secret is not None else self.webhook_secret, diff --git a/tests/lib/test_bedrock.py b/tests/lib/test_bedrock.py index dab9abd1cf..656dcb55a0 100644 --- a/tests/lib/test_bedrock.py +++ b/tests/lib/test_bedrock.py @@ -2,12 +2,14 @@ import json from typing import Any, Union, Protocol, cast +from pathlib import Path import httpx import pytest from httpx import URL from respx import MockRouter +import openai.lib.bedrock as bedrock_module from openai import OpenAIError, NotFoundError from tests.utils import update_env from openai._types import Omit @@ -76,6 +78,13 @@ class MockRequestCall(Protocol): request: httpx.Request +class MockAwsCredentials: + def __init__(self, access_key: str, secret_key: str, token: str | None = None) -> None: + self.access_key = access_key + self.secret_key = secret_key + self.token = token + + def make_sync_client(**kwargs: Any) -> BedrockOpenAI: return BedrockOpenAI(http_client=httpx.Client(trust_env=False), **kwargs) @@ -123,6 +132,79 @@ def test_bedrock_config_precedence(client_cls: type[Client]) -> None: assert client.api_key == "explicit token" +@pytest.mark.respx() +def test_env_bearer_does_not_require_botocore(monkeypatch: pytest.MonkeyPatch, respx_mock: MockRouter) -> None: + real_import_module = bedrock_module.importlib.import_module + + def import_module(name: str) -> Any: + if name.startswith("botocore"): + raise ImportError(name) + return real_import_module(name) + + monkeypatch.setattr(bedrock_module.importlib, "import_module", import_module) + respx_mock.post("https://example.com/openai/v1/responses").mock( + return_value=httpx.Response(200, json=RESPONSE_BODY) + ) + with update_env( + AWS_BEDROCK_BASE_URL="https://example.com/openai/v1", + AWS_BEARER_TOKEN_BEDROCK="env token", + ): + client = make_sync_client() + + client.responses.create(model="gpt-4o", input="hello") + + request = cast("list[MockRequestCall]", respx_mock.calls)[0].request + assert request.headers["Authorization"] == "Bearer env token" + + +def test_empty_env_bearer_without_botocore_uses_aws_credentials(monkeypatch: pytest.MonkeyPatch) -> None: + real_import_module = bedrock_module.importlib.import_module + + def import_module(name: str) -> Any: + if name.startswith("botocore"): + raise ImportError(name) + return real_import_module(name) + + monkeypatch.setattr(bedrock_module.importlib, "import_module", import_module) + with update_env(AWS_BEARER_TOKEN_BEDROCK="", AWS_REGION="us-east-1"): + with pytest.raises(OpenAIError, match="requires botocore"): + BedrockOpenAI() + + +@pytest.mark.respx() +def test_env_bearer_uses_botocore_bearer_auth(monkeypatch: pytest.MonkeyPatch, respx_mock: MockRouter) -> None: + auth_module = bedrock_module.importlib.import_module("botocore.auth") + calls = 0 + real_add_auth = auth_module.BearerAuth.add_auth + + def add_auth(auth: object, request: object) -> None: + nonlocal calls + calls += 1 + real_add_auth(auth, request) + + monkeypatch.setattr(auth_module.BearerAuth, "add_auth", add_auth) + respx_mock.post("https://example.com/openai/v1/responses").mock( + return_value=httpx.Response(200, json=RESPONSE_BODY) + ) + with update_env(AWS_BEARER_TOKEN_BEDROCK="env token"): + client = make_sync_client(base_url="https://example.com/openai/v1") + + client.responses.create(model="gpt-4o", input="hello") + + request = cast("list[MockRequestCall]", respx_mock.calls)[0].request + assert request.headers["Authorization"] == "Bearer env token" + assert calls == 1 + + +@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) +def test_empty_env_bearer_falls_back_to_aws_credentials(client_cls: type[Client]) -> None: + with update_env(AWS_BEARER_TOKEN_BEDROCK="", AWS_REGION="us-east-1"): + client = make_sync_client() if client_cls is BedrockOpenAI else make_async_client() + + assert client.api_key == "" + assert client._bedrock_aws_auth is not None + + @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) def test_bedrock_region_precedence(client_cls: type[Client]) -> None: with update_env(AWS_BEDROCK_BASE_URL=Omit(), AWS_REGION="us-east-1", AWS_DEFAULT_REGION="us-west-2"): @@ -170,8 +252,10 @@ def test_does_not_use_openai_api_key(client_cls: type[Client]) -> None: AWS_BEARER_TOKEN_BEDROCK=Omit(), AWS_BEDROCK_BASE_URL="https://example.com/openai/v1", ): - with pytest.raises(OpenAIError, match="AWS_BEARER_TOKEN_BEDROCK"): - client_cls() + client = client_cls() + + assert client.api_key == "" + assert client._bedrock_aws_auth is not None @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) @@ -184,6 +268,33 @@ def test_rejects_static_token_and_provider(client_cls: type[Client]) -> None: ) +@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) +def test_rejects_empty_explicit_bearer_token(client_cls: type[Client]) -> None: + with pytest.raises(OpenAIError, match="must not be empty"): + client_cls(base_url="https://example.com/openai/v1", api_key="") + + +@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) +def test_rejects_bearer_and_aws_credentials(client_cls: type[Client]) -> None: + with pytest.raises(OpenAIError, match="mutually exclusive"): + client_cls( + base_url="https://example.com/openai/v1", + api_key="token", + aws_access_key_id="access key", + aws_secret_access_key="secret key", + ) + + +@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) +def test_rejects_partial_explicit_aws_credentials(client_cls: type[Client]) -> None: + with pytest.raises(OpenAIError, match="must be provided together"): + client_cls( + base_url="https://example.com/openai/v1", + aws_region="us-east-1", + aws_access_key_id="access key", + ) + + @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) def test_requires_refreshable_tokens_to_use_provider_option(client_cls: type[Client]) -> None: with pytest.raises(OpenAIError, match="bedrock_token_provider"): @@ -240,6 +351,59 @@ async def test_token_provider_refresh_async(respx_mock: MockRouter) -> None: assert calls[1].request.headers["Authorization"] == "Bearer second" +@pytest.mark.respx() +def test_explicit_aws_credentials_override_ambient_bearer(respx_mock: MockRouter) -> None: + respx_mock.post("https://example.com/openai/v1/responses").mock( + return_value=httpx.Response(200, json=RESPONSE_BODY) + ) + with update_env(AWS_BEARER_TOKEN_BEDROCK="ambient token"): + client = BedrockOpenAI( + base_url="https://example.com/openai/v1", + aws_region="us-east-1", + aws_access_key_id="access key", + aws_secret_access_key="secret key", + aws_session_token="session token", + http_client=httpx.Client(trust_env=False), + ) + + client.responses.create(model="gpt-4o", input="hello") + + request = cast("list[MockRequestCall]", respx_mock.calls)[0].request + assert request.headers["Authorization"].startswith("AWS4-HMAC-SHA256 Credential=access key/") + assert request.headers["X-Amz-Security-Token"] == "session token" + + +@pytest.mark.respx() +def test_aws_credentials_provider_refreshes_before_retries(respx_mock: MockRouter) -> None: + respx_mock.post("https://example.com/openai/v1/responses").mock( + side_effect=[ + httpx.Response(500, json={"error": "server error"}), + httpx.Response(200, json=RESPONSE_BODY), + ] + ) + credentials = iter( + [ + MockAwsCredentials("first access key", "first secret", "first session token"), + MockAwsCredentials("second access key", "second secret", "second session token"), + ] + ) + client = BedrockOpenAI( + base_url="https://example.com/openai/v1", + aws_region="us-east-1", + aws_credentials_provider=lambda: next(credentials), + http_client=httpx.Client(trust_env=False), + max_retries=1, + ) + + client.responses.create(model="gpt-4o", input="hello") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert "Credential=first access key/" in calls[0].request.headers["Authorization"] + assert calls[0].request.headers["X-Amz-Security-Token"] == "first session token" + assert "Credential=second access key/" in calls[1].request.headers["Authorization"] + assert calls[1].request.headers["X-Amz-Security-Token"] == "second session token" + + def test_preserves_token_provider_across_with_options() -> None: client = BedrockOpenAI( base_url="https://example.com/openai/v1", @@ -252,6 +416,48 @@ def test_preserves_token_provider_across_with_options() -> None: assert copied_client._refresh_api_key() == "provider token" +def test_preserves_aws_credentials_across_with_options() -> None: + client = BedrockOpenAI( + base_url="https://example.com/openai/v1", + aws_region="us-east-1", + aws_access_key_id="access key", + aws_secret_access_key="secret key", + http_client=httpx.Client(trust_env=False), + ) + + copied_client = client.with_options(timeout=1) + + assert copied_client._bedrock_aws_auth is not None + assert copied_client._aws_access_key_id == "access key" + + +@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) +def test_with_options_replaces_the_aws_credential_source(client_cls: type[Client], tmp_path: Path) -> None: + config_path = tmp_path / "config" + config_path.write_text("[profile other-profile]\nregion = us-east-1\n") + explicit_credentials_client = client_cls( + base_url="https://example.com/openai/v1", + aws_region="us-east-1", + aws_access_key_id="access key", + aws_secret_access_key="secret key", + ) + with update_env(AWS_CONFIG_FILE=str(config_path)): + profile_client = explicit_credentials_client.with_options(aws_profile="other-profile") + + assert profile_client._aws_profile == "other-profile" + assert profile_client._aws_access_key_id is None + assert profile_client._aws_secret_access_key is None + + explicit_credentials_client = profile_client.with_options( + aws_access_key_id="replacement access key", + aws_secret_access_key="replacement secret key", + ) + + assert explicit_credentials_client._aws_profile is None + assert explicit_credentials_client._aws_access_key_id == "replacement access key" + assert explicit_credentials_client._aws_secret_access_key == "replacement secret key" + + @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) def test_with_options_api_key_replaces_token_provider(client_cls: type[Client]) -> None: client = ( @@ -311,7 +517,7 @@ def test_with_options_aws_region_keeps_explicit_base_url(client_cls: type[Client def test_rejects_non_bedrock_copy_auth(copy_kwargs: dict[str, Any]) -> None: client = make_sync_client(base_url="https://example.com/openai/v1", api_key="token") - with pytest.raises(OpenAIError, match="only supports Bedrock bearer token authentication"): + with pytest.raises(OpenAIError, match="only supports Bedrock bearer token or AWS credential authentication"): client.with_options(**copy_kwargs) From ee2213e0c190bae68d950510f80f8f4cc97967a5 Mon Sep 17 00:00:00 2001 From: Hayden Date: Fri, 12 Jun 2026 10:26:32 -0700 Subject: [PATCH 2/9] Harden AWS authentication for Bedrock --- README.md | 32 +- pyproject.toml | 1 + requirements-dev.lock | 20 +- src/openai/__init__.py | 7 + src/openai/_utils/_logs.py | 2 +- src/openai/lib/_bedrock_auth.py | 253 ++++++ src/openai/lib/bedrock.py | 950 ++++++++++++--------- tests/fixtures/bedrock_auth/v1/cases.json | 252 ++++++ tests/fixtures/bedrock_auth/v1/schema.json | 78 ++ tests/lib/test_bedrock.py | 253 +++++- tests/lib/test_bedrock_auth_conformance.py | 403 +++++++++ tests/test_module_client.py | 53 ++ 12 files changed, 1866 insertions(+), 438 deletions(-) create mode 100644 src/openai/lib/_bedrock_auth.py create mode 100644 tests/fixtures/bedrock_auth/v1/cases.json create mode 100644 tests/fixtures/bedrock_auth/v1/schema.json create mode 100644 tests/lib/test_bedrock_auth_conformance.py diff --git a/README.md b/README.md index c89273612e..feb1dd0715 100644 --- a/README.md +++ b/README.md @@ -930,11 +930,18 @@ An example of using the client with Microsoft Entra ID (formerly known as Azure To use this library with [Amazon Bedrock's OpenAI-compatible API](https://docs.aws.amazon.com/bedrock/latest/userguide/models-api-compatibility.html), use the `BedrockOpenAI` class instead of the `OpenAI` class. +Install the optional Bedrock dependencies to use the standard AWS credential chain and SigV4 authentication: + +```sh +pip install 'openai[bedrock]' +``` + ```py from openai import BedrockOpenAI -# gets the bearer token from AWS_BEARER_TOKEN_BEDROCK and the region from AWS_REGION/AWS_DEFAULT_REGION -client = BedrockOpenAI() +# Uses your normal AWS credentials. You can omit aws_region when it is +# configured through AWS_REGION, AWS_DEFAULT_REGION, or your AWS profile. +client = BedrockOpenAI(aws_region="us-west-2") response = client.responses.create( model="openai.gpt-5.4", @@ -946,31 +953,30 @@ print(response.output_text) `BedrockOpenAI` configures AWS authentication and the Bedrock Mantle endpoint, then uses the normal SDK resources. AWS controls which endpoints and features are supported; unsupported calls surface the provider's normal HTTP errors through the SDK. -Pass `base_url` or set `AWS_BEDROCK_BASE_URL` to override the derived `https://bedrock-mantle..api.aws/openai/v1` endpoint. The legacy module client supports `openai.api_type = "amazon-bedrock"` or `OPENAI_API_TYPE=amazon-bedrock`. - -Set `AWS_BEARER_TOKEN_BEDROCK` to an [Amazon Bedrock API key](https://docs.aws.amazon.com/bedrock/latest/userguide/api-keys.html). To refresh tokens yourself, pass a provider instead of `api_key`: +The default AWS credential chain supports environment credentials, shared credentials and config files, named profiles, SSO and assume-role profiles, and workload credentials such as ECS, EKS, and EC2 metadata. To select a named profile: ```py client = BedrockOpenAI( - aws_region="us-west-2", - bedrock_token_provider=lambda: refresh_bedrock_token(), + aws_profile="my-profile", ) ``` -To use the standard AWS credential chain and SigV4 authentication, install the Bedrock extra and omit bearer-token configuration: +You can also pass explicit temporary credentials or an `aws_credentials_provider` that returns botocore-compatible credentials. Explicit bearer and AWS credential options are mutually exclusive. -```sh -pip install 'openai[bedrock]' -``` +Pass `base_url` or set `AWS_BEDROCK_BASE_URL` to override the derived `https://bedrock-mantle..api.aws/openai/v1` endpoint. The legacy module client supports `openai.api_type = "amazon-bedrock"` or `OPENAI_API_TYPE=amazon-bedrock`. + +Normal SDK requests use replayable, fully signed bodies. Low-level one-shot request streams are signed with `UNSIGNED-PAYLOAD` only when retries are disabled with `max_retries=0`; buffering is recommended because streamed request bodies cannot be safely retried. + +Bearer tokens remain available as a compatibility or manual authentication mode. Set `AWS_BEARER_TOKEN_BEDROCK` to an [Amazon Bedrock API key](https://docs.aws.amazon.com/bedrock/latest/userguide/api-keys.html), pass `api_key`, or provide a refresh callback: ```py client = BedrockOpenAI( aws_region="us-west-2", - aws_profile="my-profile", # optional; otherwise uses the default AWS credential chain + bedrock_token_provider=lambda: refresh_bedrock_token(), ) ``` -You can also pass explicit temporary credentials or an `aws_credentials_provider` that returns botocore-compatible credentials. Explicit bearer and AWS credential options are mutually exclusive. Without explicit authentication, `AWS_BEARER_TOKEN_BEDROCK` takes precedence over the default AWS credential chain. +Without explicit authentication, `AWS_BEARER_TOKEN_BEDROCK` takes precedence over the default AWS credential chain for backwards compatibility. ## Versioning diff --git a/pyproject.toml b/pyproject.toml index d84ebfb1b4..0faa3167ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ dev-dependencies = [ "respx", "pytest", "pytest-asyncio", + "jsonschema>=4.23.0", "ruff", "time-machine", "nox", diff --git a/requirements-dev.lock b/requirements-dev.lock index 7d6764ed7a..61f1f31d58 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -30,8 +30,10 @@ async-timeout==5.0.1 # via aiohttp attrs==25.4.0 # via aiohttp + # via jsonschema # via nox # via outcome + # via referencing # via trio awscrt==0.31.2 # via botocore @@ -106,6 +108,9 @@ jiter==0.12.0 # via openai jmespath==1.1.0 # via botocore +jsonschema==4.25.1 +jsonschema-specifications==2025.9.1 + # via jsonschema markdown-it-py==3.0.0 # via rich mdurl==0.1.2 @@ -167,16 +172,23 @@ pytest==8.4.2 pytest-asyncio==1.2.0 pytest-xdist==3.8.0 python-dateutil==2.9.0.post0 + # via botocore # via pandas # via time-machine pytz==2025.2 # via pandas +referencing==0.36.2 + # via jsonschema + # via jsonschema-specifications requests==2.32.5 # via azure-core # via msal respx==0.22.0 rich==14.2.0 # via inline-snapshot +rpds-py==0.27.1 + # via jsonschema + # via referencing ruff==0.14.7 six==1.17.0 # via python-dateutil @@ -200,9 +212,11 @@ trio==0.31.0 types-pyaudio==0.2.16.20250801 types-pytz==2025.2.0.20251108 # via pandas-stubs -types-requests==2.32.4.20250913 +types-requests==2.31.0.6 # via types-tqdm types-tqdm==4.67.0.20250809 +types-urllib3==1.26.25.14 + # via types-requests typing-extensions==4.15.0 # via aiosignal # via anyio @@ -217,16 +231,16 @@ typing-extensions==4.15.0 # via pydantic-core # via pyright # via pytest-asyncio + # via referencing # via typing-inspection # via virtualenv typing-inspection==0.4.2 # via pydantic tzdata==2025.2 # via pandas -urllib3==2.5.0 +urllib3==1.26.20 # via botocore # via requests - # via types-requests virtualenv==20.35.4 # via nox websockets==15.0.1 diff --git a/src/openai/__init__.py b/src/openai/__init__.py index 3786d106cb..b4de9c5754 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -313,6 +313,13 @@ def api_key(self, value: str | None) -> None: # type: ignore _bedrock_api_key = value + @override + def _refresh_api_key(self) -> str: + if api_key is not None: + return api_key + + return super()._refresh_api_key() + class _AmbiguousModuleClientUsageError(OpenAIError): def __init__(self) -> None: diff --git a/src/openai/_utils/_logs.py b/src/openai/_utils/_logs.py index 376946933c..eaffa5ec7a 100644 --- a/src/openai/_utils/_logs.py +++ b/src/openai/_utils/_logs.py @@ -8,7 +8,7 @@ httpx_logger: logging.Logger = logging.getLogger("httpx") -SENSITIVE_HEADERS = {"api-key", "authorization"} +SENSITIVE_HEADERS = {"api-key", "authorization", "x-amz-security-token"} def _basic_config() -> None: diff --git a/src/openai/lib/_bedrock_auth.py b/src/openai/lib/_bedrock_auth.py new file mode 100644 index 0000000000..140ca894b0 --- /dev/null +++ b/src/openai/lib/_bedrock_auth.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +import os +import importlib +from typing import Literal, Mapping, Callable, Protocol, cast +from dataclasses import field, dataclass + +from .._exceptions import OpenAIError + +AwsCredentialsProvider = Callable[[], object] + + +class _BotocoreSession(Protocol): + def get_credentials(self) -> object | None: ... + + +_AUTHORIZATION = "authorization" +_UNSIGNED_PAYLOAD = "UNSIGNED-PAYLOAD" +_AWS_SIGNING_HEADERS = ( + _AUTHORIZATION, + "x-amz-content-sha256", + "x-amz-date", + "x-amz-security-token", +) + + +@dataclass(frozen=True) +class BedrockBearerAuthConfig: + source: Literal["explicit", "provider", "environment"] + region_source: Literal["explicit", "environment"] | None = None + + +@dataclass(frozen=True) +class BedrockAwsAuthConfig: + region: str + source: Literal["static", "profile", "provider", "default"] + region_source: Literal["explicit", "environment", "profile"] = "explicit" + profile: str | None = None + access_key_id: str | None = field(default=None, repr=False) + secret_access_key: str | None = field(default=None, repr=False) + session_token: str | None = field(default=None, repr=False) + credentials_provider: AwsCredentialsProvider | None = field(default=None, repr=False, compare=False) + + +class BedrockAwsAuth: + def __init__(self, config: BedrockAwsAuthConfig, *, session: _BotocoreSession | None = None) -> None: + try: + auth_module = importlib.import_module("botocore.auth") + session_module = importlib.import_module("botocore.session") + awsrequest_module = importlib.import_module("botocore.awsrequest") + credentials_module = importlib.import_module("botocore.credentials") + except ImportError as exc: + raise OpenAIError( + "Bedrock AWS authentication requires optional AWS dependencies. " + "Install them with `pip install openai[bedrock]` and try again." + ) from exc + + if session is None: + try: + session = session_module.Session(profile=config.profile) + except Exception as exc: + raise OpenAIError( + "Failed to resolve AWS credentials for Bedrock. Verify your AWS profile, environment variables, " + "or runtime identity configuration and try again." + ) from exc + + assert session is not None + self.config = config + self._session = session + self._credentials_provider = config.credentials_provider + self._explicit_credentials = ( + credentials_module.Credentials(config.access_key_id, config.secret_access_key, config.session_token) + if config.access_key_id is not None and config.secret_access_key is not None + else None + ) + self._aws_request_cls = awsrequest_module.AWSRequest + self._sigv4_auth_cls = auth_module.SigV4Auth + + @classmethod + def resolve( + cls, + *, + region: str | None, + profile: str | None, + access_key_id: str | None, + secret_access_key: str | None, + session_token: str | None, + credentials_provider: AwsCredentialsProvider | None, + ) -> BedrockAwsAuth: + try: + session_module = importlib.import_module("botocore.session") + except ImportError as exc: + raise OpenAIError( + "Bedrock AWS authentication requires optional AWS dependencies. " + "Install them with `pip install openai[bedrock]` and try again." + ) from exc + + try: + session = session_module.Session(profile=profile) + resolved_region, region_source = resolve_aws_region_with_source(region, session=session) + except OpenAIError: + raise + except Exception as exc: + raise OpenAIError( + "Failed to resolve AWS credentials for Bedrock. Verify your AWS profile, environment variables, " + "or runtime identity configuration and try again." + ) from exc + + source: Literal["static", "profile", "provider", "default"] + if access_key_id is not None: + source = "static" + elif profile is not None: + source = "profile" + elif credentials_provider is not None: + source = "provider" + else: + source = "default" + + config = BedrockAwsAuthConfig( + region=resolved_region, + source=source, + region_source=region_source, + profile=profile, + access_key_id=access_key_id, + secret_access_key=secret_access_key, + session_token=session_token, + credentials_provider=credentials_provider, + ) + return cls(config, session=session) + + def sign(self, *, method: str, url: str, headers: Mapping[str, str], body: bytes | None) -> dict[str, str]: + try: + credentials = ( + self._credentials_provider() + if self._credentials_provider is not None + else self._explicit_credentials or self._session.get_credentials() + ) + if credentials is None: + raise OpenAIError( + "Could not find credentials for Bedrock. Pass a bearer credential or AWS credentials, " + "set `AWS_BEARER_TOKEN_BEDROCK`, or configure the default AWS credential chain." + ) + + get_frozen_credentials = getattr(credentials, "get_frozen_credentials", None) + if callable(get_frozen_credentials): + credentials = get_frozen_credentials() + + signed_headers = { + name: value for name, value in headers.items() if name.lower() not in _AWS_SIGNING_HEADERS + } + if body is None: + signed_headers["X-Amz-Content-SHA256"] = _UNSIGNED_PAYLOAD + + aws_request = self._aws_request_cls( + method=method, + url=url, + data=body, + headers=signed_headers, + ) + self._sigv4_auth_cls(credentials, "bedrock-mantle", self.config.region).add_auth(aws_request) + except OpenAIError: + raise + except Exception as exc: + raise OpenAIError( + "Failed to resolve AWS credentials for Bedrock. Verify your AWS profile, environment variables, " + "or runtime identity configuration and try again." + ) from exc + + return dict(aws_request.headers.items()) + + +def resolve_aws_region_with_source( + aws_region: str | None, *, session: object | None = None +) -> tuple[str, Literal["explicit", "environment", "profile"]]: + region = aws_region + source: Literal["explicit", "environment", "profile"] = "explicit" + if region is None or not region.strip(): + region = os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") + source = "environment" + if (region is None or not region.strip()) and session is not None: + get_config_variable = getattr(session, "get_config_variable", None) + if callable(get_config_variable): + region = cast("str | None", get_config_variable("region")) + source = "profile" + + if region is None or not region.strip(): + raise OpenAIError( + "Bedrock requires an AWS region. Pass `aws_region`, or set `AWS_REGION` or `AWS_DEFAULT_REGION`." + ) + + return region.strip(), source + + +def resolve_aws_region(aws_region: str | None, *, session: object | None = None) -> str: + return resolve_aws_region_with_source(aws_region, session=session)[0] + + +def resolve_bedrock_env_token() -> str | None: + return os.environ.get("AWS_BEARER_TOKEN_BEDROCK") or None + + +def has_explicit_aws_auth( + *, + aws_profile: str | None, + aws_access_key_id: str | None, + aws_secret_access_key: str | None, + aws_session_token: str | None, + aws_credentials_provider: AwsCredentialsProvider | None, +) -> bool: + return any( + value is not None + for value in ( + aws_profile, + aws_access_key_id, + aws_secret_access_key, + aws_session_token, + aws_credentials_provider, + ) + ) + + +def validate_explicit_aws_auth( + *, + aws_profile: str | None, + aws_access_key_id: str | None, + aws_secret_access_key: str | None, + aws_session_token: str | None, + aws_credentials_provider: AwsCredentialsProvider | None, +) -> None: + if (aws_access_key_id is None) != (aws_secret_access_key is None): + raise OpenAIError( + "Static AWS credentials require both `aws_access_key_id` and `aws_secret_access_key`. " + "An `aws_session_token` may only be used with both." + ) + + credential_sources = sum( + ( + aws_profile is not None, + aws_access_key_id is not None, + aws_credentials_provider is not None, + ) + ) + if credential_sources > 1: + raise OpenAIError( + "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " + "static AWS credentials, profile, or credential provider." + ) + + if aws_session_token is not None and aws_access_key_id is None: + raise OpenAIError( + "Static AWS credentials require both `aws_access_key_id` and `aws_secret_access_key`. " + "An `aws_session_token` may only be used with both." + ) diff --git a/src/openai/lib/bedrock.py b/src/openai/lib/bedrock.py index a7dfae7be7..802f476d84 100644 --- a/src/openai/lib/bedrock.py +++ b/src/openai/lib/bedrock.py @@ -3,197 +3,78 @@ import os import re import inspect -import importlib -from typing import Any, Mapping, Callable, Awaitable, cast -from typing_extensions import Self, override +from typing import Any, Literal, Mapping, Callable, Awaitable, cast +from dataclasses import replace +from typing_extensions import Self, Unpack, override import httpx from ..auth import WorkloadIdentity -from .._types import NOT_GIVEN, Headers, Timeout, NotGiven -from .._utils import is_given +from .._types import NOT_GIVEN, Omit, Headers, Timeout, NotGiven, HttpxSendArgs +from .._utils import asyncify, is_given from .._client import OpenAI, AsyncOpenAI from .._models import SecurityOptions, FinalRequestOptions from .._exceptions import OpenAIError from .._base_client import DEFAULT_MAX_RETRIES +from ._bedrock_auth import ( + BedrockAwsAuth as _BedrockAwsAuth, + BedrockAwsAuthConfig as _BedrockAwsAuthConfig, + AwsCredentialsProvider, + BedrockBearerAuthConfig as _BedrockBearerAuthConfig, + resolve_aws_region as _resolve_aws_region, + has_explicit_aws_auth as _has_explicit_aws_auth, + resolve_bedrock_env_token as _resolve_bedrock_env_token, + validate_explicit_aws_auth as _validate_explicit_aws_auth, + resolve_aws_region_with_source as _resolve_aws_region_with_source, +) BedrockTokenProvider = Callable[[], str] AsyncBedrockTokenProvider = Callable[[], "str | Awaitable[str]"] -AwsCredentialsProvider = Callable[[], object] - - -class _BedrockAwsBearerAuth: - def __init__(self) -> None: - try: - auth_module = importlib.import_module("botocore.auth") - awsrequest_module = importlib.import_module("botocore.awsrequest") - tokens_module = importlib.import_module("botocore.tokens") - except ImportError: - self._bearer_auth_cls = None - self._aws_request_cls = None - self._frozen_auth_token_cls = None - return - self._bearer_auth_cls = auth_module.BearerAuth - self._aws_request_cls = awsrequest_module.AWSRequest - self._frozen_auth_token_cls = tokens_module.FrozenAuthToken +_BEDROCK_AUTH_INTENT_EXTENSION = "openai.bedrock_auth_intent" +_BEDROCK_AUTH_INTENT_DEFAULT = "default" +_BEDROCK_AUTH_INTENT_OMIT = "omit" +_BEDROCK_AUTH_INTENT_OVERRIDE = "override" +_BEDROCK_MAX_RETRIES_EXTENSION = "openai.bedrock_max_retries" +_AWS_SIGNING_HEADERS = ("authorization", "x-amz-content-sha256", "x-amz-date", "x-amz-security-token") - def sign(self, request: httpx.Request, token: str) -> None: - if self._bearer_auth_cls is None or self._aws_request_cls is None or self._frozen_auth_token_cls is None: - return - headers = dict(request.headers) - headers.pop("authorization", None) - aws_request = self._aws_request_cls( - method=request.method, - url=str(request.url), - data=request.read(), - headers=headers, - ) - self._bearer_auth_cls(self._frozen_auth_token_cls(token)).add_auth(aws_request) - request.headers.clear() - request.headers.update(dict(aws_request.headers.items())) +def _authorization_intent(*header_sets: Mapping[str, str | Omit]) -> str: + intent = _BEDROCK_AUTH_INTENT_DEFAULT + for headers in header_sets: + for name, value in headers.items(): + if name.lower() == "authorization": + intent = _BEDROCK_AUTH_INTENT_OMIT if isinstance(value, Omit) else _BEDROCK_AUTH_INTENT_OVERRIDE + return intent -class _BedrockAwsAuth: - def __init__( - self, - *, - region: str, - profile: str | None, - access_key_id: str | None, - secret_access_key: str | None, - session_token: str | None, - credentials_provider: AwsCredentialsProvider | None, - ) -> None: - try: - auth_module = importlib.import_module("botocore.auth") - session_module = importlib.import_module("botocore.session") - awsrequest_module = importlib.import_module("botocore.awsrequest") - credentials_module = importlib.import_module("botocore.credentials") - except ImportError as exc: - raise OpenAIError( - "AWS credential authentication requires botocore. Install it with `pip install openai[bedrock]`." - ) from exc - - session = session_module.Session(profile=profile) - service_model = session.get_service_model("bedrock-runtime") - auth_options = cast("list[str]", service_model.metadata.get("auth", [])) - if auth_module.resolve_auth_scheme_preference(["sigv4"], auth_options) != "v4": - raise OpenAIError("The installed botocore version does not support Bedrock SigV4 authentication.") - - self._region = region - self._session = session - self._credentials_provider = credentials_provider - self._explicit_credentials = ( - credentials_module.Credentials(access_key_id, secret_access_key, session_token) - if access_key_id is not None and secret_access_key is not None - else None - ) - self._aws_request_cls = awsrequest_module.AWSRequest - self._sigv4_auth_cls = auth_module.SigV4Auth - - def sign(self, request: httpx.Request) -> None: - credentials = ( - self._credentials_provider() - if self._credentials_provider is not None - else self._explicit_credentials or self._session.get_credentials() - ) - if credentials is None: - raise OpenAIError( - "Could not resolve AWS credentials. Configure the standard AWS credential chain or pass explicit " - "AWS credentials to the Bedrock client." - ) - - get_frozen_credentials = getattr(credentials, "get_frozen_credentials", None) - if callable(get_frozen_credentials): - credentials = get_frozen_credentials() - - headers = dict(request.headers) - headers.pop("authorization", None) - aws_request = self._aws_request_cls( - method=request.method, - url=str(request.url), - data=request.read(), - headers=headers, - ) - self._sigv4_auth_cls(credentials, "bedrock-mantle", self._region).add_auth(aws_request) - request.headers.clear() - request.headers.update(dict(aws_request.headers.items())) +def _same_origin(left: httpx.URL, right: httpx.URL) -> bool: + return (left.scheme, left.host, left.port) == (right.scheme, right.host, right.port) -def _has_explicit_aws_auth( - *, - aws_profile: str | None, - aws_access_key_id: str | None, - aws_secret_access_key: str | None, - aws_session_token: str | None, - aws_credentials_provider: AwsCredentialsProvider | None, -) -> bool: - return any( - value is not None - for value in ( - aws_profile, - aws_access_key_id, - aws_secret_access_key, - aws_session_token, - aws_credentials_provider, - ) - ) - +def _constructor_accepts_keyword(constructor: Callable[..., object], name: str) -> bool: + try: + parameters = inspect.signature(constructor).parameters + except (TypeError, ValueError): + return False -def _validate_explicit_aws_auth( - *, - aws_profile: str | None, - aws_access_key_id: str | None, - aws_secret_access_key: str | None, - aws_session_token: str | None, - aws_credentials_provider: AwsCredentialsProvider | None, -) -> None: - if (aws_access_key_id is None) != (aws_secret_access_key is None): - raise OpenAIError("The `aws_access_key_id` and `aws_secret_access_key` arguments must be provided together.") - - credential_sources = sum( - ( - aws_profile is not None, - aws_access_key_id is not None, - aws_credentials_provider is not None, - ) + return name in parameters or any( + parameter.kind is inspect.Parameter.VAR_KEYWORD for parameter in parameters.values() ) - if credential_sources > 1: - raise OpenAIError( - "The `aws_profile`, explicit AWS credentials, and `aws_credentials_provider` arguments are mutually exclusive." - ) - if aws_session_token is not None and aws_access_key_id is None: - raise OpenAIError("The `aws_session_token` argument requires explicit AWS access key credentials.") - - -def _resolve_aws_region(aws_region: str | None) -> str: - region = aws_region or os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") - if region is None or not region.strip(): - raise OpenAIError("AWS credential authentication requires `aws_region`, `AWS_REGION`, or `AWS_DEFAULT_REGION`.") - return region.strip() - - -def _resolve_bedrock_env_token() -> str | None: - if "AWS_BEARER_TOKEN_BEDROCK" not in os.environ: - return None +def _body_for_signing(request: httpx.Request) -> bytes | None: try: - session_module = importlib.import_module("botocore.session") - except ImportError: - return os.environ.get("AWS_BEARER_TOKEN_BEDROCK") or None - - auth_token = session_module.Session().get_auth_token(signing_name="bedrock") - if auth_token is None: - return None + return request.content + except httpx.RequestNotRead as exc: + max_retries = request.extensions.get(_BEDROCK_MAX_RETRIES_EXTENSION) + if max_retries == 0: + return None - get_frozen_token = getattr(auth_token, "get_frozen_token", None) - if callable(get_frozen_token): - auth_token = get_frozen_token() - token = cast(str, auth_token.token) - return token or None + raise OpenAIError( + "Bedrock SigV4 authentication requires a replayable request body when retries are enabled. " + "Buffer the body, set `max_retries=0` to use `UNSIGNED-PAYLOAD`, or use bearer authentication." + ) from exc def _normalize_bedrock_base_url(base_url: str | httpx.URL) -> httpx.URL: @@ -207,22 +88,41 @@ def _normalize_bedrock_base_url(base_url: str | httpx.URL) -> httpx.URL: return url.copy_with(path=path or "/") -def _resolve_bedrock_base_url(base_url: str | httpx.URL | None, aws_region: str | None) -> httpx.URL: +def _configured_aws_region(aws_region: str | None) -> str | None: + region = aws_region if aws_region is not None and aws_region.strip() else None + region = region or os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") + return region.strip() if region is not None and region.strip() else None + + +def _configured_aws_region_source(aws_region: str | None) -> Literal["explicit", "environment"] | None: + if aws_region is not None and aws_region.strip(): + return "explicit" + environment_region = os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") + if environment_region is not None and environment_region.strip(): + return "environment" + return None + + +def _resolve_bedrock_base_url( + base_url: str | httpx.URL | None, + aws_region: str | None, + *, + use_environment: bool = True, +) -> httpx.URL: """Resolve Bedrock base URL precedence from explicit, env, then region config.""" if isinstance(base_url, str) and not base_url.strip(): base_url = None - if base_url is None: + if base_url is None and use_environment: env_base_url = os.environ.get("AWS_BEDROCK_BASE_URL") if env_base_url is not None and env_base_url.strip(): base_url = env_base_url if base_url is None: - region = aws_region or os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") - if region is None or not region.strip(): + region = _configured_aws_region(aws_region) + if region is None: raise OpenAIError( - "Must provide one of the `base_url` or `aws_region` arguments, or set the " - "`AWS_BEDROCK_BASE_URL`, `AWS_REGION`, or `AWS_DEFAULT_REGION` environment variable." + "Bedrock requires an AWS region. Pass `aws_region`, or set `AWS_REGION` or `AWS_DEFAULT_REGION`." ) base_url = f"https://bedrock-mantle.{region}.api.aws/openai/v1" @@ -270,18 +170,115 @@ async def get_token() -> str: return get_token +def _resolve_bedrock_auth( + *, + api_key: str | None, + token_provider: object | None, + aws_region: str | None, + aws_profile: str | None, + aws_access_key_id: str | None, + aws_secret_access_key: str | None, + aws_session_token: str | None, + aws_credentials_provider: AwsCredentialsProvider | None, + auth_config: _BedrockBearerAuthConfig | _BedrockAwsAuthConfig | None, + enforce_credentials: bool, +) -> tuple[_BedrockBearerAuthConfig | _BedrockAwsAuthConfig, _BedrockAwsAuth | None, str | None, str | None]: + if auth_config is not None: + if isinstance(auth_config, _BedrockAwsAuthConfig): + aws_auth = _BedrockAwsAuth(auth_config) if enforce_credentials else None + return auth_config, aws_auth, api_key, auth_config.region + + return auth_config, None, api_key, aws_region + + explicit_bearer_auth = api_key is not None or token_provider is not None + explicit_aws_auth = _has_explicit_aws_auth( + aws_profile=aws_profile, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_credentials_provider=aws_credentials_provider, + ) + if explicit_bearer_auth and explicit_aws_auth: + raise OpenAIError( + "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " + "static AWS credentials, profile, or credential provider." + ) + + _validate_explicit_aws_auth( + aws_profile=aws_profile, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_credentials_provider=aws_credentials_provider, + ) + + if explicit_bearer_auth: + source: Literal["explicit", "provider"] = "provider" if token_provider is not None else "explicit" + return ( + _BedrockBearerAuthConfig(source=source, region_source=_configured_aws_region_source(aws_region)), + None, + api_key, + _configured_aws_region(aws_region), + ) + + if not explicit_aws_auth: + api_key = _resolve_bedrock_env_token() + if api_key is not None: + return ( + _BedrockBearerAuthConfig( + source="environment", + region_source=_configured_aws_region_source(aws_region), + ), + None, + api_key, + _configured_aws_region(aws_region), + ) + + if enforce_credentials: + aws_auth = _BedrockAwsAuth.resolve( + region=aws_region, + profile=aws_profile, + access_key_id=aws_access_key_id, + secret_access_key=aws_secret_access_key, + session_token=aws_session_token, + credentials_provider=aws_credentials_provider, + ) + return aws_auth.config, aws_auth, None, aws_auth.config.region + + resolved_region, region_source = _resolve_aws_region_with_source(aws_region) + aws_source: Literal["static", "profile", "provider", "default"] + if aws_access_key_id is not None: + aws_source = "static" + elif aws_profile is not None: + aws_source = "profile" + elif aws_credentials_provider is not None: + aws_source = "provider" + else: + aws_source = "default" + return ( + _BedrockAwsAuthConfig( + region=resolved_region, + source=aws_source, + region_source=region_source, + profile=aws_profile, + access_key_id=aws_access_key_id, + secret_access_key=aws_secret_access_key, + session_token=aws_session_token, + credentials_provider=aws_credentials_provider, + ), + None, + None, + resolved_region, + ) + + class BedrockOpenAI(OpenAI): """API client for Amazon Bedrock's OpenAI-compatible endpoint.""" _bedrock_token_provider: BedrockTokenProvider | None - _bedrock_aws_bearer_auth: _BedrockAwsBearerAuth | None + _bedrock_auth_config: _BedrockBearerAuthConfig | _BedrockAwsAuthConfig _bedrock_aws_auth: _BedrockAwsAuth | None _uses_region_derived_base_url: bool - _aws_profile: str | None - _aws_access_key_id: str | None - _aws_secret_access_key: str | None - _aws_session_token: str | None - _aws_credentials_provider: AwsCredentialsProvider | None aws_region: str | None def __init__( @@ -307,6 +304,8 @@ def __init__( http_client: httpx.Client | None = None, _strict_response_validation: bool = False, _enforce_credentials: bool = True, + _auth_config: _BedrockBearerAuthConfig | _BedrockAwsAuthConfig | None = None, + _base_url_is_region_derived: bool | None = None, ) -> None: """Construct a new synchronous Amazon Bedrock client instance. @@ -325,53 +324,32 @@ def __init__( raise OpenAIError("The `api_key` argument must not be empty.") if api_key is not None and bedrock_token_provider is not None: - raise OpenAIError("The `api_key` and `bedrock_token_provider` arguments are mutually exclusive.") - - explicit_bearer_auth = api_key is not None or bedrock_token_provider is not None - explicit_aws_auth = _has_explicit_aws_auth( - aws_profile=aws_profile, - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_session_token=aws_session_token, - aws_credentials_provider=aws_credentials_provider, - ) - if explicit_bearer_auth and explicit_aws_auth: - raise OpenAIError("Bearer token and AWS credential authentication arguments are mutually exclusive.") + raise OpenAIError( + "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " + "static AWS credentials, profile, or credential provider." + ) - _validate_explicit_aws_auth( + auth_config, aws_auth, api_key, resolved_region = _resolve_bedrock_auth( + api_key=api_key, + token_provider=bedrock_token_provider, + aws_region=aws_region, aws_profile=aws_profile, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, aws_credentials_provider=aws_credentials_provider, + auth_config=_auth_config, + enforce_credentials=_enforce_credentials, ) - if not explicit_bearer_auth and not explicit_aws_auth: - api_key = _resolve_bedrock_env_token() - - use_aws_auth = api_key is None and bedrock_token_provider is None - resolved_region = _resolve_aws_region(aws_region) if use_aws_auth else aws_region - self._bedrock_token_provider = bedrock_token_provider - self._bedrock_aws_bearer_auth = _BedrockAwsBearerAuth() if not use_aws_auth else None - self._bedrock_aws_auth = ( - _BedrockAwsAuth( - region=cast(str, resolved_region), - profile=aws_profile, - access_key_id=aws_access_key_id, - secret_access_key=aws_secret_access_key, - session_token=aws_session_token, - credentials_provider=aws_credentials_provider, - ) - if use_aws_auth and _enforce_credentials - else None + self._bedrock_auth_config = auth_config + self._bedrock_aws_auth = aws_auth + self._uses_region_derived_base_url = ( + _uses_region_derived_bedrock_base_url(base_url) + if _base_url_is_region_derived is None + else _base_url_is_region_derived ) - self._uses_region_derived_base_url = _uses_region_derived_bedrock_base_url(base_url) - self._aws_profile = aws_profile - self._aws_access_key_id = aws_access_key_id - self._aws_secret_access_key = aws_secret_access_key - self._aws_session_token = aws_session_token - self._aws_credentials_provider = aws_credentials_provider self.aws_region = resolved_region super().__init__( @@ -382,7 +360,11 @@ def __init__( organization=organization, project=project, webhook_secret=webhook_secret, - base_url=_resolve_bedrock_base_url(base_url, resolved_region), + base_url=_resolve_bedrock_base_url( + base_url, + resolved_region, + use_environment=_base_url_is_region_derived is not True, + ), websocket_base_url=websocket_base_url, timeout=timeout, max_retries=max_retries, @@ -393,9 +375,16 @@ def __init__( _enforce_credentials=False, ) + def _uses_aws_auth(self) -> bool: + return ( + isinstance(self._bedrock_auth_config, _BedrockAwsAuthConfig) + and not self.api_key + and self._api_key_provider is None + ) + @override def _auth_headers(self, security: SecurityOptions) -> dict[str, str]: - if self._bedrock_aws_auth is not None: + if self._uses_aws_auth(): return {} if security.get("bearer_auth", False) or security.get("admin_api_key_auth", False): @@ -405,14 +394,21 @@ def _auth_headers(self, security: SecurityOptions) -> dict[str, str]: @override def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None: - if self._bedrock_aws_auth is not None: + if self._uses_aws_auth(): return super()._validate_headers(headers, custom_headers) @override def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: - if ( + if self._uses_aws_auth(): + if options.follow_redirects: + raise OpenAIError( + "Bedrock SigV4 authentication does not support automatic redirects. " + "Send a new request to the redirect target so it can be signed again." + ) + options.follow_redirects = False + elif ( self._api_key_provider is not None and options.security.get("admin_api_key_auth", False) and not options.security.get("bearer_auth", False) @@ -421,12 +417,58 @@ def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: return super()._prepare_options(options) + @override + def _build_request(self, options: FinalRequestOptions, *, retries_taken: int = 0) -> httpx.Request: + request = super()._build_request(options, retries_taken=retries_taken) + if not self._uses_aws_auth(): + return request + + option_headers: Headers = options.headers if is_given(options.headers) else {} + request.extensions[_BEDROCK_AUTH_INTENT_EXTENSION] = _authorization_intent( + self._custom_headers, + option_headers, + ) + request.extensions[_BEDROCK_MAX_RETRIES_EXTENSION] = options.get_max_retries(self.max_retries) + return request + @override def _prepare_request(self, request: httpx.Request) -> None: - if self._bedrock_aws_auth is not None: - self._bedrock_aws_auth.sign(request) - elif self._bedrock_aws_bearer_auth is not None: - self._bedrock_aws_bearer_auth.sign(request, self.api_key) + if not self._uses_aws_auth(): + return + if self._bedrock_aws_auth is None: + assert isinstance(self._bedrock_auth_config, _BedrockAwsAuthConfig) + self._bedrock_aws_auth = _BedrockAwsAuth(self._bedrock_auth_config) + + intent = request.extensions.get(_BEDROCK_AUTH_INTENT_EXTENSION, _BEDROCK_AUTH_INTENT_DEFAULT) + if intent == _BEDROCK_AUTH_INTENT_OMIT: + for header in _AWS_SIGNING_HEADERS: + request.headers.pop(header, None) + return + if intent == _BEDROCK_AUTH_INTENT_OVERRIDE or "Authorization" in request.headers: + return + if not _same_origin(request.url, self.base_url): + raise OpenAIError("Refusing to sign a Bedrock request for an origin other than the configured `base_url`.") + + signed_headers = self._bedrock_aws_auth.sign( + method=request.method, + url=str(request.url), + headers=dict(request.headers), + body=_body_for_signing(request), + ) + request.headers.clear() + request.headers.update(signed_headers) + + @override + def _send_request( + self, + request: httpx.Request, + *, + stream: bool, + **kwargs: Unpack[HttpxSendArgs], + ) -> httpx.Response: + if self._uses_aws_auth(): + kwargs["auth"] = httpx.Auth() + return super()._send_request(request, stream=stream, **kwargs) @override def copy( @@ -470,7 +512,10 @@ def copy( raise OpenAIError("BedrockOpenAI only supports Bedrock bearer token or AWS credential authentication.") if api_key is not None and bedrock_token_provider is not None: - raise OpenAIError("The `api_key` and `bedrock_token_provider` arguments are mutually exclusive.") + raise OpenAIError( + "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " + "static AWS credentials, profile, or credential provider." + ) headers = self._custom_headers if default_headers is not None: @@ -491,6 +536,12 @@ def copy( aws_session_token=aws_session_token, aws_credentials_provider=aws_credentials_provider, ) + if (api_key is not None or bedrock_token_provider is not None) and aws_auth_override: + raise OpenAIError( + "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " + "static AWS credentials, profile, or credential provider." + ) + auth_override = api_key is not None or bedrock_token_provider is not None or aws_auth_override if api_key is not None or aws_auth_override: next_token_provider = None elif bedrock_token_provider is not None: @@ -498,69 +549,100 @@ def copy( else: next_token_provider = self._bedrock_token_provider - preserve_aws_auth = ( - self._bedrock_aws_auth is not None - and not aws_auth_override - and api_key is None - and next_token_provider is None - ) - next_api_key = ( - api_key - if api_key is not None - else None - if next_token_provider is not None or preserve_aws_auth or aws_auth_override - else self.api_key + next_auth_config: _BedrockBearerAuthConfig | _BedrockAwsAuthConfig | None + if auth_override: + next_auth_config = None + elif isinstance(self._bedrock_auth_config, _BedrockAwsAuthConfig) and self.api_key: + # The legacy module client allows a module-level API key to replace + # its construction-time default AWS authentication. + next_auth_config = None + elif aws_region is not None: + if isinstance(self._bedrock_auth_config, _BedrockAwsAuthConfig): + next_auth_config = replace( + self._bedrock_auth_config, + region=_resolve_aws_region(aws_region), + region_source="explicit", + ) + else: + next_auth_config = replace(self._bedrock_auth_config, region_source="explicit") + else: + next_auth_config = self._bedrock_auth_config + + next_aws_region = aws_region if aws_region is not None else self.aws_region + if aws_profile is not None and aws_region is None and self._bedrock_auth_config.region_source != "explicit": + next_aws_region = None + + next_api_key = api_key + if next_api_key is None and next_token_provider is None: + next_api_key = ( + None if aws_auth_override or isinstance(next_auth_config, _BedrockAwsAuthConfig) else self.api_key + ) + + blank_base_url_override = isinstance(base_url, str) and not base_url.strip() + next_base_url = None if blank_base_url_override else base_url + next_base_url_is_region_derived = False + recompute_region_base_url = self._uses_region_derived_base_url and ( + aws_region is not None or (aws_profile is not None and next_aws_region is None) ) - next_base_url = base_url - if next_base_url is None and not (aws_region is not None and self._uses_region_derived_base_url): + if blank_base_url_override: + next_base_url_is_region_derived = _uses_region_derived_bedrock_base_url(None) + elif next_base_url is None and not recompute_region_base_url: next_base_url = self.base_url - - return self.__class__( - api_key=next_api_key, - bedrock_token_provider=next_token_provider, - aws_region=aws_region if aws_region is not None else self.aws_region, - aws_profile=aws_profile if aws_profile is not None else self._aws_profile if preserve_aws_auth else None, - aws_access_key_id=( - aws_access_key_id - if aws_access_key_id is not None - else self._aws_access_key_id - if preserve_aws_auth - else None - ), - aws_secret_access_key=( - aws_secret_access_key - if aws_secret_access_key is not None - else self._aws_secret_access_key - if preserve_aws_auth - else None - ), - aws_session_token=( - aws_session_token - if aws_session_token is not None - else self._aws_session_token - if preserve_aws_auth - else None - ), - aws_credentials_provider=( - aws_credentials_provider - if aws_credentials_provider is not None - else self._aws_credentials_provider - if preserve_aws_auth - else None - ), - organization=organization if organization is not None else self.organization, - project=project if project is not None else self.project, - webhook_secret=webhook_secret if webhook_secret is not None else self.webhook_secret, - websocket_base_url=websocket_base_url if websocket_base_url is not None else self.websocket_base_url, - base_url=next_base_url, - timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, - http_client=http_client or self._client, - max_retries=max_retries if is_given(max_retries) else self.max_retries, - default_headers=headers, - default_query=params, - _enforce_credentials=True if _enforce_credentials is None else _enforce_credentials, + next_base_url_is_region_derived = self._uses_region_derived_base_url + elif next_base_url is None and next_aws_region is not None: + next_base_url = f"https://bedrock-mantle.{next_aws_region}.api.aws/openai/v1" + next_base_url_is_region_derived = True + elif next_base_url is None: + next_base_url_is_region_derived = True + + constructor_kwargs: dict[str, Any] = { + "api_key": next_api_key, + "bedrock_token_provider": next_token_provider, + "aws_region": next_aws_region, + "organization": organization if organization is not None else self.organization, + "project": project if project is not None else self.project, + "webhook_secret": webhook_secret if webhook_secret is not None else self.webhook_secret, + "websocket_base_url": websocket_base_url if websocket_base_url is not None else self.websocket_base_url, + "base_url": next_base_url, + "timeout": self.timeout if isinstance(timeout, NotGiven) else timeout, + "http_client": http_client or self._client, + "max_retries": max_retries if is_given(max_retries) else self.max_retries, + "default_headers": headers, + "default_query": params, + "_enforce_credentials": True if _enforce_credentials is None else _enforce_credentials, **_extra_kwargs, + } + aws_overrides = { + "aws_profile": aws_profile, + "aws_access_key_id": aws_access_key_id, + "aws_secret_access_key": aws_secret_access_key, + "aws_session_token": aws_session_token, + "aws_credentials_provider": aws_credentials_provider, + } + constructor_kwargs.update({name: value for name, value in aws_overrides.items() if value is not None}) + + supports_auth_config = _constructor_accepts_keyword(self.__class__.__init__, "_auth_config") + supports_base_url_provenance = _constructor_accepts_keyword( + self.__class__.__init__, "_base_url_is_region_derived" ) + if supports_auth_config: + constructor_kwargs["_auth_config"] = next_auth_config + if supports_base_url_provenance: + constructor_kwargs["_base_url_is_region_derived"] = next_base_url_is_region_derived + + copied = self.__class__(**constructor_kwargs) + if not supports_auth_config and next_auth_config is not None: + copied._bedrock_auth_config = next_auth_config + if isinstance(next_auth_config, _BedrockAwsAuthConfig): + copied._bedrock_aws_auth = _BedrockAwsAuth(next_auth_config) + copied._bedrock_token_provider = None + copied.api_key = "" + copied._api_key_provider = None + copied.aws_region = next_auth_config.region + if not supports_base_url_provenance: + copied._uses_region_derived_base_url = next_base_url_is_region_derived + + return copied with_options = copy @@ -569,14 +651,9 @@ class AsyncBedrockOpenAI(AsyncOpenAI): """Async API client for Amazon Bedrock's OpenAI-compatible endpoint.""" _bedrock_token_provider: AsyncBedrockTokenProvider | None - _bedrock_aws_bearer_auth: _BedrockAwsBearerAuth | None + _bedrock_auth_config: _BedrockBearerAuthConfig | _BedrockAwsAuthConfig _bedrock_aws_auth: _BedrockAwsAuth | None _uses_region_derived_base_url: bool - _aws_profile: str | None - _aws_access_key_id: str | None - _aws_secret_access_key: str | None - _aws_session_token: str | None - _aws_credentials_provider: AwsCredentialsProvider | None aws_region: str | None def __init__( @@ -602,6 +679,8 @@ def __init__( http_client: httpx.AsyncClient | None = None, _strict_response_validation: bool = False, _enforce_credentials: bool = True, + _auth_config: _BedrockBearerAuthConfig | _BedrockAwsAuthConfig | None = None, + _base_url_is_region_derived: bool | None = None, ) -> None: """Construct a new asynchronous Amazon Bedrock client instance. @@ -620,53 +699,32 @@ def __init__( raise OpenAIError("The `api_key` argument must not be empty.") if api_key is not None and bedrock_token_provider is not None: - raise OpenAIError("The `api_key` and `bedrock_token_provider` arguments are mutually exclusive.") - - explicit_bearer_auth = api_key is not None or bedrock_token_provider is not None - explicit_aws_auth = _has_explicit_aws_auth( - aws_profile=aws_profile, - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_session_token=aws_session_token, - aws_credentials_provider=aws_credentials_provider, - ) - if explicit_bearer_auth and explicit_aws_auth: - raise OpenAIError("Bearer token and AWS credential authentication arguments are mutually exclusive.") + raise OpenAIError( + "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " + "static AWS credentials, profile, or credential provider." + ) - _validate_explicit_aws_auth( + auth_config, aws_auth, api_key, resolved_region = _resolve_bedrock_auth( + api_key=api_key, + token_provider=bedrock_token_provider, + aws_region=aws_region, aws_profile=aws_profile, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, aws_credentials_provider=aws_credentials_provider, + auth_config=_auth_config, + enforce_credentials=_enforce_credentials, ) - if not explicit_bearer_auth and not explicit_aws_auth: - api_key = _resolve_bedrock_env_token() - - use_aws_auth = api_key is None and bedrock_token_provider is None - resolved_region = _resolve_aws_region(aws_region) if use_aws_auth else aws_region - self._bedrock_token_provider = bedrock_token_provider - self._bedrock_aws_bearer_auth = _BedrockAwsBearerAuth() if not use_aws_auth else None - self._bedrock_aws_auth = ( - _BedrockAwsAuth( - region=cast(str, resolved_region), - profile=aws_profile, - access_key_id=aws_access_key_id, - secret_access_key=aws_secret_access_key, - session_token=aws_session_token, - credentials_provider=aws_credentials_provider, - ) - if use_aws_auth and _enforce_credentials - else None + self._bedrock_auth_config = auth_config + self._bedrock_aws_auth = aws_auth + self._uses_region_derived_base_url = ( + _uses_region_derived_bedrock_base_url(base_url) + if _base_url_is_region_derived is None + else _base_url_is_region_derived ) - self._uses_region_derived_base_url = _uses_region_derived_bedrock_base_url(base_url) - self._aws_profile = aws_profile - self._aws_access_key_id = aws_access_key_id - self._aws_secret_access_key = aws_secret_access_key - self._aws_session_token = aws_session_token - self._aws_credentials_provider = aws_credentials_provider self.aws_region = resolved_region super().__init__( @@ -679,7 +737,11 @@ def __init__( organization=organization, project=project, webhook_secret=webhook_secret, - base_url=_resolve_bedrock_base_url(base_url, resolved_region), + base_url=_resolve_bedrock_base_url( + base_url, + resolved_region, + use_environment=_base_url_is_region_derived is not True, + ), websocket_base_url=websocket_base_url, timeout=timeout, max_retries=max_retries, @@ -690,9 +752,16 @@ def __init__( _enforce_credentials=False, ) + def _uses_aws_auth(self) -> bool: + return ( + isinstance(self._bedrock_auth_config, _BedrockAwsAuthConfig) + and not self.api_key + and self._api_key_provider is None + ) + @override def _auth_headers(self, security: SecurityOptions) -> dict[str, str]: - if self._bedrock_aws_auth is not None: + if self._uses_aws_auth(): return {} if security.get("bearer_auth", False) or security.get("admin_api_key_auth", False): @@ -702,14 +771,21 @@ def _auth_headers(self, security: SecurityOptions) -> dict[str, str]: @override def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None: - if self._bedrock_aws_auth is not None: + if self._uses_aws_auth(): return super()._validate_headers(headers, custom_headers) @override async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: - if ( + if self._uses_aws_auth(): + if options.follow_redirects: + raise OpenAIError( + "Bedrock SigV4 authentication does not support automatic redirects. " + "Send a new request to the redirect target so it can be signed again." + ) + options.follow_redirects = False + elif ( self._api_key_provider is not None and options.security.get("admin_api_key_auth", False) and not options.security.get("bearer_auth", False) @@ -718,12 +794,58 @@ async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOp return await super()._prepare_options(options) + @override + def _build_request(self, options: FinalRequestOptions, *, retries_taken: int = 0) -> httpx.Request: + request = super()._build_request(options, retries_taken=retries_taken) + if not self._uses_aws_auth(): + return request + + option_headers: Headers = options.headers if is_given(options.headers) else {} + request.extensions[_BEDROCK_AUTH_INTENT_EXTENSION] = _authorization_intent( + self._custom_headers, + option_headers, + ) + request.extensions[_BEDROCK_MAX_RETRIES_EXTENSION] = options.get_max_retries(self.max_retries) + return request + @override async def _prepare_request(self, request: httpx.Request) -> None: - if self._bedrock_aws_auth is not None: - self._bedrock_aws_auth.sign(request) - elif self._bedrock_aws_bearer_auth is not None: - self._bedrock_aws_bearer_auth.sign(request, self.api_key) + if not self._uses_aws_auth(): + return + if self._bedrock_aws_auth is None: + assert isinstance(self._bedrock_auth_config, _BedrockAwsAuthConfig) + self._bedrock_aws_auth = await asyncify(_BedrockAwsAuth)(self._bedrock_auth_config) + + intent = request.extensions.get(_BEDROCK_AUTH_INTENT_EXTENSION, _BEDROCK_AUTH_INTENT_DEFAULT) + if intent == _BEDROCK_AUTH_INTENT_OMIT: + for header in _AWS_SIGNING_HEADERS: + request.headers.pop(header, None) + return + if intent == _BEDROCK_AUTH_INTENT_OVERRIDE or "Authorization" in request.headers: + return + if not _same_origin(request.url, self.base_url): + raise OpenAIError("Refusing to sign a Bedrock request for an origin other than the configured `base_url`.") + + signed_headers = await asyncify(self._bedrock_aws_auth.sign)( + method=request.method, + url=str(request.url), + headers=dict(request.headers), + body=_body_for_signing(request), + ) + request.headers.clear() + request.headers.update(signed_headers) + + @override + async def _send_request( + self, + request: httpx.Request, + *, + stream: bool, + **kwargs: Unpack[HttpxSendArgs], + ) -> httpx.Response: + if self._uses_aws_auth(): + kwargs["auth"] = httpx.Auth() + return await super()._send_request(request, stream=stream, **kwargs) @override def copy( @@ -767,7 +889,10 @@ def copy( raise OpenAIError("AsyncBedrockOpenAI only supports Bedrock bearer token or AWS credential authentication.") if api_key is not None and bedrock_token_provider is not None: - raise OpenAIError("The `api_key` and `bedrock_token_provider` arguments are mutually exclusive.") + raise OpenAIError( + "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " + "static AWS credentials, profile, or credential provider." + ) headers = self._custom_headers if default_headers is not None: @@ -788,6 +913,12 @@ def copy( aws_session_token=aws_session_token, aws_credentials_provider=aws_credentials_provider, ) + if (api_key is not None or bedrock_token_provider is not None) and aws_auth_override: + raise OpenAIError( + "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " + "static AWS credentials, profile, or credential provider." + ) + auth_override = api_key is not None or bedrock_token_provider is not None or aws_auth_override if api_key is not None or aws_auth_override: next_token_provider = None elif bedrock_token_provider is not None: @@ -795,68 +926,97 @@ def copy( else: next_token_provider = self._bedrock_token_provider - preserve_aws_auth = ( - self._bedrock_aws_auth is not None - and not aws_auth_override - and api_key is None - and next_token_provider is None - ) - next_api_key = ( - api_key - if api_key is not None - else None - if next_token_provider is not None or preserve_aws_auth or aws_auth_override - else self.api_key + next_auth_config: _BedrockBearerAuthConfig | _BedrockAwsAuthConfig | None + if auth_override: + next_auth_config = None + elif isinstance(self._bedrock_auth_config, _BedrockAwsAuthConfig) and self.api_key: + next_auth_config = None + elif aws_region is not None: + if isinstance(self._bedrock_auth_config, _BedrockAwsAuthConfig): + next_auth_config = replace( + self._bedrock_auth_config, + region=_resolve_aws_region(aws_region), + region_source="explicit", + ) + else: + next_auth_config = replace(self._bedrock_auth_config, region_source="explicit") + else: + next_auth_config = self._bedrock_auth_config + + next_aws_region = aws_region if aws_region is not None else self.aws_region + if aws_profile is not None and aws_region is None and self._bedrock_auth_config.region_source != "explicit": + next_aws_region = None + + next_api_key = api_key + if next_api_key is None and next_token_provider is None: + next_api_key = ( + None if aws_auth_override or isinstance(next_auth_config, _BedrockAwsAuthConfig) else self.api_key + ) + + blank_base_url_override = isinstance(base_url, str) and not base_url.strip() + next_base_url = None if blank_base_url_override else base_url + next_base_url_is_region_derived = False + recompute_region_base_url = self._uses_region_derived_base_url and ( + aws_region is not None or (aws_profile is not None and next_aws_region is None) ) - next_base_url = base_url - if next_base_url is None and not (aws_region is not None and self._uses_region_derived_base_url): + if blank_base_url_override: + next_base_url_is_region_derived = _uses_region_derived_bedrock_base_url(None) + elif next_base_url is None and not recompute_region_base_url: next_base_url = self.base_url - - return self.__class__( - api_key=next_api_key, - bedrock_token_provider=next_token_provider, - aws_region=aws_region if aws_region is not None else self.aws_region, - aws_profile=aws_profile if aws_profile is not None else self._aws_profile if preserve_aws_auth else None, - aws_access_key_id=( - aws_access_key_id - if aws_access_key_id is not None - else self._aws_access_key_id - if preserve_aws_auth - else None - ), - aws_secret_access_key=( - aws_secret_access_key - if aws_secret_access_key is not None - else self._aws_secret_access_key - if preserve_aws_auth - else None - ), - aws_session_token=( - aws_session_token - if aws_session_token is not None - else self._aws_session_token - if preserve_aws_auth - else None - ), - aws_credentials_provider=( - aws_credentials_provider - if aws_credentials_provider is not None - else self._aws_credentials_provider - if preserve_aws_auth - else None - ), - organization=organization if organization is not None else self.organization, - project=project if project is not None else self.project, - webhook_secret=webhook_secret if webhook_secret is not None else self.webhook_secret, - websocket_base_url=websocket_base_url if websocket_base_url is not None else self.websocket_base_url, - base_url=next_base_url, - timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, - http_client=http_client or self._client, - max_retries=max_retries if is_given(max_retries) else self.max_retries, - default_headers=headers, - default_query=params, - _enforce_credentials=True if _enforce_credentials is None else _enforce_credentials, + next_base_url_is_region_derived = self._uses_region_derived_base_url + elif next_base_url is None and next_aws_region is not None: + next_base_url = f"https://bedrock-mantle.{next_aws_region}.api.aws/openai/v1" + next_base_url_is_region_derived = True + elif next_base_url is None: + next_base_url_is_region_derived = True + + constructor_kwargs: dict[str, Any] = { + "api_key": next_api_key, + "bedrock_token_provider": next_token_provider, + "aws_region": next_aws_region, + "organization": organization if organization is not None else self.organization, + "project": project if project is not None else self.project, + "webhook_secret": webhook_secret if webhook_secret is not None else self.webhook_secret, + "websocket_base_url": websocket_base_url if websocket_base_url is not None else self.websocket_base_url, + "base_url": next_base_url, + "timeout": self.timeout if isinstance(timeout, NotGiven) else timeout, + "http_client": http_client or self._client, + "max_retries": max_retries if is_given(max_retries) else self.max_retries, + "default_headers": headers, + "default_query": params, + "_enforce_credentials": True if _enforce_credentials is None else _enforce_credentials, **_extra_kwargs, + } + aws_overrides = { + "aws_profile": aws_profile, + "aws_access_key_id": aws_access_key_id, + "aws_secret_access_key": aws_secret_access_key, + "aws_session_token": aws_session_token, + "aws_credentials_provider": aws_credentials_provider, + } + constructor_kwargs.update({name: value for name, value in aws_overrides.items() if value is not None}) + + supports_auth_config = _constructor_accepts_keyword(self.__class__.__init__, "_auth_config") + supports_base_url_provenance = _constructor_accepts_keyword( + self.__class__.__init__, "_base_url_is_region_derived" ) + if supports_auth_config: + constructor_kwargs["_auth_config"] = next_auth_config + if supports_base_url_provenance: + constructor_kwargs["_base_url_is_region_derived"] = next_base_url_is_region_derived + + copied = self.__class__(**constructor_kwargs) + if not supports_auth_config and next_auth_config is not None: + copied._bedrock_auth_config = next_auth_config + if isinstance(next_auth_config, _BedrockAwsAuthConfig): + copied._bedrock_aws_auth = _BedrockAwsAuth(next_auth_config) + copied._bedrock_token_provider = None + copied.api_key = "" + copied._api_key_provider = None + copied.aws_region = next_auth_config.region + if not supports_base_url_provenance: + copied._uses_region_derived_base_url = next_base_url_is_region_derived + + return copied with_options = copy diff --git a/tests/fixtures/bedrock_auth/v1/cases.json b/tests/fixtures/bedrock_auth/v1/cases.json new file mode 100644 index 0000000000..a53e68165a --- /dev/null +++ b/tests/fixtures/bedrock_auth/v1/cases.json @@ -0,0 +1,252 @@ +{ + "schema_version": 1, + "suite": "aws-bedrock-auth", + "cases": [ + { + "id": "auth.explicit-bearer", + "kind": "auth_selection", + "given": { + "explicit": { "bearer": "explicit-bearer-token" }, + "environment": { "AWS_BEARER_TOKEN_BEDROCK": "environment-bearer-token" }, + "default_chain_available": true + }, + "expected": { "auth_mode": "bearer", "auth_source": "explicit" } + }, + { + "id": "auth.explicit-aws-over-environment-bearer", + "kind": "auth_selection", + "given": { + "explicit": { + "aws": { + "access_key_id": "AKIDEXAMPLE", + "secret_access_key": "fixture-secret-access-key" + } + }, + "environment": { "AWS_BEARER_TOKEN_BEDROCK": "environment-bearer-token" }, + "default_chain_available": true + }, + "expected": { "auth_mode": "sigv4", "auth_source": "static" } + }, + { + "id": "auth.environment-bearer-over-default-chain", + "kind": "auth_selection", + "given": { + "explicit": {}, + "environment": { "AWS_BEARER_TOKEN_BEDROCK": "environment-bearer-token" }, + "default_chain_available": true + }, + "expected": { "auth_mode": "bearer", "auth_source": "environment" } + }, + { + "id": "auth.default-chain", + "kind": "auth_selection", + "given": { + "explicit": {}, + "environment": {}, + "default_chain_available": true + }, + "expected": { "auth_mode": "sigv4", "auth_source": "default" } + }, + { + "id": "auth.empty-environment-bearer-is-absent", + "kind": "auth_selection", + "given": { + "explicit": {}, + "environment": { "AWS_BEARER_TOKEN_BEDROCK": "" }, + "default_chain_available": true + }, + "expected": { "auth_mode": "sigv4", "auth_source": "default" } + }, + { + "id": "auth.conflicting-explicit-modes", + "kind": "auth_selection", + "given": { + "explicit": { + "bearer": "explicit-bearer-token", + "aws": { + "access_key_id": "AKIDEXAMPLE", + "secret_access_key": "fixture-secret-access-key" + } + }, + "environment": {}, + "default_chain_available": true + }, + "expected": { "error": "bedrock_conflicting_auth" } + }, + { + "id": "sigv4.responses.static-credentials", + "kind": "sigv4", + "given": { + "credentials": { + "access_key_id": "AKIDEXAMPLE", + "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" + }, + "signing": { + "service": "bedrock-mantle", + "region": "us-east-1", + "timestamp": "2026-06-01T12:34:56Z" + }, + "request": { + "method": "POST", + "url": "https://bedrock-mantle.us-east-1.api.aws/openai/v1/responses", + "headers": { + "content-length": "47", + "content-type": "application/json", + "host": "bedrock-mantle.us-east-1.api.aws" + }, + "body_base64": "eyJpbnB1dCI6ImhlbGxvIiwibW9kZWwiOiJvcGVuYWkuZ3B0LW9zcy0xMjBiIn0=" + } + }, + "expected": { + "payload_sha256": "9d161ca213ef31dd5b8f01f52294bcd5057bc778cf6ee7719e67b7ff43ef1022", + "canonical_request_sha256": "2941daa7f544cd6d05e5c14615ca3ed4fe206a230214a71971092254690c0f1c", + "headers": { + "x-amz-date": "20260601T123456Z", + "authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20260601/us-east-1/bedrock-mantle/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=85b44442b454238644a1605b04febb4ecf96d2c5a7698db21ce563c0a5646cb6" + } + } + }, + { + "id": "sigv4.responses.temporary-credentials", + "kind": "sigv4", + "given": { + "credentials": { + "access_key_id": "AKIDEXAMPLE", + "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "session_token": "fixture-session-token" + }, + "signing": { + "service": "bedrock-mantle", + "region": "us-east-1", + "timestamp": "2026-06-01T12:34:56Z" + }, + "request": { + "method": "POST", + "url": "https://bedrock-mantle.us-east-1.api.aws/openai/v1/responses", + "headers": { + "content-length": "47", + "content-type": "application/json", + "host": "bedrock-mantle.us-east-1.api.aws" + }, + "body_base64": "eyJpbnB1dCI6ImhlbGxvIiwibW9kZWwiOiJvcGVuYWkuZ3B0LW9zcy0xMjBiIn0=" + } + }, + "expected": { + "payload_sha256": "9d161ca213ef31dd5b8f01f52294bcd5057bc778cf6ee7719e67b7ff43ef1022", + "canonical_request_sha256": "f60762d083c88cc0956eb3e9d3b3a966bfe3f396efd49bc1ef7c8922395bcecb", + "headers": { + "x-amz-date": "20260601T123456Z", + "x-amz-security-token": "fixture-session-token", + "authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20260601/us-east-1/bedrock-mantle/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-security-token, Signature=d17a0ba1525dfe52b63f163b4a0ac1109723905e5ac9711bcdb6c35befdbaac2" + } + } + }, + { + "id": "sigv4.responses.unsigned-payload", + "kind": "sigv4", + "given": { + "credentials": { + "access_key_id": "AKIDEXAMPLE", + "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" + }, + "signing": { + "service": "bedrock-mantle", + "region": "us-east-1", + "timestamp": "2026-06-01T12:34:56Z" + }, + "request": { + "method": "POST", + "url": "https://bedrock-mantle.us-east-1.api.aws/openai/v1/responses", + "headers": { + "content-type": "application/octet-stream", + "host": "bedrock-mantle.us-east-1.api.aws" + }, + "body_mode": "unsigned" + } + }, + "expected": { + "payload_sha256": "UNSIGNED-PAYLOAD", + "canonical_request_sha256": "9d9a089a0d274e69db40db264ea8ec6f9f6afaf03cf779121ec8dfa55278337d", + "headers": { + "x-amz-content-sha256": "UNSIGNED-PAYLOAD", + "x-amz-date": "20260601T123456Z", + "authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20260601/us-east-1/bedrock-mantle/aws4_request, SignedHeaders=content-type;host;x-amz-content-sha256;x-amz-date, Signature=f0d51439dfe33967eb590c0e33bc26fd7f8103d9345f5ac613ca299e90b555c8" + } + } + }, + { + "id": "retry.fresh-credentials-and-time", + "kind": "retry_signing", + "given": { + "response_statuses": [500, 200], + "timestamps": ["2026-06-01T12:34:56Z", "2026-06-01T12:35:01Z"], + "access_key_ids": ["FIRSTACCESSKEY", "SECONDACCESSKEY"], + "body_base64": "eyJpbnB1dCI6ImhlbGxvIiwibW9kZWwiOiJvcGVuYWkuZ3B0LW9zcy0xMjBiIn0=" + }, + "expected": { + "attempts": 2, + "credential_provider_calls": 2, + "x_amz_dates": ["20260601T123456Z", "20260601T123501Z"], + "body_sha256": [ + "9d161ca213ef31dd5b8f01f52294bcd5057bc778cf6ee7719e67b7ff43ef1022", + "9d161ca213ef31dd5b8f01f52294bcd5057bc778cf6ee7719e67b7ff43ef1022" + ] + } + }, + { + "id": "body.replayable-bytes", + "kind": "body_replay", + "given": { + "body_kind": "bytes", + "body_base64": "eyJpbnB1dCI6ImhlbGxvIiwibW9kZWwiOiJvcGVuYWkuZ3B0LW9zcy0xMjBiIn0=", + "response_statuses": [500, 200] + }, + "expected": { "attempts": 2, "result": "replayed" } + }, + { + "id": "body.non-replayable-stream-with-retries", + "kind": "body_replay", + "given": { + "body_kind": "one_shot_stream", + "chunks_base64": ["Zmlyc3Q=", "c2Vjb25k"], + "max_retries": 1 + }, + "expected": { "network_attempts": 0, "result": "bedrock_non_replayable_body" } + }, + { + "id": "body.unsigned-one-shot-stream", + "kind": "body_replay", + "given": { + "body_kind": "one_shot_stream", + "chunks_base64": ["Zmlyc3Q=", "c2Vjb25k"], + "max_retries": 0 + }, + "expected": { + "attempts": 1, + "credential_provider_calls": 1, + "body_reads": 1, + "x_amz_content_sha256": "UNSIGNED-PAYLOAD", + "result": "unsigned_payload" + } + }, + { + "id": "redaction.sigv4-and-bearer", + "kind": "redaction", + "given": { + "headers": { + "authorization": "AWS4-HMAC-SHA256 fixture-authorization", + "x-amz-security-token": "fixture-session-token", + "x-request-id": "req_fixture" + } + }, + "expected": { + "headers": { + "authorization": "", + "x-amz-security-token": "", + "x-request-id": "req_fixture" + }, + "forbidden_substrings": ["fixture-authorization", "fixture-session-token"] + } + } + ] +} diff --git a/tests/fixtures/bedrock_auth/v1/schema.json b/tests/fixtures/bedrock_auth/v1/schema.json new file mode 100644 index 0000000000..52940a6df5 --- /dev/null +++ b/tests/fixtures/bedrock_auth/v1/schema.json @@ -0,0 +1,78 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://openai.com/sdk-fixtures/bedrock-auth/v1/schema.json", + "title": "OpenAI SDK AWS Bedrock authentication conformance fixtures", + "type": "object", + "additionalProperties": false, + "required": ["schema_version", "suite", "cases"], + "properties": { + "schema_version": { "const": 1 }, + "suite": { "const": "aws-bedrock-auth" }, + "cases": { + "type": "array", + "minItems": 1, + "items": { "$ref": "#/$defs/case" } + } + }, + "$defs": { + "case": { + "type": "object", + "additionalProperties": false, + "required": ["id", "kind", "given", "expected"], + "properties": { + "id": { "type": "string", "minLength": 1 }, + "kind": { + "enum": ["auth_selection", "sigv4", "retry_signing", "body_replay", "redaction"] + }, + "given": { "type": "object" }, + "expected": { "type": "object" } + }, + "allOf": [ + { + "if": { "properties": { "kind": { "const": "auth_selection" } } }, + "then": { + "properties": { + "given": { "required": ["explicit", "environment", "default_chain_available"] } + } + } + }, + { + "if": { "properties": { "kind": { "const": "sigv4" } } }, + "then": { + "properties": { + "given": { "required": ["credentials", "signing", "request"] }, + "expected": { "required": ["payload_sha256", "canonical_request_sha256", "headers"] } + } + } + }, + { + "if": { "properties": { "kind": { "const": "retry_signing" } } }, + "then": { + "properties": { + "given": { "required": ["response_statuses", "timestamps", "access_key_ids", "body_base64"] }, + "expected": { "required": ["attempts", "credential_provider_calls", "x_amz_dates", "body_sha256"] } + } + } + }, + { + "if": { "properties": { "kind": { "const": "body_replay" } } }, + "then": { + "properties": { + "given": { "required": ["body_kind"] }, + "expected": { "required": ["result"] } + } + } + }, + { + "if": { "properties": { "kind": { "const": "redaction" } } }, + "then": { + "properties": { + "given": { "required": ["headers"] }, + "expected": { "required": ["headers", "forbidden_substrings"] } + } + } + } + ] + } + } +} diff --git a/tests/lib/test_bedrock.py b/tests/lib/test_bedrock.py index 656dcb55a0..87ab50b218 100644 --- a/tests/lib/test_bedrock.py +++ b/tests/lib/test_bedrock.py @@ -9,11 +9,12 @@ from httpx import URL from respx import MockRouter -import openai.lib.bedrock as bedrock_module +import openai.lib._bedrock_auth as bedrock_auth_module from openai import OpenAIError, NotFoundError from tests.utils import update_env from openai._types import Omit from openai.lib.bedrock import BedrockOpenAI, AsyncBedrockOpenAI +from openai.lib._bedrock_auth import BedrockAwsAuthConfig Client = Union[BedrockOpenAI, AsyncBedrockOpenAI] @@ -134,14 +135,14 @@ def test_bedrock_config_precedence(client_cls: type[Client]) -> None: @pytest.mark.respx() def test_env_bearer_does_not_require_botocore(monkeypatch: pytest.MonkeyPatch, respx_mock: MockRouter) -> None: - real_import_module = bedrock_module.importlib.import_module + real_import_module = bedrock_auth_module.importlib.import_module def import_module(name: str) -> Any: if name.startswith("botocore"): raise ImportError(name) return real_import_module(name) - monkeypatch.setattr(bedrock_module.importlib, "import_module", import_module) + monkeypatch.setattr(bedrock_auth_module.importlib, "import_module", import_module) respx_mock.post("https://example.com/openai/v1/responses").mock( return_value=httpx.Response(200, json=RESPONSE_BODY) ) @@ -158,22 +159,22 @@ def import_module(name: str) -> Any: def test_empty_env_bearer_without_botocore_uses_aws_credentials(monkeypatch: pytest.MonkeyPatch) -> None: - real_import_module = bedrock_module.importlib.import_module + real_import_module = bedrock_auth_module.importlib.import_module def import_module(name: str) -> Any: if name.startswith("botocore"): raise ImportError(name) return real_import_module(name) - monkeypatch.setattr(bedrock_module.importlib, "import_module", import_module) + monkeypatch.setattr(bedrock_auth_module.importlib, "import_module", import_module) with update_env(AWS_BEARER_TOKEN_BEDROCK="", AWS_REGION="us-east-1"): - with pytest.raises(OpenAIError, match="requires botocore"): + with pytest.raises(OpenAIError, match="requires optional AWS dependencies"): BedrockOpenAI() @pytest.mark.respx() -def test_env_bearer_uses_botocore_bearer_auth(monkeypatch: pytest.MonkeyPatch, respx_mock: MockRouter) -> None: - auth_module = bedrock_module.importlib.import_module("botocore.auth") +def test_env_bearer_does_not_use_botocore_bearer_auth(monkeypatch: pytest.MonkeyPatch, respx_mock: MockRouter) -> None: + auth_module = bedrock_auth_module.importlib.import_module("botocore.auth") calls = 0 real_add_auth = auth_module.BearerAuth.add_auth @@ -193,7 +194,7 @@ def add_auth(auth: object, request: object) -> None: request = cast("list[MockRequestCall]", respx_mock.calls)[0].request assert request.headers["Authorization"] == "Bearer env token" - assert calls == 1 + assert calls == 0 @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) @@ -227,6 +228,26 @@ def test_bedrock_region_precedence(client_cls: type[Client]) -> None: assert default_region_client.base_url == URL("https://bedrock-mantle.us-west-2.api.aws/openai/v1/") +@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) +def test_aws_profile_supplies_region(client_cls: type[Client], tmp_path: Path) -> None: + config_path = tmp_path / "config" + config_path.write_text("[profile production]\nregion = eu-central-1\n") + with update_env( + AWS_CONFIG_FILE=str(config_path), + AWS_BEDROCK_BASE_URL=Omit(), + AWS_REGION=Omit(), + AWS_DEFAULT_REGION=Omit(), + ): + client = ( + make_sync_client(aws_profile="production") + if client_cls is BedrockOpenAI + else make_async_client(aws_profile="production") + ) + + assert client.aws_region == "eu-central-1" + assert client.base_url == URL("https://bedrock-mantle.eu-central-1.api.aws/openai/v1/") + + @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) def test_normalizes_responses_url(client_cls: type[Client]) -> None: client = ( @@ -241,7 +262,7 @@ def test_normalizes_responses_url(client_cls: type[Client]) -> None: @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) def test_requires_endpoint_configuration(client_cls: type[Client]) -> None: with update_env(AWS_BEDROCK_BASE_URL=Omit(), AWS_REGION=Omit(), AWS_DEFAULT_REGION=Omit()): - with pytest.raises(OpenAIError, match="Must provide one of the `base_url` or `aws_region`"): + with pytest.raises(OpenAIError, match="Bedrock requires an AWS region"): client_cls(api_key="token") @@ -251,8 +272,9 @@ def test_does_not_use_openai_api_key(client_cls: type[Client]) -> None: OPENAI_API_KEY="openai token", AWS_BEARER_TOKEN_BEDROCK=Omit(), AWS_BEDROCK_BASE_URL="https://example.com/openai/v1", + AWS_REGION="us-east-1", ): - client = client_cls() + client = make_sync_client() if client_cls is BedrockOpenAI else make_async_client() assert client.api_key == "" assert client._bedrock_aws_auth is not None @@ -260,7 +282,7 @@ def test_does_not_use_openai_api_key(client_cls: type[Client]) -> None: @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) def test_rejects_static_token_and_provider(client_cls: type[Client]) -> None: - with pytest.raises(OpenAIError, match="mutually exclusive"): + with pytest.raises(OpenAIError, match="authentication is ambiguous"): client_cls( base_url="https://example.com/openai/v1", api_key="token", @@ -276,7 +298,7 @@ def test_rejects_empty_explicit_bearer_token(client_cls: type[Client]) -> None: @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) def test_rejects_bearer_and_aws_credentials(client_cls: type[Client]) -> None: - with pytest.raises(OpenAIError, match="mutually exclusive"): + with pytest.raises(OpenAIError, match="authentication is ambiguous"): client_cls( base_url="https://example.com/openai/v1", api_key="token", @@ -287,7 +309,7 @@ def test_rejects_bearer_and_aws_credentials(client_cls: type[Client]) -> None: @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) def test_rejects_partial_explicit_aws_credentials(client_cls: type[Client]) -> None: - with pytest.raises(OpenAIError, match="must be provided together"): + with pytest.raises(OpenAIError, match="require both"): client_cls( base_url="https://example.com/openai/v1", aws_region="us-east-1", @@ -428,34 +450,197 @@ def test_preserves_aws_credentials_across_with_options() -> None: copied_client = client.with_options(timeout=1) assert copied_client._bedrock_aws_auth is not None - assert copied_client._aws_access_key_id == "access key" + assert isinstance(copied_client._bedrock_auth_config, BedrockAwsAuthConfig) + assert copied_client._bedrock_auth_config.access_key_id == "access key" + + +@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) +def test_preserves_default_chain_mode_across_with_options(client_cls: type[Client]) -> None: + with update_env(AWS_BEARER_TOKEN_BEDROCK=Omit(), AWS_REGION="us-east-1"): + client = make_sync_client() if client_cls is BedrockOpenAI else make_async_client() + + with update_env(AWS_BEARER_TOKEN_BEDROCK="late bearer", AWS_REGION="us-east-1"): + copied_client = client.with_options(timeout=1) + + assert isinstance(copied_client._bedrock_auth_config, BedrockAwsAuthConfig) + assert copied_client._bedrock_auth_config.source == "default" + assert copied_client.api_key == "" + + +@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) +def test_preserves_region_derived_url_provenance_across_multiple_copies(client_cls: type[Client]) -> None: + with update_env(AWS_BEDROCK_BASE_URL=Omit(), AWS_REGION=Omit(), AWS_DEFAULT_REGION=Omit()): + client = ( + make_sync_client(aws_region="us-east-1", api_key="token") + if client_cls is BedrockOpenAI + else make_async_client(aws_region="us-east-1", api_key="token") + ) + copied_client = client.with_options(timeout=1).with_options(aws_region="eu-west-1") + + assert copied_client.base_url == URL("https://bedrock-mantle.eu-west-1.api.aws/openai/v1/") + + +@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) +def test_blank_base_url_restores_region_derived_url_provenance(client_cls: type[Client]) -> None: + with update_env(AWS_BEDROCK_BASE_URL=Omit(), AWS_REGION=Omit(), AWS_DEFAULT_REGION=Omit()): + client = ( + make_sync_client(aws_region="us-east-1", api_key="token") + if client_cls is BedrockOpenAI + else make_async_client(aws_region="us-east-1", api_key="token") + ) + copied_client = client.with_options(base_url="").with_options(aws_region="eu-west-1") + + assert copied_client.base_url == URL("https://bedrock-mantle.eu-west-1.api.aws/openai/v1/") @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) def test_with_options_replaces_the_aws_credential_source(client_cls: type[Client], tmp_path: Path) -> None: config_path = tmp_path / "config" config_path.write_text("[profile other-profile]\nregion = us-east-1\n") - explicit_credentials_client = client_cls( - base_url="https://example.com/openai/v1", - aws_region="us-east-1", - aws_access_key_id="access key", - aws_secret_access_key="secret key", + explicit_credentials_client = ( + make_sync_client( + base_url="https://example.com/openai/v1", + aws_region="us-east-1", + aws_access_key_id="access key", + aws_secret_access_key="secret key", + ) + if client_cls is BedrockOpenAI + else make_async_client( + base_url="https://example.com/openai/v1", + aws_region="us-east-1", + aws_access_key_id="access key", + aws_secret_access_key="secret key", + ) ) with update_env(AWS_CONFIG_FILE=str(config_path)): profile_client = explicit_credentials_client.with_options(aws_profile="other-profile") - assert profile_client._aws_profile == "other-profile" - assert profile_client._aws_access_key_id is None - assert profile_client._aws_secret_access_key is None + assert isinstance(profile_client._bedrock_auth_config, BedrockAwsAuthConfig) + assert profile_client._bedrock_auth_config.profile == "other-profile" + assert profile_client._bedrock_auth_config.access_key_id is None + assert profile_client._bedrock_auth_config.secret_access_key is None explicit_credentials_client = profile_client.with_options( aws_access_key_id="replacement access key", aws_secret_access_key="replacement secret key", ) - assert explicit_credentials_client._aws_profile is None - assert explicit_credentials_client._aws_access_key_id == "replacement access key" - assert explicit_credentials_client._aws_secret_access_key == "replacement secret key" + assert isinstance(explicit_credentials_client._bedrock_auth_config, BedrockAwsAuthConfig) + assert explicit_credentials_client._bedrock_auth_config.profile is None + assert explicit_credentials_client._bedrock_auth_config.access_key_id == "replacement access key" + assert explicit_credentials_client._bedrock_auth_config.secret_access_key == "replacement secret key" + + +@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) +def test_with_options_replacing_profile_re_resolves_profile_region(client_cls: type[Client], tmp_path: Path) -> None: + config_path = tmp_path / "config" + config_path.write_text("[profile east]\nregion = us-east-1\n[profile west]\nregion = us-west-2\n") + + with update_env( + AWS_CONFIG_FILE=str(config_path), + AWS_BEARER_TOKEN_BEDROCK=Omit(), + AWS_BEDROCK_BASE_URL=Omit(), + AWS_REGION=Omit(), + AWS_DEFAULT_REGION=Omit(), + ): + client = ( + make_sync_client(aws_profile="east") + if client_cls is BedrockOpenAI + else make_async_client(aws_profile="east") + ) + + with update_env( + AWS_CONFIG_FILE=str(config_path), + AWS_BEARER_TOKEN_BEDROCK=Omit(), + AWS_BEDROCK_BASE_URL="https://late-environment.example.com/v1", + AWS_REGION=Omit(), + AWS_DEFAULT_REGION=Omit(), + ): + copied_client = client.with_options(aws_profile="west") + + assert copied_client.aws_region == "us-west-2" + assert copied_client.base_url == URL("https://bedrock-mantle.us-west-2.api.aws/openai/v1/") + assert isinstance(copied_client._bedrock_auth_config, BedrockAwsAuthConfig) + assert copied_client._bedrock_auth_config.profile == "west" + + +@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) +def test_with_options_switching_from_bearer_to_profile_re_resolves_environment_region( + client_cls: type[Client], tmp_path: Path +) -> None: + config_path = tmp_path / "config" + config_path.write_text("[profile west]\nregion = us-west-2\n") + + with update_env( + AWS_CONFIG_FILE=str(config_path), + AWS_BEDROCK_BASE_URL=Omit(), + AWS_REGION="us-east-1", + AWS_DEFAULT_REGION=Omit(), + ): + client = ( + make_sync_client(api_key="token") if client_cls is BedrockOpenAI else make_async_client(api_key="token") + ) + + with update_env( + AWS_CONFIG_FILE=str(config_path), + AWS_BEARER_TOKEN_BEDROCK=Omit(), + AWS_BEDROCK_BASE_URL=Omit(), + AWS_REGION=Omit(), + AWS_DEFAULT_REGION=Omit(), + ): + copied_client = client.with_options(aws_profile="west") + + assert copied_client.aws_region == "us-west-2" + assert copied_client.base_url == URL("https://bedrock-mantle.us-west-2.api.aws/openai/v1/") + + +def test_with_options_supports_subclasses_with_the_previous_constructor_signature() -> None: + class LegacyBedrockOpenAI(BedrockOpenAI): + def __init__( + self, + *, + api_key: str | None = None, + bedrock_token_provider: Any = None, + aws_region: str | None = None, + organization: str | None = None, + project: str | None = None, + webhook_secret: str | None = None, + base_url: str | httpx.URL | None = None, + websocket_base_url: str | httpx.URL | None = None, + timeout: Any = None, + max_retries: int = 2, + default_headers: Any = None, + default_query: Any = None, + http_client: httpx.Client | None = None, + _enforce_credentials: bool = True, + ) -> None: + super().__init__( + api_key=api_key, + bedrock_token_provider=bedrock_token_provider, + aws_region=aws_region, + organization=organization, + project=project, + webhook_secret=webhook_secret, + base_url=base_url, + websocket_base_url=websocket_base_url, + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + default_query=default_query, + http_client=http_client, + _enforce_credentials=_enforce_credentials, + ) + + client = LegacyBedrockOpenAI( + api_key="token", + aws_region="us-east-1", + http_client=httpx.Client(trust_env=False), + ) + + copied_client = client.with_options(timeout=1) + + assert isinstance(copied_client, LegacyBedrockOpenAI) + assert copied_client.api_key == "token" @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) @@ -478,6 +663,22 @@ def test_with_options_api_key_replaces_token_provider(client_cls: type[Client]) assert copied_client._bedrock_token_provider is None +@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) +def test_with_options_rejects_explicit_bearer_provider_and_aws_credentials(client_cls: type[Client]) -> None: + client = ( + make_sync_client(base_url="https://example.com/openai/v1", api_key="token") + if client_cls is BedrockOpenAI + else make_async_client(base_url="https://example.com/openai/v1", api_key="token") + ) + + with pytest.raises(OpenAIError, match="authentication is ambiguous"): + client.with_options( + bedrock_token_provider=lambda: "provider token", + aws_access_key_id="access key", + aws_secret_access_key="secret key", + ) + + @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) def test_with_options_aws_region_recomputes_region_derived_base_url(client_cls: type[Client]) -> None: with update_env(AWS_BEDROCK_BASE_URL=Omit(), AWS_REGION=Omit(), AWS_DEFAULT_REGION=Omit()): diff --git a/tests/lib/test_bedrock_auth_conformance.py b/tests/lib/test_bedrock_auth_conformance.py new file mode 100644 index 0000000000..b08f160c35 --- /dev/null +++ b/tests/lib/test_bedrock_auth_conformance.py @@ -0,0 +1,403 @@ +from __future__ import annotations + +import json +import base64 +import hashlib +import logging +import threading +from typing import Any, Iterator, AsyncIterator, cast +from pathlib import Path +from datetime import datetime + +import httpx +import pytest +import jsonschema + +from openai import OpenAIError, APIStatusError +from openai._types import Omit +from openai._utils import SensitiveHeadersFilter +from openai.lib.bedrock import BedrockOpenAI, AsyncBedrockOpenAI +from openai.lib._bedrock_auth import BedrockAwsAuth, BedrockAwsAuthConfig, BedrockBearerAuthConfig + +FIXTURE_PATH = Path(__file__).parents[1] / "fixtures" / "bedrock_auth" / "v1" / "cases.json" +SCHEMA_PATH = FIXTURE_PATH.with_name("schema.json") +FIXTURES = cast(dict[str, Any], json.loads(FIXTURE_PATH.read_text())) +SCHEMA = cast(dict[str, Any], json.loads(SCHEMA_PATH.read_text())) + + +def _cases(kind: str) -> list[dict[str, Any]]: + return [cast(dict[str, Any], case) for case in FIXTURES["cases"] if case["kind"] == kind] + + +def _fixed_datetime(value: str) -> datetime: + return datetime.fromisoformat(value.replace("Z", "+00:00")) + + +def _lower_headers(headers: httpx.Headers | dict[str, str]) -> dict[str, str]: + return {name.lower(): value for name, value in headers.items()} + + +def _canonical_request_sha256(case: dict[str, Any], signed_headers: dict[str, str], payload_hash: str) -> str: + request = case["given"]["request"] + authorization = signed_headers["authorization"] + signed_header_names = authorization.split("SignedHeaders=", 1)[1].split(",", 1)[0].split(";") + canonical_headers = "".join(f"{name}:{' '.join(signed_headers[name].split())}\n" for name in signed_header_names) + url = httpx.URL(request["url"]) + canonical_request = "\n".join( + ( + request["method"], + url.raw_path.split(b"?", 1)[0].decode(), + url.query.decode(), + canonical_headers, + ";".join(signed_header_names), + payload_hash, + ) + ) + return hashlib.sha256(canonical_request.encode()).hexdigest() + + +@pytest.mark.parametrize("case", _cases("auth_selection"), ids=lambda case: case["id"]) +def test_auth_selection_fixture(case: dict[str, Any], monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False) + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + for name, value in case["given"]["environment"].items(): + monkeypatch.setenv(name, value) + + explicit = case["given"]["explicit"] + kwargs: dict[str, Any] = { + "aws_region": "us-east-1", + "http_client": httpx.Client(trust_env=False), + "_enforce_credentials": False, + } + if "bearer" in explicit: + kwargs["api_key"] = explicit["bearer"] + if "aws" in explicit: + kwargs["aws_access_key_id"] = explicit["aws"]["access_key_id"] + kwargs["aws_secret_access_key"] = explicit["aws"]["secret_access_key"] + + if case["expected"].get("error") == "bedrock_conflicting_auth": + with pytest.raises(OpenAIError, match="authentication is ambiguous"): + BedrockOpenAI(**kwargs) + kwargs["http_client"].close() + return + + with BedrockOpenAI(**kwargs) as client: + config = client._bedrock_auth_config + if isinstance(config, BedrockBearerAuthConfig): + mode = "bearer" + else: + mode = "sigv4" + + assert mode == case["expected"]["auth_mode"] + assert config.source == case["expected"]["auth_source"] + + +@pytest.mark.parametrize("case", _cases("sigv4"), ids=lambda case: case["id"]) +def test_sigv4_fixture(case: dict[str, Any], monkeypatch: pytest.MonkeyPatch) -> None: + credentials = case["given"]["credentials"] + signing = case["given"]["signing"] + request = case["given"]["request"] + body = None if request.get("body_mode") == "unsigned" else base64.b64decode(request["body_base64"]) + payload_hash = "UNSIGNED-PAYLOAD" if body is None else hashlib.sha256(body).hexdigest() + + auth = BedrockAwsAuth( + BedrockAwsAuthConfig( + region=signing["region"], + source="static", + access_key_id=credentials["access_key_id"], + secret_access_key=credentials["secret_access_key"], + session_token=credentials.get("session_token"), + ) + ) + botocore_auth = pytest.importorskip("botocore.auth") + monkeypatch.setattr(botocore_auth, "get_current_datetime", lambda: _fixed_datetime(signing["timestamp"])) + + signed_headers = _lower_headers( + auth.sign( + method=request["method"], + url=request["url"], + headers=request["headers"], + body=body, + ) + ) + + assert signing["service"] == "bedrock-mantle" + assert payload_hash == case["expected"]["payload_sha256"] + assert _canonical_request_sha256(case, signed_headers, payload_hash) == case["expected"]["canonical_request_sha256"] + for name, value in case["expected"]["headers"].items(): + assert signed_headers[name] == value + + +class _Credentials: + def __init__(self, access_key: str, secret_key: str, token: str | None = None) -> None: + self.access_key = access_key + self.secret_key = secret_key + self.token = token + + +def test_retry_signing_fixture(monkeypatch: pytest.MonkeyPatch) -> None: + case = _cases("retry_signing")[0] + timestamps = iter(_fixed_datetime(value) for value in case["given"]["timestamps"]) + credentials = iter( + _Credentials(access_key, f"{access_key}-secret") for access_key in case["given"]["access_key_ids"] + ) + provider_calls = 0 + + def credentials_provider() -> _Credentials: + nonlocal provider_calls + provider_calls += 1 + return next(credentials) + + botocore_auth = pytest.importorskip("botocore.auth") + monkeypatch.setattr(botocore_auth, "get_current_datetime", lambda: next(timestamps)) + + requests: list[httpx.Request] = [] + statuses = iter(case["given"]["response_statuses"]) + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(next(statuses), request=request, json={}) + + body = base64.b64decode(case["given"]["body_base64"]) + with BedrockOpenAI( + base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", + aws_region="us-east-1", + aws_credentials_provider=credentials_provider, + max_retries=case["given"].get("max_retries", 1), + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) as client: + client.post( + "/responses", + content=body, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/json"}}, + ) + + assert len(requests) == case["expected"]["attempts"] + assert provider_calls == case["expected"]["credential_provider_calls"] + assert [request.headers["X-Amz-Date"] for request in requests] == case["expected"]["x_amz_dates"] + assert [hashlib.sha256(request.content).hexdigest() for request in requests] == case["expected"]["body_sha256"] + for request, access_key in zip(requests, case["given"]["access_key_ids"]): + assert f"Credential={access_key}/" in request.headers["Authorization"] + + +@pytest.mark.parametrize("case", _cases("body_replay"), ids=lambda case: case["id"]) +def test_body_replay_fixture(case: dict[str, Any]) -> None: + provider_calls = 0 + network_calls = 0 + body_reads = 0 + requests: list[httpx.Request] = [] + + def credentials_provider() -> _Credentials: + nonlocal provider_calls + provider_calls += 1 + return _Credentials("access-key", "secret-key") + + def body() -> Iterator[bytes]: + nonlocal body_reads + body_reads += 1 + for chunk in case["given"].get("chunks_base64", []): + yield base64.b64decode(chunk) + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal network_calls + network_calls += 1 + requests.append(request) + statuses = case["given"].get("response_statuses", [200]) + return httpx.Response(statuses[network_calls - 1], request=request, json={}) + + body_kind = case["given"]["body_kind"] + content: bytes | Iterator[bytes] + if body_kind == "bytes": + content = base64.b64decode(case["given"]["body_base64"]) + else: + assert body_kind == "one_shot_stream" + content = body() + + with BedrockOpenAI( + base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", + aws_region="us-east-1", + aws_credentials_provider=credentials_provider, + max_retries=case["given"].get("max_retries", 1), + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) as client: + if case["expected"]["result"] == "bedrock_non_replayable_body": + with pytest.raises(OpenAIError, match="requires a replayable request body"): + client.post("/responses", content=content, cast_to=httpx.Response) + else: + client.post("/responses", content=content, cast_to=httpx.Response) + + assert network_calls == case["expected"].get("network_attempts", case["expected"].get("attempts")) + if body_kind == "bytes": + assert all(request.content == content for request in requests) + assert provider_calls == case["expected"]["attempts"] + elif case["expected"]["result"] == "unsigned_payload": + assert body_reads == case["expected"]["body_reads"] + assert provider_calls == case["expected"]["credential_provider_calls"] + assert requests[0].headers["X-Amz-Content-SHA256"] == case["expected"]["x_amz_content_sha256"] + assert "x-amz-content-sha256" in requests[0].headers["Authorization"] + else: + assert (body_reads, provider_calls) == (0, 0) + + +@pytest.mark.asyncio +async def test_non_replayable_async_body_fails_before_credentials_or_network() -> None: + provider_calls = 0 + network_calls = 0 + body_reads = 0 + + def credentials_provider() -> _Credentials: + nonlocal provider_calls + provider_calls += 1 + return _Credentials("access-key", "secret-key") + + async def body() -> AsyncIterator[bytes]: + nonlocal body_reads + body_reads += 1 + yield b"body" + + async def handler(request: httpx.Request) -> httpx.Response: + nonlocal network_calls + network_calls += 1 + return httpx.Response(200, request=request) + + async with AsyncBedrockOpenAI( + base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", + aws_region="us-east-1", + aws_credentials_provider=credentials_provider, + http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler), trust_env=False), + ) as client: + with pytest.raises(OpenAIError, match="requires a replayable request body"): + await client.post("/responses", content=body(), cast_to=httpx.Response) + + assert (body_reads, provider_calls, network_calls) == (0, 0, 0) + + +@pytest.mark.asyncio +async def test_async_one_shot_body_uses_unsigned_payload_when_retries_are_disabled() -> None: + case = next(case for case in _cases("body_replay") if case["expected"]["result"] == "unsigned_payload") + requests: list[httpx.Request] = [] + + async def body() -> AsyncIterator[bytes]: + for chunk in case["given"]["chunks_base64"]: + yield base64.b64decode(chunk) + + async def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(200, request=request, json={}) + + async with AsyncBedrockOpenAI( + base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", + aws_region="us-east-1", + aws_access_key_id="access-key", + aws_secret_access_key="secret-key", + max_retries=0, + http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler), trust_env=False), + ) as client: + await client.post("/responses", content=body(), cast_to=httpx.Response) + + assert requests[0].headers["X-Amz-Content-SHA256"] == case["expected"]["x_amz_content_sha256"] + assert "x-amz-content-sha256" in requests[0].headers["Authorization"] + + +@pytest.mark.asyncio +async def test_async_credentials_are_resolved_off_event_loop() -> None: + event_loop_thread = threading.get_ident() + provider_threads: list[int] = [] + + def credentials_provider() -> _Credentials: + provider_threads.append(threading.get_ident()) + return _Credentials("access-key", "secret-key") + + async def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, request=request, json={}) + + async with AsyncBedrockOpenAI( + base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", + aws_region="us-east-1", + aws_credentials_provider=credentials_provider, + http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler), trust_env=False), + ) as client: + await client.post("/responses", content=b"{}", cast_to=httpx.Response) + + assert provider_threads + assert all(thread_id != event_loop_thread for thread_id in provider_threads) + + +def test_explicit_authorization_omit_and_override_are_preserved() -> None: + requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(200, request=request, json={}) + + with BedrockOpenAI( + base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", + aws_region="us-east-1", + aws_access_key_id="access-key", + aws_secret_access_key="secret-key", + http_client=httpx.Client( + headers={"Authorization": "Bearer client-default"}, + auth=httpx.BasicAuth("username", "password"), + transport=httpx.MockTransport(handler), + trust_env=False, + ), + ) as client: + client.get( + "/models", + cast_to=httpx.Response, + options={"headers": {"Authorization": Omit()}}, + ) + client.get( + "/models", + cast_to=httpx.Response, + options={"headers": {"Authorization": "Bearer explicit-override"}}, + ) + + assert "Authorization" not in requests[0].headers + assert requests[1].headers["Authorization"] == "Bearer explicit-override" + + +def test_sigv4_redirects_are_not_followed() -> None: + requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + if len(requests) == 1: + return httpx.Response(307, request=request, headers={"Location": "/redirected"}) + return httpx.Response(200, request=request) + + with BedrockOpenAI( + base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", + aws_region="us-east-1", + aws_access_key_id="access-key", + aws_secret_access_key="secret-key", + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) as client: + with pytest.raises(APIStatusError) as exc: + client.get("/models", cast_to=httpx.Response) + + assert exc.value.status_code == 307 + assert len(requests) == 1 + + +def test_redaction_fixture(caplog: pytest.LogCaptureFixture) -> None: + case = _cases("redaction")[0] + logger = logging.getLogger("test_bedrock_redaction") + logger.addFilter(SensitiveHeadersFilter()) + + with caplog.at_level(logging.DEBUG): + logger.debug("Request options: %s", {"headers": case["given"]["headers"]}) + + headers = cast(dict[str, Any], caplog.records[0].args)["headers"] + assert headers == case["expected"]["headers"] + for value in case["expected"]["forbidden_substrings"]: + assert value not in caplog.messages[0] + + +def test_fixture_envelope_is_versioned_and_ids_are_unique() -> None: + jsonschema.Draft202012Validator(SCHEMA).validate(FIXTURES) # pyright: ignore[reportUnknownMemberType] + assert FIXTURES["schema_version"] == 1 + assert FIXTURES["suite"] == "aws-bedrock-auth" + ids = [case["id"] for case in FIXTURES["cases"]] + assert len(ids) == len(set(ids)) diff --git a/tests/test_module_client.py b/tests/test_module_client.py index 23bcb61716..cb509d3d19 100644 --- a/tests/test_module_client.py +++ b/tests/test_module_client.py @@ -253,6 +253,30 @@ def test_bedrock_module_api_key_overrides_cached_env_token_after_load() -> None: assert client.api_key == "new Bedrock token" +def test_bedrock_module_api_key_switches_cached_aws_client_to_bearer() -> None: + requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(200, request=request, json={}) + + with fresh_env(): + openai.api_type = "amazon-bedrock" + openai.http_client = httpx.Client(transport=httpx.MockTransport(handler), trust_env=False) + _os.environ["AWS_ACCESS_KEY_ID"] = "access key" + _os.environ["AWS_SECRET_ACCESS_KEY"] = "secret key" + _os.environ["AWS_REGION"] = "us-west-2" + + client = openai.responses._client + assert isinstance(client, BedrockOpenAI) + assert client._uses_aws_auth() + + openai.api_key = "new Bedrock token" + client.get("/models", cast_to=httpx.Response) + + assert requests[0].headers["Authorization"] == "Bearer new Bedrock token" + + def test_bedrock_api_type_uses_token_provider_without_mutating_module_api_key() -> None: with fresh_env(): openai.api_type = "amazon-bedrock" @@ -263,3 +287,32 @@ def test_bedrock_api_type_uses_token_provider_without_mutating_module_api_key() assert isinstance(client, BedrockOpenAI) assert client._refresh_api_key() == "provider Bedrock token" assert openai.api_key is None + + +def test_bedrock_module_api_key_overrides_cached_token_provider() -> None: + requests: list[httpx.Request] = [] + provider_calls = 0 + + def token_provider() -> str: + nonlocal provider_calls + provider_calls += 1 + raise AssertionError("the replaced token provider must not be called") + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(200, request=request, json={}) + + with fresh_env(): + openai.api_type = "amazon-bedrock" + openai.bedrock_token_provider = token_provider + openai.http_client = httpx.Client(transport=httpx.MockTransport(handler), trust_env=False) + _os.environ["AWS_REGION"] = "us-west-2" + + client = openai.responses._client + assert isinstance(client, BedrockOpenAI) + + openai.api_key = "new Bedrock token" + client.get("/models", cast_to=httpx.Response) + + assert provider_calls == 0 + assert requests[0].headers["Authorization"] == "Bearer new Bedrock token" From 64230fd07d644aa1bd19c7a7d46af7678024de82 Mon Sep 17 00:00:00 2001 From: Hayden Date: Fri, 12 Jun 2026 11:18:15 -0700 Subject: [PATCH 3/9] Add Bedrock provider authentication --- README.md | 48 +- examples/bedrock.py | 12 +- src/openai/__init__.py | 6 + src/openai/_client.py | 246 +++- src/openai/_provider.py | 59 + src/openai/lib/_bedrock_auth.py | 80 +- src/openai/lib/azure.py | 9 + src/openai/lib/bedrock.py | 1188 ++++++++------------ src/openai/providers/__init__.py | 3 + src/openai/providers/bedrock.py | 404 +++++++ tests/fixtures/bedrock/v1/sigv4.json | 22 + tests/fixtures/bedrock_auth/v1/cases.json | 151 +-- tests/lib/test_bedrock.py | 305 ++++- tests/lib/test_bedrock_auth_conformance.py | 183 +-- tests/lib/test_bedrock_provider.py | 337 ++++++ 15 files changed, 2077 insertions(+), 976 deletions(-) create mode 100644 src/openai/_provider.py create mode 100644 src/openai/providers/__init__.py create mode 100644 src/openai/providers/bedrock.py create mode 100644 tests/fixtures/bedrock/v1/sigv4.json create mode 100644 tests/lib/test_bedrock_provider.py diff --git a/README.md b/README.md index feb1dd0715..df46c6b690 100644 --- a/README.md +++ b/README.md @@ -928,7 +928,7 @@ An example of using the client with Microsoft Entra ID (formerly known as Azure ## Amazon Bedrock -To use this library with [Amazon Bedrock's OpenAI-compatible API](https://docs.aws.amazon.com/bedrock/latest/userguide/models-api-compatibility.html), use the `BedrockOpenAI` class instead of the `OpenAI` class. +To use this library with [Amazon Bedrock's OpenAI-compatible API](https://docs.aws.amazon.com/bedrock/latest/userguide/models-api-compatibility.html), configure the standard `OpenAI` client with the Bedrock provider. Install the optional Bedrock dependencies to use the standard AWS credential chain and SigV4 authentication: @@ -937,11 +937,16 @@ pip install 'openai[bedrock]' ``` ```py -from openai import BedrockOpenAI +from openai import OpenAI +from openai.providers import bedrock -# Uses your normal AWS credentials. You can omit aws_region when it is +# Uses your normal AWS credentials. You can omit region when it is # configured through AWS_REGION, AWS_DEFAULT_REGION, or your AWS profile. -client = BedrockOpenAI(aws_region="us-west-2") +client = OpenAI( + provider=bedrock( + region="us-west-2", + ) +) response = client.responses.create( model="openai.gpt-5.4", @@ -951,32 +956,51 @@ response = client.responses.create( print(response.output_text) ``` -`BedrockOpenAI` configures AWS authentication and the Bedrock Mantle endpoint, then uses the normal SDK resources. AWS controls which endpoints and features are supported; unsupported calls surface the provider's normal HTTP errors through the SDK. +The provider configures AWS authentication and the Bedrock Mantle endpoint while retaining the normal SDK resources, retries, streaming, and error handling. AWS controls which endpoints and features are supported; unsupported calls surface the provider's normal HTTP errors through the SDK. The default AWS credential chain supports environment credentials, shared credentials and config files, named profiles, SSO and assume-role profiles, and workload credentials such as ECS, EKS, and EC2 metadata. To select a named profile: ```py -client = BedrockOpenAI( - aws_profile="my-profile", +client = OpenAI( + provider=bedrock( + profile="my-profile", + ) ) ``` -You can also pass explicit temporary credentials or an `aws_credentials_provider` that returns botocore-compatible credentials. Explicit bearer and AWS credential options are mutually exclusive. +You can also pass `access_key_id` and `secret_access_key`, with an optional `session_token`, or a refreshable `credential_provider` that returns botocore-compatible credentials. Explicit bearer and AWS credential options are mutually exclusive. -Pass `base_url` or set `AWS_BEDROCK_BASE_URL` to override the derived `https://bedrock-mantle..api.aws/openai/v1` endpoint. The legacy module client supports `openai.api_type = "amazon-bedrock"` or `OPENAI_API_TYPE=amazon-bedrock`. +Pass `base_url` to `bedrock(...)` or set `AWS_BEDROCK_BASE_URL` to override the derived `https://bedrock-mantle..api.aws/openai/v1` endpoint. -Normal SDK requests use replayable, fully signed bodies. Low-level one-shot request streams are signed with `UNSIGNED-PAYLOAD` only when retries are disabled with `max_retries=0`; buffering is recommended because streamed request bodies cannot be safely retried. +SigV4 requests require replayable, fully serialized request bodies. Standard JSON requests already meet this requirement, and response streaming is unaffected. Low-level one-shot request streams must be buffered before sending, or sent with bearer authentication and retries disabled. Bearer tokens remain available as a compatibility or manual authentication mode. Set `AWS_BEARER_TOKEN_BEDROCK` to an [Amazon Bedrock API key](https://docs.aws.amazon.com/bedrock/latest/userguide/api-keys.html), pass `api_key`, or provide a refresh callback: ```py +client = OpenAI( + provider=bedrock( + region="us-west-2", + token_provider=lambda: refresh_bedrock_token(), + ) +) +``` + +Without explicit authentication, `AWS_BEARER_TOKEN_BEDROCK` takes precedence over the default AWS credential chain for backwards compatibility. + +### Legacy `BedrockOpenAI` client + +`BedrockOpenAI` and `AsyncBedrockOpenAI` remain available for existing applications and delegate to the same provider implementation. New applications should prefer `OpenAI(provider=bedrock(...))`. + +```py +from openai import BedrockOpenAI + client = BedrockOpenAI( aws_region="us-west-2", - bedrock_token_provider=lambda: refresh_bedrock_token(), + aws_profile="my-profile", ) ``` -Without explicit authentication, `AWS_BEARER_TOKEN_BEDROCK` takes precedence over the default AWS credential chain for backwards compatibility. +The legacy module client also continues to support `openai.api_type = "amazon-bedrock"` or `OPENAI_API_TYPE=amazon-bedrock`. ## Versioning diff --git a/examples/bedrock.py b/examples/bedrock.py index 24dafb5b80..6a1837ef49 100644 --- a/examples/bedrock.py +++ b/examples/bedrock.py @@ -1,9 +1,11 @@ -from openai import BedrockOpenAI +from openai import OpenAI +from openai.providers import bedrock -client = BedrockOpenAI() - -# For refreshed Bedrock bearer tokens: -# client = BedrockOpenAI(aws_region="us-west-2", bedrock_token_provider=get_bedrock_token) +client = OpenAI( + provider=bedrock( + region="us-west-2", + ) +) response = client.responses.create( model="openai.gpt-5.4", diff --git a/src/openai/__init__.py b/src/openai/__init__.py index b4de9c5754..a3f3237c38 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -320,6 +320,12 @@ def _refresh_api_key(self) -> str: return super()._refresh_api_key() + @override + def _legacy_auth_configuration(self) -> _bedrock._LegacyAuthConfiguration: + if api_key is not None: + return ("bearer", api_key) + return super()._legacy_auth_configuration() + class _AmbiguousModuleClientUsageError(OpenAIError): def __init__(self) -> None: diff --git a/src/openai/_client.py b/src/openai/_client.py index 499a62dfe5..0bd3d4a402 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -31,6 +31,7 @@ from ._compat import cached_property from ._models import SecurityOptions, FinalRequestOptions from ._version import __version__ +from ._provider import _Provider, _provider_name, _ProviderRuntime, _configure_provider from ._streaming import Stream as Stream, AsyncStream as AsyncStream from ._exceptions import OpenAIError, APIStatusError from ._base_client import ( @@ -110,6 +111,8 @@ class OpenAI(SyncAPIClient): project: str | None webhook_secret: str | None _workload_identity_auth: WorkloadIdentityAuth | None + _provider: _Provider | None + _provider_runtime: _ProviderRuntime | None websocket_base_url: str | httpx.URL | None """Base URL for WebSocket connections. @@ -128,6 +131,7 @@ def __init__( organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, + provider: _Provider | None = None, base_url: str | httpx.URL | None = None, websocket_base_url: str | httpx.URL | None = None, timeout: float | Timeout | None | NotGiven = not_given, @@ -157,13 +161,44 @@ def __init__( - `organization` from `OPENAI_ORG_ID` - `project` from `OPENAI_PROJECT_ID` - `webhook_secret` from `OPENAI_WEBHOOK_SECRET` + + When `provider` is supplied, authentication and the base URL are configured by that provider instead. """ + provider_runtime: _ProviderRuntime | None = None + if provider is not None: + provider_name = _provider_name(provider) + conflicts = [ + name + for name, value in ( + ("api_key", api_key), + ("admin_api_key", admin_api_key), + ("workload_identity", workload_identity), + ("base_url", base_url), + ) + if value is not None + ] + if conflicts: + formatted = ", ".join(f"`{name}`" for name in conflicts) + raise OpenAIError( + f"`provider` cannot be combined with top-level {formatted}. " + f"Move provider authentication and routing options into `{provider_name}(...)`." + ) + + provider_runtime = _configure_provider(provider) + + self._provider = provider + self._provider_runtime = provider_runtime + if api_key is not None and api_key != WORKLOAD_IDENTITY_API_KEY_PLACEHOLDER and workload_identity is not None: raise OpenAIError("The `api_key` and `workload_identity` arguments are mutually exclusive") - self.workload_identity = workload_identity + self.workload_identity = workload_identity if provider_runtime is None else None - if workload_identity is not None: + if provider_runtime is not None: + self.api_key = "" + self._api_key_provider = None + self._workload_identity_auth = None + elif workload_identity is not None: self.api_key = WORKLOAD_IDENTITY_API_KEY_PLACEHOLDER self._api_key_provider = None self._workload_identity_auth = WorkloadIdentityAuth( @@ -180,12 +215,13 @@ def __init__( self._api_key_provider = None self._workload_identity_auth = None - if admin_api_key is None: + if admin_api_key is None and provider_runtime is None: admin_api_key = os.environ.get("OPENAI_ADMIN_KEY") - self.admin_api_key = admin_api_key + self.admin_api_key = admin_api_key if provider_runtime is None else None if ( - _enforce_credentials + provider_runtime is None + and _enforce_credentials and not self.api_key and self._api_key_provider is None and workload_identity is None @@ -195,11 +231,11 @@ def __init__( "Missing credentials. Please pass an `api_key`, `workload_identity`, `admin_api_key`, or set the `OPENAI_API_KEY` or `OPENAI_ADMIN_KEY` environment variable." ) - if organization is None: + if organization is None and provider_runtime is None: organization = os.environ.get("OPENAI_ORG_ID") self.organization = organization - if project is None: + if project is None and provider_runtime is None: project = os.environ.get("OPENAI_PROJECT_ID") self.project = project @@ -209,12 +245,14 @@ def __init__( self.websocket_base_url = websocket_base_url - if base_url is None: + if provider_runtime is not None: + base_url = provider_runtime.base_url + elif base_url is None: base_url = os.environ.get("OPENAI_BASE_URL") if base_url is None: base_url = f"https://api.openai.com/v1" - custom_headers_env = os.environ.get("OPENAI_CUSTOM_HEADERS") + custom_headers_env = os.environ.get("OPENAI_CUSTOM_HEADERS") if provider_runtime is None else None if custom_headers_env is not None: parsed: dict[str, str] = {} for line in custom_headers_env.split("\n"): @@ -437,10 +475,16 @@ def _send_request( stream: bool, **kwargs: Unpack[HttpxSendArgs], ) -> httpx.Response: - return self._send_with_auth_retry(request, stream=stream, **kwargs) + response = self._send_with_auth_retry(request, stream=stream, **kwargs) + if self._provider_runtime is not None and self._provider_runtime.normalize_response is not None: + response = self._provider_runtime.normalize_response(response) + return response @override def _auth_headers(self, security: SecurityOptions) -> dict[str, str]: + if self._provider_runtime is not None: + return {} + if security.get("bearer_auth", False): headers = self._bearer_auth if headers: @@ -461,6 +505,9 @@ def _bearer_auth(self) -> dict[str, str]: @property @override def auth_headers(self) -> dict[str, str]: + if self._provider_runtime is not None: + return {} + api_key = self.api_key if not api_key or api_key == WORKLOAD_IDENTITY_API_KEY_PLACEHOLDER: return {} @@ -486,6 +533,9 @@ def default_headers(self) -> dict[str, str | Omit]: @override def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None: + if self._provider_runtime is not None: + return + if _has_header(headers, "Authorization") or _has_omitted_header(custom_headers, "Authorization"): return @@ -495,11 +545,26 @@ def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None: @override def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: - if self._api_key_provider is not None and options.security.get("bearer_auth", False): + if self._provider_runtime is not None: + if self._provider_runtime.transform_request is not None: + options = self._provider_runtime.transform_request(options) + elif self._api_key_provider is not None and options.security.get("bearer_auth", False): self._refresh_api_key() return super()._prepare_options(options) + @override + def _prepare_request(self, request: httpx.Request) -> None: + if self._provider_runtime is not None and self._provider_runtime.prepare_request is not None: + self._provider_runtime.prepare_request(request) + + @override + def _custom_auth(self, security: SecurityOptions) -> httpx.Auth | None: + if self._provider_runtime is not None: + return httpx.Auth() + + return super()._custom_auth(security) + def _refresh_api_key(self) -> str: if self._api_key_provider is not None: self.api_key = self._api_key_provider() @@ -512,6 +577,7 @@ def copy( api_key: str | Callable[[], str] | None = None, admin_api_key: str | None = None, workload_identity: WorkloadIdentity | None = None, + provider: _Provider | None | NotGiven = not_given, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -550,21 +616,43 @@ def copy( http_client = http_client or self._client + next_provider = self._provider if isinstance(provider, NotGiven) else provider + auth_options: dict[str, Any] + if next_provider is not None: + auth_options = { + "provider": next_provider, + "api_key": api_key, + "admin_api_key": admin_api_key, + "workload_identity": workload_identity, + "base_url": base_url, + } + elif self._provider is not None: + auth_options = { + "api_key": api_key, + "admin_api_key": admin_api_key, + "workload_identity": workload_identity, + "base_url": base_url, + } + else: + auth_options = { + "api_key": api_key or self._api_key_provider or self.api_key, + "admin_api_key": admin_api_key or self.admin_api_key, + "workload_identity": workload_identity or self.workload_identity, + "base_url": base_url or self.base_url, + } + return self.__class__( - api_key=api_key or self._api_key_provider or self.api_key, - admin_api_key=admin_api_key or self.admin_api_key, - workload_identity=workload_identity or self.workload_identity, organization=organization or self.organization, project=project or self.project, webhook_secret=webhook_secret or self.webhook_secret, websocket_base_url=websocket_base_url or self.websocket_base_url, - base_url=base_url or self.base_url, timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, http_client=http_client, max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, _enforce_credentials=True if _enforce_credentials is None else _enforce_credentials, + **auth_options, **_extra_kwargs, ) @@ -616,6 +704,8 @@ class AsyncOpenAI(AsyncAPIClient): project: str | None webhook_secret: str | None _workload_identity_auth: WorkloadIdentityAuth | None + _provider: _Provider | None + _provider_runtime: _ProviderRuntime | None websocket_base_url: str | httpx.URL | None """Base URL for WebSocket connections. @@ -634,6 +724,7 @@ def __init__( organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, + provider: _Provider | None = None, base_url: str | httpx.URL | None = None, websocket_base_url: str | httpx.URL | None = None, timeout: float | Timeout | None | NotGiven = not_given, @@ -663,13 +754,44 @@ def __init__( - `organization` from `OPENAI_ORG_ID` - `project` from `OPENAI_PROJECT_ID` - `webhook_secret` from `OPENAI_WEBHOOK_SECRET` + + When `provider` is supplied, authentication and the base URL are configured by that provider instead. """ + provider_runtime: _ProviderRuntime | None = None + if provider is not None: + provider_name = _provider_name(provider) + conflicts = [ + name + for name, value in ( + ("api_key", api_key), + ("admin_api_key", admin_api_key), + ("workload_identity", workload_identity), + ("base_url", base_url), + ) + if value is not None + ] + if conflicts: + formatted = ", ".join(f"`{name}`" for name in conflicts) + raise OpenAIError( + f"`provider` cannot be combined with top-level {formatted}. " + f"Move provider authentication and routing options into `{provider_name}(...)`." + ) + + provider_runtime = _configure_provider(provider) + + self._provider = provider + self._provider_runtime = provider_runtime + if api_key is not None and api_key != WORKLOAD_IDENTITY_API_KEY_PLACEHOLDER and workload_identity is not None: raise OpenAIError("The `api_key` and `workload_identity` arguments are mutually exclusive") - self.workload_identity = workload_identity + self.workload_identity = workload_identity if provider_runtime is None else None - if workload_identity is not None: + if provider_runtime is not None: + self.api_key = "" + self._api_key_provider = None + self._workload_identity_auth = None + elif workload_identity is not None: self.api_key = WORKLOAD_IDENTITY_API_KEY_PLACEHOLDER self._api_key_provider = None self._workload_identity_auth = WorkloadIdentityAuth( @@ -686,12 +808,13 @@ def __init__( self._api_key_provider = None self._workload_identity_auth = None - if admin_api_key is None: + if admin_api_key is None and provider_runtime is None: admin_api_key = os.environ.get("OPENAI_ADMIN_KEY") - self.admin_api_key = admin_api_key + self.admin_api_key = admin_api_key if provider_runtime is None else None if ( - _enforce_credentials + provider_runtime is None + and _enforce_credentials and not self.api_key and self._api_key_provider is None and workload_identity is None @@ -701,11 +824,11 @@ def __init__( "Missing credentials. Please pass an `api_key`, `workload_identity`, `admin_api_key`, or set the `OPENAI_API_KEY` or `OPENAI_ADMIN_KEY` environment variable." ) - if organization is None: + if organization is None and provider_runtime is None: organization = os.environ.get("OPENAI_ORG_ID") self.organization = organization - if project is None: + if project is None and provider_runtime is None: project = os.environ.get("OPENAI_PROJECT_ID") self.project = project @@ -715,12 +838,14 @@ def __init__( self.websocket_base_url = websocket_base_url - if base_url is None: + if provider_runtime is not None: + base_url = provider_runtime.base_url + elif base_url is None: base_url = os.environ.get("OPENAI_BASE_URL") if base_url is None: base_url = f"https://api.openai.com/v1" - custom_headers_env = os.environ.get("OPENAI_CUSTOM_HEADERS") + custom_headers_env = os.environ.get("OPENAI_CUSTOM_HEADERS") if provider_runtime is None else None if custom_headers_env is not None: parsed: dict[str, str] = {} for line in custom_headers_env.split("\n"): @@ -943,10 +1068,19 @@ async def _send_request( stream: bool, **kwargs: Unpack[HttpxSendArgs], ) -> httpx.Response: - return await self._send_with_auth_retry(request, stream=stream, **kwargs) + response = await self._send_with_auth_retry(request, stream=stream, **kwargs) + if self._provider_runtime is not None: + if self._provider_runtime.normalize_async_response is not None: + response = await self._provider_runtime.normalize_async_response(response) + elif self._provider_runtime.normalize_response is not None: + response = self._provider_runtime.normalize_response(response) + return response @override def _auth_headers(self, security: SecurityOptions) -> dict[str, str]: + if self._provider_runtime is not None: + return {} + if security.get("bearer_auth", False): headers = self._bearer_auth if headers: @@ -967,6 +1101,9 @@ def _bearer_auth(self) -> dict[str, str]: @property @override def auth_headers(self) -> dict[str, str]: + if self._provider_runtime is not None: + return {} + api_key = self.api_key if not api_key or api_key == WORKLOAD_IDENTITY_API_KEY_PLACEHOLDER: return {} @@ -992,6 +1129,9 @@ def default_headers(self) -> dict[str, str | Omit]: @override def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None: + if self._provider_runtime is not None: + return + if _has_header(headers, "Authorization") or _has_omitted_header(custom_headers, "Authorization"): return @@ -1001,11 +1141,34 @@ def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None: @override async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: - if self._api_key_provider is not None and options.security.get("bearer_auth", False): + if self._provider_runtime is not None: + if self._provider_runtime.transform_async_request is not None: + options = await self._provider_runtime.transform_async_request(options) + elif self._provider_runtime.transform_request is not None: + options = self._provider_runtime.transform_request(options) + elif self._api_key_provider is not None and options.security.get("bearer_auth", False): await self._refresh_api_key() return await super()._prepare_options(options) + @override + async def _prepare_request(self, request: httpx.Request) -> None: + if self._provider_runtime is None: + return + + if self._provider_runtime.prepare_async_request is not None: + await self._provider_runtime.prepare_async_request(request) + elif self._provider_runtime.prepare_request is not None: + self._provider_runtime.prepare_request(request) + + @property + @override + def custom_auth(self) -> httpx.Auth | None: + if self._provider_runtime is not None: + return httpx.Auth() + + return super().custom_auth + async def _refresh_api_key(self) -> str: if self._api_key_provider is not None: self.api_key = await self._api_key_provider() @@ -1018,6 +1181,7 @@ def copy( api_key: str | Callable[[], Awaitable[str]] | None = None, admin_api_key: str | None = None, workload_identity: WorkloadIdentity | None = None, + provider: _Provider | None | NotGiven = not_given, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -1055,21 +1219,43 @@ def copy( params = set_default_query http_client = http_client or self._client + next_provider = self._provider if isinstance(provider, NotGiven) else provider + auth_options: dict[str, Any] + if next_provider is not None: + auth_options = { + "provider": next_provider, + "api_key": api_key, + "admin_api_key": admin_api_key, + "workload_identity": workload_identity, + "base_url": base_url, + } + elif self._provider is not None: + auth_options = { + "api_key": api_key, + "admin_api_key": admin_api_key, + "workload_identity": workload_identity, + "base_url": base_url, + } + else: + auth_options = { + "api_key": api_key or self._api_key_provider or self.api_key, + "admin_api_key": admin_api_key or self.admin_api_key, + "workload_identity": workload_identity or self.workload_identity, + "base_url": base_url or self.base_url, + } + return self.__class__( - api_key=api_key or self._api_key_provider or self.api_key, - admin_api_key=admin_api_key or self.admin_api_key, - workload_identity=workload_identity or self.workload_identity, organization=organization or self.organization, project=project or self.project, webhook_secret=webhook_secret or self.webhook_secret, websocket_base_url=websocket_base_url or self.websocket_base_url, - base_url=base_url or self.base_url, timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, http_client=http_client, max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, _enforce_credentials=True if _enforce_credentials is None else _enforce_credentials, + **auth_options, **_extra_kwargs, ) diff --git a/src/openai/_provider.py b/src/openai/_provider.py new file mode 100644 index 0000000000..f7b9d0bb84 --- /dev/null +++ b/src/openai/_provider.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from typing import Callable, Protocol, Awaitable +from weakref import WeakKeyDictionary +from dataclasses import dataclass + +import httpx + +from ._models import FinalRequestOptions +from ._exceptions import OpenAIError + + +class _Provider: + """Opaque configuration returned by an OpenAI-owned provider factory.""" + + __slots__ = ("__weakref__",) + + +@dataclass +class _ProviderRuntime: + name: str + base_url: str | httpx.URL + transform_request: Callable[[FinalRequestOptions], FinalRequestOptions] | None = None + transform_async_request: Callable[[FinalRequestOptions], Awaitable[FinalRequestOptions]] | None = None + prepare_request: Callable[[httpx.Request], None] | None = None + prepare_async_request: Callable[[httpx.Request], Awaitable[None]] | None = None + normalize_response: Callable[[httpx.Response], httpx.Response] | None = None + normalize_async_response: Callable[[httpx.Response], Awaitable[httpx.Response]] | None = None + + +class _ProviderDefinition(Protocol): + @property + def name(self) -> str: ... + + def configure(self) -> _ProviderRuntime: ... + + +_provider_definitions: WeakKeyDictionary[_Provider, _ProviderDefinition] = WeakKeyDictionary() + + +def _create_provider(definition: _ProviderDefinition) -> _Provider: # pyright: ignore[reportUnusedFunction] + provider = _Provider() + _provider_definitions[provider] = definition + return provider + + +def _provider_name(provider: _Provider) -> str: # pyright: ignore[reportUnusedFunction] + return _get_provider_definition(provider).name + + +def _configure_provider(provider: _Provider) -> _ProviderRuntime: # pyright: ignore[reportUnusedFunction] + return _get_provider_definition(provider).configure() + + +def _get_provider_definition(provider: _Provider) -> _ProviderDefinition: + try: + return _provider_definitions[provider] + except (KeyError, TypeError) as exc: + raise OpenAIError("Invalid provider. Providers must be created by an OpenAI provider factory.") from exc diff --git a/src/openai/lib/_bedrock_auth.py b/src/openai/lib/_bedrock_auth.py index 140ca894b0..fc92c489f2 100644 --- a/src/openai/lib/_bedrock_auth.py +++ b/src/openai/lib/_bedrock_auth.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import hashlib import importlib from typing import Literal, Mapping, Callable, Protocol, cast from dataclasses import field, dataclass @@ -15,7 +16,6 @@ def get_credentials(self) -> object | None: ... _AUTHORIZATION = "authorization" -_UNSIGNED_PAYLOAD = "UNSIGNED-PAYLOAD" _AWS_SIGNING_HEADERS = ( _AUTHORIZATION, "x-amz-content-sha256", @@ -24,12 +24,6 @@ def get_credentials(self) -> object | None: ... ) -@dataclass(frozen=True) -class BedrockBearerAuthConfig: - source: Literal["explicit", "provider", "environment"] - region_source: Literal["explicit", "environment"] | None = None - - @dataclass(frozen=True) class BedrockAwsAuthConfig: region: str @@ -137,7 +131,8 @@ def sign(self, *, method: str, url: str, headers: Mapping[str, str], body: bytes ) if credentials is None: raise OpenAIError( - "Could not find credentials for Bedrock. Pass a bearer credential or AWS credentials, " + "Could not find credentials for Bedrock. Pass a bearer credential or AWS credentials to " + "`bedrock(...)`, " "set `AWS_BEARER_TOKEN_BEDROCK`, or configure the default AWS credential chain." ) @@ -148,9 +143,7 @@ def sign(self, *, method: str, url: str, headers: Mapping[str, str], body: bytes signed_headers = { name: value for name, value in headers.items() if name.lower() not in _AWS_SIGNING_HEADERS } - if body is None: - signed_headers["X-Amz-Content-SHA256"] = _UNSIGNED_PAYLOAD - + signed_headers["X-Amz-Content-SHA256"] = hashlib.sha256(body or b"").hexdigest() aws_request = self._aws_request_cls( method=method, url=url, @@ -185,69 +178,8 @@ def resolve_aws_region_with_source( if region is None or not region.strip(): raise OpenAIError( - "Bedrock requires an AWS region. Pass `aws_region`, or set `AWS_REGION` or `AWS_DEFAULT_REGION`." + "Bedrock requires an AWS region. Pass `region` to `bedrock(...)`, or set `AWS_REGION` or " + "`AWS_DEFAULT_REGION`." ) return region.strip(), source - - -def resolve_aws_region(aws_region: str | None, *, session: object | None = None) -> str: - return resolve_aws_region_with_source(aws_region, session=session)[0] - - -def resolve_bedrock_env_token() -> str | None: - return os.environ.get("AWS_BEARER_TOKEN_BEDROCK") or None - - -def has_explicit_aws_auth( - *, - aws_profile: str | None, - aws_access_key_id: str | None, - aws_secret_access_key: str | None, - aws_session_token: str | None, - aws_credentials_provider: AwsCredentialsProvider | None, -) -> bool: - return any( - value is not None - for value in ( - aws_profile, - aws_access_key_id, - aws_secret_access_key, - aws_session_token, - aws_credentials_provider, - ) - ) - - -def validate_explicit_aws_auth( - *, - aws_profile: str | None, - aws_access_key_id: str | None, - aws_secret_access_key: str | None, - aws_session_token: str | None, - aws_credentials_provider: AwsCredentialsProvider | None, -) -> None: - if (aws_access_key_id is None) != (aws_secret_access_key is None): - raise OpenAIError( - "Static AWS credentials require both `aws_access_key_id` and `aws_secret_access_key`. " - "An `aws_session_token` may only be used with both." - ) - - credential_sources = sum( - ( - aws_profile is not None, - aws_access_key_id is not None, - aws_credentials_provider is not None, - ) - ) - if credential_sources > 1: - raise OpenAIError( - "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " - "static AWS credentials, profile, or credential provider." - ) - - if aws_session_token is not None and aws_access_key_id is None: - raise OpenAIError( - "Static AWS credentials require both `aws_access_key_id` and `aws_secret_access_key`. " - "An `aws_session_token` may only be used with both." - ) diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index 4fcae24788..888b480dbb 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -13,6 +13,7 @@ from .._client import OpenAI, AsyncOpenAI from .._compat import model_copy from .._models import SecurityOptions, FinalRequestOptions +from .._provider import _Provider from .._streaming import Stream, AsyncStream from .._exceptions import OpenAIError from .._base_client import DEFAULT_MAX_RETRIES, BaseClient @@ -283,6 +284,7 @@ def copy( api_key: str | Callable[[], str] | None = None, admin_api_key: str | None = None, workload_identity: WorkloadIdentity | None = None, + provider: _Provider | None | NotGiven = NOT_GIVEN, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -304,6 +306,9 @@ def copy( """ Create a new client instance re-using the same options given to the current client with optional overriding. """ + if not isinstance(provider, NotGiven): + raise OpenAIError("Configure `provider` on `OpenAI`, not on `AzureOpenAI.with_options()`.") + return super().copy( api_key=api_key, admin_api_key=admin_api_key, @@ -603,6 +608,7 @@ def copy( api_key: str | Callable[[], Awaitable[str]] | None = None, admin_api_key: str | None = None, workload_identity: WorkloadIdentity | None = None, + provider: _Provider | None | NotGiven = NOT_GIVEN, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -624,6 +630,9 @@ def copy( """ Create a new client instance re-using the same options given to the current client with optional overriding. """ + if not isinstance(provider, NotGiven): + raise OpenAIError("Configure `provider` on `AsyncOpenAI`, not on `AsyncAzureOpenAI.with_options()`.") + return super().copy( api_key=api_key, admin_api_key=admin_api_key, diff --git a/src/openai/lib/bedrock.py b/src/openai/lib/bedrock.py index 802f476d84..b05474f077 100644 --- a/src/openai/lib/bedrock.py +++ b/src/openai/lib/bedrock.py @@ -2,54 +2,57 @@ import os import re +import hashlib import inspect -from typing import Any, Literal, Mapping, Callable, Awaitable, cast -from dataclasses import replace -from typing_extensions import Self, Unpack, override +from typing import Any, Literal, Mapping, Callable, Optional, Awaitable, cast +from dataclasses import field, replace, dataclass +from typing_extensions import Self, override import httpx from ..auth import WorkloadIdentity -from .._types import NOT_GIVEN, Omit, Headers, Timeout, NotGiven, HttpxSendArgs -from .._utils import asyncify, is_given +from .._types import NOT_GIVEN, Timeout, NotGiven +from .._utils import is_given from .._client import OpenAI, AsyncOpenAI -from .._models import SecurityOptions, FinalRequestOptions +from .._models import FinalRequestOptions +from .._provider import _Provider, _configure_provider from .._exceptions import OpenAIError from .._base_client import DEFAULT_MAX_RETRIES -from ._bedrock_auth import ( - BedrockAwsAuth as _BedrockAwsAuth, - BedrockAwsAuthConfig as _BedrockAwsAuthConfig, - AwsCredentialsProvider, - BedrockBearerAuthConfig as _BedrockBearerAuthConfig, - resolve_aws_region as _resolve_aws_region, - has_explicit_aws_auth as _has_explicit_aws_auth, - resolve_bedrock_env_token as _resolve_bedrock_env_token, - validate_explicit_aws_auth as _validate_explicit_aws_auth, - resolve_aws_region_with_source as _resolve_aws_region_with_source, -) +from ..providers.bedrock import AwsCredentialsProvider, bedrock BedrockTokenProvider = Callable[[], str] AsyncBedrockTokenProvider = Callable[[], "str | Awaitable[str]"] +_LegacyAuthMode = Literal["bearer", "token_provider", "aws"] +_LegacyAuthConfiguration = tuple[_LegacyAuthMode, Optional[object]] +_LEGACY_SIGNATURE_KEY = os.urandom(32) -_BEDROCK_AUTH_INTENT_EXTENSION = "openai.bedrock_auth_intent" -_BEDROCK_AUTH_INTENT_DEFAULT = "default" -_BEDROCK_AUTH_INTENT_OMIT = "omit" -_BEDROCK_AUTH_INTENT_OVERRIDE = "override" -_BEDROCK_MAX_RETRIES_EXTENSION = "openai.bedrock_max_retries" -_AWS_SIGNING_HEADERS = ("authorization", "x-amz-content-sha256", "x-amz-date", "x-amz-security-token") +@dataclass(frozen=True) +class _LegacyRuntimeSignature: + mode: _LegacyAuthMode + base_url: str + region: str | None + credential_identity: object = field(repr=False) -def _authorization_intent(*header_sets: Mapping[str, str | Omit]) -> str: - intent = _BEDROCK_AUTH_INTENT_DEFAULT - for headers in header_sets: - for name, value in headers.items(): - if name.lower() == "authorization": - intent = _BEDROCK_AUTH_INTENT_OMIT if isinstance(value, Omit) else _BEDROCK_AUTH_INTENT_OVERRIDE - return intent + +@dataclass(frozen=True) +class _LegacyBedrockState: + explicit_api_key: str | None = field(repr=False) + token_provider: BedrockTokenProvider | AsyncBedrockTokenProvider | None = field(repr=False, compare=False) + aws_region: str | None + region_was_explicit: bool + aws_profile: str | None + aws_access_key_id: str | None = field(repr=False) + aws_secret_access_key: str | None = field(repr=False) + aws_session_token: str | None = field(repr=False) + aws_credentials_provider: AwsCredentialsProvider | None = field(repr=False, compare=False) + uses_environment_bearer: bool + environment_bearer_token: str | None = field(repr=False) + uses_region_derived_base_url: bool -def _same_origin(left: httpx.URL, right: httpx.URL) -> bool: - return (left.scheme, left.host, left.port) == (right.scheme, right.host, right.port) +def _state_api_key(state: _LegacyBedrockState) -> str: + return state.explicit_api_key or (state.environment_bearer_token if state.uses_environment_bearer else "") or "" def _constructor_accepts_keyword(constructor: Callable[..., object], name: str) -> bool: @@ -63,134 +66,73 @@ def _constructor_accepts_keyword(constructor: Callable[..., object], name: str) ) -def _body_for_signing(request: httpx.Request) -> bytes | None: - try: - return request.content - except httpx.RequestNotRead as exc: - max_retries = request.extensions.get(_BEDROCK_MAX_RETRIES_EXTENSION) - if max_retries == 0: - return None - - raise OpenAIError( - "Bedrock SigV4 authentication requires a replayable request body when retries are enabled. " - "Buffer the body, set `max_retries=0` to use `UNSIGNED-PAYLOAD`, or use bearer authentication." - ) from exc - - -def _normalize_bedrock_base_url(base_url: str | httpx.URL) -> httpx.URL: - """Normalize a Bedrock Responses URL variant back to the provider API root.""" - url = httpx.URL(base_url) - path = url.path.rstrip("/") - responses_match = re.search(r"/responses(?:/.*)?$", path) - if responses_match is not None: - path = path[: responses_match.start()] - - return url.copy_with(path=path or "/") +def _configured_region(region: str | None) -> str | None: + configured = region or os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") + return configured.strip() if configured is not None and configured.strip() else None -def _configured_aws_region(aws_region: str | None) -> str | None: - region = aws_region if aws_region is not None and aws_region.strip() else None - region = region or os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") - return region.strip() if region is not None and region.strip() else None - - -def _configured_aws_region_source(aws_region: str | None) -> Literal["explicit", "environment"] | None: - if aws_region is not None and aws_region.strip(): - return "explicit" - environment_region = os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") - if environment_region is not None and environment_region.strip(): - return "environment" - return None - - -def _resolve_bedrock_base_url( - base_url: str | httpx.URL | None, - aws_region: str | None, - *, - use_environment: bool = True, -) -> httpx.URL: - """Resolve Bedrock base URL precedence from explicit, env, then region config.""" +def _uses_region_derived_base_url(base_url: str | httpx.URL | None) -> bool: if isinstance(base_url, str) and not base_url.strip(): base_url = None - - if base_url is None and use_environment: - env_base_url = os.environ.get("AWS_BEDROCK_BASE_URL") - if env_base_url is not None and env_base_url.strip(): - base_url = env_base_url - - if base_url is None: - region = _configured_aws_region(aws_region) - if region is None: - raise OpenAIError( - "Bedrock requires an AWS region. Pass `aws_region`, or set `AWS_REGION` or `AWS_DEFAULT_REGION`." - ) - - base_url = f"https://bedrock-mantle.{region}.api.aws/openai/v1" - - return _normalize_bedrock_base_url(base_url) - - -def _uses_region_derived_bedrock_base_url(base_url: str | httpx.URL | None) -> bool: - if isinstance(base_url, str) and not base_url.strip(): - base_url = None - if base_url is not None: return False - env_base_url = os.environ.get("AWS_BEDROCK_BASE_URL") - return env_base_url is None or not env_base_url.strip() - - -def _bedrock_token_provider(provider: BedrockTokenProvider) -> BedrockTokenProvider: - """Adapt a sync Bedrock token provider to the base client's api_key callback.""" - - def get_token() -> str: - token = cast(object, provider()) - if not isinstance(token, str) or not token: - raise ValueError(f"Expected `bedrock_token_provider` argument to return a string but it returned {token}") - - return token - - return get_token - + environment_base_url = os.environ.get("AWS_BEDROCK_BASE_URL") + return environment_base_url is None or not environment_base_url.strip() -def _async_bedrock_token_provider(provider: AsyncBedrockTokenProvider) -> Callable[[], Awaitable[str]]: - """Adapt a sync or async Bedrock token provider to the async api_key callback.""" - async def get_token() -> str: - token = cast(object, provider()) - if inspect.isawaitable(token): - token = await token - - if not isinstance(token, str) or not token: - raise ValueError(f"Expected `bedrock_token_provider` argument to return a string but it returned {token}") +def _has_explicit_aws_auth( + *, + aws_profile: str | None, + aws_access_key_id: str | None, + aws_secret_access_key: str | None, + aws_session_token: str | None, + aws_credentials_provider: AwsCredentialsProvider | None, +) -> bool: + return any( + value is not None + for value in ( + aws_profile, + aws_access_key_id, + aws_secret_access_key, + aws_session_token, + aws_credentials_provider, + ) + ) - return token - return get_token +def _environment_bearer_token() -> str: + token = os.environ.get("AWS_BEARER_TOKEN_BEDROCK") + if not token: + raise OpenAIError( + "Could not find credentials for Bedrock. Set `AWS_BEARER_TOKEN_BEDROCK` or configure the default " + "AWS credential chain." + ) + return token -def _resolve_bedrock_auth( +def _legacy_provider( *, api_key: str | None, - token_provider: object | None, + token_provider: BedrockTokenProvider | AsyncBedrockTokenProvider | None, aws_region: str | None, aws_profile: str | None, aws_access_key_id: str | None, aws_secret_access_key: str | None, aws_session_token: str | None, aws_credentials_provider: AwsCredentialsProvider | None, - auth_config: _BedrockBearerAuthConfig | _BedrockAwsAuthConfig | None, - enforce_credentials: bool, -) -> tuple[_BedrockBearerAuthConfig | _BedrockAwsAuthConfig, _BedrockAwsAuth | None, str | None, str | None]: - if auth_config is not None: - if isinstance(auth_config, _BedrockAwsAuthConfig): - aws_auth = _BedrockAwsAuth(auth_config) if enforce_credentials else None - return auth_config, aws_auth, api_key, auth_config.region - - return auth_config, None, api_key, aws_region + base_url: str | httpx.URL | None, +) -> tuple[_Provider, _LegacyBedrockState, str]: + if callable(cast(object, api_key)): + raise OpenAIError("Pass refreshable Bedrock credentials via `bedrock_token_provider`, not `api_key`.") + if api_key == "": + raise OpenAIError("The `api_key` argument must not be empty.") + if api_key is not None and token_provider is not None: + raise OpenAIError( + "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " + "static AWS credentials, profile, or credential provider." + ) - explicit_bearer_auth = api_key is not None or token_provider is not None explicit_aws_auth = _has_explicit_aws_auth( aws_profile=aws_profile, aws_access_key_id=aws_access_key_id, @@ -198,87 +140,245 @@ def _resolve_bedrock_auth( aws_session_token=aws_session_token, aws_credentials_provider=aws_credentials_provider, ) - if explicit_bearer_auth and explicit_aws_auth: + if (api_key is not None or token_provider is not None) and explicit_aws_auth: raise OpenAIError( "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " "static AWS credentials, profile, or credential provider." ) - _validate_explicit_aws_auth( + environment_token = os.environ.get("AWS_BEARER_TOKEN_BEDROCK") + uses_environment_bearer = ( + api_key is None and token_provider is None and not explicit_aws_auth and bool(environment_token) + ) + resolved_region = _configured_region(aws_region) + uses_region_derived_base_url = _uses_region_derived_base_url(base_url) + + provider_base_url: str | httpx.URL | None | NotGiven + if isinstance(base_url, str) and not base_url.strip(): + provider_base_url = None + elif base_url is None: + provider_base_url = NOT_GIVEN + else: + provider_base_url = base_url + + provider = bedrock( + region=aws_region, + base_url=provider_base_url, + api_key=api_key if api_key is not None else environment_token if uses_environment_bearer else NOT_GIVEN, + token_provider=token_provider, + access_key_id=aws_access_key_id, + secret_access_key=aws_secret_access_key, + session_token=aws_session_token, + profile=aws_profile, + credential_provider=aws_credentials_provider, + ) + state = _LegacyBedrockState( + explicit_api_key=api_key, + token_provider=token_provider, + aws_region=resolved_region, + region_was_explicit=bool(aws_region and aws_region.strip()), aws_profile=aws_profile, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, aws_credentials_provider=aws_credentials_provider, + uses_environment_bearer=uses_environment_bearer, + environment_bearer_token=environment_token if uses_environment_bearer else None, + uses_region_derived_base_url=uses_region_derived_base_url, ) + return provider, state, api_key or (environment_token if uses_environment_bearer else "") or "" - if explicit_bearer_auth: - source: Literal["explicit", "provider"] = "provider" if token_provider is not None else "explicit" - return ( - _BedrockBearerAuthConfig(source=source, region_source=_configured_aws_region_source(aws_region)), - None, - api_key, - _configured_aws_region(aws_region), + +def _copy_configuration( + client: BedrockOpenAI | AsyncBedrockOpenAI, + *, + api_key: str | None, + token_provider: BedrockTokenProvider | AsyncBedrockTokenProvider | None, + aws_region: str | None, + aws_profile: str | None, + aws_access_key_id: str | None, + aws_secret_access_key: str | None, + aws_session_token: str | None, + aws_credentials_provider: AwsCredentialsProvider | None, + base_url: str | httpx.URL | None, +) -> tuple[dict[str, object], _Provider | None, _LegacyBedrockState | None]: + _synchronize_legacy_routing_state(client) + state = client._bedrock_state + current_api_key = client.api_key or "" + api_key_was_mutated = state.token_provider is None and current_api_key != _state_api_key(state) + aws_override = _has_explicit_aws_auth( + aws_profile=aws_profile, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_credentials_provider=aws_credentials_provider, + ) + explicit_bearer_override = api_key is not None or token_provider is not None + if explicit_bearer_override and aws_override: + raise OpenAIError( + "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " + "static AWS credentials, profile, or credential provider." ) - if not explicit_aws_auth: - api_key = _resolve_bedrock_env_token() - if api_key is not None: - return ( - _BedrockBearerAuthConfig( - source="environment", - region_source=_configured_aws_region_source(aws_region), - ), - None, - api_key, - _configured_aws_region(aws_region), - ) + effective_api_key = ( + api_key + if api_key is not None + else current_api_key + if api_key_was_mutated and token_provider is None and not aws_override + else None + ) + bearer_override = effective_api_key is not None or token_provider is not None + + routing_override = aws_region is not None or base_url is not None + if not bearer_override and not aws_override and not routing_override: + _refresh_legacy_provider_runtime(client) + return {}, client._bedrock_provider, client._bedrock_state + + if bearer_override: + next_api_key = effective_api_key + next_token_provider = token_provider + next_profile = None + next_access_key_id = None + next_secret_access_key = None + next_session_token = None + next_credentials_provider = None + elif aws_override: + next_api_key = None + next_token_provider = None + next_profile = aws_profile + next_access_key_id = aws_access_key_id + next_secret_access_key = aws_secret_access_key + next_session_token = aws_session_token + next_credentials_provider = aws_credentials_provider + else: + next_api_key = state.explicit_api_key + next_token_provider = state.token_provider + if state.uses_environment_bearer: + next_api_key = state.environment_bearer_token or _environment_bearer_token() + next_token_provider = None + next_profile = state.aws_profile + next_access_key_id = state.aws_access_key_id + next_secret_access_key = state.aws_secret_access_key + next_session_token = state.aws_session_token + next_credentials_provider = state.aws_credentials_provider - if enforce_credentials: - aws_auth = _BedrockAwsAuth.resolve( - region=aws_region, - profile=aws_profile, - access_key_id=aws_access_key_id, - secret_access_key=aws_secret_access_key, - session_token=aws_session_token, - credentials_provider=aws_credentials_provider, - ) - return aws_auth.config, aws_auth, None, aws_auth.config.region - - resolved_region, region_source = _resolve_aws_region_with_source(aws_region) - aws_source: Literal["static", "profile", "provider", "default"] - if aws_access_key_id is not None: - aws_source = "static" - elif aws_profile is not None: - aws_source = "profile" - elif aws_credentials_provider is not None: - aws_source = "provider" + next_region = aws_region if aws_region is not None else client.aws_region + if aws_profile is not None and aws_region is None and not state.region_was_explicit: + next_region = None + + if base_url is not None: + next_base_url: str | httpx.URL | None = base_url + elif state.uses_region_derived_base_url: + next_base_url = "" else: - aws_source = "default" + next_base_url = client.base_url + return ( - _BedrockAwsAuthConfig( - region=resolved_region, - source=aws_source, - region_source=region_source, - profile=aws_profile, - access_key_id=aws_access_key_id, - secret_access_key=aws_secret_access_key, - session_token=aws_session_token, - credentials_provider=aws_credentials_provider, - ), + { + "api_key": next_api_key, + "bedrock_token_provider": next_token_provider, + "aws_region": next_region, + "aws_profile": next_profile, + "aws_access_key_id": next_access_key_id, + "aws_secret_access_key": next_secret_access_key, + "aws_session_token": next_session_token, + "aws_credentials_provider": next_credentials_provider, + "base_url": next_base_url, + }, None, None, - resolved_region, ) +def _legacy_runtime_signature( + client: BedrockOpenAI | AsyncBedrockOpenAI, + configuration: _LegacyAuthConfiguration, +) -> _LegacyRuntimeSignature: + mode, credential = configuration + credential_identity: object = ( + hashlib.blake2s(credential.encode(), key=_LEGACY_SIGNATURE_KEY).digest() + if isinstance(credential, str) + else id(credential) + ) + return _LegacyRuntimeSignature( + mode=mode, + base_url=str(client.base_url), + region=client.aws_region, + credential_identity=credential_identity, + ) + + +def _provider_for_legacy_client( + client: BedrockOpenAI | AsyncBedrockOpenAI, + configuration: _LegacyAuthConfiguration, +) -> _Provider: + mode, credential = configuration + if mode == "bearer": + if not isinstance(credential, str) or not credential: + raise OpenAIError("The Bedrock bearer credential must not be empty.") + return bedrock( + region=client.aws_region, + base_url=client.base_url, + api_key=credential, + ) + if mode == "token_provider": + return bedrock( + region=client.aws_region, + base_url=client.base_url, + token_provider=cast("AsyncBedrockTokenProvider", credential), + ) + + state = client._bedrock_state + return bedrock( + region=client.aws_region, + base_url=client.base_url, + profile=state.aws_profile, + access_key_id=state.aws_access_key_id, + secret_access_key=state.aws_secret_access_key, + session_token=state.aws_session_token, + credential_provider=state.aws_credentials_provider, + ) + + +def _synchronize_legacy_routing_state(client: BedrockOpenAI | AsyncBedrockOpenAI) -> None: + previous_signature = client._bedrock_runtime_signature + base_url_changed = str(client.base_url) != previous_signature.base_url + region_changed = client.aws_region != previous_signature.region + if base_url_changed: + client._bedrock_state = replace(client._bedrock_state, uses_region_derived_base_url=False) + client._uses_region_derived_base_url = False + if region_changed: + client._bedrock_state = replace( + client._bedrock_state, + aws_region=client.aws_region, + region_was_explicit=client.aws_region is not None, + ) + if client._bedrock_state.uses_region_derived_base_url and client.aws_region is not None: + client.base_url = f"https://bedrock-mantle.{client.aws_region}.api.aws/openai/v1" + + +def _refresh_legacy_provider_runtime(client: BedrockOpenAI | AsyncBedrockOpenAI) -> None: + _synchronize_legacy_routing_state(client) + configuration = client._legacy_auth_configuration() + signature = _legacy_runtime_signature(client, configuration) + if signature == client._bedrock_runtime_signature: + return + + provider = _provider_for_legacy_client(client, configuration) + client._bedrock_provider = provider + client._provider = provider + client._provider_runtime = _configure_provider(provider) + client._bedrock_runtime_signature = signature + + class BedrockOpenAI(OpenAI): - """API client for Amazon Bedrock's OpenAI-compatible endpoint.""" + """Compatibility client for Amazon Bedrock's OpenAI-compatible endpoint.""" + _bedrock_provider: _Provider + _bedrock_state: _LegacyBedrockState _bedrock_token_provider: BedrockTokenProvider | None - _bedrock_auth_config: _BedrockBearerAuthConfig | _BedrockAwsAuthConfig - _bedrock_aws_auth: _BedrockAwsAuth | None _uses_region_derived_base_url: bool + _bedrock_runtime_signature: _LegacyRuntimeSignature aws_region: str | None def __init__( @@ -304,67 +404,33 @@ def __init__( http_client: httpx.Client | None = None, _strict_response_validation: bool = False, _enforce_credentials: bool = True, - _auth_config: _BedrockBearerAuthConfig | _BedrockAwsAuthConfig | None = None, - _base_url_is_region_derived: bool | None = None, + _provider: _Provider | None = None, + _state: _LegacyBedrockState | None = None, ) -> None: - """Construct a new synchronous Amazon Bedrock client instance. - - This automatically infers the following arguments from their corresponding environment variables if they are not provided: - - bearer authentication from `AWS_BEARER_TOKEN_BEDROCK` - - `aws_region` from `AWS_REGION` or `AWS_DEFAULT_REGION` when `base_url` and `AWS_BEDROCK_BASE_URL` are not set - - `base_url` from `AWS_BEDROCK_BASE_URL` - - `bedrock_token_provider` is invoked before each request when provided. When no bearer token is configured, - the client uses the standard AWS credential chain and SigV4 authentication. - """ - if callable(cast(object, api_key)): - raise OpenAIError("Pass refreshable Bedrock credentials via `bedrock_token_provider`, not `api_key`.") - - if api_key == "": - raise OpenAIError("The `api_key` argument must not be empty.") - - if api_key is not None and bedrock_token_provider is not None: - raise OpenAIError( - "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " - "static AWS credentials, profile, or credential provider." + if _provider is None or _state is None: + _provider, _state, public_api_key = _legacy_provider( + api_key=api_key, + token_provider=bedrock_token_provider, + aws_region=aws_region, + aws_profile=aws_profile, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_credentials_provider=aws_credentials_provider, + base_url=base_url, + ) + else: + public_api_key = ( + _state.explicit_api_key + or (_state.environment_bearer_token if _state.uses_environment_bearer else "") + or "" ) - - auth_config, aws_auth, api_key, resolved_region = _resolve_bedrock_auth( - api_key=api_key, - token_provider=bedrock_token_provider, - aws_region=aws_region, - aws_profile=aws_profile, - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_session_token=aws_session_token, - aws_credentials_provider=aws_credentials_provider, - auth_config=_auth_config, - enforce_credentials=_enforce_credentials, - ) - - self._bedrock_token_provider = bedrock_token_provider - self._bedrock_auth_config = auth_config - self._bedrock_aws_auth = aws_auth - self._uses_region_derived_base_url = ( - _uses_region_derived_bedrock_base_url(base_url) - if _base_url_is_region_derived is None - else _base_url_is_region_derived - ) - self.aws_region = resolved_region super().__init__( - api_key=_bedrock_token_provider(bedrock_token_provider) - if bedrock_token_provider is not None - else api_key or "", - admin_api_key="", + provider=_provider, organization=organization, project=project, webhook_secret=webhook_secret, - base_url=_resolve_bedrock_base_url( - base_url, - resolved_region, - use_environment=_base_url_is_region_derived is not True, - ), websocket_base_url=websocket_base_url, timeout=timeout, max_retries=max_retries, @@ -375,101 +441,51 @@ def __init__( _enforce_credentials=False, ) + self._bedrock_provider = _provider + self._bedrock_state = _state + self._bedrock_token_provider = cast("BedrockTokenProvider | None", _state.token_provider) + self._uses_region_derived_base_url = _state.uses_region_derived_base_url + canonical_region = re.fullmatch(r"bedrock-mantle\.([a-z0-9-]+)\.api\.aws", self.base_url.host) + self.aws_region = _state.aws_region or (canonical_region.group(1) if canonical_region is not None else None) + self.api_key = public_api_key or "" + self._bedrock_runtime_signature = _legacy_runtime_signature(self, self._legacy_auth_configuration()) + + def _legacy_auth_configuration(self) -> _LegacyAuthConfiguration: + if self._bedrock_token_provider is not None: + return ("token_provider", self._bedrock_token_provider) + if ( + self._bedrock_state.explicit_api_key is not None + or self._bedrock_state.uses_environment_bearer + or self.api_key + ): + return ("bearer", self.api_key) + return ("aws", None) + def _uses_aws_auth(self) -> bool: return ( - isinstance(self._bedrock_auth_config, _BedrockAwsAuthConfig) + self._bedrock_state.explicit_api_key is None and not self.api_key - and self._api_key_provider is None + and self._bedrock_token_provider is None + and not self._bedrock_state.uses_environment_bearer ) @override - def _auth_headers(self, security: SecurityOptions) -> dict[str, str]: - if self._uses_aws_auth(): - return {} - - if security.get("bearer_auth", False) or security.get("admin_api_key_auth", False): - return self._bearer_auth - - return {} - - @override - def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None: - if self._uses_aws_auth(): - return - - super()._validate_headers(headers, custom_headers) + def _refresh_api_key(self) -> str: + if self._bedrock_state.uses_environment_bearer: + captured = self._bedrock_state.environment_bearer_token or "" + return self.api_key if self.api_key and self.api_key != captured else captured + if self._bedrock_token_provider is not None: + token = cast(object, self._bedrock_token_provider()) + if not isinstance(token, str) or not token: + raise ValueError("Expected `bedrock_token_provider` argument to return a non-empty string.") + return token + return self.api_key @override def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: - if self._uses_aws_auth(): - if options.follow_redirects: - raise OpenAIError( - "Bedrock SigV4 authentication does not support automatic redirects. " - "Send a new request to the redirect target so it can be signed again." - ) - options.follow_redirects = False - elif ( - self._api_key_provider is not None - and options.security.get("admin_api_key_auth", False) - and not options.security.get("bearer_auth", False) - ): - self._refresh_api_key() - + _refresh_legacy_provider_runtime(self) return super()._prepare_options(options) - @override - def _build_request(self, options: FinalRequestOptions, *, retries_taken: int = 0) -> httpx.Request: - request = super()._build_request(options, retries_taken=retries_taken) - if not self._uses_aws_auth(): - return request - - option_headers: Headers = options.headers if is_given(options.headers) else {} - request.extensions[_BEDROCK_AUTH_INTENT_EXTENSION] = _authorization_intent( - self._custom_headers, - option_headers, - ) - request.extensions[_BEDROCK_MAX_RETRIES_EXTENSION] = options.get_max_retries(self.max_retries) - return request - - @override - def _prepare_request(self, request: httpx.Request) -> None: - if not self._uses_aws_auth(): - return - if self._bedrock_aws_auth is None: - assert isinstance(self._bedrock_auth_config, _BedrockAwsAuthConfig) - self._bedrock_aws_auth = _BedrockAwsAuth(self._bedrock_auth_config) - - intent = request.extensions.get(_BEDROCK_AUTH_INTENT_EXTENSION, _BEDROCK_AUTH_INTENT_DEFAULT) - if intent == _BEDROCK_AUTH_INTENT_OMIT: - for header in _AWS_SIGNING_HEADERS: - request.headers.pop(header, None) - return - if intent == _BEDROCK_AUTH_INTENT_OVERRIDE or "Authorization" in request.headers: - return - if not _same_origin(request.url, self.base_url): - raise OpenAIError("Refusing to sign a Bedrock request for an origin other than the configured `base_url`.") - - signed_headers = self._bedrock_aws_auth.sign( - method=request.method, - url=str(request.url), - headers=dict(request.headers), - body=_body_for_signing(request), - ) - request.headers.clear() - request.headers.update(signed_headers) - - @override - def _send_request( - self, - request: httpx.Request, - *, - stream: bool, - **kwargs: Unpack[HttpxSendArgs], - ) -> httpx.Response: - if self._uses_aws_auth(): - kwargs["auth"] = httpx.Auth() - return super()._send_request(request, stream=stream, **kwargs) - @override def copy( self, @@ -477,6 +493,7 @@ def copy( api_key: str | BedrockTokenProvider | None = None, admin_api_key: str | None = None, workload_identity: WorkloadIdentity | None = None, + provider: _Provider | None | NotGiven = NOT_GIVEN, bedrock_token_provider: BedrockTokenProvider | None = None, aws_region: str | None = None, aws_profile: str | None = None, @@ -499,111 +516,46 @@ def copy( _enforce_credentials: bool | None = None, _extra_kwargs: Mapping[str, Any] = {}, ) -> Self: - if default_headers is not None and set_default_headers is not None: - raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive") - - if default_query is not None and set_default_query is not None: - raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive") - if callable(api_key): raise OpenAIError("Pass refreshable Bedrock credentials via `bedrock_token_provider`, not `api_key`.") - + if not isinstance(provider, NotGiven): + raise OpenAIError("Configure `provider` on `OpenAI`, not on `BedrockOpenAI.with_options()`.") if admin_api_key is not None or workload_identity is not None: raise OpenAIError("BedrockOpenAI only supports Bedrock bearer token or AWS credential authentication.") - - if api_key is not None and bedrock_token_provider is not None: - raise OpenAIError( - "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " - "static AWS credentials, profile, or credential provider." - ) + if default_headers is not None and set_default_headers is not None: + raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive") + if default_query is not None and set_default_query is not None: + raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive") headers = self._custom_headers if default_headers is not None: headers = {**headers, **default_headers} elif set_default_headers is not None: headers = set_default_headers - params = self._custom_query if default_query is not None: params = {**params, **default_query} elif set_default_query is not None: params = set_default_query - aws_auth_override = _has_explicit_aws_auth( + provider_kwargs, inherited_provider, inherited_state = _copy_configuration( + self, + api_key=api_key, + token_provider=bedrock_token_provider, + aws_region=aws_region, aws_profile=aws_profile, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, aws_credentials_provider=aws_credentials_provider, + base_url=base_url, ) - if (api_key is not None or bedrock_token_provider is not None) and aws_auth_override: - raise OpenAIError( - "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " - "static AWS credentials, profile, or credential provider." - ) - auth_override = api_key is not None or bedrock_token_provider is not None or aws_auth_override - if api_key is not None or aws_auth_override: - next_token_provider = None - elif bedrock_token_provider is not None: - next_token_provider = bedrock_token_provider - else: - next_token_provider = self._bedrock_token_provider - - next_auth_config: _BedrockBearerAuthConfig | _BedrockAwsAuthConfig | None - if auth_override: - next_auth_config = None - elif isinstance(self._bedrock_auth_config, _BedrockAwsAuthConfig) and self.api_key: - # The legacy module client allows a module-level API key to replace - # its construction-time default AWS authentication. - next_auth_config = None - elif aws_region is not None: - if isinstance(self._bedrock_auth_config, _BedrockAwsAuthConfig): - next_auth_config = replace( - self._bedrock_auth_config, - region=_resolve_aws_region(aws_region), - region_source="explicit", - ) - else: - next_auth_config = replace(self._bedrock_auth_config, region_source="explicit") - else: - next_auth_config = self._bedrock_auth_config - - next_aws_region = aws_region if aws_region is not None else self.aws_region - if aws_profile is not None and aws_region is None and self._bedrock_auth_config.region_source != "explicit": - next_aws_region = None - - next_api_key = api_key - if next_api_key is None and next_token_provider is None: - next_api_key = ( - None if aws_auth_override or isinstance(next_auth_config, _BedrockAwsAuthConfig) else self.api_key - ) - - blank_base_url_override = isinstance(base_url, str) and not base_url.strip() - next_base_url = None if blank_base_url_override else base_url - next_base_url_is_region_derived = False - recompute_region_base_url = self._uses_region_derived_base_url and ( - aws_region is not None or (aws_profile is not None and next_aws_region is None) - ) - if blank_base_url_override: - next_base_url_is_region_derived = _uses_region_derived_bedrock_base_url(None) - elif next_base_url is None and not recompute_region_base_url: - next_base_url = self.base_url - next_base_url_is_region_derived = self._uses_region_derived_base_url - elif next_base_url is None and next_aws_region is not None: - next_base_url = f"https://bedrock-mantle.{next_aws_region}.api.aws/openai/v1" - next_base_url_is_region_derived = True - elif next_base_url is None: - next_base_url_is_region_derived = True - constructor_kwargs: dict[str, Any] = { - "api_key": next_api_key, - "bedrock_token_provider": next_token_provider, - "aws_region": next_aws_region, + **provider_kwargs, "organization": organization if organization is not None else self.organization, "project": project if project is not None else self.project, "webhook_secret": webhook_secret if webhook_secret is not None else self.webhook_secret, "websocket_base_url": websocket_base_url if websocket_base_url is not None else self.websocket_base_url, - "base_url": next_base_url, "timeout": self.timeout if isinstance(timeout, NotGiven) else timeout, "http_client": http_client or self._client, "max_retries": max_retries if is_given(max_retries) else self.max_retries, @@ -612,48 +564,45 @@ def copy( "_enforce_credentials": True if _enforce_credentials is None else _enforce_credentials, **_extra_kwargs, } - aws_overrides = { - "aws_profile": aws_profile, - "aws_access_key_id": aws_access_key_id, - "aws_secret_access_key": aws_secret_access_key, - "aws_session_token": aws_session_token, - "aws_credentials_provider": aws_credentials_provider, - } - constructor_kwargs.update({name: value for name, value in aws_overrides.items() if value is not None}) - - supports_auth_config = _constructor_accepts_keyword(self.__class__.__init__, "_auth_config") - supports_base_url_provenance = _constructor_accepts_keyword( - self.__class__.__init__, "_base_url_is_region_derived" - ) - if supports_auth_config: - constructor_kwargs["_auth_config"] = next_auth_config - if supports_base_url_provenance: - constructor_kwargs["_base_url_is_region_derived"] = next_base_url_is_region_derived - - copied = self.__class__(**constructor_kwargs) - if not supports_auth_config and next_auth_config is not None: - copied._bedrock_auth_config = next_auth_config - if isinstance(next_auth_config, _BedrockAwsAuthConfig): - copied._bedrock_aws_auth = _BedrockAwsAuth(next_auth_config) - copied._bedrock_token_provider = None - copied.api_key = "" - copied._api_key_provider = None - copied.aws_region = next_auth_config.region - if not supports_base_url_provenance: - copied._uses_region_derived_base_url = next_base_url_is_region_derived - - return copied + if inherited_provider is not None and _constructor_accepts_keyword(self.__class__.__init__, "_provider"): + constructor_kwargs["_provider"] = inherited_provider + constructor_kwargs["_state"] = inherited_state + elif inherited_provider is not None: + constructor_kwargs.update( + api_key=self._bedrock_state.explicit_api_key or self._bedrock_state.environment_bearer_token, + bedrock_token_provider=self._bedrock_state.token_provider, + aws_region=self._bedrock_state.aws_region, + aws_profile=self._bedrock_state.aws_profile, + aws_access_key_id=self._bedrock_state.aws_access_key_id, + aws_secret_access_key=self._bedrock_state.aws_secret_access_key, + aws_session_token=self._bedrock_state.aws_session_token, + aws_credentials_provider=self._bedrock_state.aws_credentials_provider, + base_url="" if self._bedrock_state.uses_region_derived_base_url else self.base_url, + ) + constructor_kwargs = { + name: value + for name, value in constructor_kwargs.items() + if _constructor_accepts_keyword(self.__class__.__init__, name) + } + elif self.__class__ is not BedrockOpenAI: + constructor_kwargs = { + name: value + for name, value in constructor_kwargs.items() + if value is not None or _constructor_accepts_keyword(self.__class__.__init__, name) + } + return self.__class__(**constructor_kwargs) with_options = copy class AsyncBedrockOpenAI(AsyncOpenAI): - """Async API client for Amazon Bedrock's OpenAI-compatible endpoint.""" + """Async compatibility client for Amazon Bedrock's OpenAI-compatible endpoint.""" + _bedrock_provider: _Provider + _bedrock_state: _LegacyBedrockState _bedrock_token_provider: AsyncBedrockTokenProvider | None - _bedrock_auth_config: _BedrockBearerAuthConfig | _BedrockAwsAuthConfig - _bedrock_aws_auth: _BedrockAwsAuth | None _uses_region_derived_base_url: bool + _bedrock_runtime_signature: _LegacyRuntimeSignature aws_region: str | None def __init__( @@ -679,69 +628,33 @@ def __init__( http_client: httpx.AsyncClient | None = None, _strict_response_validation: bool = False, _enforce_credentials: bool = True, - _auth_config: _BedrockBearerAuthConfig | _BedrockAwsAuthConfig | None = None, - _base_url_is_region_derived: bool | None = None, + _provider: _Provider | None = None, + _state: _LegacyBedrockState | None = None, ) -> None: - """Construct a new asynchronous Amazon Bedrock client instance. - - This automatically infers the following arguments from their corresponding environment variables if they are not provided: - - bearer authentication from `AWS_BEARER_TOKEN_BEDROCK` - - `aws_region` from `AWS_REGION` or `AWS_DEFAULT_REGION` when `base_url` and `AWS_BEDROCK_BASE_URL` are not set - - `base_url` from `AWS_BEDROCK_BASE_URL` - - `bedrock_token_provider` is invoked before each request when provided. When no bearer token is configured, - the client uses the standard AWS credential chain and SigV4 authentication. - """ - if callable(cast(object, api_key)): - raise OpenAIError("Pass refreshable Bedrock credentials via `bedrock_token_provider`, not `api_key`.") - - if api_key == "": - raise OpenAIError("The `api_key` argument must not be empty.") - - if api_key is not None and bedrock_token_provider is not None: - raise OpenAIError( - "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " - "static AWS credentials, profile, or credential provider." + if _provider is None or _state is None: + _provider, _state, public_api_key = _legacy_provider( + api_key=api_key, + token_provider=bedrock_token_provider, + aws_region=aws_region, + aws_profile=aws_profile, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_credentials_provider=aws_credentials_provider, + base_url=base_url, + ) + else: + public_api_key = ( + _state.explicit_api_key + or (_state.environment_bearer_token if _state.uses_environment_bearer else "") + or "" ) - - auth_config, aws_auth, api_key, resolved_region = _resolve_bedrock_auth( - api_key=api_key, - token_provider=bedrock_token_provider, - aws_region=aws_region, - aws_profile=aws_profile, - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_session_token=aws_session_token, - aws_credentials_provider=aws_credentials_provider, - auth_config=_auth_config, - enforce_credentials=_enforce_credentials, - ) - - self._bedrock_token_provider = bedrock_token_provider - self._bedrock_auth_config = auth_config - self._bedrock_aws_auth = aws_auth - self._uses_region_derived_base_url = ( - _uses_region_derived_bedrock_base_url(base_url) - if _base_url_is_region_derived is None - else _base_url_is_region_derived - ) - self.aws_region = resolved_region super().__init__( - api_key=( - _async_bedrock_token_provider(bedrock_token_provider) - if bedrock_token_provider is not None - else api_key or "" - ), - admin_api_key="", + provider=_provider, organization=organization, project=project, webhook_secret=webhook_secret, - base_url=_resolve_bedrock_base_url( - base_url, - resolved_region, - use_environment=_base_url_is_region_derived is not True, - ), websocket_base_url=websocket_base_url, timeout=timeout, max_retries=max_retries, @@ -752,101 +665,53 @@ def __init__( _enforce_credentials=False, ) + self._bedrock_provider = _provider + self._bedrock_state = _state + self._bedrock_token_provider = cast("AsyncBedrockTokenProvider | None", _state.token_provider) + self._uses_region_derived_base_url = _state.uses_region_derived_base_url + canonical_region = re.fullmatch(r"bedrock-mantle\.([a-z0-9-]+)\.api\.aws", self.base_url.host) + self.aws_region = _state.aws_region or (canonical_region.group(1) if canonical_region is not None else None) + self.api_key = public_api_key or "" + self._bedrock_runtime_signature = _legacy_runtime_signature(self, self._legacy_auth_configuration()) + + def _legacy_auth_configuration(self) -> _LegacyAuthConfiguration: + if self._bedrock_token_provider is not None: + return ("token_provider", self._bedrock_token_provider) + if ( + self._bedrock_state.explicit_api_key is not None + or self._bedrock_state.uses_environment_bearer + or self.api_key + ): + return ("bearer", self.api_key) + return ("aws", None) + def _uses_aws_auth(self) -> bool: return ( - isinstance(self._bedrock_auth_config, _BedrockAwsAuthConfig) + self._bedrock_state.explicit_api_key is None and not self.api_key - and self._api_key_provider is None + and self._bedrock_token_provider is None + and not self._bedrock_state.uses_environment_bearer ) @override - def _auth_headers(self, security: SecurityOptions) -> dict[str, str]: - if self._uses_aws_auth(): - return {} - - if security.get("bearer_auth", False) or security.get("admin_api_key_auth", False): - return self._bearer_auth - - return {} - - @override - def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None: - if self._uses_aws_auth(): - return - - super()._validate_headers(headers, custom_headers) + async def _refresh_api_key(self) -> str: + if self._bedrock_state.uses_environment_bearer: + captured = self._bedrock_state.environment_bearer_token or "" + return self.api_key if self.api_key and self.api_key != captured else captured + if self._bedrock_token_provider is not None: + token = cast(object, self._bedrock_token_provider()) + if inspect.isawaitable(token): + token = await token + if not isinstance(token, str) or not token: + raise ValueError("Expected `bedrock_token_provider` argument to return a non-empty string.") + return token + return self.api_key @override async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: - if self._uses_aws_auth(): - if options.follow_redirects: - raise OpenAIError( - "Bedrock SigV4 authentication does not support automatic redirects. " - "Send a new request to the redirect target so it can be signed again." - ) - options.follow_redirects = False - elif ( - self._api_key_provider is not None - and options.security.get("admin_api_key_auth", False) - and not options.security.get("bearer_auth", False) - ): - await self._refresh_api_key() - + _refresh_legacy_provider_runtime(self) return await super()._prepare_options(options) - @override - def _build_request(self, options: FinalRequestOptions, *, retries_taken: int = 0) -> httpx.Request: - request = super()._build_request(options, retries_taken=retries_taken) - if not self._uses_aws_auth(): - return request - - option_headers: Headers = options.headers if is_given(options.headers) else {} - request.extensions[_BEDROCK_AUTH_INTENT_EXTENSION] = _authorization_intent( - self._custom_headers, - option_headers, - ) - request.extensions[_BEDROCK_MAX_RETRIES_EXTENSION] = options.get_max_retries(self.max_retries) - return request - - @override - async def _prepare_request(self, request: httpx.Request) -> None: - if not self._uses_aws_auth(): - return - if self._bedrock_aws_auth is None: - assert isinstance(self._bedrock_auth_config, _BedrockAwsAuthConfig) - self._bedrock_aws_auth = await asyncify(_BedrockAwsAuth)(self._bedrock_auth_config) - - intent = request.extensions.get(_BEDROCK_AUTH_INTENT_EXTENSION, _BEDROCK_AUTH_INTENT_DEFAULT) - if intent == _BEDROCK_AUTH_INTENT_OMIT: - for header in _AWS_SIGNING_HEADERS: - request.headers.pop(header, None) - return - if intent == _BEDROCK_AUTH_INTENT_OVERRIDE or "Authorization" in request.headers: - return - if not _same_origin(request.url, self.base_url): - raise OpenAIError("Refusing to sign a Bedrock request for an origin other than the configured `base_url`.") - - signed_headers = await asyncify(self._bedrock_aws_auth.sign)( - method=request.method, - url=str(request.url), - headers=dict(request.headers), - body=_body_for_signing(request), - ) - request.headers.clear() - request.headers.update(signed_headers) - - @override - async def _send_request( - self, - request: httpx.Request, - *, - stream: bool, - **kwargs: Unpack[HttpxSendArgs], - ) -> httpx.Response: - if self._uses_aws_auth(): - kwargs["auth"] = httpx.Auth() - return await super()._send_request(request, stream=stream, **kwargs) - @override def copy( self, @@ -854,6 +719,7 @@ def copy( api_key: str | AsyncBedrockTokenProvider | None = None, admin_api_key: str | None = None, workload_identity: WorkloadIdentity | None = None, + provider: _Provider | None | NotGiven = NOT_GIVEN, bedrock_token_provider: AsyncBedrockTokenProvider | None = None, aws_region: str | None = None, aws_profile: str | None = None, @@ -876,109 +742,46 @@ def copy( _enforce_credentials: bool | None = None, _extra_kwargs: Mapping[str, Any] = {}, ) -> Self: - if default_headers is not None and set_default_headers is not None: - raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive") - - if default_query is not None and set_default_query is not None: - raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive") - if callable(api_key): raise OpenAIError("Pass refreshable Bedrock credentials via `bedrock_token_provider`, not `api_key`.") - + if not isinstance(provider, NotGiven): + raise OpenAIError("Configure `provider` on `AsyncOpenAI`, not on `AsyncBedrockOpenAI.with_options()`.") if admin_api_key is not None or workload_identity is not None: raise OpenAIError("AsyncBedrockOpenAI only supports Bedrock bearer token or AWS credential authentication.") - - if api_key is not None and bedrock_token_provider is not None: - raise OpenAIError( - "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " - "static AWS credentials, profile, or credential provider." - ) + if default_headers is not None and set_default_headers is not None: + raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive") + if default_query is not None and set_default_query is not None: + raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive") headers = self._custom_headers if default_headers is not None: headers = {**headers, **default_headers} elif set_default_headers is not None: headers = set_default_headers - params = self._custom_query if default_query is not None: params = {**params, **default_query} elif set_default_query is not None: params = set_default_query - aws_auth_override = _has_explicit_aws_auth( + provider_kwargs, inherited_provider, inherited_state = _copy_configuration( + self, + api_key=api_key, + token_provider=bedrock_token_provider, + aws_region=aws_region, aws_profile=aws_profile, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, aws_credentials_provider=aws_credentials_provider, + base_url=base_url, ) - if (api_key is not None or bedrock_token_provider is not None) and aws_auth_override: - raise OpenAIError( - "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " - "static AWS credentials, profile, or credential provider." - ) - auth_override = api_key is not None or bedrock_token_provider is not None or aws_auth_override - if api_key is not None or aws_auth_override: - next_token_provider = None - elif bedrock_token_provider is not None: - next_token_provider = bedrock_token_provider - else: - next_token_provider = self._bedrock_token_provider - - next_auth_config: _BedrockBearerAuthConfig | _BedrockAwsAuthConfig | None - if auth_override: - next_auth_config = None - elif isinstance(self._bedrock_auth_config, _BedrockAwsAuthConfig) and self.api_key: - next_auth_config = None - elif aws_region is not None: - if isinstance(self._bedrock_auth_config, _BedrockAwsAuthConfig): - next_auth_config = replace( - self._bedrock_auth_config, - region=_resolve_aws_region(aws_region), - region_source="explicit", - ) - else: - next_auth_config = replace(self._bedrock_auth_config, region_source="explicit") - else: - next_auth_config = self._bedrock_auth_config - - next_aws_region = aws_region if aws_region is not None else self.aws_region - if aws_profile is not None and aws_region is None and self._bedrock_auth_config.region_source != "explicit": - next_aws_region = None - - next_api_key = api_key - if next_api_key is None and next_token_provider is None: - next_api_key = ( - None if aws_auth_override or isinstance(next_auth_config, _BedrockAwsAuthConfig) else self.api_key - ) - - blank_base_url_override = isinstance(base_url, str) and not base_url.strip() - next_base_url = None if blank_base_url_override else base_url - next_base_url_is_region_derived = False - recompute_region_base_url = self._uses_region_derived_base_url and ( - aws_region is not None or (aws_profile is not None and next_aws_region is None) - ) - if blank_base_url_override: - next_base_url_is_region_derived = _uses_region_derived_bedrock_base_url(None) - elif next_base_url is None and not recompute_region_base_url: - next_base_url = self.base_url - next_base_url_is_region_derived = self._uses_region_derived_base_url - elif next_base_url is None and next_aws_region is not None: - next_base_url = f"https://bedrock-mantle.{next_aws_region}.api.aws/openai/v1" - next_base_url_is_region_derived = True - elif next_base_url is None: - next_base_url_is_region_derived = True - constructor_kwargs: dict[str, Any] = { - "api_key": next_api_key, - "bedrock_token_provider": next_token_provider, - "aws_region": next_aws_region, + **provider_kwargs, "organization": organization if organization is not None else self.organization, "project": project if project is not None else self.project, "webhook_secret": webhook_secret if webhook_secret is not None else self.webhook_secret, "websocket_base_url": websocket_base_url if websocket_base_url is not None else self.websocket_base_url, - "base_url": next_base_url, "timeout": self.timeout if isinstance(timeout, NotGiven) else timeout, "http_client": http_client or self._client, "max_retries": max_retries if is_given(max_retries) else self.max_retries, @@ -987,36 +790,41 @@ def copy( "_enforce_credentials": True if _enforce_credentials is None else _enforce_credentials, **_extra_kwargs, } - aws_overrides = { - "aws_profile": aws_profile, - "aws_access_key_id": aws_access_key_id, - "aws_secret_access_key": aws_secret_access_key, - "aws_session_token": aws_session_token, - "aws_credentials_provider": aws_credentials_provider, - } - constructor_kwargs.update({name: value for name, value in aws_overrides.items() if value is not None}) - - supports_auth_config = _constructor_accepts_keyword(self.__class__.__init__, "_auth_config") - supports_base_url_provenance = _constructor_accepts_keyword( - self.__class__.__init__, "_base_url_is_region_derived" - ) - if supports_auth_config: - constructor_kwargs["_auth_config"] = next_auth_config - if supports_base_url_provenance: - constructor_kwargs["_base_url_is_region_derived"] = next_base_url_is_region_derived - - copied = self.__class__(**constructor_kwargs) - if not supports_auth_config and next_auth_config is not None: - copied._bedrock_auth_config = next_auth_config - if isinstance(next_auth_config, _BedrockAwsAuthConfig): - copied._bedrock_aws_auth = _BedrockAwsAuth(next_auth_config) - copied._bedrock_token_provider = None - copied.api_key = "" - copied._api_key_provider = None - copied.aws_region = next_auth_config.region - if not supports_base_url_provenance: - copied._uses_region_derived_base_url = next_base_url_is_region_derived - - return copied + if inherited_provider is not None and _constructor_accepts_keyword(self.__class__.__init__, "_provider"): + constructor_kwargs["_provider"] = inherited_provider + constructor_kwargs["_state"] = inherited_state + elif inherited_provider is not None: + constructor_kwargs.update( + api_key=self._bedrock_state.explicit_api_key or self._bedrock_state.environment_bearer_token, + bedrock_token_provider=self._bedrock_state.token_provider, + aws_region=self._bedrock_state.aws_region, + aws_profile=self._bedrock_state.aws_profile, + aws_access_key_id=self._bedrock_state.aws_access_key_id, + aws_secret_access_key=self._bedrock_state.aws_secret_access_key, + aws_session_token=self._bedrock_state.aws_session_token, + aws_credentials_provider=self._bedrock_state.aws_credentials_provider, + base_url="" if self._bedrock_state.uses_region_derived_base_url else self.base_url, + ) + constructor_kwargs = { + name: value + for name, value in constructor_kwargs.items() + if _constructor_accepts_keyword(self.__class__.__init__, name) + } + elif self.__class__ is not AsyncBedrockOpenAI: + constructor_kwargs = { + name: value + for name, value in constructor_kwargs.items() + if value is not None or _constructor_accepts_keyword(self.__class__.__init__, name) + } + return self.__class__(**constructor_kwargs) with_options = copy + + +__all__ = [ + "BedrockOpenAI", + "AsyncBedrockOpenAI", + "BedrockTokenProvider", + "AsyncBedrockTokenProvider", + "AwsCredentialsProvider", +] diff --git a/src/openai/providers/__init__.py b/src/openai/providers/__init__.py new file mode 100644 index 0000000000..bb5bcdbd9e --- /dev/null +++ b/src/openai/providers/__init__.py @@ -0,0 +1,3 @@ +from .bedrock import bedrock as bedrock + +__all__ = ["bedrock"] diff --git a/src/openai/providers/bedrock.py b/src/openai/providers/bedrock.py new file mode 100644 index 0000000000..e5cc5268ab --- /dev/null +++ b/src/openai/providers/bedrock.py @@ -0,0 +1,404 @@ +from __future__ import annotations + +import os +import re +import inspect +from typing import Literal, Callable, Awaitable, cast +from dataclasses import field, dataclass + +import httpx + +from .._types import NOT_GIVEN, NotGiven +from .._utils import asyncify +from .._models import FinalRequestOptions +from .._provider import _Provider, _create_provider, _ProviderRuntime +from .._exceptions import OpenAIError +from ..lib._bedrock_auth import ( + BedrockAwsAuth, + BedrockAwsAuthConfig, + AwsCredentialsProvider, +) + +BedrockTokenProvider = Callable[[], "str | Awaitable[str]"] + +_AWS_SIGNING_HEADERS = ("authorization", "x-amz-content-sha256", "x-amz-date", "x-amz-security-token") +_CANONICAL_BEDROCK_HOST = re.compile(r"^bedrock-mantle\.([a-z0-9-]+)\.api\.aws$", re.IGNORECASE) + + +def _normalize_optional_string(value: str | None) -> str | None: + if value is None: + return None + + normalized = value.strip() + return normalized or None + + +def _normalize_base_url(base_url: str | httpx.URL) -> httpx.URL: + url = httpx.URL(base_url) + path = url.path.rstrip("/") + responses_match = re.search(r"/responses(?:/.*)?$", path) + if responses_match is not None: + path = path[: responses_match.start()] + + return url.copy_with(path=path or "/") + + +def _same_origin(left: httpx.URL, right: httpx.URL) -> bool: + return (left.scheme, left.host, left.port) == (right.scheme, right.host, right.port) + + +def _body_for_signing(request: httpx.Request) -> bytes: + try: + return request.content + except httpx.RequestNotRead as exc: + raise OpenAIError( + "Bedrock SigV4 authentication requires a replayable request body. " + "Buffer the body before sending or use bearer authentication." + ) from exc + + +def _assert_provider_owns_authorization(request: httpx.Request) -> None: + if "Authorization" in request.headers: + raise OpenAIError("Bedrock provider authentication cannot be combined with a custom `Authorization` header.") + + +def _without_redirects(options: FinalRequestOptions) -> FinalRequestOptions: + if options.follow_redirects: + raise OpenAIError( + "Bedrock SigV4 authentication does not support automatic redirects. " + "Send a new request to the redirect target so it can be signed again." + ) + options.follow_redirects = False + return options + + +class _BedrockBearerAuth: + def __init__(self, token_provider: BedrockTokenProvider, *, base_url: httpx.URL) -> None: + self._token_provider = token_provider + self._base_url = base_url + + def _validate_request(self, request: httpx.Request) -> None: + _assert_provider_owns_authorization(request) + if not _same_origin(request.url, self._base_url): + raise OpenAIError( + "Refusing to authenticate a Bedrock request for an origin other than the configured provider URL." + ) + + def _resolve_token(self) -> str: + try: + token = cast(object, self._token_provider()) + except OpenAIError: + raise + except Exception as exc: + raise OpenAIError("Failed to resolve a bearer credential for Bedrock.") from exc + + if inspect.isawaitable(token): + close = getattr(token, "close", None) + if callable(close): + close() + raise OpenAIError("An async Bedrock token provider requires `AsyncOpenAI`.") + if not isinstance(token, str) or not token.strip(): + raise OpenAIError("The Bedrock bearer credential provider must return a non-empty string.") + return token + + async def _resolve_token_async(self) -> str: + try: + token = cast(object, self._token_provider()) + if inspect.isawaitable(token): + token = await token + except OpenAIError: + raise + except Exception as exc: + raise OpenAIError("Failed to resolve a bearer credential for Bedrock.") from exc + + if not isinstance(token, str) or not token.strip(): + raise OpenAIError("The Bedrock bearer credential provider must return a non-empty string.") + return token + + def prepare_request(self, request: httpx.Request) -> None: + self._validate_request(request) + request.headers["Authorization"] = f"Bearer {self._resolve_token()}" + + async def prepare_async_request(self, request: httpx.Request) -> None: + self._validate_request(request) + request.headers["Authorization"] = f"Bearer {await self._resolve_token_async()}" + + +class _BedrockSigV4Auth: + def __init__( + self, + *, + config: BedrockAwsAuthConfig, + base_url: httpx.URL, + auth: BedrockAwsAuth | None = None, + ) -> None: + self._config = config + self._base_url = base_url + self._auth = auth + + def _validate_request(self, request: httpx.Request) -> bytes: + _assert_provider_owns_authorization(request) + if not _same_origin(request.url, self._base_url): + raise OpenAIError( + "Refusing to sign a Bedrock request for an origin other than the configured provider URL." + ) + + endpoint_region_match = _CANONICAL_BEDROCK_HOST.fullmatch(request.url.host) + if endpoint_region_match is not None and endpoint_region_match.group(1) != self._config.region: + raise OpenAIError( + f"The Bedrock endpoint region `{endpoint_region_match.group(1)}` does not match the " + f"SigV4 region `{self._config.region}`." + ) + + return _body_for_signing(request) + + def _sign(self, request: httpx.Request, *, auth: BedrockAwsAuth, body: bytes) -> None: + for header in _AWS_SIGNING_HEADERS: + request.headers.pop(header, None) + + signed_headers = auth.sign( + method=request.method, + url=str(request.url), + headers=dict(request.headers), + body=body, + ) + request.headers.clear() + request.headers.update(signed_headers) + + def prepare_request(self, request: httpx.Request) -> None: + body = self._validate_request(request) + if self._auth is None: + self._auth = BedrockAwsAuth(self._config) + self._sign(request, auth=self._auth, body=body) + + async def prepare_async_request(self, request: httpx.Request) -> None: + body = self._validate_request(request) + if self._auth is None: + self._auth = await asyncify(BedrockAwsAuth)(self._config) + + signed_headers = await asyncify(self._auth.sign)( + method=request.method, + url=str(request.url), + headers={ + name: value for name, value in request.headers.items() if name.lower() not in _AWS_SIGNING_HEADERS + }, + body=body, + ) + request.headers.clear() + request.headers.update(signed_headers) + + +@dataclass(frozen=True) +class _BedrockProviderDefinition: + configured_region: str | None + region_source: Literal["explicit", "environment"] | None + configured_base_url: httpx.URL | None + api_key: str | None = field(default=None, repr=False) + token_provider: BedrockTokenProvider | None = field(default=None, repr=False, compare=False) + use_environment_bearer: bool = False + profile: str | None = None + access_key_id: str | None = field(default=None, repr=False) + secret_access_key: str | None = field(default=None, repr=False) + session_token: str | None = field(default=None, repr=False) + credential_provider: AwsCredentialsProvider | None = field(default=None, repr=False, compare=False) + name: str = field(default="bedrock", init=False) + + def _aws_source(self) -> Literal["static", "profile", "provider", "default"]: + if self.access_key_id is not None: + return "static" + if self.profile is not None: + return "profile" + if self.credential_provider is not None: + return "provider" + return "default" + + def _resolve_aws_auth(self) -> tuple[BedrockAwsAuthConfig, BedrockAwsAuth | None]: + if self.configured_region is not None: + return ( + BedrockAwsAuthConfig( + region=self.configured_region, + source=self._aws_source(), + region_source=self.region_source or "explicit", + profile=self.profile, + access_key_id=self.access_key_id, + secret_access_key=self.secret_access_key, + session_token=self.session_token, + credentials_provider=self.credential_provider, + ), + None, + ) + + auth = BedrockAwsAuth.resolve( + region=None, + profile=self.profile, + access_key_id=self.access_key_id, + secret_access_key=self.secret_access_key, + session_token=self.session_token, + credentials_provider=self.credential_provider, + ) + return auth.config, auth + + def configure(self) -> _ProviderRuntime: + def environment_token() -> str: + token = os.environ.get("AWS_BEARER_TOKEN_BEDROCK") + if not token: + raise OpenAIError( + "Could not find credentials for Bedrock. Pass a bearer credential or AWS credentials to " + "`bedrock(...)`, set `AWS_BEARER_TOKEN_BEDROCK`, or configure the default AWS credential chain." + ) + return token + + auth: _BedrockBearerAuth | _BedrockSigV4Auth | None = None + bearer_provider: BedrockTokenProvider | None = None + if self.api_key is not None: + bearer_provider = lambda: self.api_key or "" + region = self.configured_region + elif self.token_provider is not None: + bearer_provider = self.token_provider + region = self.configured_region + elif self.use_environment_bearer: + bearer_provider = environment_token + region = self.configured_region + else: + aws_config, aws_auth = self._resolve_aws_auth() + region = aws_config.region + base_url = self.configured_base_url or _normalize_base_url( + f"https://bedrock-mantle.{region}.api.aws/openai/v1" + ) + auth = _BedrockSigV4Auth(config=aws_config, base_url=base_url, auth=aws_auth) + + if self.configured_base_url is not None: + base_url = self.configured_base_url + elif region is not None: + base_url = _normalize_base_url(f"https://bedrock-mantle.{region}.api.aws/openai/v1") + else: + raise OpenAIError( + "Bedrock requires an AWS region. Pass `region` to `bedrock(...)`, or set `AWS_REGION` or " + "`AWS_DEFAULT_REGION`." + ) + + if bearer_provider is not None: + auth = _BedrockBearerAuth(bearer_provider, base_url=base_url) + + assert auth is not None + if isinstance(auth, _BedrockSigV4Auth): + return _ProviderRuntime( + name=self.name, + base_url=base_url, + transform_request=_without_redirects, + prepare_request=auth.prepare_request, + prepare_async_request=auth.prepare_async_request, + ) + + return _ProviderRuntime( + name=self.name, + base_url=base_url, + prepare_request=auth.prepare_request, + prepare_async_request=auth.prepare_async_request, + ) + + +def bedrock( + *, + region: str | None = None, + base_url: str | httpx.URL | None | NotGiven = NOT_GIVEN, + api_key: str | None | NotGiven = NOT_GIVEN, + token_provider: BedrockTokenProvider | None = None, + access_key_id: str | None = None, + secret_access_key: str | None = None, + session_token: str | None = None, + profile: str | None = None, + credential_provider: AwsCredentialsProvider | None = None, +) -> _Provider: + """Configure the standard OpenAI client for Amazon Bedrock Mantle.""" + + normalized_region = _normalize_optional_string(region) + if region is not None and normalized_region is None: + raise OpenAIError("The Bedrock AWS `region` must not be empty.") + + region_source: Literal["explicit", "environment"] | None = None + if normalized_region is not None: + region_source = "explicit" + else: + normalized_region = _normalize_optional_string( + os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") + ) + if normalized_region is not None: + region_source = "environment" + + configured_base_url: httpx.URL | None + if isinstance(base_url, NotGiven): + environment_base_url = _normalize_optional_string(os.environ.get("AWS_BEDROCK_BASE_URL")) + configured_base_url = _normalize_base_url(environment_base_url) if environment_base_url else None + elif base_url is None: + configured_base_url = None + else: + if isinstance(base_url, str) and not base_url.strip(): + raise OpenAIError("The Bedrock `base_url` must not be empty.") + configured_base_url = _normalize_base_url(base_url) + + normalized_profile = _normalize_optional_string(profile) + if profile is not None and normalized_profile is None: + raise OpenAIError("The Bedrock AWS `profile` must not be empty.") + + if (access_key_id is None) != (secret_access_key is None) or (session_token is not None and access_key_id is None): + raise OpenAIError( + "Static AWS credentials require both `access_key_id` and `secret_access_key`. " + "A `session_token` may only be used with both." + ) + if access_key_id is not None and (not access_key_id.strip() or not cast(str, secret_access_key).strip()): + raise OpenAIError("Static AWS credentials require non-empty `access_key_id` and `secret_access_key` values.") + if session_token is not None and not session_token.strip(): + raise OpenAIError("A static AWS `session_token` must not be empty when provided.") + + explicit_api_key = not isinstance(api_key, NotGiven) and api_key is not None + if explicit_api_key and (not isinstance(api_key, str) or not api_key.strip()): + raise OpenAIError("The Bedrock bearer credential must not be empty.") + if explicit_api_key and token_provider is not None: + raise OpenAIError("The `api_key` and `token_provider` options are mutually exclusive. Configure only one.") + + explicit_bearer = explicit_api_key or token_provider is not None + aws_modes = sum( + ( + access_key_id is not None, + normalized_profile is not None, + credential_provider is not None, + ) + ) + if aws_modes > 1: + raise OpenAIError( + "Bedrock authentication is ambiguous. Configure exactly one explicit AWS mode: static credentials, " + "profile, or credential provider." + ) + if explicit_bearer and aws_modes: + raise OpenAIError( + "Bedrock authentication is ambiguous. Configure exactly one explicit mode: bearer credential, " + "static AWS credentials, profile, or credential provider." + ) + + skip_environment_bearer = not isinstance(api_key, NotGiven) and api_key is None + use_environment_bearer = ( + not explicit_bearer + and not aws_modes + and not skip_environment_bearer + and bool(os.environ.get("AWS_BEARER_TOKEN_BEDROCK")) + ) + + return _create_provider( + _BedrockProviderDefinition( + configured_region=normalized_region, + region_source=region_source, + configured_base_url=configured_base_url, + api_key=cast("str | None", api_key) if explicit_api_key else None, + token_provider=token_provider, + use_environment_bearer=use_environment_bearer, + profile=normalized_profile, + access_key_id=access_key_id, + secret_access_key=secret_access_key, + session_token=session_token, + credential_provider=credential_provider, + ) + ) + + +__all__ = ["bedrock", "BedrockTokenProvider", "AwsCredentialsProvider"] diff --git a/tests/fixtures/bedrock/v1/sigv4.json b/tests/fixtures/bedrock/v1/sigv4.json new file mode 100644 index 0000000000..e0e552ae1c --- /dev/null +++ b/tests/fixtures/bedrock/v1/sigv4.json @@ -0,0 +1,22 @@ +{ + "signingDate": "2025-01-02T03:04:05.000Z", + "region": "us-east-1", + "service": "bedrock-mantle", + "credentials": { + "accessKeyId": "AKIDEXAMPLE", + "secretAccessKey": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "sessionToken": "fixture-session-token" + }, + "request": { + "method": "POST", + "url": "https://bedrock-mantle.us-east-1.api.aws/openai/v1/responses", + "body": "{\"model\":\"gpt-4o\",\"input\":\"hello\"}", + "contentType": "application/json" + }, + "expected": { + "date": "20250102T030405Z", + "payloadHash": "50329e51ad520f21b77bad0b01999930ff556cd1bf18434701251ba6c9f877bc", + "canonicalRequestHash": "1b69b17ef7548a7bf16a6ee749acfd44b7793e04216345dddd6cfaf4c01bfde5", + "authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20250102/us-east-1/bedrock-mantle/aws4_request, SignedHeaders=content-type;host;x-amz-content-sha256;x-amz-date;x-amz-security-token, Signature=dc20c8fbe516daf0ccae7e5b7a78fc2936870413a9f1af6e1b2d44b970ce411f" + } +} diff --git a/tests/fixtures/bedrock_auth/v1/cases.json b/tests/fixtures/bedrock_auth/v1/cases.json index a53e68165a..2d05c8827e 100644 --- a/tests/fixtures/bedrock_auth/v1/cases.json +++ b/tests/fixtures/bedrock_auth/v1/cases.json @@ -6,11 +6,18 @@ "id": "auth.explicit-bearer", "kind": "auth_selection", "given": { - "explicit": { "bearer": "explicit-bearer-token" }, - "environment": { "AWS_BEARER_TOKEN_BEDROCK": "environment-bearer-token" }, + "explicit": { + "bearer": "explicit-bearer-token" + }, + "environment": { + "AWS_BEARER_TOKEN_BEDROCK": "environment-bearer-token" + }, "default_chain_available": true }, - "expected": { "auth_mode": "bearer", "auth_source": "explicit" } + "expected": { + "auth_mode": "bearer", + "auth_source": "explicit" + } }, { "id": "auth.explicit-aws-over-environment-bearer", @@ -22,20 +29,30 @@ "secret_access_key": "fixture-secret-access-key" } }, - "environment": { "AWS_BEARER_TOKEN_BEDROCK": "environment-bearer-token" }, + "environment": { + "AWS_BEARER_TOKEN_BEDROCK": "environment-bearer-token" + }, "default_chain_available": true }, - "expected": { "auth_mode": "sigv4", "auth_source": "static" } + "expected": { + "auth_mode": "sigv4", + "auth_source": "static" + } }, { "id": "auth.environment-bearer-over-default-chain", "kind": "auth_selection", "given": { "explicit": {}, - "environment": { "AWS_BEARER_TOKEN_BEDROCK": "environment-bearer-token" }, + "environment": { + "AWS_BEARER_TOKEN_BEDROCK": "environment-bearer-token" + }, "default_chain_available": true }, - "expected": { "auth_mode": "bearer", "auth_source": "environment" } + "expected": { + "auth_mode": "bearer", + "auth_source": "environment" + } }, { "id": "auth.default-chain", @@ -45,17 +62,25 @@ "environment": {}, "default_chain_available": true }, - "expected": { "auth_mode": "sigv4", "auth_source": "default" } + "expected": { + "auth_mode": "sigv4", + "auth_source": "default" + } }, { "id": "auth.empty-environment-bearer-is-absent", "kind": "auth_selection", "given": { "explicit": {}, - "environment": { "AWS_BEARER_TOKEN_BEDROCK": "" }, + "environment": { + "AWS_BEARER_TOKEN_BEDROCK": "" + }, "default_chain_available": true }, - "expected": { "auth_mode": "sigv4", "auth_source": "default" } + "expected": { + "auth_mode": "sigv4", + "auth_source": "default" + } }, { "id": "auth.conflicting-explicit-modes", @@ -71,7 +96,9 @@ "environment": {}, "default_chain_available": true }, - "expected": { "error": "bedrock_conflicting_auth" } + "expected": { + "error": "bedrock_conflicting_auth" + } }, { "id": "sigv4.responses.static-credentials", @@ -99,10 +126,11 @@ }, "expected": { "payload_sha256": "9d161ca213ef31dd5b8f01f52294bcd5057bc778cf6ee7719e67b7ff43ef1022", - "canonical_request_sha256": "2941daa7f544cd6d05e5c14615ca3ed4fe206a230214a71971092254690c0f1c", + "canonical_request_sha256": "c8da8b3104c310e27f93e2b58560f59a9397cbca97d954f50654e9305dc1c9ab", "headers": { + "x-amz-content-sha256": "9d161ca213ef31dd5b8f01f52294bcd5057bc778cf6ee7719e67b7ff43ef1022", "x-amz-date": "20260601T123456Z", - "authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20260601/us-east-1/bedrock-mantle/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=85b44442b454238644a1605b04febb4ecf96d2c5a7698db21ce563c0a5646cb6" + "authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20260601/us-east-1/bedrock-mantle/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-content-sha256;x-amz-date, Signature=bba77a80c7b39301d1aa29db9bc2f2b349deb932ae2dc5f800cec17baaf0f5de" } } }, @@ -133,44 +161,12 @@ }, "expected": { "payload_sha256": "9d161ca213ef31dd5b8f01f52294bcd5057bc778cf6ee7719e67b7ff43ef1022", - "canonical_request_sha256": "f60762d083c88cc0956eb3e9d3b3a966bfe3f396efd49bc1ef7c8922395bcecb", + "canonical_request_sha256": "fe294f3ccfa69899816c7e59516dbfae6bf89959da4d64296aa7e0fcb1f9d130", "headers": { + "x-amz-content-sha256": "9d161ca213ef31dd5b8f01f52294bcd5057bc778cf6ee7719e67b7ff43ef1022", "x-amz-date": "20260601T123456Z", "x-amz-security-token": "fixture-session-token", - "authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20260601/us-east-1/bedrock-mantle/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-security-token, Signature=d17a0ba1525dfe52b63f163b4a0ac1109723905e5ac9711bcdb6c35befdbaac2" - } - } - }, - { - "id": "sigv4.responses.unsigned-payload", - "kind": "sigv4", - "given": { - "credentials": { - "access_key_id": "AKIDEXAMPLE", - "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" - }, - "signing": { - "service": "bedrock-mantle", - "region": "us-east-1", - "timestamp": "2026-06-01T12:34:56Z" - }, - "request": { - "method": "POST", - "url": "https://bedrock-mantle.us-east-1.api.aws/openai/v1/responses", - "headers": { - "content-type": "application/octet-stream", - "host": "bedrock-mantle.us-east-1.api.aws" - }, - "body_mode": "unsigned" - } - }, - "expected": { - "payload_sha256": "UNSIGNED-PAYLOAD", - "canonical_request_sha256": "9d9a089a0d274e69db40db264ea8ec6f9f6afaf03cf779121ec8dfa55278337d", - "headers": { - "x-amz-content-sha256": "UNSIGNED-PAYLOAD", - "x-amz-date": "20260601T123456Z", - "authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20260601/us-east-1/bedrock-mantle/aws4_request, SignedHeaders=content-type;host;x-amz-content-sha256;x-amz-date, Signature=f0d51439dfe33967eb590c0e33bc26fd7f8103d9345f5ac613ca299e90b555c8" + "authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20260601/us-east-1/bedrock-mantle/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-content-sha256;x-amz-date;x-amz-security-token, Signature=693a4fef0373c0929f1f71d0da29c748f6bf42e75059648ceb97d52a204fb37c" } } }, @@ -178,15 +174,27 @@ "id": "retry.fresh-credentials-and-time", "kind": "retry_signing", "given": { - "response_statuses": [500, 200], - "timestamps": ["2026-06-01T12:34:56Z", "2026-06-01T12:35:01Z"], - "access_key_ids": ["FIRSTACCESSKEY", "SECONDACCESSKEY"], + "response_statuses": [ + 500, + 200 + ], + "timestamps": [ + "2026-06-01T12:34:56Z", + "2026-06-01T12:35:01Z" + ], + "access_key_ids": [ + "FIRSTACCESSKEY", + "SECONDACCESSKEY" + ], "body_base64": "eyJpbnB1dCI6ImhlbGxvIiwibW9kZWwiOiJvcGVuYWkuZ3B0LW9zcy0xMjBiIn0=" }, "expected": { "attempts": 2, "credential_provider_calls": 2, - "x_amz_dates": ["20260601T123456Z", "20260601T123501Z"], + "x_amz_dates": [ + "20260601T123456Z", + "20260601T123501Z" + ], "body_sha256": [ "9d161ca213ef31dd5b8f01f52294bcd5057bc778cf6ee7719e67b7ff43ef1022", "9d161ca213ef31dd5b8f01f52294bcd5057bc778cf6ee7719e67b7ff43ef1022" @@ -199,34 +207,46 @@ "given": { "body_kind": "bytes", "body_base64": "eyJpbnB1dCI6ImhlbGxvIiwibW9kZWwiOiJvcGVuYWkuZ3B0LW9zcy0xMjBiIn0=", - "response_statuses": [500, 200] + "response_statuses": [ + 500, + 200 + ] }, - "expected": { "attempts": 2, "result": "replayed" } + "expected": { + "attempts": 2, + "result": "replayed" + } }, { "id": "body.non-replayable-stream-with-retries", "kind": "body_replay", "given": { "body_kind": "one_shot_stream", - "chunks_base64": ["Zmlyc3Q=", "c2Vjb25k"], + "chunks_base64": [ + "Zmlyc3Q=", + "c2Vjb25k" + ], "max_retries": 1 }, - "expected": { "network_attempts": 0, "result": "bedrock_non_replayable_body" } + "expected": { + "network_attempts": 0, + "result": "bedrock_non_replayable_body" + } }, { - "id": "body.unsigned-one-shot-stream", + "id": "body.non-replayable-stream-without-retries", "kind": "body_replay", "given": { "body_kind": "one_shot_stream", - "chunks_base64": ["Zmlyc3Q=", "c2Vjb25k"], + "chunks_base64": [ + "Zmlyc3Q=", + "c2Vjb25k" + ], "max_retries": 0 }, "expected": { - "attempts": 1, - "credential_provider_calls": 1, - "body_reads": 1, - "x_amz_content_sha256": "UNSIGNED-PAYLOAD", - "result": "unsigned_payload" + "network_attempts": 0, + "result": "bedrock_non_replayable_body" } }, { @@ -245,7 +265,10 @@ "x-amz-security-token": "", "x-request-id": "req_fixture" }, - "forbidden_substrings": ["fixture-authorization", "fixture-session-token"] + "forbidden_substrings": [ + "fixture-authorization", + "fixture-session-token" + ] } } ] diff --git a/tests/lib/test_bedrock.py b/tests/lib/test_bedrock.py index 87ab50b218..244dfab79f 100644 --- a/tests/lib/test_bedrock.py +++ b/tests/lib/test_bedrock.py @@ -14,7 +14,6 @@ from tests.utils import update_env from openai._types import Omit from openai.lib.bedrock import BedrockOpenAI, AsyncBedrockOpenAI -from openai.lib._bedrock_auth import BedrockAwsAuthConfig Client = Union[BedrockOpenAI, AsyncBedrockOpenAI] @@ -168,8 +167,9 @@ def import_module(name: str) -> Any: monkeypatch.setattr(bedrock_auth_module.importlib, "import_module", import_module) with update_env(AWS_BEARER_TOKEN_BEDROCK="", AWS_REGION="us-east-1"): + client = make_sync_client() with pytest.raises(OpenAIError, match="requires optional AWS dependencies"): - BedrockOpenAI() + client.get("/models", cast_to=httpx.Response) @pytest.mark.respx() @@ -203,7 +203,7 @@ def test_empty_env_bearer_falls_back_to_aws_credentials(client_cls: type[Client] client = make_sync_client() if client_cls is BedrockOpenAI else make_async_client() assert client.api_key == "" - assert client._bedrock_aws_auth is not None + assert client._uses_aws_auth() @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) @@ -277,7 +277,7 @@ def test_does_not_use_openai_api_key(client_cls: type[Client]) -> None: client = make_sync_client() if client_cls is BedrockOpenAI else make_async_client() assert client.api_key == "" - assert client._bedrock_aws_auth is not None + assert client._uses_aws_auth() @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) @@ -438,6 +438,214 @@ def test_preserves_token_provider_across_with_options() -> None: assert copied_client._refresh_api_key() == "provider token" +def test_preserves_environment_bearer_across_with_options() -> None: + requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(200, request=request, json={}) + + with update_env(AWS_BEARER_TOKEN_BEDROCK="first token"): + client = BedrockOpenAI( + base_url="https://example.com/openai/v1", + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) + + with update_env(AWS_BEARER_TOKEN_BEDROCK="second token"): + copied_client = client.with_options(timeout=1) + copied_client.get("/models", cast_to=httpx.Response) + + assert copied_client.api_key == "first token" + assert requests[0].headers["Authorization"] == "Bearer first token" + + +def test_environment_bearer_routing_copy_remains_mutable() -> None: + requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(200, request=request, json={}) + + with update_env(AWS_BEARER_TOKEN_BEDROCK="first token"): + client = BedrockOpenAI( + aws_region="us-east-1", + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) + copied_client = client.with_options(aws_region="us-west-2") + copied_client.api_key = "second token" + copied_client.get("/models", cast_to=httpx.Response) + + assert copied_client.api_key == "second token" + assert requests[0].headers["Authorization"] == "Bearer second token" + + +def test_legacy_api_key_mutation_updates_requests_and_copies() -> None: + requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(200, request=request, json={}) + + client = BedrockOpenAI( + base_url="https://example.com/openai/v1", + api_key="first token", + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) + client.api_key = "second token" + client.get("/models", cast_to=httpx.Response) + copied_client = client.with_options(timeout=1) + copied_client.get("/models", cast_to=httpx.Response) + client.api_key = "first token" + reverted_client = client.with_options(timeout=2) + reverted_client.get("/models", cast_to=httpx.Response) + + assert copied_client.api_key == "second token" + assert reverted_client.api_key == "first token" + assert [request.headers["Authorization"] for request in requests] == [ + "Bearer second token", + "Bearer second token", + "Bearer first token", + ] + + +def test_legacy_api_key_mutation_switches_aws_client_to_bearer() -> None: + requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(200, request=request, json={}) + + client = BedrockOpenAI( + base_url="https://example.com/openai/v1", + aws_region="us-east-1", + aws_access_key_id="access key", + aws_secret_access_key="secret key", + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) + client.api_key = "bearer token" + client.get("/models", cast_to=httpx.Response, options={"follow_redirects": True}) + + assert requests[0].headers["Authorization"] == "Bearer bearer token" + + +def test_explicit_aws_copy_override_wins_over_mutated_api_key() -> None: + client = BedrockOpenAI( + base_url="https://example.com/openai/v1", + api_key="first token", + http_client=httpx.Client(trust_env=False), + ) + client.api_key = "second token" + + copied_client = client.with_options( + aws_region="us-east-1", + aws_access_key_id="access key", + aws_secret_access_key="secret key", + ) + + assert copied_client._uses_aws_auth() + assert copied_client.api_key == "" + + +def test_clearing_legacy_bearer_does_not_switch_to_aws_authentication() -> None: + network_calls = 0 + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal network_calls + network_calls += 1 + return httpx.Response(200, request=request) + + with update_env(AWS_ACCESS_KEY_ID="access key", AWS_SECRET_ACCESS_KEY="secret key"): + client = BedrockOpenAI( + base_url="https://example.com/openai/v1", + api_key="bearer token", + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) + client.api_key = None # type: ignore[assignment] + with pytest.raises(OpenAIError, match="bearer credential must not be empty"): + client.get("/models", cast_to=httpx.Response) + + assert network_calls == 0 + + +def test_legacy_state_repr_does_not_expose_credentials() -> None: + client = BedrockOpenAI( + base_url="https://example.com/openai/v1", + aws_region="us-east-1", + aws_access_key_id="secret access key id", + aws_secret_access_key="secret access key", + aws_session_token="secret session token", + http_client=httpx.Client(trust_env=False), + ) + + assert "secret" not in repr(client._bedrock_state) + + bearer_client = BedrockOpenAI( + base_url="https://example.com/openai/v1", + api_key="secret bearer token", + http_client=httpx.Client(trust_env=False), + ) + assert "secret bearer token" not in repr(bearer_client._bedrock_runtime_signature) + + +@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) +def test_direct_routing_mutations_survive_clone(client_cls: type[Client]) -> None: + client = ( + make_sync_client(base_url="https://first.example/openai/v1", aws_region="us-east-1", api_key="token") + if client_cls is BedrockOpenAI + else make_async_client( + base_url="https://first.example/openai/v1", + aws_region="us-east-1", + api_key="token", + ) + ) + client.base_url = "https://second.example/openai/v1" + client.aws_region = "us-west-2" + + copied_client = client.with_options(timeout=1) + + assert copied_client.base_url == URL("https://second.example/openai/v1/") + assert copied_client.aws_region == "us-west-2" + assert not copied_client._bedrock_state.uses_region_derived_base_url + + +@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) +def test_direct_region_mutation_survives_clone(client_cls: type[Client]) -> None: + with update_env(AWS_BEDROCK_BASE_URL=Omit(), AWS_REGION=Omit(), AWS_DEFAULT_REGION=Omit()): + client = ( + make_sync_client(aws_region="us-east-1", api_key="token") + if client_cls is BedrockOpenAI + else make_async_client(aws_region="us-east-1", api_key="token") + ) + client.aws_region = "us-west-2" + copied_client = client.with_options(timeout=1) + + assert copied_client.aws_region == "us-west-2" + assert copied_client.base_url == URL("https://bedrock-mantle.us-west-2.api.aws/openai/v1/") + + +@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) +def test_direct_base_url_mutation_survives_auth_override(client_cls: type[Client]) -> None: + with update_env(AWS_BEDROCK_BASE_URL=Omit(), AWS_REGION=Omit(), AWS_DEFAULT_REGION=Omit()): + client = ( + make_sync_client( + aws_region="us-east-1", + aws_access_key_id="access key", + aws_secret_access_key="secret key", + ) + if client_cls is BedrockOpenAI + else make_async_client( + aws_region="us-east-1", + aws_access_key_id="access key", + aws_secret_access_key="secret key", + ) + ) + client.base_url = "https://custom.example/openai/v1" + copied_client = client.with_options(api_key="bearer token") + + assert copied_client.base_url == URL("https://custom.example/openai/v1/") + assert not copied_client._uses_aws_auth() + + def test_preserves_aws_credentials_across_with_options() -> None: client = BedrockOpenAI( base_url="https://example.com/openai/v1", @@ -449,9 +657,8 @@ def test_preserves_aws_credentials_across_with_options() -> None: copied_client = client.with_options(timeout=1) - assert copied_client._bedrock_aws_auth is not None - assert isinstance(copied_client._bedrock_auth_config, BedrockAwsAuthConfig) - assert copied_client._bedrock_auth_config.access_key_id == "access key" + assert copied_client._uses_aws_auth() + assert copied_client._bedrock_state.aws_access_key_id == "access key" @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) @@ -462,8 +669,10 @@ def test_preserves_default_chain_mode_across_with_options(client_cls: type[Clien with update_env(AWS_BEARER_TOKEN_BEDROCK="late bearer", AWS_REGION="us-east-1"): copied_client = client.with_options(timeout=1) - assert isinstance(copied_client._bedrock_auth_config, BedrockAwsAuthConfig) - assert copied_client._bedrock_auth_config.source == "default" + assert copied_client._uses_aws_auth() + assert copied_client._bedrock_state.aws_profile is None + assert copied_client._bedrock_state.aws_access_key_id is None + assert copied_client._bedrock_state.aws_credentials_provider is None assert copied_client.api_key == "" @@ -480,6 +689,22 @@ def test_preserves_region_derived_url_provenance_across_multiple_copies(client_c assert copied_client.base_url == URL("https://bedrock-mantle.eu-west-1.api.aws/openai/v1/") +@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) +def test_preserves_region_derived_url_after_auth_override(client_cls: type[Client]) -> None: + with update_env(AWS_BEDROCK_BASE_URL=Omit(), AWS_REGION=Omit(), AWS_DEFAULT_REGION=Omit()): + client = ( + make_sync_client(aws_region="us-east-1", api_key="token") + if client_cls is BedrockOpenAI + else make_async_client(aws_region="us-east-1", api_key="token") + ) + copied_client = client.with_options( + aws_access_key_id="access key", + aws_secret_access_key="secret key", + ).with_options(aws_region="us-west-2") + + assert copied_client.base_url == URL("https://bedrock-mantle.us-west-2.api.aws/openai/v1/") + + @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) def test_blank_base_url_restores_region_derived_url_provenance(client_cls: type[Client]) -> None: with update_env(AWS_BEDROCK_BASE_URL=Omit(), AWS_REGION=Omit(), AWS_DEFAULT_REGION=Omit()): @@ -515,20 +740,18 @@ def test_with_options_replaces_the_aws_credential_source(client_cls: type[Client with update_env(AWS_CONFIG_FILE=str(config_path)): profile_client = explicit_credentials_client.with_options(aws_profile="other-profile") - assert isinstance(profile_client._bedrock_auth_config, BedrockAwsAuthConfig) - assert profile_client._bedrock_auth_config.profile == "other-profile" - assert profile_client._bedrock_auth_config.access_key_id is None - assert profile_client._bedrock_auth_config.secret_access_key is None + assert profile_client._bedrock_state.aws_profile == "other-profile" + assert profile_client._bedrock_state.aws_access_key_id is None + assert profile_client._bedrock_state.aws_secret_access_key is None explicit_credentials_client = profile_client.with_options( aws_access_key_id="replacement access key", aws_secret_access_key="replacement secret key", ) - assert isinstance(explicit_credentials_client._bedrock_auth_config, BedrockAwsAuthConfig) - assert explicit_credentials_client._bedrock_auth_config.profile is None - assert explicit_credentials_client._bedrock_auth_config.access_key_id == "replacement access key" - assert explicit_credentials_client._bedrock_auth_config.secret_access_key == "replacement secret key" + assert explicit_credentials_client._bedrock_state.aws_profile is None + assert explicit_credentials_client._bedrock_state.aws_access_key_id == "replacement access key" + assert explicit_credentials_client._bedrock_state.aws_secret_access_key == "replacement secret key" @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) @@ -560,8 +783,49 @@ def test_with_options_replacing_profile_re_resolves_profile_region(client_cls: t assert copied_client.aws_region == "us-west-2" assert copied_client.base_url == URL("https://bedrock-mantle.us-west-2.api.aws/openai/v1/") - assert isinstance(copied_client._bedrock_auth_config, BedrockAwsAuthConfig) - assert copied_client._bedrock_auth_config.profile == "west" + assert copied_client._bedrock_state.aws_profile == "west" + + +@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) +@pytest.mark.parametrize( + ("copy_kwargs", "uses_aws_auth"), + [ + ({"api_key": "bearer token"}, False), + ( + { + "aws_access_key_id": "access key", + "aws_secret_access_key": "secret key", + }, + True, + ), + ], +) +def test_profile_derived_region_survives_auth_override( + client_cls: type[Client], + copy_kwargs: dict[str, Any], + uses_aws_auth: bool, + tmp_path: Path, +) -> None: + config_path = tmp_path / "config" + config_path.write_text("[profile west]\nregion = us-west-2\n") + + with update_env( + AWS_CONFIG_FILE=str(config_path), + AWS_BEARER_TOKEN_BEDROCK=Omit(), + AWS_BEDROCK_BASE_URL=Omit(), + AWS_REGION=Omit(), + AWS_DEFAULT_REGION=Omit(), + ): + client = ( + make_sync_client(aws_profile="west") + if client_cls is BedrockOpenAI + else make_async_client(aws_profile="west") + ) + copied_client = client.with_options(**copy_kwargs) + + assert copied_client.aws_region == "us-west-2" + assert copied_client.base_url == URL("https://bedrock-mantle.us-west-2.api.aws/openai/v1/") + assert copied_client._uses_aws_auth() is uses_aws_auth @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) @@ -637,10 +901,11 @@ def __init__( http_client=httpx.Client(trust_env=False), ) - copied_client = client.with_options(timeout=1) + copied_client = client.with_options(timeout=1).with_options(aws_region="us-west-2") assert isinstance(copied_client, LegacyBedrockOpenAI) assert copied_client.api_key == "token" + assert copied_client.base_url == URL("https://bedrock-mantle.us-west-2.api.aws/openai/v1/") @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) diff --git a/tests/lib/test_bedrock_auth_conformance.py b/tests/lib/test_bedrock_auth_conformance.py index b08f160c35..526de29b43 100644 --- a/tests/lib/test_bedrock_auth_conformance.py +++ b/tests/lib/test_bedrock_auth_conformance.py @@ -13,16 +13,18 @@ import pytest import jsonschema -from openai import OpenAIError, APIStatusError -from openai._types import Omit +from openai import OpenAI, AsyncOpenAI, OpenAIError, APIStatusError from openai._utils import SensitiveHeadersFilter -from openai.lib.bedrock import BedrockOpenAI, AsyncBedrockOpenAI -from openai.lib._bedrock_auth import BedrockAwsAuth, BedrockAwsAuthConfig, BedrockBearerAuthConfig +from openai.providers import bedrock +from openai.lib.bedrock import BedrockOpenAI +from openai.lib._bedrock_auth import BedrockAwsAuth, BedrockAwsAuthConfig FIXTURE_PATH = Path(__file__).parents[1] / "fixtures" / "bedrock_auth" / "v1" / "cases.json" SCHEMA_PATH = FIXTURE_PATH.with_name("schema.json") +SHARED_SIGV4_FIXTURE_PATH = Path(__file__).parents[1] / "fixtures" / "bedrock" / "v1" / "sigv4.json" FIXTURES = cast(dict[str, Any], json.loads(FIXTURE_PATH.read_text())) SCHEMA = cast(dict[str, Any], json.loads(SCHEMA_PATH.read_text())) +SHARED_SIGV4_FIXTURE = cast(dict[str, Any], json.loads(SHARED_SIGV4_FIXTURE_PATH.read_text())) def _cases(kind: str) -> list[dict[str, Any]]: @@ -56,6 +58,47 @@ def _canonical_request_sha256(case: dict[str, Any], signed_headers: dict[str, st return hashlib.sha256(canonical_request.encode()).hexdigest() +def test_shared_sigv4_fixture_matches_node(monkeypatch: pytest.MonkeyPatch) -> None: + fixture = SHARED_SIGV4_FIXTURE + credentials = fixture["credentials"] + request = fixture["request"] + body = request["body"].encode() + payload_hash = hashlib.sha256(body).hexdigest() + botocore_auth = pytest.importorskip("botocore.auth") + monkeypatch.setattr(botocore_auth, "get_current_datetime", lambda: _fixed_datetime(fixture["signingDate"])) + + auth = BedrockAwsAuth( + BedrockAwsAuthConfig( + region=fixture["region"], + source="static", + access_key_id=credentials["accessKeyId"], + secret_access_key=credentials["secretAccessKey"], + session_token=credentials["sessionToken"], + ) + ) + signed_headers = _lower_headers( + auth.sign( + method=request["method"], + url=request["url"], + headers={ + "content-type": request["contentType"], + "host": httpx.URL(request["url"]).host, + }, + body=body, + ) + ) + canonical_case = {"given": {"request": request}} + + assert fixture["service"] == "bedrock-mantle" + assert payload_hash == fixture["expected"]["payloadHash"] + assert ( + _canonical_request_sha256(canonical_case, signed_headers, payload_hash) + == fixture["expected"]["canonicalRequestHash"] + ) + assert signed_headers["authorization"] == fixture["expected"]["authorization"] + assert signed_headers["x-amz-date"] == fixture["expected"]["date"] + + @pytest.mark.parametrize("case", _cases("auth_selection"), ids=lambda case: case["id"]) def test_auth_selection_fixture(case: dict[str, Any], monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False) @@ -82,14 +125,23 @@ def test_auth_selection_fixture(case: dict[str, Any], monkeypatch: pytest.Monkey return with BedrockOpenAI(**kwargs) as client: - config = client._bedrock_auth_config - if isinstance(config, BedrockBearerAuthConfig): + state = client._bedrock_state + if not client._uses_aws_auth(): mode = "bearer" + source = "explicit" if state.explicit_api_key is not None else "environment" else: mode = "sigv4" + if state.aws_access_key_id is not None: + source = "static" + elif state.aws_profile is not None: + source = "profile" + elif state.aws_credentials_provider is not None: + source = "provider" + else: + source = "default" assert mode == case["expected"]["auth_mode"] - assert config.source == case["expected"]["auth_source"] + assert source == case["expected"]["auth_source"] @pytest.mark.parametrize("case", _cases("sigv4"), ids=lambda case: case["id"]) @@ -97,8 +149,8 @@ def test_sigv4_fixture(case: dict[str, Any], monkeypatch: pytest.MonkeyPatch) -> credentials = case["given"]["credentials"] signing = case["given"]["signing"] request = case["given"]["request"] - body = None if request.get("body_mode") == "unsigned" else base64.b64decode(request["body_base64"]) - payload_hash = "UNSIGNED-PAYLOAD" if body is None else hashlib.sha256(body).hexdigest() + body = base64.b64decode(request["body_base64"]) + payload_hash = hashlib.sha256(body).hexdigest() auth = BedrockAwsAuth( BedrockAwsAuthConfig( @@ -159,10 +211,12 @@ def handler(request: httpx.Request) -> httpx.Response: return httpx.Response(next(statuses), request=request, json={}) body = base64.b64decode(case["given"]["body_base64"]) - with BedrockOpenAI( - base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", - aws_region="us-east-1", - aws_credentials_provider=credentials_provider, + with OpenAI( + provider=bedrock( + base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", + region="us-east-1", + credential_provider=credentials_provider, + ), max_retries=case["given"].get("max_retries", 1), http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), ) as client: @@ -214,10 +268,12 @@ def handler(request: httpx.Request) -> httpx.Response: assert body_kind == "one_shot_stream" content = body() - with BedrockOpenAI( - base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", - aws_region="us-east-1", - aws_credentials_provider=credentials_provider, + with OpenAI( + provider=bedrock( + base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", + region="us-east-1", + credential_provider=credentials_provider, + ), max_retries=case["given"].get("max_retries", 1), http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), ) as client: @@ -231,11 +287,6 @@ def handler(request: httpx.Request) -> httpx.Response: if body_kind == "bytes": assert all(request.content == content for request in requests) assert provider_calls == case["expected"]["attempts"] - elif case["expected"]["result"] == "unsigned_payload": - assert body_reads == case["expected"]["body_reads"] - assert provider_calls == case["expected"]["credential_provider_calls"] - assert requests[0].headers["X-Amz-Content-SHA256"] == case["expected"]["x_amz_content_sha256"] - assert "x-amz-content-sha256" in requests[0].headers["Authorization"] else: assert (body_reads, provider_calls) == (0, 0) @@ -261,10 +312,12 @@ async def handler(request: httpx.Request) -> httpx.Response: network_calls += 1 return httpx.Response(200, request=request) - async with AsyncBedrockOpenAI( - base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", - aws_region="us-east-1", - aws_credentials_provider=credentials_provider, + async with AsyncOpenAI( + provider=bedrock( + base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", + region="us-east-1", + credential_provider=credentials_provider, + ), http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler), trust_env=False), ) as client: with pytest.raises(OpenAIError, match="requires a replayable request body"): @@ -273,33 +326,6 @@ async def handler(request: httpx.Request) -> httpx.Response: assert (body_reads, provider_calls, network_calls) == (0, 0, 0) -@pytest.mark.asyncio -async def test_async_one_shot_body_uses_unsigned_payload_when_retries_are_disabled() -> None: - case = next(case for case in _cases("body_replay") if case["expected"]["result"] == "unsigned_payload") - requests: list[httpx.Request] = [] - - async def body() -> AsyncIterator[bytes]: - for chunk in case["given"]["chunks_base64"]: - yield base64.b64decode(chunk) - - async def handler(request: httpx.Request) -> httpx.Response: - requests.append(request) - return httpx.Response(200, request=request, json={}) - - async with AsyncBedrockOpenAI( - base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", - aws_region="us-east-1", - aws_access_key_id="access-key", - aws_secret_access_key="secret-key", - max_retries=0, - http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler), trust_env=False), - ) as client: - await client.post("/responses", content=body(), cast_to=httpx.Response) - - assert requests[0].headers["X-Amz-Content-SHA256"] == case["expected"]["x_amz_content_sha256"] - assert "x-amz-content-sha256" in requests[0].headers["Authorization"] - - @pytest.mark.asyncio async def test_async_credentials_are_resolved_off_event_loop() -> None: event_loop_thread = threading.get_ident() @@ -312,10 +338,12 @@ def credentials_provider() -> _Credentials: async def handler(request: httpx.Request) -> httpx.Response: return httpx.Response(200, request=request, json={}) - async with AsyncBedrockOpenAI( - base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", - aws_region="us-east-1", - aws_credentials_provider=credentials_provider, + async with AsyncOpenAI( + provider=bedrock( + base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", + region="us-east-1", + credential_provider=credentials_provider, + ), http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler), trust_env=False), ) as client: await client.post("/responses", content=b"{}", cast_to=httpx.Response) @@ -324,38 +352,29 @@ async def handler(request: httpx.Request) -> httpx.Response: assert all(thread_id != event_loop_thread for thread_id in provider_threads) -def test_explicit_authorization_omit_and_override_are_preserved() -> None: +def test_custom_http_client_auth_cannot_replace_sigv4() -> None: requests: list[httpx.Request] = [] def handler(request: httpx.Request) -> httpx.Response: requests.append(request) return httpx.Response(200, request=request, json={}) - with BedrockOpenAI( - base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", - aws_region="us-east-1", - aws_access_key_id="access-key", - aws_secret_access_key="secret-key", + with OpenAI( + provider=bedrock( + base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", + region="us-east-1", + access_key_id="access-key", + secret_access_key="secret-key", + ), http_client=httpx.Client( - headers={"Authorization": "Bearer client-default"}, auth=httpx.BasicAuth("username", "password"), transport=httpx.MockTransport(handler), trust_env=False, ), ) as client: - client.get( - "/models", - cast_to=httpx.Response, - options={"headers": {"Authorization": Omit()}}, - ) - client.get( - "/models", - cast_to=httpx.Response, - options={"headers": {"Authorization": "Bearer explicit-override"}}, - ) + client.get("/models", cast_to=httpx.Response) - assert "Authorization" not in requests[0].headers - assert requests[1].headers["Authorization"] == "Bearer explicit-override" + assert requests[0].headers["Authorization"].startswith("AWS4-HMAC-SHA256 Credential=access-key/") def test_sigv4_redirects_are_not_followed() -> None: @@ -367,11 +386,13 @@ def handler(request: httpx.Request) -> httpx.Response: return httpx.Response(307, request=request, headers={"Location": "/redirected"}) return httpx.Response(200, request=request) - with BedrockOpenAI( - base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", - aws_region="us-east-1", - aws_access_key_id="access-key", - aws_secret_access_key="secret-key", + with OpenAI( + provider=bedrock( + base_url="https://bedrock-mantle.us-east-1.api.aws/openai/v1", + region="us-east-1", + access_key_id="access-key", + secret_access_key="secret-key", + ), http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), ) as client: with pytest.raises(APIStatusError) as exc: diff --git a/tests/lib/test_bedrock_provider.py b/tests/lib/test_bedrock_provider.py new file mode 100644 index 0000000000..20dee2e367 --- /dev/null +++ b/tests/lib/test_bedrock_provider.py @@ -0,0 +1,337 @@ +from __future__ import annotations + +from typing import Any, Iterator, cast + +import httpx +import pytest + +import openai.lib._bedrock_auth as bedrock_auth_module +from openai import OpenAI, AsyncOpenAI, OpenAIError +from tests.utils import update_env +from openai._provider import _create_provider, _ProviderRuntime +from openai.providers import bedrock + + +def test_sync_provider_owns_endpoint_and_bearer_authentication() -> None: + requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(200, request=request, json={}) + + client = OpenAI( + provider=bedrock(region="us-east-1", api_key="bedrock token"), + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) + client.get("/models", cast_to=httpx.Response) + + assert client.base_url == httpx.URL("https://bedrock-mantle.us-east-1.api.aws/openai/v1/") + assert requests[0].url == httpx.URL("https://bedrock-mantle.us-east-1.api.aws/openai/v1/models") + assert requests[0].headers["Authorization"] == "Bearer bedrock token" + + +@pytest.mark.asyncio +async def test_async_provider_owns_endpoint_and_bearer_authentication() -> None: + requests: list[httpx.Request] = [] + + async def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(200, request=request, json={}) + + client = AsyncOpenAI( + provider=bedrock(region="us-east-1", token_provider=lambda: "bedrock token"), + http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler), trust_env=False), + ) + await client.get("/models", cast_to=httpx.Response) + await client.close() + + assert requests[0].headers["Authorization"] == "Bearer bedrock token" + + +def test_provider_ignores_openai_environment_configuration() -> None: + requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(200, request=request, json={}) + + with update_env( + OPENAI_API_KEY="openai token", + OPENAI_BASE_URL="https://api.openai.invalid/v1", + OPENAI_CUSTOM_HEADERS="Authorization: Bearer openai custom token", + ): + client = OpenAI( + provider=bedrock(region="us-east-1", api_key="bedrock token"), + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) + client.get("/models", cast_to=httpx.Response) + + assert client.api_key == "" + assert requests[0].url.host == "bedrock-mantle.us-east-1.api.aws" + assert requests[0].headers["Authorization"] == "Bearer bedrock token" + + +@pytest.mark.parametrize( + ("option", "value"), + [ + ("api_key", "openai token"), + ("admin_api_key", "admin token"), + ("workload_identity", cast(Any, object())), + ("base_url", "https://api.openai.invalid/v1"), + ], +) +def test_provider_rejects_top_level_authentication_and_routing(option: str, value: object) -> None: + with pytest.raises( + OpenAIError, + match=rf"`provider` cannot be combined with top-level `{option}`.*`bedrock\(\.\.\.\)`", + ): + OpenAI(provider=bedrock(region="us-east-1", api_key="bedrock token"), **{option: value}) # type: ignore[arg-type] + + +def test_provider_survives_with_options_and_can_be_replaced() -> None: + client = OpenAI(provider=bedrock(region="us-east-1", api_key="first")) + + copied = client.with_options(timeout=1) + replaced = client.with_options(provider=bedrock(region="eu-west-1", api_key="second")) + + assert copied.base_url == client.base_url + assert copied._provider is client._provider + assert replaced.base_url == httpx.URL("https://bedrock-mantle.eu-west-1.api.aws/openai/v1/") + assert replaced._provider is not client._provider + + +def test_provider_normalizes_responses_before_status_handling() -> None: + class NormalizingProvider: + name = "normalizing" + + def configure(self) -> _ProviderRuntime: + def normalize(response: httpx.Response) -> httpx.Response: + return httpx.Response(200, request=response.request, json={"normalized": True}) + + return _ProviderRuntime( + name=self.name, + base_url="https://provider.example/v1", + normalize_response=normalize, + ) + + client = OpenAI( + provider=_create_provider(NormalizingProvider()), + max_retries=0, + http_client=httpx.Client( + transport=httpx.MockTransport(lambda request: httpx.Response(500, request=request, json={})), + trust_env=False, + ), + ) + + response = client.get("/models", cast_to=httpx.Response) + + assert response.status_code == 200 + assert response.json() == {"normalized": True} + + +def test_environment_bearer_mode_survives_clone_and_refreshes_each_attempt() -> None: + requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(200, request=request, json={}) + + with update_env(AWS_BEARER_TOKEN_BEDROCK="first token"): + client = OpenAI( + provider=bedrock(region="us-east-1"), + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) + client.get("/models", cast_to=httpx.Response) + + copied = client.with_options(timeout=1) + with update_env(AWS_BEARER_TOKEN_BEDROCK="second token"): + copied.get("/models", cast_to=httpx.Response) + + assert [request.headers["Authorization"] for request in requests] == ["Bearer first token", "Bearer second token"] + + +def test_provider_can_be_removed_with_explicit_openai_credentials() -> None: + client = OpenAI(provider=bedrock(region="us-east-1", api_key="bedrock token")) + + copied = client.with_options(provider=None, api_key="openai token") + + assert copied._provider is None + assert copied.api_key == "openai token" + assert copied.base_url == httpx.URL("https://api.openai.com/v1/") + + +def test_bearer_provider_does_not_load_botocore(monkeypatch: pytest.MonkeyPatch) -> None: + real_import_module = bedrock_auth_module.importlib.import_module + + def import_module(name: str) -> Any: + if name.startswith("botocore"): + raise AssertionError("bearer authentication must not import botocore") + return real_import_module(name) + + monkeypatch.setattr(bedrock_auth_module.importlib, "import_module", import_module) + + client = OpenAI(provider=bedrock(region="us-east-1", api_key="bedrock token")) + request = client._build_request(client._prepare_options(_get_options())) + client._prepare_request(request) + + assert request.headers["Authorization"] == "Bearer bedrock token" + + +def test_missing_aws_dependency_is_actionable_and_lazy(monkeypatch: pytest.MonkeyPatch) -> None: + real_import_module = bedrock_auth_module.importlib.import_module + network_calls = 0 + + def import_module(name: str) -> Any: + if name.startswith("botocore"): + raise ImportError(name) + return real_import_module(name) + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal network_calls + network_calls += 1 + return httpx.Response(200, request=request) + + monkeypatch.setattr(bedrock_auth_module.importlib, "import_module", import_module) + client = OpenAI( + provider=bedrock(region="us-east-1"), + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) + + with pytest.raises(OpenAIError, match=r"pip install openai\[bedrock\]"): + client.get("/models", cast_to=httpx.Response) + + assert network_calls == 0 + + +def test_api_key_none_skips_environment_bearer_fallback() -> None: + requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(200, request=request, json={}) + + with update_env( + AWS_BEARER_TOKEN_BEDROCK="environment bearer", + AWS_ACCESS_KEY_ID="access key", + AWS_SECRET_ACCESS_KEY="secret key", + ): + client = OpenAI( + provider=bedrock(region="us-east-1", api_key=None), + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) + client.get("/models", cast_to=httpx.Response) + + assert requests[0].headers["Authorization"].startswith("AWS4-HMAC-SHA256 Credential=access key/") + + +def test_provider_rejects_custom_authorization_before_network() -> None: + network_calls = 0 + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal network_calls + network_calls += 1 + return httpx.Response(200, request=request) + + client = OpenAI( + provider=bedrock(region="us-east-1", api_key="bedrock token"), + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) + + with pytest.raises(OpenAIError, match="cannot be combined with a custom `Authorization` header"): + client.get( + "/models", + cast_to=httpx.Response, + options={"headers": {"Authorization": "Bearer custom"}}, + ) + + assert network_calls == 0 + + +def test_bearer_provider_rejects_cross_origin_requests_before_resolving_credentials() -> None: + network_calls = 0 + provider_calls = 0 + + def token_provider() -> str: + nonlocal provider_calls + provider_calls += 1 + return "bedrock token" + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal network_calls + network_calls += 1 + return httpx.Response(200, request=request) + + client = OpenAI( + provider=bedrock(base_url="https://bedrock.example/openai/v1", token_provider=token_provider), + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) + + with pytest.raises(OpenAIError, match="origin other than the configured provider URL"): + client.get("https://attacker.example/steal", cast_to=httpx.Response) + + assert (provider_calls, network_calls) == (0, 0) + + +@pytest.mark.asyncio +async def test_async_bearer_provider_rejects_cross_origin_requests_before_resolving_credentials() -> None: + network_calls = 0 + provider_calls = 0 + + async def token_provider() -> str: + nonlocal provider_calls + provider_calls += 1 + return "bedrock token" + + async def handler(request: httpx.Request) -> httpx.Response: + nonlocal network_calls + network_calls += 1 + return httpx.Response(200, request=request) + + client = AsyncOpenAI( + provider=bedrock(base_url="https://bedrock.example/openai/v1", token_provider=token_provider), + http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler), trust_env=False), + ) + + with pytest.raises(OpenAIError, match="origin other than the configured provider URL"): + await client.get("https://attacker.example/steal", cast_to=httpx.Response) + + await client.close() + assert (provider_calls, network_calls) == (0, 0) + + +def test_bearer_provider_allows_one_shot_body_when_retries_are_disabled() -> None: + requests: list[httpx.Request] = [] + + def body() -> Iterator[bytes]: + yield b"body" + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(200, request=request) + + client = OpenAI( + provider=bedrock(base_url="https://bedrock.example/openai/v1", api_key="bedrock token"), + max_retries=0, + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) + + client.post("/responses", content=body(), cast_to=httpx.Response) + + assert requests[0].content == b"body" + + +def test_opaque_provider_repr_does_not_expose_credentials() -> None: + provider = bedrock( + region="us-east-1", + access_key_id="secret access key id", + secret_access_key="secret access key", + session_token="secret session token", + ) + + assert "secret" not in repr(provider) + + +def _get_options() -> Any: + from openai._models import FinalRequestOptions + + return FinalRequestOptions(method="get", url="/models", security={"bearer_auth": True}) From 09c071cf75637a1d2f550e6c898d24592c87dceb Mon Sep 17 00:00:00 2001 From: Hayden Date: Fri, 12 Jun 2026 11:34:27 -0700 Subject: [PATCH 4/9] Cover explicit profile auth precedence --- tests/fixtures/bedrock_auth/v1/cases.json | 17 +++++++++++++++++ tests/lib/test_bedrock_auth_conformance.py | 2 ++ 2 files changed, 19 insertions(+) diff --git a/tests/fixtures/bedrock_auth/v1/cases.json b/tests/fixtures/bedrock_auth/v1/cases.json index 2d05c8827e..602ab703c3 100644 --- a/tests/fixtures/bedrock_auth/v1/cases.json +++ b/tests/fixtures/bedrock_auth/v1/cases.json @@ -39,6 +39,23 @@ "auth_source": "static" } }, + { + "id": "auth.explicit-profile-over-environment-bearer", + "kind": "auth_selection", + "given": { + "explicit": { + "profile": "fixture-profile" + }, + "environment": { + "AWS_BEARER_TOKEN_BEDROCK": "environment-bearer-token" + }, + "default_chain_available": true + }, + "expected": { + "auth_mode": "sigv4", + "auth_source": "profile" + } + }, { "id": "auth.environment-bearer-over-default-chain", "kind": "auth_selection", diff --git a/tests/lib/test_bedrock_auth_conformance.py b/tests/lib/test_bedrock_auth_conformance.py index 526de29b43..58f6889650 100644 --- a/tests/lib/test_bedrock_auth_conformance.py +++ b/tests/lib/test_bedrock_auth_conformance.py @@ -117,6 +117,8 @@ def test_auth_selection_fixture(case: dict[str, Any], monkeypatch: pytest.Monkey if "aws" in explicit: kwargs["aws_access_key_id"] = explicit["aws"]["access_key_id"] kwargs["aws_secret_access_key"] = explicit["aws"]["secret_access_key"] + if "profile" in explicit: + kwargs["aws_profile"] = explicit["profile"] if case["expected"].get("error") == "bedrock_conflicting_auth": with pytest.raises(OpenAIError, match="authentication is ambiguous"): From 803331f684ccddccab3d20dbda478adad6ecb3a1 Mon Sep 17 00:00:00 2001 From: Hayden Date: Mon, 15 Jun 2026 09:55:54 -0700 Subject: [PATCH 5/9] Add opt-in Bedrock live test --- tests/lib/bedrock_live.py | 93 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 tests/lib/bedrock_live.py diff --git a/tests/lib/bedrock_live.py b/tests/lib/bedrock_live.py new file mode 100644 index 0000000000..579344eabb --- /dev/null +++ b/tests/lib/bedrock_live.py @@ -0,0 +1,93 @@ +"""Opt-in live test for the Amazon Bedrock provider. + +This file is intentionally named ``bedrock_live.py`` so the standard pytest +suite does not collect it. Run it explicitly with: + + rye run pytest -q -s tests/lib/bedrock_live.py + +The test loads ``.env`` from the repository root when it exists. Set +``BEDROCK_LIVE_ENV_FILE`` to load a different file. Existing environment +variables take precedence over values from the file. + +Useful environment variables: + +- ``BEDROCK_LIVE_MODEL`` (defaults to ``openai.gpt-5.4``) +- ``BEDROCK_LIVE_REGION`` (otherwise uses the normal AWS region chain) +- ``BEDROCK_LIVE_PROFILE`` (otherwise uses the normal AWS credential chain) +- ``AWS_BEDROCK_BASE_URL`` (optional endpoint override) + +The normal AWS environment variables, shared config, credential-process, SSO, +and workload credential sources are resolved by botocore. The test explicitly +disables the ``AWS_BEARER_TOKEN_BEDROCK`` fallback so a passing run proves that +AWS credentials and SigV4 signing worked. +""" + +from __future__ import annotations + +import os +import re +from pathlib import Path + +from openai import OpenAI +from openai.providers import bedrock + +_REPOSITORY_ROOT = Path(__file__).resolve().parents[2] +_ENV_NAME = re.compile(r"[A-Za-z_][A-Za-z0-9_]*") + + +def _load_env_file() -> None: + configured_path = os.environ.get("BEDROCK_LIVE_ENV_FILE") + path = Path(configured_path).expanduser() if configured_path else _REPOSITORY_ROOT / ".env" + if not path.exists(): + if configured_path: + raise RuntimeError(f"BEDROCK_LIVE_ENV_FILE does not exist: {path}") + return + + for line_number, raw_line in enumerate(path.read_text().splitlines(), start=1): + line = raw_line.strip() + if not line or line.startswith("#"): + continue + if line.startswith("export "): + line = line.removeprefix("export ").lstrip() + + name, separator, raw_value = line.partition("=") + name = name.strip() + if not separator or _ENV_NAME.fullmatch(name) is None: + raise RuntimeError(f"Invalid environment assignment at {path}:{line_number}") + + value = raw_value.strip() + if len(value) >= 2 and value[0] == value[-1] and value[0] in {"'", '"'}: + value = value[1:-1] + elif " #" in value: + value = value.split(" #", 1)[0].rstrip() + + os.environ.setdefault(name, value) + + +_load_env_file() + + +def test_bedrock_live_response() -> None: + model = os.environ.get("BEDROCK_LIVE_MODEL") or "openai.gpt-5.4" + region = os.environ.get("BEDROCK_LIVE_REGION") or None + profile = os.environ.get("BEDROCK_LIVE_PROFILE") or None + + with OpenAI( + provider=bedrock( + region=region, + profile=profile, + api_key=None, + ), + timeout=60, + max_retries=2, + ) as client: + response = client.responses.create( + model=model, + input="Reply with exactly: bedrock live test ok", + max_output_tokens=64, + ) + + output_text = response.output_text.strip() + assert output_text, f"Bedrock returned no output text for response {response.id}" + assert "bedrock live test ok" in output_text.lower() + print(f"Bedrock live response {response.id}: {output_text}") From 6f96c7ea0009ae38617df739e0220b8b4595f347 Mon Sep 17 00:00:00 2001 From: Hayden Date: Mon, 15 Jun 2026 11:00:23 -0700 Subject: [PATCH 6/9] Preserve Bedrock profile regions for custom endpoints --- src/openai/lib/bedrock.py | 71 +++++++++++++++++++++++---------- src/openai/providers/bedrock.py | 11 ++++- tests/lib/test_bedrock.py | 47 ++++++++++++++++++++++ 3 files changed, 107 insertions(+), 22 deletions(-) diff --git a/src/openai/lib/bedrock.py b/src/openai/lib/bedrock.py index b05474f077..466e8ed75e 100644 --- a/src/openai/lib/bedrock.py +++ b/src/openai/lib/bedrock.py @@ -18,7 +18,7 @@ from .._provider import _Provider, _configure_provider from .._exceptions import OpenAIError from .._base_client import DEFAULT_MAX_RETRIES -from ..providers.bedrock import AwsCredentialsProvider, bedrock +from ..providers.bedrock import AwsCredentialsProvider, bedrock, _BedrockProviderRuntime BedrockTokenProvider = Callable[[], str] AsyncBedrockTokenProvider = Callable[[], "str | Awaitable[str]"] @@ -122,6 +122,7 @@ def _legacy_provider( aws_session_token: str | None, aws_credentials_provider: AwsCredentialsProvider | None, base_url: str | httpx.URL | None, + region_was_explicit: bool | None = None, ) -> tuple[_Provider, _LegacyBedrockState, str]: if callable(cast(object, api_key)): raise OpenAIError("Pass refreshable Bedrock credentials via `bedrock_token_provider`, not `api_key`.") @@ -176,7 +177,9 @@ def _legacy_provider( explicit_api_key=api_key, token_provider=token_provider, aws_region=resolved_region, - region_was_explicit=bool(aws_region and aws_region.strip()), + region_was_explicit=( + bool(aws_region and aws_region.strip()) if region_was_explicit is None else region_was_explicit + ), aws_profile=aws_profile, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, @@ -263,6 +266,7 @@ def _copy_configuration( next_credentials_provider = state.aws_credentials_provider next_region = aws_region if aws_region is not None else client.aws_region + next_region_was_explicit = aws_region is not None or state.region_was_explicit if aws_profile is not None and aws_region is None and not state.region_was_explicit: next_region = None @@ -273,21 +277,21 @@ def _copy_configuration( else: next_base_url = client.base_url - return ( - { - "api_key": next_api_key, - "bedrock_token_provider": next_token_provider, - "aws_region": next_region, - "aws_profile": next_profile, - "aws_access_key_id": next_access_key_id, - "aws_secret_access_key": next_secret_access_key, - "aws_session_token": next_session_token, - "aws_credentials_provider": next_credentials_provider, - "base_url": next_base_url, - }, - None, - None, - ) + provider_kwargs: dict[str, object] = { + "api_key": next_api_key, + "bedrock_token_provider": next_token_provider, + "aws_region": next_region, + "aws_profile": next_profile, + "aws_access_key_id": next_access_key_id, + "aws_secret_access_key": next_secret_access_key, + "aws_session_token": next_session_token, + "aws_credentials_provider": next_credentials_provider, + "base_url": next_base_url, + } + if _constructor_accepts_keyword(client.__class__.__init__, "_region_was_explicit"): + provider_kwargs["_region_was_explicit"] = next_region_was_explicit + + return provider_kwargs, None, None def _legacy_runtime_signature( @@ -368,7 +372,14 @@ def _refresh_legacy_provider_runtime(client: BedrockOpenAI | AsyncBedrockOpenAI) client._bedrock_provider = provider client._provider = provider client._provider_runtime = _configure_provider(provider) - client._bedrock_runtime_signature = signature + if ( + isinstance(client._provider_runtime, _BedrockProviderRuntime) + and client.aws_region is None + and client._provider_runtime.region is not None + ): + client.aws_region = client._provider_runtime.region + client._bedrock_state = replace(client._bedrock_state, aws_region=client.aws_region) + client._bedrock_runtime_signature = _legacy_runtime_signature(client, configuration) class BedrockOpenAI(OpenAI): @@ -406,6 +417,7 @@ def __init__( _enforce_credentials: bool = True, _provider: _Provider | None = None, _state: _LegacyBedrockState | None = None, + _region_was_explicit: bool | None = None, ) -> None: if _provider is None or _state is None: _provider, _state, public_api_key = _legacy_provider( @@ -418,6 +430,7 @@ def __init__( aws_session_token=aws_session_token, aws_credentials_provider=aws_credentials_provider, base_url=base_url, + region_was_explicit=_region_was_explicit, ) else: public_api_key = ( @@ -446,7 +459,15 @@ def __init__( self._bedrock_token_provider = cast("BedrockTokenProvider | None", _state.token_provider) self._uses_region_derived_base_url = _state.uses_region_derived_base_url canonical_region = re.fullmatch(r"bedrock-mantle\.([a-z0-9-]+)\.api\.aws", self.base_url.host) - self.aws_region = _state.aws_region or (canonical_region.group(1) if canonical_region is not None else None) + provider_region = ( + self._provider_runtime.region if isinstance(self._provider_runtime, _BedrockProviderRuntime) else None + ) + self.aws_region = ( + _state.aws_region + or provider_region + or (canonical_region.group(1) if canonical_region is not None else None) + ) + self._bedrock_state = replace(_state, aws_region=self.aws_region) self.api_key = public_api_key or "" self._bedrock_runtime_signature = _legacy_runtime_signature(self, self._legacy_auth_configuration()) @@ -630,6 +651,7 @@ def __init__( _enforce_credentials: bool = True, _provider: _Provider | None = None, _state: _LegacyBedrockState | None = None, + _region_was_explicit: bool | None = None, ) -> None: if _provider is None or _state is None: _provider, _state, public_api_key = _legacy_provider( @@ -642,6 +664,7 @@ def __init__( aws_session_token=aws_session_token, aws_credentials_provider=aws_credentials_provider, base_url=base_url, + region_was_explicit=_region_was_explicit, ) else: public_api_key = ( @@ -670,7 +693,15 @@ def __init__( self._bedrock_token_provider = cast("AsyncBedrockTokenProvider | None", _state.token_provider) self._uses_region_derived_base_url = _state.uses_region_derived_base_url canonical_region = re.fullmatch(r"bedrock-mantle\.([a-z0-9-]+)\.api\.aws", self.base_url.host) - self.aws_region = _state.aws_region or (canonical_region.group(1) if canonical_region is not None else None) + provider_region = ( + self._provider_runtime.region if isinstance(self._provider_runtime, _BedrockProviderRuntime) else None + ) + self.aws_region = ( + _state.aws_region + or provider_region + or (canonical_region.group(1) if canonical_region is not None else None) + ) + self._bedrock_state = replace(_state, aws_region=self.aws_region) self.api_key = public_api_key or "" self._bedrock_runtime_signature = _legacy_runtime_signature(self, self._legacy_auth_configuration()) diff --git a/src/openai/providers/bedrock.py b/src/openai/providers/bedrock.py index e5cc5268ab..5cee093bfd 100644 --- a/src/openai/providers/bedrock.py +++ b/src/openai/providers/bedrock.py @@ -188,6 +188,11 @@ async def prepare_async_request(self, request: httpx.Request) -> None: request.headers.update(signed_headers) +@dataclass +class _BedrockProviderRuntime(_ProviderRuntime): + region: str | None = None + + @dataclass(frozen=True) class _BedrockProviderDefinition: configured_region: str | None @@ -282,17 +287,19 @@ def environment_token() -> str: assert auth is not None if isinstance(auth, _BedrockSigV4Auth): - return _ProviderRuntime( + return _BedrockProviderRuntime( name=self.name, base_url=base_url, + region=region, transform_request=_without_redirects, prepare_request=auth.prepare_request, prepare_async_request=auth.prepare_async_request, ) - return _ProviderRuntime( + return _BedrockProviderRuntime( name=self.name, base_url=base_url, + region=region, prepare_request=auth.prepare_request, prepare_async_request=auth.prepare_async_request, ) diff --git a/tests/lib/test_bedrock.py b/tests/lib/test_bedrock.py index 244dfab79f..2a6dc2b7a2 100644 --- a/tests/lib/test_bedrock.py +++ b/tests/lib/test_bedrock.py @@ -828,6 +828,53 @@ def test_profile_derived_region_survives_auth_override( assert copied_client._uses_aws_auth() is uses_aws_auth +@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) +@pytest.mark.parametrize( + ("aws_profile", "config_contents"), + [ + ("west", "[profile west]\nregion = us-west-2\n[profile east]\nregion = us-east-1\n"), + (None, "[default]\nregion = us-west-2\n[profile east]\nregion = us-east-1\n"), + ], + ids=["named-profile", "default-profile"], +) +def test_profile_derived_region_survives_credential_override_with_custom_base_url( + client_cls: type[Client], aws_profile: str | None, config_contents: str, tmp_path: Path +) -> None: + config_path = tmp_path / "config" + config_path.write_text(config_contents) + + with update_env( + AWS_CONFIG_FILE=str(config_path), + AWS_PROFILE=Omit(), + AWS_BEARER_TOKEN_BEDROCK=Omit(), + AWS_REGION=Omit(), + AWS_DEFAULT_REGION=Omit(), + ): + client = ( + make_sync_client(base_url="https://custom.example/openai/v1", aws_profile=aws_profile) + if client_cls is BedrockOpenAI + else make_async_client(base_url="https://custom.example/openai/v1", aws_profile=aws_profile) + ) + assert client.aws_region == "us-west-2" + + client.aws_region = None + client.with_options(timeout=1) + + copied_client = client.with_options( + aws_access_key_id="replacement access key", + aws_secret_access_key="replacement secret key", + ) + profile_client = copied_client.with_options(aws_profile="east") + + assert client.aws_region == "us-west-2" + assert copied_client.aws_region == "us-west-2" + assert copied_client._bedrock_state.aws_region == "us-west-2" + assert copied_client.base_url == URL("https://custom.example/openai/v1/") + assert copied_client._uses_aws_auth() + assert profile_client.aws_region == "us-east-1" + assert profile_client.base_url == URL("https://custom.example/openai/v1/") + + @pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI]) def test_with_options_switching_from_bearer_to_profile_re_resolves_environment_region( client_cls: type[Client], tmp_path: Path From c1c1848bac47f8104157f5bbab23607fd84427e1 Mon Sep 17 00:00:00 2001 From: Hayden Date: Mon, 15 Jun 2026 11:52:44 -0700 Subject: [PATCH 7/9] Align Bedrock live test with GPT-OSS --- tests/lib/bedrock_live.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/lib/bedrock_live.py b/tests/lib/bedrock_live.py index 579344eabb..95c1a71c1d 100644 --- a/tests/lib/bedrock_live.py +++ b/tests/lib/bedrock_live.py @@ -11,10 +11,10 @@ Useful environment variables: -- ``BEDROCK_LIVE_MODEL`` (defaults to ``openai.gpt-5.4``) +- ``BEDROCK_LIVE_MODEL`` (defaults to ``openai.gpt-oss-120b``) - ``BEDROCK_LIVE_REGION`` (otherwise uses the normal AWS region chain) - ``BEDROCK_LIVE_PROFILE`` (otherwise uses the normal AWS credential chain) -- ``AWS_BEDROCK_BASE_URL`` (optional endpoint override) +- ``AWS_BEDROCK_BASE_URL`` (required; GPT-OSS uses the ``/v1`` endpoint) The normal AWS environment variables, shared config, credential-process, SSO, and workload credential sources are resolved by botocore. The test explicitly @@ -68,14 +68,21 @@ def _load_env_file() -> None: def test_bedrock_live_response() -> None: - model = os.environ.get("BEDROCK_LIVE_MODEL") or "openai.gpt-5.4" + model = os.environ.get("BEDROCK_LIVE_MODEL") or "openai.gpt-oss-120b" region = os.environ.get("BEDROCK_LIVE_REGION") or None profile = os.environ.get("BEDROCK_LIVE_PROFILE") or None + base_url = os.environ.get("AWS_BEDROCK_BASE_URL") or None + if base_url is None: + raise RuntimeError( + "Set AWS_BEDROCK_BASE_URL to the Bedrock GPT-OSS endpoint, for example " + "https://bedrock-mantle.us-west-2.api.aws/v1." + ) with OpenAI( provider=bedrock( region=region, profile=profile, + base_url=base_url, api_key=None, ), timeout=60, @@ -84,7 +91,7 @@ def test_bedrock_live_response() -> None: response = client.responses.create( model=model, input="Reply with exactly: bedrock live test ok", - max_output_tokens=64, + store=False, ) output_text = response.output_text.strip() From c89f2bf6ab0f9450ed224f6735c2549f8d3928d0 Mon Sep 17 00:00:00 2001 From: Hayden Date: Mon, 15 Jun 2026 12:08:45 -0700 Subject: [PATCH 8/9] Test Bedrock default credential chain --- tests/lib/test_bedrock_auth_conformance.py | 73 ++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/tests/lib/test_bedrock_auth_conformance.py b/tests/lib/test_bedrock_auth_conformance.py index 58f6889650..e8ed6344a1 100644 --- a/tests/lib/test_bedrock_auth_conformance.py +++ b/tests/lib/test_bedrock_auth_conformance.py @@ -146,6 +146,79 @@ def test_auth_selection_fixture(case: dict[str, Any], monkeypatch: pytest.Monkey assert source == case["expected"]["auth_source"] +def test_default_chain_uses_environment_session_credentials(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False) + monkeypatch.delenv("AWS_PROFILE", raising=False) + monkeypatch.delenv("AWS_DEFAULT_PROFILE", raising=False) + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "environment-access-key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "environment-secret-key") + monkeypatch.setenv("AWS_SESSION_TOKEN", "environment-session-token") + monkeypatch.setenv("AWS_REGION", "us-east-1") + + requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(200, request=request, json={}) + + with OpenAI( + provider=bedrock(), + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) as client: + client.get("/models", cast_to=httpx.Response) + + assert "Credential=environment-access-key/" in requests[0].headers["Authorization"] + assert requests[0].headers["X-Amz-Security-Token"] == "environment-session-token" + + +def test_default_chain_uses_aws_profile_and_shared_files(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + credentials_path = tmp_path / "credentials" + credentials_path.write_text( + "[work]\n" + "aws_access_key_id = profile-access-key\n" + "aws_secret_access_key = profile-secret-key\n" + "aws_session_token = profile-session-token\n" + ) + config_path = tmp_path / "config" + config_path.write_text("[profile work]\nregion = us-west-2\n") + + for name in ( + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + "AWS_SECURITY_TOKEN", + "AWS_DEFAULT_PROFILE", + "AWS_REGION", + "AWS_DEFAULT_REGION", + "AWS_BEARER_TOKEN_BEDROCK", + "AWS_WEB_IDENTITY_TOKEN_FILE", + "AWS_ROLE_ARN", + "AWS_CONTAINER_CREDENTIALS_FULL_URI", + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", + ): + monkeypatch.delenv(name, raising=False) + monkeypatch.setenv("AWS_PROFILE", "work") + monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", str(credentials_path)) + monkeypatch.setenv("AWS_CONFIG_FILE", str(config_path)) + monkeypatch.setenv("AWS_EC2_METADATA_DISABLED", "true") + + requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(200, request=request, json={}) + + with OpenAI( + provider=bedrock(), + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) as client: + client.get("/models", cast_to=httpx.Response) + + assert requests[0].url.host == "bedrock-mantle.us-west-2.api.aws" + assert "Credential=profile-access-key/" in requests[0].headers["Authorization"] + assert requests[0].headers["X-Amz-Security-Token"] == "profile-session-token" + + @pytest.mark.parametrize("case", _cases("sigv4"), ids=lambda case: case["id"]) def test_sigv4_fixture(case: dict[str, Any], monkeypatch: pytest.MonkeyPatch) -> None: credentials = case["given"]["credentials"] From e09d0c28a56d6061ff110697f135e5e827c2ad20 Mon Sep 17 00:00:00 2001 From: Hayden Date: Mon, 15 Jun 2026 14:06:24 -0700 Subject: [PATCH 9/9] Cover the Bedrock AWS credential chain --- tests/lib/test_bedrock.py | 15 +- tests/lib/test_bedrock_auth_conformance.py | 73 ---- tests/lib/test_bedrock_credential_chain.py | 389 +++++++++++++++++++++ 3 files changed, 397 insertions(+), 80 deletions(-) create mode 100644 tests/lib/test_bedrock_credential_chain.py diff --git a/tests/lib/test_bedrock.py b/tests/lib/test_bedrock.py index 2a6dc2b7a2..e8d89c8018 100644 --- a/tests/lib/test_bedrock.py +++ b/tests/lib/test_bedrock.py @@ -166,7 +166,7 @@ def import_module(name: str) -> Any: return real_import_module(name) monkeypatch.setattr(bedrock_auth_module.importlib, "import_module", import_module) - with update_env(AWS_BEARER_TOKEN_BEDROCK="", AWS_REGION="us-east-1"): + with update_env(AWS_BEDROCK_BASE_URL=Omit(), AWS_BEARER_TOKEN_BEDROCK="", AWS_REGION="us-east-1"): client = make_sync_client() with pytest.raises(OpenAIError, match="requires optional AWS dependencies"): client.get("/models", cast_to=httpx.Response) @@ -942,13 +942,14 @@ def __init__( _enforce_credentials=_enforce_credentials, ) - client = LegacyBedrockOpenAI( - api_key="token", - aws_region="us-east-1", - http_client=httpx.Client(trust_env=False), - ) + with update_env(AWS_BEDROCK_BASE_URL=Omit()): + client = LegacyBedrockOpenAI( + api_key="token", + aws_region="us-east-1", + http_client=httpx.Client(trust_env=False), + ) - copied_client = client.with_options(timeout=1).with_options(aws_region="us-west-2") + copied_client = client.with_options(timeout=1).with_options(aws_region="us-west-2") assert isinstance(copied_client, LegacyBedrockOpenAI) assert copied_client.api_key == "token" diff --git a/tests/lib/test_bedrock_auth_conformance.py b/tests/lib/test_bedrock_auth_conformance.py index e8ed6344a1..58f6889650 100644 --- a/tests/lib/test_bedrock_auth_conformance.py +++ b/tests/lib/test_bedrock_auth_conformance.py @@ -146,79 +146,6 @@ def test_auth_selection_fixture(case: dict[str, Any], monkeypatch: pytest.Monkey assert source == case["expected"]["auth_source"] -def test_default_chain_uses_environment_session_credentials(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False) - monkeypatch.delenv("AWS_PROFILE", raising=False) - monkeypatch.delenv("AWS_DEFAULT_PROFILE", raising=False) - monkeypatch.setenv("AWS_ACCESS_KEY_ID", "environment-access-key") - monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "environment-secret-key") - monkeypatch.setenv("AWS_SESSION_TOKEN", "environment-session-token") - monkeypatch.setenv("AWS_REGION", "us-east-1") - - requests: list[httpx.Request] = [] - - def handler(request: httpx.Request) -> httpx.Response: - requests.append(request) - return httpx.Response(200, request=request, json={}) - - with OpenAI( - provider=bedrock(), - http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), - ) as client: - client.get("/models", cast_to=httpx.Response) - - assert "Credential=environment-access-key/" in requests[0].headers["Authorization"] - assert requests[0].headers["X-Amz-Security-Token"] == "environment-session-token" - - -def test_default_chain_uses_aws_profile_and_shared_files(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - credentials_path = tmp_path / "credentials" - credentials_path.write_text( - "[work]\n" - "aws_access_key_id = profile-access-key\n" - "aws_secret_access_key = profile-secret-key\n" - "aws_session_token = profile-session-token\n" - ) - config_path = tmp_path / "config" - config_path.write_text("[profile work]\nregion = us-west-2\n") - - for name in ( - "AWS_ACCESS_KEY_ID", - "AWS_SECRET_ACCESS_KEY", - "AWS_SESSION_TOKEN", - "AWS_SECURITY_TOKEN", - "AWS_DEFAULT_PROFILE", - "AWS_REGION", - "AWS_DEFAULT_REGION", - "AWS_BEARER_TOKEN_BEDROCK", - "AWS_WEB_IDENTITY_TOKEN_FILE", - "AWS_ROLE_ARN", - "AWS_CONTAINER_CREDENTIALS_FULL_URI", - "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", - ): - monkeypatch.delenv(name, raising=False) - monkeypatch.setenv("AWS_PROFILE", "work") - monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", str(credentials_path)) - monkeypatch.setenv("AWS_CONFIG_FILE", str(config_path)) - monkeypatch.setenv("AWS_EC2_METADATA_DISABLED", "true") - - requests: list[httpx.Request] = [] - - def handler(request: httpx.Request) -> httpx.Response: - requests.append(request) - return httpx.Response(200, request=request, json={}) - - with OpenAI( - provider=bedrock(), - http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), - ) as client: - client.get("/models", cast_to=httpx.Response) - - assert requests[0].url.host == "bedrock-mantle.us-west-2.api.aws" - assert "Credential=profile-access-key/" in requests[0].headers["Authorization"] - assert requests[0].headers["X-Amz-Security-Token"] == "profile-session-token" - - @pytest.mark.parametrize("case", _cases("sigv4"), ids=lambda case: case["id"]) def test_sigv4_fixture(case: dict[str, Any], monkeypatch: pytest.MonkeyPatch) -> None: credentials = case["given"]["credentials"] diff --git a/tests/lib/test_bedrock_credential_chain.py b/tests/lib/test_bedrock_credential_chain.py new file mode 100644 index 0000000000..93e476fe5b --- /dev/null +++ b/tests/lib/test_bedrock_credential_chain.py @@ -0,0 +1,389 @@ +from __future__ import annotations + +import sys +import json +import threading +from typing import Any, Iterator, cast +from pathlib import Path +from datetime import datetime, timezone, timedelta +from contextlib import contextmanager +from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler +from typing_extensions import override + +import httpx +import pytest + +from openai import OpenAI +from openai.providers import bedrock + +_FUTURE_EXPIRATION = "2099-01-01T00:00:00Z" +_AWS_ENVIRONMENT_NAMES = ( + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + "AWS_SECURITY_TOKEN", + "AWS_PROFILE", + "AWS_DEFAULT_PROFILE", + "AWS_REGION", + "AWS_DEFAULT_REGION", + "AWS_CONFIG_FILE", + "AWS_SHARED_CREDENTIALS_FILE", + "AWS_CREDENTIAL_FILE", + "AWS_WEB_IDENTITY_TOKEN_FILE", + "AWS_ROLE_ARN", + "AWS_ROLE_SESSION_NAME", + "AWS_CONTAINER_CREDENTIALS_FULL_URI", + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", + "AWS_CONTAINER_AUTHORIZATION_TOKEN", + "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE", + "AWS_EC2_METADATA_DISABLED", + "AWS_EC2_METADATA_SERVICE_ENDPOINT", + "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE", + "AWS_BEARER_TOKEN_BEDROCK", + "AWS_BEDROCK_BASE_URL", + "BOTO_CONFIG", + "HTTP_PROXY", + "HTTPS_PROXY", + "ALL_PROXY", + "http_proxy", + "https_proxy", + "all_proxy", + "NO_PROXY", + "no_proxy", +) + + +def _isolate_aws_environment(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> tuple[Path, Path]: + for name in _AWS_ENVIRONMENT_NAMES: + monkeypatch.delenv(name, raising=False) + + config_path = tmp_path / "config" + credentials_path = tmp_path / "credentials" + config_path.write_text("") + credentials_path.write_text("") + monkeypatch.setenv("AWS_CONFIG_FILE", str(config_path)) + monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", str(credentials_path)) + monkeypatch.setenv("AWS_EC2_METADATA_DISABLED", "true") + monkeypatch.setenv("NO_PROXY", "127.0.0.1,localhost") + monkeypatch.setenv("no_proxy", "127.0.0.1,localhost") + return config_path, credentials_path + + +def _signed_request(**provider_options: Any) -> httpx.Request: + requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(200, request=request, json={}) + + with OpenAI( + provider=bedrock(**provider_options), + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) as client: + client.get("/models", cast_to=httpx.Response) + + assert len(requests) == 1 + return requests[0] + + +def _credentials_metadata(name: str, *, expiration: str = _FUTURE_EXPIRATION) -> dict[str, str]: + return { + "access_key": f"{name}-access-key", + "secret_key": f"{name}-secret-key", + "token": f"{name}-session-token", + "expiry_time": expiration, + } + + +def _assert_signed_with(request: httpx.Request, name: str, *, region: str) -> None: + assert request.url.host == f"bedrock-mantle.{region}.api.aws" + assert f"Credential={name}-access-key/" in request.headers["Authorization"] + assert request.headers["X-Amz-Security-Token"] == f"{name}-session-token" + + +@contextmanager +def _metadata_server() -> Iterator[tuple[str, list[tuple[str, str, str | None]]]]: + calls: list[tuple[str, str, str | None]] = [] + + class Handler(BaseHTTPRequestHandler): + def _respond(self, body: str, *, content_type: str = "text/plain") -> None: + encoded = body.encode() + self.send_response(200) + self.send_header("Content-Type", content_type) + self.send_header("Content-Length", str(len(encoded))) + self.end_headers() + self.wfile.write(encoded) + + def do_PUT(self) -> None: + calls.append(("PUT", self.path, self.headers.get("Authorization"))) + if self.path != "/latest/api/token": + self.send_error(404) + return + self._respond("metadata-token") + + def do_GET(self) -> None: + calls.append(("GET", self.path, self.headers.get("Authorization"))) + if self.path == "/container-credentials": + self._respond( + json.dumps( + { + "AccessKeyId": "container-access-key", + "SecretAccessKey": "container-secret-key", + "Token": "container-session-token", + "Expiration": _FUTURE_EXPIRATION, + } + ), + content_type="application/json", + ) + return + if self.path == "/latest/meta-data/iam/security-credentials/": + self._respond("instance-role") + return + if self.path == "/latest/meta-data/iam/security-credentials/instance-role": + self._respond( + json.dumps( + { + "Code": "Success", + "LastUpdated": "2026-01-01T00:00:00Z", + "Type": "AWS-HMAC", + "AccessKeyId": "imds-access-key", + "SecretAccessKey": "imds-secret-key", + "Token": "imds-session-token", + "Expiration": _FUTURE_EXPIRATION, + } + ), + content_type="application/json", + ) + return + self.send_error(404) + + @override + def log_message(self, format: str, *args: Any) -> None: + del format, args + return + + server = ThreadingHTTPServer(("127.0.0.1", 0), Handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + host, port = cast("tuple[str, int]", server.server_address) + try: + yield f"http://{host}:{port}", calls + finally: + server.shutdown() + server.server_close() + thread.join() + + +def test_default_chain_uses_environment_session_credentials(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + _isolate_aws_environment(monkeypatch, tmp_path) + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "environment-access-key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "environment-secret-key") + monkeypatch.setenv("AWS_SESSION_TOKEN", "environment-session-token") + monkeypatch.setenv("AWS_REGION", "us-east-1") + + request = _signed_request() + + _assert_signed_with(request, "environment", region="us-east-1") + + +def test_default_chain_uses_aws_profile_and_shared_files(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + config_path, credentials_path = _isolate_aws_environment(monkeypatch, tmp_path) + credentials_path.write_text( + "[work]\n" + "aws_access_key_id = profile-access-key\n" + "aws_secret_access_key = profile-secret-key\n" + "aws_session_token = profile-session-token\n" + ) + config_path.write_text("[profile work]\nregion = us-west-2\n") + monkeypatch.setenv("AWS_PROFILE", "work") + + request = _signed_request() + + _assert_signed_with(request, "profile", region="us-west-2") + + +def test_default_chain_uses_assume_role_profile(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + config_path, credentials_path = _isolate_aws_environment(monkeypatch, tmp_path) + credentials_path.write_text( + "[source]\naws_access_key_id = source-access-key\naws_secret_access_key = source-secret-key\n" + ) + config_path.write_text( + "[profile assumed]\n" + "role_arn = arn:aws:iam::123456789012:role/example\n" + "source_profile = source\n" + "region = us-east-2\n" + ) + monkeypatch.setenv("AWS_PROFILE", "assumed") + credentials_module = pytest.importorskip("botocore.credentials") + fetches = 0 + + def fetch_credentials(_: object) -> dict[str, str]: + nonlocal fetches + fetches += 1 + return _credentials_metadata("assume-role") + + monkeypatch.setattr(credentials_module.AssumeRoleCredentialFetcher, "fetch_credentials", fetch_credentials) + + request = _signed_request() + + assert fetches == 1 + _assert_signed_with(request, "assume-role", region="us-east-2") + + +def test_default_chain_uses_sso_profile(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + config_path, _ = _isolate_aws_environment(monkeypatch, tmp_path) + config_path.write_text( + "[profile sso]\n" + "sso_session = company\n" + "sso_account_id = 123456789012\n" + "sso_role_name = Example\n" + "region = us-west-1\n\n" + "[sso-session company]\n" + "sso_start_url = https://example.awsapps.com/start\n" + "sso_region = us-east-1\n" + ) + monkeypatch.setenv("AWS_PROFILE", "sso") + credentials_module = pytest.importorskip("botocore.credentials") + fetches = 0 + + def fetch_credentials(_: object) -> dict[str, str]: + nonlocal fetches + fetches += 1 + return _credentials_metadata("sso") + + monkeypatch.setattr(credentials_module.SSOCredentialFetcher, "fetch_credentials", fetch_credentials) + + request = _signed_request() + + assert fetches == 1 + _assert_signed_with(request, "sso", region="us-west-1") + + +def test_default_chain_uses_web_identity(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + _isolate_aws_environment(monkeypatch, tmp_path) + token_path = tmp_path / "web-identity-token" + token_path.write_text("web-identity-token") + monkeypatch.setenv("AWS_ROLE_ARN", "arn:aws:iam::123456789012:role/web-identity") + monkeypatch.setenv("AWS_ROLE_SESSION_NAME", "bedrock-test") + monkeypatch.setenv("AWS_WEB_IDENTITY_TOKEN_FILE", str(token_path)) + monkeypatch.setenv("AWS_REGION", "eu-west-1") + credentials_module = pytest.importorskip("botocore.credentials") + fetches = 0 + + def fetch_credentials(_: object) -> dict[str, str]: + nonlocal fetches + fetches += 1 + return _credentials_metadata("web-identity") + + monkeypatch.setattr( + credentials_module.AssumeRoleWithWebIdentityCredentialFetcher, + "fetch_credentials", + fetch_credentials, + ) + + request = _signed_request() + + assert fetches == 1 + _assert_signed_with(request, "web-identity", region="eu-west-1") + + +def test_default_chain_uses_credential_process(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + config_path, _ = _isolate_aws_environment(monkeypatch, tmp_path) + process_path = tmp_path / "credentials_process.py" + process_output = { + "Version": 1, + "AccessKeyId": "process-access-key", + "SecretAccessKey": "process-secret-key", + "SessionToken": "process-session-token", + "Expiration": _FUTURE_EXPIRATION, + } + process_path.write_text(f"print({json.dumps(process_output)!r})\n") + command = f"{json.dumps(sys.executable)} {json.dumps(str(process_path))}" + config_path.write_text(f"[profile process]\nregion = ap-southeast-2\ncredential_process = {command}\n") + monkeypatch.setenv("AWS_PROFILE", "process") + + request = _signed_request() + + _assert_signed_with(request, "process", region="ap-southeast-2") + + +@pytest.mark.parametrize("use_token_file", [False, True], ids=["ecs", "eks-pod-identity"]) +def test_default_chain_uses_container_credentials( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path, use_token_file: bool +) -> None: + _isolate_aws_environment(monkeypatch, tmp_path) + monkeypatch.setenv("AWS_REGION", "us-west-2") + with _metadata_server() as (base_url, metadata_calls): + monkeypatch.setenv("AWS_CONTAINER_CREDENTIALS_FULL_URI", f"{base_url}/container-credentials") + expected_authorization = None + if use_token_file: + token_path = tmp_path / "container-authorization-token" + token_path.write_text("pod-identity-token") + monkeypatch.setenv("AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE", str(token_path)) + expected_authorization = "pod-identity-token" + + request = _signed_request() + + assert metadata_calls == [("GET", "/container-credentials", expected_authorization)] + _assert_signed_with(request, "container", region="us-west-2") + + +def test_default_chain_uses_ec2_instance_metadata(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + _isolate_aws_environment(monkeypatch, tmp_path) + monkeypatch.setenv("AWS_REGION", "us-east-1") + monkeypatch.setenv("AWS_EC2_METADATA_DISABLED", "false") + with _metadata_server() as (base_url, metadata_calls): + monkeypatch.setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", base_url) + + request = _signed_request() + + assert [(method, path) for method, path, _ in metadata_calls] == [ + ("PUT", "/latest/api/token"), + ("GET", "/latest/meta-data/iam/security-credentials/"), + ("GET", "/latest/meta-data/iam/security-credentials/instance-role"), + ] + _assert_signed_with(request, "imds", region="us-east-1") + + +def test_default_chain_refreshes_credentials_before_retry(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + _isolate_aws_environment(monkeypatch, tmp_path) + monkeypatch.setenv("AWS_REGION", "us-east-1") + credentials_module = pytest.importorskip("botocore.credentials") + session_module = pytest.importorskip("botocore.session") + refreshes = 0 + + def expires_soon() -> str: + return (datetime.now(timezone.utc) + timedelta(minutes=5)).isoformat() + + def refresh() -> dict[str, str]: + nonlocal refreshes + refreshes += 1 + return _credentials_metadata(f"refresh-{refreshes}", expiration=expires_soon()) + + credentials = credentials_module.RefreshableCredentials.create_from_metadata( + metadata=_credentials_metadata("initial", expiration=expires_soon()), + refresh_using=refresh, + method="assume-role", + ) + + def get_credentials(_: object) -> Any: + return credentials + + monkeypatch.setattr(session_module.Session, "get_credentials", get_credentials) + requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(500 if len(requests) == 1 else 200, request=request, json={}) + + with OpenAI( + provider=bedrock(), + max_retries=1, + http_client=httpx.Client(transport=httpx.MockTransport(handler), trust_env=False), + ) as client: + client.get("/models", cast_to=httpx.Response) + + assert refreshes == 2 + assert len(requests) == 2 + _assert_signed_with(requests[0], "refresh-1", region="us-east-1") + _assert_signed_with(requests[1], "refresh-2", region="us-east-1")