Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 104 additions & 87 deletions src/otari/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,98 +154,16 @@ def __init__(
def _map_api_exception(self, error: ApiException) -> OtariError:
"""Map a generated ``ApiException`` to a typed otari exception.

``ApiException`` carries ``.status`` (int) and ``.body`` (the raw JSON
string the gateway returned) plus ``.headers``. The gateway encodes the
human-readable reason under the ``detail`` key (FastAPI convention).

Most status mappings only apply in platform mode; in non-platform mode
the generic :class:`OtariError` is raised so the caller still gets a
single SDK exception type. The one cross-mode case is
:class:`UnsupportedCapabilityError`, surfaced in both modes.
Thin wrapper over the module-level :func:`map_api_exception` so the
control-plane resources can reuse the same mapping without a client
instance.
"""
status = error.status if isinstance(error.status, int) else 0
headers = error.headers or {}
detail = self._extract_detail(error)
correlation_id = _header_get(headers, "x-correlation-id")
retry_after = _header_get(headers, "retry-after")

full = f"{detail} (correlation_id={correlation_id})" if correlation_id else detail

# Unsupported-capability is surfaced regardless of mode.
if status == 400 and _UNSUPPORTED_MODERATION_RE.search(detail):
provider = _parse_unsupported_provider(detail)
capability = "multimodal_moderation" if "multimodal" in detail else "moderation"
return UnsupportedCapabilityError(
full,
status_code=status,
original_error=error,
provider_name=PROVIDER_NAME,
provider=provider,
capability=capability,
)

if status in (401, 403):
return AuthenticationError(
full, status_code=status, original_error=error, provider_name=PROVIDER_NAME
)
if status == 402:
return InsufficientFundsError(
full, status_code=status, original_error=error, provider_name=PROVIDER_NAME
)
if status == 404:
return ModelNotFoundError(
full, status_code=status, original_error=error, provider_name=PROVIDER_NAME
)
if status == 409:
return BatchNotCompleteError(
full,
status_code=status,
original_error=error,
provider_name=PROVIDER_NAME,
batch_id=_extract_batch_id(detail),
batch_status=_extract_status(detail),
)
if status == 429:
return RateLimitError(
full,
status_code=status,
original_error=error,
provider_name=PROVIDER_NAME,
retry_after=retry_after,
)
if status == 504:
return GatewayTimeoutError(
full, status_code=status, original_error=error, provider_name=PROVIDER_NAME
)
# 502 and any other 5xx are upstream-provider failures.
if status == 502 or 500 <= status < 600:
return UpstreamProviderError(
full, status_code=status, original_error=error, provider_name=PROVIDER_NAME
)

return OtariError(
full, status_code=status, original_error=error, provider_name=PROVIDER_NAME
)
return map_api_exception(error)

@staticmethod
def _extract_detail(error: ApiException) -> str:
"""Pull the gateway's human-readable detail from an ``ApiException`` body."""
body = error.body
if isinstance(body, (bytes, bytearray)):
body = body.decode("utf-8", "replace")
if isinstance(body, str) and body:
try:
parsed = json.loads(body)
except (ValueError, TypeError):
return body
if isinstance(parsed, dict):
detail = parsed.get("detail") or parsed.get("message") or parsed.get("error")
if isinstance(detail, str):
return detail
if detail is not None:
return str(detail)
return body
return error.reason or "An error occurred"
return extract_detail(error)

def _map_streaming_response(self, response: httpx.Response, body: bytes) -> OtariError:
"""Map a failed raw streaming response to a typed otari exception.
Expand Down Expand Up @@ -331,3 +249,102 @@ def _extract_status(message: str) -> str | None:
def _url_encode(value: str) -> str:
"""Percent-encode a single URL component."""
return urllib.parse.quote(value, safe="")


def extract_detail(error: ApiException) -> str:
"""Pull the gateway's human-readable detail from an ``ApiException`` body."""
body = error.body
if isinstance(body, (bytes, bytearray)):
body = body.decode("utf-8", "replace")
if isinstance(body, str) and body:
try:
parsed = json.loads(body)
except (ValueError, TypeError):
return body
if isinstance(parsed, dict):
detail = parsed.get("detail") or parsed.get("message") or parsed.get("error")
if isinstance(detail, str):
return detail
if detail is not None:
return str(detail)
return body
return error.reason or "An error occurred"



def map_api_exception(error: ApiException) -> OtariError:
"""Map a generated ``ApiException`` to a typed otari exception.

``ApiException`` carries ``.status`` (int) and ``.body`` (the raw JSON
string the gateway returned) plus ``.headers``. The gateway encodes the
human-readable reason under the ``detail`` key (FastAPI convention).

Most status mappings only apply in platform mode; in non-platform mode
the generic :class:`OtariError` is raised so the caller still gets a
single SDK exception type. The one cross-mode case is
:class:`UnsupportedCapabilityError`, surfaced in both modes.
"""
status = error.status if isinstance(error.status, int) else 0
headers = error.headers or {}
detail = extract_detail(error)
correlation_id = _header_get(headers, "x-correlation-id")
retry_after = _header_get(headers, "retry-after")

full = f"{detail} (correlation_id={correlation_id})" if correlation_id else detail

# Unsupported-capability is surfaced regardless of mode.
if status == 400 and _UNSUPPORTED_MODERATION_RE.search(detail):
provider = _parse_unsupported_provider(detail)
capability = "multimodal_moderation" if "multimodal" in detail else "moderation"
return UnsupportedCapabilityError(
full,
status_code=status,
original_error=error,
provider_name=PROVIDER_NAME,
provider=provider,
capability=capability,
)

if status in (401, 403):
return AuthenticationError(
full, status_code=status, original_error=error, provider_name=PROVIDER_NAME
)
if status == 402:
return InsufficientFundsError(
full, status_code=status, original_error=error, provider_name=PROVIDER_NAME
)
if status == 404:
return ModelNotFoundError(
full, status_code=status, original_error=error, provider_name=PROVIDER_NAME
)
if status == 409:
return BatchNotCompleteError(
full,
status_code=status,
original_error=error,
provider_name=PROVIDER_NAME,
batch_id=_extract_batch_id(detail),
batch_status=_extract_status(detail),
)
if status == 429:
return RateLimitError(
full,
status_code=status,
original_error=error,
provider_name=PROVIDER_NAME,
retry_after=retry_after,
)
if status == 504:
return GatewayTimeoutError(
full, status_code=status, original_error=error, provider_name=PROVIDER_NAME
)
# 502 and any other 5xx are upstream-provider failures.
if status == 502 or 500 <= status < 600:
return UpstreamProviderError(
full, status_code=status, original_error=error, provider_name=PROVIDER_NAME
)

return OtariError(
full, status_code=status, original_error=error, provider_name=PROVIDER_NAME
)

52 changes: 50 additions & 2 deletions src/otari/control_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@

from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any, cast
from functools import cached_property, wraps
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast

from otari import _client as _cp
from otari._base import map_api_exception
from otari._client.api.budgets_api import BudgetsApi
from otari._client.api.keys_api import KeysApi
from otari._client.api.pricing_api import PricingApi
from otari._client.api.usage_api import UsageApi
from otari._client.api.users_api import UsersApi
from otari._client.exceptions import ApiException

if TYPE_CHECKING:
from collections.abc import Callable
from datetime import datetime

from otari._client import (
Expand All @@ -49,6 +52,29 @@
)


_P = ParamSpec("_P")
_R = TypeVar("_R")


def _translate(fn: Callable[_P, _R]) -> Callable[_P, _R]:
"""Map a generated ``ApiException`` to a typed :class:`otari.errors.OtariError`.

The inference client maps generated exceptions in ``client.py``; the
control-plane ergonomic aliases get the same treatment here so callers see a
single SDK error type instead of the raw generated ``ApiException``. The
``raw`` escape hatch is intentionally left unwrapped.
"""

@wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
try:
return fn(*args, **kwargs)
except ApiException as exc:
raise map_api_exception(exc) from exc

return wrapper


class KeysResource:
"""Ergonomic accessors for the API-keys management endpoints.

Expand All @@ -59,18 +85,23 @@ class KeysResource:
def __init__(self, api: KeysApi) -> None:
self.raw = api

@_translate
def create(self, request: CreateKeyRequest, **kwargs: Any) -> CreateKeyResponse:
return self.raw.create_key_v1_keys_post(request, **kwargs)

@_translate
def get(self, key_id: str, **kwargs: Any) -> KeyInfo:
return self.raw.get_key_v1_keys_key_id_get(key_id, **kwargs)

@_translate
def list(self, skip: int | None = None, limit: int | None = None, **kwargs: Any) -> list[KeyInfo]:
return self.raw.list_keys_v1_keys_get(skip, limit, **kwargs)

@_translate
def update(self, key_id: str, request: UpdateKeyRequest, **kwargs: Any) -> KeyInfo:
return self.raw.update_key_v1_keys_key_id_patch(key_id, request, **kwargs)

@_translate
def delete(self, key_id: str, **kwargs: Any) -> None:
self.raw.delete_key_v1_keys_key_id_delete(key_id, **kwargs)

Expand All @@ -85,23 +116,29 @@ class UsersResource:
def __init__(self, api: UsersApi) -> None:
self.raw = api

@_translate
def create(self, request: CreateUserRequest, **kwargs: Any) -> UserResponse:
return self.raw.create_user_v1_users_post(request, **kwargs)

@_translate
def get(self, user_id: str, **kwargs: Any) -> UserResponse:
return self.raw.get_user_v1_users_user_id_get(user_id, **kwargs)

@_translate
def update(self, user_id: str, request: UpdateUserRequest, **kwargs: Any) -> UserResponse:
return self.raw.update_user_v1_users_user_id_patch(user_id, request, **kwargs)

@_translate
def delete(self, user_id: str, **kwargs: Any) -> None:
self.raw.delete_user_v1_users_user_id_delete(user_id, **kwargs)

@_translate
def get_usage(self, user_id: str, **kwargs: Any) -> list[UsageLogResponse]:
return self.raw.get_user_usage_v1_users_user_id_usage_get(user_id, **kwargs)

# Defined last: a method named ``list`` shadows the ``list`` builtin for any
# ``list[...]`` annotation that follows it in this class body.
@_translate
def list(self, skip: int | None = None, limit: int | None = None, **kwargs: Any) -> list[UserResponse]:
return self.raw.list_users_v1_users_get(skip, limit, **kwargs)

Expand All @@ -116,18 +153,23 @@ class BudgetsResource:
def __init__(self, api: BudgetsApi) -> None:
self.raw = api

@_translate
def create(self, request: CreateBudgetRequest, **kwargs: Any) -> BudgetResponse:
return self.raw.create_budget_v1_budgets_post(request, **kwargs)

@_translate
def get(self, budget_id: str, **kwargs: Any) -> BudgetResponse:
return self.raw.get_budget_v1_budgets_budget_id_get(budget_id, **kwargs)

@_translate
def list(self, skip: int | None = None, limit: int | None = None, **kwargs: Any) -> list[BudgetResponse]:
return self.raw.list_budgets_v1_budgets_get(skip, limit, **kwargs)

@_translate
def update(self, budget_id: str, request: UpdateBudgetRequest, **kwargs: Any) -> BudgetResponse:
return self.raw.update_budget_v1_budgets_budget_id_patch(budget_id, request, **kwargs)

@_translate
def delete(self, budget_id: str, **kwargs: Any) -> None:
self.raw.delete_budget_v1_budgets_budget_id_delete(budget_id, **kwargs)

Expand All @@ -142,20 +184,25 @@ class PricingResource:
def __init__(self, api: PricingApi) -> None:
self.raw = api

@_translate
def get(self, model_key: str, **kwargs: Any) -> PricingResponse:
return self.raw.get_pricing_v1_pricing_model_key_get(model_key, **kwargs)

@_translate
def set(self, request: SetPricingRequest, **kwargs: Any) -> PricingResponse:
return self.raw.set_pricing_v1_pricing_post(request, **kwargs)

@_translate
def delete(self, model_key: str, **kwargs: Any) -> None:
self.raw.delete_pricing_v1_pricing_model_key_delete(model_key, **kwargs)

@_translate
def get_history(self, model_key: str, **kwargs: Any) -> list[PricingResponse]:
return self.raw.get_pricing_history_v1_pricing_model_key_history_get(model_key, **kwargs)

# Defined last: a method named ``list`` shadows the ``list`` builtin for any
# ``list[...]`` annotation that follows it in this class body.
@_translate
def list(self, skip: int | None = None, limit: int | None = None, **kwargs: Any) -> list[PricingResponse]:
return self.raw.list_pricing_v1_pricing_get(skip, limit, **kwargs)

Expand All @@ -170,6 +217,7 @@ class UsageResource:
def __init__(self, api: UsageApi) -> None:
self.raw = api

@_translate
def list(
self,
start_date: datetime | None = None,
Expand Down
Loading
Loading