From ea0777cda9e4781c37a5b44508b403a2ba70609e Mon Sep 17 00:00:00 2001 From: Gabriel Igliozzi Date: Mon, 15 Jun 2026 14:48:54 +0200 Subject: [PATCH 1/2] credential vending impl --- .../catalog/rest/credentials_provider.py | 126 ++++++++++++ pyiceberg/io/__init__.py | 11 ++ tests/catalog/test_credentials_provider.py | 185 ++++++++++++++++++ 3 files changed, 322 insertions(+) create mode 100644 pyiceberg/catalog/rest/credentials_provider.py create mode 100644 tests/catalog/test_credentials_provider.py diff --git a/pyiceberg/catalog/rest/credentials_provider.py b/pyiceberg/catalog/rest/credentials_provider.py new file mode 100644 index 0000000000..d4044f3fdd --- /dev/null +++ b/pyiceberg/catalog/rest/credentials_provider.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# +from datetime import datetime + +from pydantic import Field +from requests import HTTPError, Session + +from pyiceberg.catalog import URI +from pyiceberg.catalog.rest.response import _handle_non_200_response +from pyiceberg.catalog.rest.scan_planning import StorageCredential +from pyiceberg.exceptions import ValidationException +from pyiceberg.io import ( + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN, + S3_ACCESS_KEY_ID, + S3_SECRET_ACCESS_KEY, + S3_SESSION_TOKEN, +) +from pyiceberg.typedef import IcebergBaseModel, Properties +from pyiceberg.utils.properties import get_first_property_value + +S3_SESSION_TOKEN_EXPIRES_AT_MS = "s3.session-token-expires-at-ms" +CREDENTIALS_ENDPOINT = "client.refresh-credentials-endpoint" +REFRESH_CREDENTIALS_ENABLED = "client.refresh-credentials-enabled" + + +class LoadCredentialsResponse(IcebergBaseModel): + credentials: list[StorageCredential] = Field(alias="storage-credentials") + + +class VendedCredentialsProvider: + _session: Session + _properties: Properties + + def __init__(self, session: Session, properties: Properties): + self._session = session + self._properties = properties + + def _extract_s3_credentials_from(self, props: Properties) -> tuple[str | None, str | None, str | None, str | None]: + """Extract only S3 credentials from properties.""" + access_key = get_first_property_value(props, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID) + secret_key = get_first_property_value(props, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY) + session_token = get_first_property_value(props, S3_SESSION_TOKEN, AWS_SESSION_TOKEN) + expiry = get_first_property_value(props, S3_SESSION_TOKEN_EXPIRES_AT_MS) + + return access_key, secret_key, session_token, expiry + + def _to_credentials_property_map( + self, access_key: str | None, secret_key: str | None, session_token: str | None, expiry: str | None + ) -> Properties: + return { + S3_ACCESS_KEY_ID: access_key, + S3_SECRET_ACCESS_KEY: secret_key, + S3_SESSION_TOKEN: session_token, + S3_SESSION_TOKEN_EXPIRES_AT_MS: expiry, + } + + def needs_refresh(self) -> bool: + """Return True if the S3 session token expires within 300s.""" + expiry = get_first_property_value(self._properties, S3_SESSION_TOKEN_EXPIRES_AT_MS) + if expiry is None: + return False + expires_at = datetime.fromtimestamp(int(expiry) / 1000) + seconds_remaining = (expires_at - datetime.now()).total_seconds() + return seconds_remaining < 300 + + def _build_refresh_endpoint(self) -> str: + """Build credential refresh endpoint from properties.""" + catalog_uri = get_first_property_value(self._properties, URI) + credentials_path = get_first_property_value(self._properties, CREDENTIALS_ENDPOINT) + + if catalog_uri is None: + raise ValidationException("Invalid catalog endpoint: None") + + if credentials_path is None: + raise ValidationException("Invalid credentials endpoint: None") + + return str(catalog_uri).rstrip("/") + "/" + str(credentials_path).lstrip("/") + + def _get_new_credentials(self) -> LoadCredentialsResponse | None: + try: + http_response = self._session.get(self._build_refresh_endpoint()) + http_response.raise_for_status() + return LoadCredentialsResponse.model_validate_json(http_response.text) + except HTTPError as exc: + _handle_non_200_response(exc, {}) + return None + + def get_credentials(self) -> Properties: + """Retrieve current S3 credentials, refreshing from the endpoint if near expiry.""" + access_key, secret_key, session_token, expiry = self._extract_s3_credentials_from(self._properties) + + if not self.needs_refresh(): + return self._to_credentials_property_map(access_key, secret_key, session_token, expiry) + + creds = self._get_new_credentials() + + if creds is None: + raise ValidationException("Load credential response is None") + if not creds.credentials: + raise ValueError("Invalid S3 Credentials: empty") + if len(creds.credentials) > 1: + raise ValueError("Invalid S3 Credentials: only one S3 credential should exists") + + updated_creds = self._extract_s3_credentials_from(creds.credentials[0].config) + updated_map = self._to_credentials_property_map(*updated_creds) + + # Update internal properties with new credentials + self._properties = {**self._properties, **updated_map} + + return updated_map diff --git a/pyiceberg/io/__init__.py b/pyiceberg/io/__init__.py index 7dbc651214..255da19b21 100644 --- a/pyiceberg/io/__init__.py +++ b/pyiceberg/io/__init__.py @@ -32,9 +32,13 @@ from io import SEEK_SET from types import TracebackType from typing import ( + TYPE_CHECKING, Protocol, runtime_checkable, ) + +if TYPE_CHECKING: + from pyiceberg.catalog.rest.credentials_provider import VendedCredentialsProvider from urllib.parse import urlparse from pyiceberg.typedef import EMPTY_DICT, Properties @@ -291,6 +295,13 @@ def delete(self, location: str | InputFile | OutputFile) -> None: FileNotFoundError: When the file at the provided location does not exist. """ + def set_credentials_provider(self, provider: VendedCredentialsProvider) -> None: # noqa: B027 + """Inject a credentials provider for refreshing vended storage credentials. + + Args: + provider (VendedCredentialsProvider): A concrete type of VendedCredentialsProvider (e.g S3VendedCredentialsProvider) + """ + LOCATION = "location" WAREHOUSE = "warehouse" diff --git a/tests/catalog/test_credentials_provider.py b/tests/catalog/test_credentials_provider.py new file mode 100644 index 0000000000..a828af263d --- /dev/null +++ b/tests/catalog/test_credentials_provider.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import time +from unittest.mock import MagicMock + +import pytest + +from pyiceberg.catalog.rest.credentials_provider import ( + CREDENTIALS_ENDPOINT, + LoadCredentialsResponse, + VendedCredentialsProvider, +) +from pyiceberg.catalog.rest.scan_planning import StorageCredential + +CATALOG_URI = "http://localhost:8181" +CREDENTIALS_PATH = "v1/credentials" + +BASE_PROPS = { + "uri": CATALOG_URI, + CREDENTIALS_ENDPOINT: CREDENTIALS_PATH, + "s3.access-key-id": "initial-key", + "s3.secret-access-key": "initial-secret", + "s3.session-token": "initial-token", +} + +REFRESH_RESPONSE = LoadCredentialsResponse( + credentials=[ + StorageCredential( + prefix="s3://", + config={ + "s3.access-key-id": "refreshed-key", + "s3.secret-access-key": "refreshed-secret", + "s3.session-token": "refreshed-token", + }, + ) + ] +) + + +def _make_session(response: LoadCredentialsResponse = REFRESH_RESPONSE) -> MagicMock: + session = MagicMock() + mock_response = MagicMock() + mock_response.text = response.model_dump_json(by_alias=True) + mock_response.raise_for_status.return_value = None + session.get.return_value = mock_response + return session + + +def test_get_credentials_no_expiry_returns_static_creds() -> None: + """When no expiry is set, credentials are returned from properties without an HTTP call.""" + session = _make_session() + provider = VendedCredentialsProvider(session, BASE_PROPS) + creds = provider.get_credentials() + + session.get.assert_not_called() + assert creds["s3.access-key-id"] == "initial-key" + assert creds["s3.secret-access-key"] == "initial-secret" + assert creds["s3.session-token"] == "initial-token" + + +def test_get_credentials_far_expiry_returns_static_creds() -> None: + """When expiry is far in the future (>300s), no refresh is triggered.""" + far_future_ms = str(int((time.time() + 3600) * 1000)) # expires in 1 hour + props = {**BASE_PROPS, "s3.session-token-expires-at-ms": far_future_ms} + session = _make_session() + provider = VendedCredentialsProvider(session, props) + creds = provider.get_credentials() + + session.get.assert_not_called() + assert creds["s3.access-key-id"] == "initial-key" + + +def test_get_credentials_near_expiry_calls_refresh_endpoint() -> None: + """When expiry is within 300s, the refresh endpoint is called and new creds returned.""" + near_expiry_ms = str(int((time.time() + 60) * 1000)) # expires in 60s + props = {**BASE_PROPS, "s3.session-token-expires-at-ms": near_expiry_ms} + session = _make_session() + provider = VendedCredentialsProvider(session, props) + creds = provider.get_credentials() + + session.get.assert_called_once_with(f"{CATALOG_URI}/{CREDENTIALS_PATH}") + assert creds["s3.access-key-id"] == "refreshed-key" + assert creds["s3.secret-access-key"] == "refreshed-secret" + assert creds["s3.session-token"] == "refreshed-token" + + +def test_get_credentials_raises_on_empty_credentials() -> None: + """An empty credentials list in the refresh response raises ValueError.""" + near_expiry_ms = str(int((time.time() + 60) * 1000)) + props = {**BASE_PROPS, "s3.session-token-expires-at-ms": near_expiry_ms} + empty_response = LoadCredentialsResponse(credentials=[]) + provider = VendedCredentialsProvider(_make_session(empty_response), props) + + with pytest.raises(ValueError, match="empty"): + provider.get_credentials() + + +def test_get_credentials_raises_on_multiple_credentials() -> None: + """More than one credential in the refresh response raises ValueError.""" + near_expiry_ms = str(int((time.time() + 60) * 1000)) + props = {**BASE_PROPS, "s3.session-token-expires-at-ms": near_expiry_ms} + multi_response = LoadCredentialsResponse( + credentials=[ + StorageCredential(prefix="s3://", config={}), + StorageCredential(prefix="s3://b", config={}), + ] + ) + provider = VendedCredentialsProvider(_make_session(multi_response), props) + + with pytest.raises(ValueError, match="only one"): + provider.get_credentials() + + +def test_build_refresh_endpoint_strips_trailing_slash() -> None: + props = {**BASE_PROPS, "uri": "http://localhost:8181/"} + provider = VendedCredentialsProvider(MagicMock(), props) + assert provider._build_refresh_endpoint() == f"http://localhost:8181/{CREDENTIALS_PATH}" + + +def test_build_refresh_endpoint_raises_without_uri() -> None: + props = {CREDENTIALS_ENDPOINT: CREDENTIALS_PATH} + provider = VendedCredentialsProvider(MagicMock(), props) + + from pyiceberg.exceptions import ValidationException + + with pytest.raises(ValidationException): + provider._build_refresh_endpoint() + + +def test_needs_refresh_true_when_near_expiry() -> None: + near_expiry_ms = str(int((time.time() + 60) * 1000)) + provider = VendedCredentialsProvider(MagicMock(), {**BASE_PROPS, "s3.session-token-expires-at-ms": near_expiry_ms}) + assert provider.needs_refresh() is True + + +def test_needs_refresh_false_when_far_expiry() -> None: + far_expiry_ms = str(int((time.time() + 3600) * 1000)) + provider = VendedCredentialsProvider(MagicMock(), {**BASE_PROPS, "s3.session-token-expires-at-ms": far_expiry_ms}) + assert provider.needs_refresh() is False + + +def test_needs_refresh_false_when_no_expiry() -> None: + provider = VendedCredentialsProvider(MagicMock(), BASE_PROPS) + assert provider.needs_refresh() is False + + +def test_get_credentials_updates_internal_properties_after_refresh() -> None: + """After a refresh, _properties holds the new expiry so needs_refresh() sees the updated state.""" + far_future_ms = str(int((time.time() + 3600) * 1000)) + refreshed_response = LoadCredentialsResponse( + credentials=[ + StorageCredential( + prefix="s3://", + config={ + "s3.access-key-id": "new-key", + "s3.secret-access-key": "new-secret", + "s3.session-token": "new-token", + "s3.session-token-expires-at-ms": far_future_ms, + }, + ) + ] + ) + near_expiry_ms = str(int((time.time() + 60) * 1000)) + props = {**BASE_PROPS, "s3.session-token-expires-at-ms": near_expiry_ms} + provider = VendedCredentialsProvider(_make_session(refreshed_response), props) + + assert provider.needs_refresh() is True + provider.get_credentials() + assert provider.needs_refresh() is False + assert provider._properties["s3.session-token-expires-at-ms"] == far_future_ms From 36935a1a8e49a5253d2430a65cf2952d9513d98d Mon Sep 17 00:00:00 2001 From: Gabriel Igliozzi Date: Mon, 15 Jun 2026 14:54:46 +0200 Subject: [PATCH 2/2] move s3 property to io module --- pyiceberg/catalog/rest/credentials_provider.py | 2 +- pyiceberg/io/__init__.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pyiceberg/catalog/rest/credentials_provider.py b/pyiceberg/catalog/rest/credentials_provider.py index d4044f3fdd..d89550f8bb 100644 --- a/pyiceberg/catalog/rest/credentials_provider.py +++ b/pyiceberg/catalog/rest/credentials_provider.py @@ -30,11 +30,11 @@ S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY, S3_SESSION_TOKEN, + S3_SESSION_TOKEN_EXPIRES_AT_MS, ) from pyiceberg.typedef import IcebergBaseModel, Properties from pyiceberg.utils.properties import get_first_property_value -S3_SESSION_TOKEN_EXPIRES_AT_MS = "s3.session-token-expires-at-ms" CREDENTIALS_ENDPOINT = "client.refresh-credentials-endpoint" REFRESH_CREDENTIALS_ENABLED = "client.refresh-credentials-enabled" diff --git a/pyiceberg/io/__init__.py b/pyiceberg/io/__init__.py index 255da19b21..c70adca0bd 100644 --- a/pyiceberg/io/__init__.py +++ b/pyiceberg/io/__init__.py @@ -58,6 +58,7 @@ S3_ACCESS_KEY_ID = "s3.access-key-id" S3_SECRET_ACCESS_KEY = "s3.secret-access-key" S3_SESSION_TOKEN = "s3.session-token" +S3_SESSION_TOKEN_EXPIRES_AT_MS = "s3.session-token-expires-at-ms" S3_REGION = "s3.region" S3_RESOLVE_REGION = "s3.resolve-region" S3_PROXY_URI = "s3.proxy-uri"