diff --git a/src/auth/oauth_token_manager.py b/src/auth/oauth_token_manager.py index 84ae339b..313744d4 100644 --- a/src/auth/oauth_token_manager.py +++ b/src/auth/oauth_token_manager.py @@ -1,5 +1,7 @@ +import asyncio import datetime import logging +from collections import defaultdict from typing import Dict, Optional import tornado.ioloop @@ -15,6 +17,7 @@ class OAuthTokenManager: def __init__(self, enabled, fetch_token_callback) -> None: self._refresh_tokens = {} # type: Dict[str, str] self._pending_access_tokens = {} # type: Dict[str, OAuthTokenResponse] + self._refresh_locks = defaultdict(asyncio.Lock) # type: Dict[str, asyncio.Lock] self._scheduler = None self._enabled = enabled @@ -110,7 +113,7 @@ def remove_user(self, username): def _schedule_token_refresh(self, username, refresh_token, next_refresh_datetime): if not self._scheduler: - self.scheduler = Scheduler() + self._scheduler = Scheduler() token_expires_in = next_refresh_datetime - datetime.datetime.now() if token_expires_in < datetime.timedelta(seconds=30): @@ -120,31 +123,34 @@ def _schedule_token_refresh(self, username, refresh_token, next_refresh_datetime else: next_refresh_datetime_adjusted = next_refresh_datetime - datetime.timedelta(minutes=1) - self.scheduler.schedule( + self._scheduler.schedule( next_refresh_datetime_adjusted, tornado.ioloop.IOLoop.current().add_callback, (self._refresh_token, username, refresh_token)) async def _refresh_token(self, username, refresh_token, force=False): - if not force: - if (username not in self._refresh_tokens) or (self._refresh_tokens[username] != refresh_token): - return + # serialize refreshes per user: a concurrent refresh with the same (rotated) + # refresh token would get a 401 from the provider and log the user out + async with self._refresh_locks[username]: + if not force: + if (username not in self._refresh_tokens) or (self._refresh_tokens[username] != refresh_token): + return - token_response = await self._fetch_token_callback(refresh_token, username) + token_response = await self._fetch_token_callback(refresh_token, username) - if token_response is None: - return + if token_response is None: + return - LOGGER.info(f'Refreshed token for {username}') + LOGGER.info(f'Refreshed token for {username}') - self._refresh_tokens[username] = token_response.refresh_token - self._pending_access_tokens[username] = token_response + self._refresh_tokens[username] = token_response.refresh_token + self._pending_access_tokens[username] = token_response - if token_response.should_refresh(): - self._schedule_token_refresh( - username, - token_response.refresh_token, - token_response.resolve_next_refresh_datetime()) + if token_response.should_refresh(): + self._schedule_token_refresh( + username, + token_response.refresh_token, + token_response.resolve_next_refresh_datetime()) @staticmethod def _restore_token_response_from_cookies(request_handler) -> Optional[OAuthTokenResponse]: diff --git a/src/tests/auth/test_auth_keycloak_openid.py b/src/tests/auth/test_auth_keycloak_openid.py index 6f6f2e92..fca9ef71 100644 --- a/src/tests/auth/test_auth_keycloak_openid.py +++ b/src/tests/auth/test_auth_keycloak_openid.py @@ -119,7 +119,8 @@ def send_tokens(self, token_prefix, request_handler): 'refresh_expires_in': refresh_expiration_duration }) - self.cleanup_old_tokens(self.access_token_expiration_times, token_prefix) + # Real Keycloak rotates refresh tokens, but old access tokens are stateless JWTs + # and stay valid until their expiration, even after a refresh self.cleanup_old_tokens(self.refresh_token_expiration_times, token_prefix) self.access_token_expiration_times[access_token] = time.time() + access_expiration_duration @@ -193,13 +194,56 @@ async def test_success_validate_after_refresh(self): valid_1 = await self.authenticator.validate_user(username, mock_request_handler(previous_request=request_1)) self.assertTrue(valid_1) - for i in range(1, 20): - await gen.sleep(0.1) + await self.wait_for_groups('bugy', ['g3']) + + @gen_test + async def test_success_validate_when_refresh_races_with_validation(self): + # Regression test for a flaky failure of test_success_validate_after_refresh: + # the scheduler-driven token refresh fires on the IOLoop right before validate_user, + # so the userinfo request is sent with an access token from before the refresh + username, request_1 = await self.authenticate('qwerty123') + + self.oauth_server.set_groups('bugy', ['g3']) - if self.authenticator.get_groups('bugy') == ['g3']: + await gen.sleep(auth_info_ttl + 0.5) + + token_manager = self.authenticator._token_manager + current_refresh_token = token_manager._refresh_tokens[username] + self.io_loop.add_callback(token_manager._refresh_token, username, current_refresh_token) + + valid_1 = await self.authenticator.validate_user(username, mock_request_handler(previous_request=request_1)) + self.assertTrue(valid_1) + + await self.wait_for_groups('bugy', ['g3']) + + @gen_test + async def test_success_validate_when_concurrent_refreshes(self): + # Two refreshes in flight with the same refresh token: without per-user + # serialization, the second one gets 401 (token rotated) and logs the user out + username, request_1 = await self.authenticate('qwerty123') + + self.oauth_server.set_groups('bugy', ['g3']) + + await gen.sleep(auth_info_ttl + 0.5) + + token_manager = self.authenticator._token_manager + current_refresh_token = token_manager._refresh_tokens[username] + self.io_loop.add_callback(token_manager._refresh_token, username, current_refresh_token) + self.io_loop.add_callback(token_manager._refresh_token, username, current_refresh_token) + + valid_1 = await self.authenticator.validate_user(username, mock_request_handler(previous_request=request_1)) + self.assertTrue(valid_1) + + await self.wait_for_groups('bugy', ['g3']) + + async def wait_for_groups(self, username, expected_groups): + for i in range(1, 20): + if self.authenticator.get_groups(username) == expected_groups: break - self.assertEqual(['g3'], self.authenticator.get_groups('bugy')) + await gen.sleep(0.1) + + self.assertEqual(expected_groups, self.authenticator.get_groups(username)) @gen_test async def test_failed_validate_after_deactivate(self):