Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any, Protocol
from urllib.parse import quote, urlencode, urljoin, urlparse
from urllib.parse import parse_qsl, quote, urlencode, urljoin, urlparse, urlsplit, urlunsplit

import anyio
import httpx
Expand Down Expand Up @@ -353,7 +353,14 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:
if "offline_access" in self.context.client_metadata.scope.split():
auth_params["prompt"] = "consent"

authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
auth_endpoint_parts = urlsplit(auth_endpoint)
authorization_query = urlencode(
[
*parse_qsl(auth_endpoint_parts.query, keep_blank_values=True),
*auth_params.items(),
]
)
authorization_url = urlunsplit(auth_endpoint_parts._replace(query=authorization_query))
await self.context.redirect_handler(authorization_url)

# Wait for callback
Expand Down
44 changes: 44 additions & 0 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,50 @@ def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: O
class TestOAuthFlow:
"""Test OAuth flow methods."""

@pytest.mark.anyio
async def test_authorization_endpoint_query_params_are_preserved(
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
):
"""OAuth authorization endpoints may already carry provider-specific query params."""
captured_state: str | None = None

async def redirect_handler(url: str) -> None:
nonlocal captured_state
parsed = urlparse(url)
params = parse_qs(parsed.query)

assert params["prompt"] == ["select_account"]
assert params["response_type"] == ["code"]
assert params["client_id"] == ["test_client"]

captured_state = params.get("state", [None])[0]

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", captured_state

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)
provider.context.oauth_metadata = OAuthMetadata(
issuer=AnyHttpUrl("https://auth.example.com"),
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize?prompt=select_account"),
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
)
provider.context.client_info = OAuthClientInformationFull(
client_id="test_client",
client_secret="test_secret",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)

auth_code, code_verifier = await provider._perform_authorization_code_grant()

assert auth_code == "test_auth_code"
assert code_verifier

@pytest.mark.anyio
async def test_build_protected_resource_discovery_urls(
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
Expand Down
Loading