From 4e4e43889c1f4fce38319492e85568e4c95ca599 Mon Sep 17 00:00:00 2001 From: salimtb Date: Thu, 25 Jun 2026 16:23:41 +0200 Subject: [PATCH 1/4] fix: fix balance update after transactions --- .../src/AssetsController.test.ts | 76 ++++ .../assets-controller/src/AssetsController.ts | 174 +++++++-- .../AccountsApiDataSource.test.ts | 16 + .../src/data-sources/AccountsApiDataSource.ts | 11 +- .../BackendWebsocketDataSource.test.ts | 333 +++++++++++++++++- .../BackendWebsocketDataSource.ts | 216 ++++++++++-- packages/assets-controller/src/types.ts | 5 + 7 files changed, 767 insertions(+), 64 deletions(-) diff --git a/packages/assets-controller/src/AssetsController.test.ts b/packages/assets-controller/src/AssetsController.test.ts index b81c411902..df9d52a6b1 100644 --- a/packages/assets-controller/src/AssetsController.test.ts +++ b/packages/assets-controller/src/AssetsController.test.ts @@ -1720,6 +1720,82 @@ describe('AssetsController', () => { }); }); + it('does not let subscription polling overwrite a recent websocket balance update', async () => { + const initialState: Partial = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET_ID]: { amount: '8.185173' }, + }, + }, + }; + + await withController({ state: initialState }, async ({ controller }) => { + await controller.handleAssetsUpdate( + { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET_ID]: { amount: '7.185173' }, + }, + }, + }, + 'BackendWebsocketDataSource', + ); + + await controller.handleAssetsUpdate( + { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET_ID]: { amount: '8.185173' }, + }, + }, + }, + 'AccountsApiDataSource', + ); + + expect( + controller.state.assetsBalance[MOCK_ACCOUNT_ID]?.[MOCK_ASSET_ID], + ).toStrictEqual({ amount: '7.185173' }); + }); + }); + + it('applies getAssets forceUpdate over a recent websocket balance update', async () => { + const initialState: Partial = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET_ID]: { amount: '8.185173' }, + }, + }, + }; + + await withController({ state: initialState }, async ({ controller }) => { + await controller.handleAssetsUpdate( + { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET_ID]: { amount: '7.185173' }, + }, + }, + }, + 'BackendWebsocketDataSource', + ); + + await controller.handleAssetsUpdate( + { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET_ID]: { amount: '8.185173' }, + }, + }, + }, + 'getAssets:forceUpdate', + ); + + expect( + controller.state.assetsBalance[MOCK_ACCOUNT_ID]?.[MOCK_ASSET_ID], + ).toStrictEqual({ amount: '8.185173' }); + }); + }); + it('replaces state when full update has authoritative data', async () => { const initialState: Partial = { assetsBalance: { diff --git a/packages/assets-controller/src/AssetsController.ts b/packages/assets-controller/src/AssetsController.ts index 818658766b..ada20a462b 100644 --- a/packages/assets-controller/src/AssetsController.ts +++ b/packages/assets-controller/src/AssetsController.ts @@ -188,6 +188,17 @@ const MESSENGER_EXPOSED_METHODS = [ /** Default polling interval hint for data sources (30 seconds) */ const DEFAULT_POLLING_INTERVAL_MS = 30_000; +/** Sources whose passive polling must not overwrite recent websocket balances. */ +const POLLING_BALANCE_SOURCES = new Set([ + 'AccountsApiDataSource', + 'RpcDataSource', + 'SnapDataSource', + 'StakedBalanceDataSource', +]); + +/** How long websocket balance updates block stale polling overwrites. */ +const WS_BALANCE_FRESHNESS_MS = 120_000; + // ============================================================================ // TRACE NAMES — used in Sentry spans (search these strings in Discover) // ============================================================================ @@ -525,6 +536,10 @@ function normalizeResponse(response: DataResponse): DataResponse { normalized.updateMode = response.updateMode; } + if (response.sourceId) { + normalized.sourceId = response.sourceId; + } + return normalized; } @@ -667,6 +682,9 @@ export class AssetsController extends BaseController< readonly #controllerMutex = new Mutex(); + /** Serializes account-switch fetch + subscribe to prevent overlapping races. */ + readonly #accountRefreshMutex = new Mutex(); + /** * Active balance subscriptions keyed by account ID. * Each account has one logical subscription that may span multiple data sources. @@ -685,6 +703,12 @@ export class AssetsController extends BaseController< */ #lastKnownAccountIds: ReadonlySet = new Set(); + /** + * Per `accountId:assetId`, timestamp until which websocket balance updates + * should not be overwritten by polling/API fetches. + */ + readonly #wsBalanceFreshUntil = new Map(); + /** * Get the currently selected accounts from AccountTreeController. * This includes all accounts in the same group as the selected account @@ -1139,26 +1163,45 @@ export class AssetsController extends BaseController< return; } + const hasOverlap = [...currentIds].some((id) => + this.#lastKnownAccountIds.has(id), + ); + if (!hasOverlap && this.#lastKnownAccountIds.size > 0) { + return; + } + log('Account tree changed with new accounts, re-subscribing', { previousCount: this.#lastKnownAccountIds.size, currentCount: currentIds.size, }); this.#lastKnownAccountIds = currentIds; - this.#subscribeAssets(); this.#ensureNativeBalancesDefaultZero(); this.#ensureDefaultTrackedAssetsSeeded(); - this.getAssets(accounts, { - chainIds: [...this.#enabledChains], - forceUpdate: true, - }).catch((error) => { - log('Failed to fetch assets after tree change', error); + this.#runAccountTreeRefresh(accounts).catch((error) => { + log('Failed to refresh assets after tree change', error); }); } else { this.#start(); } } + async #runAccountTreeRefresh(accounts: InternalAccount[]): Promise { + const releaseLock = await this.#accountRefreshMutex.acquire(); + try { + await this.getAssets(accounts, { + chainIds: [...this.#enabledChains], + forceUpdate: true, + }); + this.#subscribeAssets(); + } catch (error) { + log('Failed to fetch assets after tree change', error); + this.#subscribeAssets(); + } finally { + releaseLock(); + } + } + #registerActionHandlers(): void { this.messenger.registerMethodActionHandlers( this, @@ -1428,7 +1471,11 @@ export class AssetsController extends BaseController< // The fast pipeline only contains a subset of data sources (AccountsApi + // StakedBalance), so it must always merge to avoid wiping Snap/RPC // balances that the background pipeline hasn't yet replaced. - await this.#updateState({ ...response, updateMode: 'merge' }); + await this.#updateState({ + ...response, + updateMode: 'merge', + sourceId: 'getAssets:forceUpdate', + }); // Background pipeline: snap and RPC run in parallel after the fast path // commits to state. Their balances are merged together before detection. @@ -1454,7 +1501,11 @@ export class AssetsController extends BaseController< request, ) .then(({ response: slowResponse }) => - this.#updateState({ ...slowResponse, updateMode: 'merge' }), + this.#updateState({ + ...slowResponse, + updateMode: 'merge', + sourceId: 'getAssets:forceUpdate', + }), ) .catch((error) => log('Background pipeline failed', { error })); @@ -2113,10 +2164,56 @@ export class AssetsController extends BaseController< }); } + #filterBalancesRespectingWsFreshness( + accountId: string, + accountBalances: Record, + sourceId?: string, + ): Record { + if (!sourceId || !POLLING_BALANCE_SOURCES.has(sourceId)) { + return accountBalances; + } + + const now = Date.now(); + const filtered: Record = {}; + + for (const [assetId, balance] of Object.entries(accountBalances)) { + const freshUntil = this.#wsBalanceFreshUntil.get(`${accountId}:${assetId}`); + if (freshUntil !== undefined && now < freshUntil) { + continue; + } + filtered[assetId] = balance; + } + + return filtered; + } + + #markWsBalancesFresh( + assetsBalance: Record>, + ): void { + const freshUntil = Date.now() + WS_BALANCE_FRESHNESS_MS; + for (const [accountId, balances] of Object.entries(assetsBalance)) { + for (const assetId of Object.keys(balances)) { + this.#wsBalanceFreshUntil.set(`${accountId}:${assetId}`, freshUntil); + } + } + } + + #clearWsBalanceFreshness( + assetsBalance: Record>, + ): void { + for (const [accountId, balances] of Object.entries(assetsBalance)) { + for (const assetId of Object.keys(balances)) { + this.#wsBalanceFreshUntil.delete(`${accountId}:${assetId}`); + } + } + } + async #updateState(response: DataResponse): Promise { const normalizedResponse = normalizeResponse(response); const mode: AssetsUpdateMode = normalizedResponse.updateMode ?? 'merge'; + const assetsBalanceToApply = normalizedResponse.assetsBalance; + const releaseLock = await this.#controllerMutex.acquire(); try { @@ -2199,10 +2296,16 @@ export class AssetsController extends BaseController< } } - if (normalizedResponse.assetsBalance) { + if (assetsBalanceToApply) { for (const [accountId, accountBalances] of Object.entries( - normalizedResponse.assetsBalance, + assetsBalanceToApply, )) { + const filteredAccountBalances = + this.#filterBalancesRespectingWsFreshness( + accountId, + accountBalances, + normalizedResponse.sourceId, + ); const previousBalances = previousState.assetsBalance[accountId] ?? {}; const customAssetIds = @@ -2217,11 +2320,11 @@ export class AssetsController extends BaseController< // Merge: response overlays previous balances. const effective: Record = mode === 'merge' - ? { ...previousBalances, ...accountBalances } + ? { ...previousBalances, ...filteredAccountBalances } : ((): Record => { // Determine which chain namespaces this response covers. const coveredChains = new Set( - Object.keys(accountBalances).map( + Object.keys(filteredAccountBalances).map( (assetId) => assetId.split('/')[0], ), ); @@ -2238,7 +2341,7 @@ export class AssetsController extends BaseController< } // Apply the response (authoritative for covered chains). - Object.assign(next, accountBalances); + Object.assign(next, filteredAccountBalances); // Preserve custom assets that the response omitted. for (const customId of customAssetIds) { @@ -2385,6 +2488,15 @@ export class AssetsController extends BaseController< }); } } + + // Authoritative fetch on account switch — drop WS freshness locks so API + // balances (e.g. receiver +1 USDC) replace stale values from a prior send. + if ( + normalizedResponse.sourceId === 'getAssets:forceUpdate' && + assetsBalanceToApply + ) { + this.#clearWsBalanceFreshness(assetsBalanceToApply); + } } finally { releaseLock(); } @@ -3021,17 +3133,23 @@ export class AssetsController extends BaseController< this.#lastKnownAccountIds = new Set(accounts.map((a) => a.id)); - // Subscribe and fetch for the new account group - this.#subscribeAssets(); - if (accounts.length > 0) { - await this.getAssets(accounts, { - chainIds: [...this.#enabledChains], - forceUpdate: true, - }); - } + const releaseLock = await this.#accountRefreshMutex.acquire(); + try { + if (accounts.length > 0) { + await this.getAssets(accounts, { + chainIds: [...this.#enabledChains], + forceUpdate: true, + }); + } - this.#ensureNativeBalancesDefaultZero(); - this.#ensureDefaultTrackedAssetsSeeded(); + // Subscribe after fetch so WS notifications can recover state + this.#subscribeAssets(); + + this.#ensureNativeBalancesDefaultZero(); + this.#ensureDefaultTrackedAssetsSeeded(); + } finally { + releaseLock(); + } } async #handleEnabledNetworksChanged( @@ -3151,6 +3269,7 @@ export class AssetsController extends BaseController< request?: DataRequest, ): Promise { const updateStart = performance.now(); + log('Assets updated from data source', { sourceId, hasBalance: Boolean(response.assetsBalance), @@ -3202,7 +3321,14 @@ export class AssetsController extends BaseController< response, ); - await this.#updateState(enrichedResponse); + await this.#updateState({ ...enrichedResponse, sourceId }); + + if ( + sourceId === 'BackendWebsocketDataSource' && + enrichedResponse.assetsBalance + ) { + this.#markWsBalancesFresh(enrichedResponse.assetsBalance); + } this.#emitTrace(TRACE_UPDATE_PIPELINE, { source: sourceId, diff --git a/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts b/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts index 19c5d3fdf8..d964b6a744 100644 --- a/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts @@ -300,6 +300,22 @@ describe('AccountsApiDataSource', () => { expect(apiClient.accounts.fetchV5MultiAccountBalances).toHaveBeenCalledWith( [`eip155:1:${MOCK_ADDRESS}`], + undefined, + undefined, + ); + + controller.destroy(); + }); + + it('fetch bypasses TanStack cache when forceUpdate is true', async () => { + const { controller, apiClient } = await setupController(); + + await controller.fetch(createDataRequest({ forceUpdate: true })); + + expect(apiClient.accounts.fetchV5MultiAccountBalances).toHaveBeenCalledWith( + [`eip155:1:${MOCK_ADDRESS}`], + undefined, + { staleTime: 0, gcTime: 0 }, ); controller.destroy(); diff --git a/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts b/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts index 57200f79b7..ecd614432a 100644 --- a/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts +++ b/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts @@ -319,8 +319,17 @@ export class AccountsApiDataSource extends AbstractDataSource< return response; } + const fetchOptions = request.forceUpdate + ? { staleTime: 0, gcTime: 0 } + : undefined; + const apiResponse = await fetchWithTimeout( - () => this.#apiClient.accounts.fetchV5MultiAccountBalances(accountIds), + () => + this.#apiClient.accounts.fetchV5MultiAccountBalances( + accountIds, + undefined, + fetchOptions, + ), this.#fetchTimeoutMs, ); diff --git a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts index 5da516cd1f..5a27a02b4e 100644 --- a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts @@ -35,6 +35,8 @@ type SetupResult = { wsSubscribeMock: jest.Mock; getConnectionInfoMock: jest.Mock; findSubscriptionsMock: jest.Mock; + addChannelCallbackMock: jest.Mock; + removeChannelCallbackMock: jest.Mock; assetsUpdateHandler: jest.Mock; activeChainsUpdateHandler: jest.Mock; triggerConnectionStateChange: (state: WebSocketState) => void; @@ -78,10 +80,12 @@ function createDataRequest( }; } -function createMockWsSubscription(): WebSocketSubscription { +function createMockWsSubscription( + channels: string[] = [], +): WebSocketSubscription { return { unsubscribe: jest.fn().mockResolvedValue(undefined), - channels: [], + channels, } as unknown as WebSocketSubscription; } @@ -129,6 +133,8 @@ function setupController( 'BackendWebSocketService:subscribe', 'BackendWebSocketService:getConnectionInfo', 'BackendWebSocketService:findSubscriptionsByChannelPrefix', + 'BackendWebSocketService:addChannelCallback', + 'BackendWebSocketService:removeChannelCallback', ], events: ['BackendWebSocketService:connectionStateChanged'], }); @@ -138,6 +144,8 @@ function setupController( const wsSubscribeMock = jest .fn() .mockResolvedValue(createMockWsSubscription()); + const addChannelCallbackMock = jest.fn(); + const removeChannelCallbackMock = jest.fn().mockReturnValue(true); const getConnectionInfoMock = jest.fn().mockReturnValue({ state: connectionState, url: 'wss://test.example.com', @@ -161,6 +169,14 @@ function setupController( 'BackendWebSocketService:findSubscriptionsByChannelPrefix', findSubscriptionsMock, ); + rootMessenger.registerActionHandler( + 'BackendWebSocketService:addChannelCallback', + addChannelCallbackMock, + ); + rootMessenger.registerActionHandler( + 'BackendWebSocketService:removeChannelCallback', + removeChannelCallbackMock, + ); const queryApiClient = { accounts: { @@ -221,6 +237,8 @@ function setupController( wsSubscribeMock, getConnectionInfoMock, findSubscriptionsMock, + addChannelCallbackMock, + removeChannelCallbackMock, assetsUpdateHandler, activeChainsUpdateHandler, triggerConnectionStateChange, @@ -491,13 +509,116 @@ describe('BackendWebsocketDataSource', () => { controller.destroy(); }); - it('unsubscribe cleans up WebSocket subscription', async () => { - const mockWsSubscription = createMockWsSubscription(); + it('subscribe update treats checksummed and lowercase EVM addresses as unchanged', async () => { const { controller, wsSubscribeMock } = setupController({ initialActiveChains: [CHAIN_MAINNET], connectionState: WebSocketState.CONNECTED, }); + await controller.subscribe({ + subscriptionId: 'sub-1', + request: createDataRequest(), + isUpdate: false, + onAssetsUpdate: jest.fn(), + }); + + await controller.subscribe({ + subscriptionId: 'sub-1', + request: createDataRequest({ + accountsWithSupportedChains: [ + { + account: createMockAccount({ + address: `0x${MOCK_ADDRESS.slice(2).toUpperCase()}`, + }), + supportedChains: [CHAIN_MAINNET], + }, + ], + chainIds: [CHAIN_MAINNET], + }), + isUpdate: true, + onAssetsUpdate: jest.fn(), + }); + + expect(wsSubscribeMock).toHaveBeenCalledTimes(1); + + controller.destroy(); + }); + + it('serializes concurrent subscribe calls so the last address wins', async () => { + const addressA = MOCK_ADDRESS; + const addressB = '0xabcdef1234567890abcdef1234567890abcdef12'; + let resolveFirstSubscribe: (() => void) | undefined; + const firstSubscribeGate = new Promise((resolve) => { + resolveFirstSubscribe = resolve; + }); + + const { controller, wsSubscribeMock } = setupController({ + initialActiveChains: [CHAIN_MAINNET], + connectionState: WebSocketState.CONNECTED, + }); + + wsSubscribeMock + .mockImplementationOnce(async () => { + await firstSubscribeGate; + return createMockWsSubscription([ + `account-activity.v1.eip155:0:${addressA.toLowerCase()}`, + ]); + }) + .mockResolvedValue( + createMockWsSubscription([ + `account-activity.v1.eip155:0:${addressB.toLowerCase()}`, + ]), + ); + + const firstSubscribe = controller.subscribe({ + subscriptionId: 'sub-1', + request: createDataRequest({ + accountsWithSupportedChains: [ + { + account: createMockAccount({ address: addressA }), + supportedChains: [CHAIN_MAINNET], + }, + ], + }), + isUpdate: false, + onAssetsUpdate: jest.fn(), + }); + + const secondSubscribe = controller.subscribe({ + subscriptionId: 'sub-1', + request: createDataRequest({ + accountsWithSupportedChains: [ + { + account: createMockAccount({ address: addressB }), + supportedChains: [CHAIN_MAINNET], + }, + ], + }), + isUpdate: true, + onAssetsUpdate: jest.fn(), + }); + + await new Promise(process.nextTick); + resolveFirstSubscribe?.(); + await Promise.all([firstSubscribe, secondSubscribe]); + + expect(wsSubscribeMock).toHaveBeenCalledTimes(2); + expect(wsSubscribeMock.mock.calls[1][0].channels).toEqual([ + `account-activity.v1.eip155:0:${addressB.toLowerCase()}`, + ]); + + controller.destroy(); + }); + + it('unsubscribe cleans up WebSocket subscription', async () => { + const channel = `account-activity.v1.eip155:0:${MOCK_ADDRESS.toLowerCase()}`; + const mockWsSubscription = createMockWsSubscription([channel]); + const { controller, wsSubscribeMock, removeChannelCallbackMock } = + setupController({ + initialActiveChains: [CHAIN_MAINNET], + connectionState: WebSocketState.CONNECTED, + }); + wsSubscribeMock.mockResolvedValueOnce(mockWsSubscription); await controller.subscribe({ @@ -510,6 +631,189 @@ describe('BackendWebsocketDataSource', () => { await controller.unsubscribe('sub-1'); expect(mockWsSubscription.unsubscribe).toHaveBeenCalled(); + expect(removeChannelCallbackMock).toHaveBeenCalledWith(channel); + + controller.destroy(); + }); + + it('registers channel callbacks as fallback when subscriptionId does not match', async () => { + const channel = `account-activity.v1.eip155:0:${MOCK_ADDRESS.toLowerCase()}`; + const mockWsSubscription = createMockWsSubscription([channel]); + const onAssetsUpdate = jest.fn().mockResolvedValue(undefined); + const { + controller, + wsSubscribeMock, + addChannelCallbackMock, + } = setupController({ + initialActiveChains: [CHAIN_MAINNET], + connectionState: WebSocketState.CONNECTED, + }); + + wsSubscribeMock.mockResolvedValueOnce(mockWsSubscription); + + await controller.subscribe({ + subscriptionId: 'sub-1', + request: createDataRequest(), + isUpdate: false, + onAssetsUpdate, + }); + + expect(addChannelCallbackMock).toHaveBeenCalledWith( + expect.objectContaining({ channelName: channel }), + ); + + const channelCallback = addChannelCallbackMock.mock.calls.find( + ([args]) => args.channelName === channel, + )?.[0].callback; + + expect(channelCallback).toBeDefined(); + + channelCallback( + createMockNotification({ + channel, + subscriptionId: 'stale-server-sub-id', + data: { + address: MOCK_ADDRESS, + tx: { chain: CHAIN_MAINNET }, + updates: [ + { + asset: { + type: 'eip155:1/erc20:0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48', + decimals: 6, + }, + postBalance: { amount: '1000000' }, + }, + ], + }, + }), + ); + + await new Promise(process.nextTick); + + expect(onAssetsUpdate).toHaveBeenCalledWith( + expect.objectContaining({ + assetsBalance: expect.objectContaining({ + 'mock-account-id': expect.any(Object), + }), + }), + expect.objectContaining({ + accountsWithSupportedChains: expect.any(Array), + }), + ); + + controller.destroy(); + }); + + it('still stores subscription state when channel callback registration fails', async () => { + const channel = `account-activity.v1.eip155:0:${MOCK_ADDRESS.toLowerCase()}`; + const onAssetsUpdate = jest.fn().mockResolvedValue(undefined); + let notificationCallback: ( + notification: ServerNotificationMessage, + ) => void = () => undefined; + + const rootMessenger = new Messenger( + { namespace: MOCK_ANY_NAMESPACE }, + ); + const controllerMessenger = new Messenger< + 'BackendWebsocketDataSource', + AllActions, + AllEvents, + RootMessenger + >({ + namespace: 'BackendWebsocketDataSource', + parent: rootMessenger, + }); + + rootMessenger.delegate({ + messenger: controllerMessenger, + actions: [ + 'BackendWebSocketService:subscribe', + 'BackendWebSocketService:getConnectionInfo', + 'BackendWebSocketService:addChannelCallback', + ], + events: ['BackendWebSocketService:connectionStateChanged'], + }); + + rootMessenger.registerActionHandler( + 'BackendWebSocketService:subscribe', + ({ callback }) => { + notificationCallback = callback; + return Promise.resolve( + createMockWsSubscription([channel]), + ); + }, + ); + rootMessenger.registerActionHandler( + 'BackendWebSocketService:getConnectionInfo', + () => ({ + state: WebSocketState.CONNECTED, + url: 'wss://test.example.com', + reconnectAttempts: 0, + timeout: 30000, + reconnectDelay: 1000, + maxReconnectDelay: 30000, + requestTimeout: 30000, + }), + ); + rootMessenger.registerActionHandler( + 'BackendWebSocketService:addChannelCallback', + () => { + throw new Error( + 'A handler for BackendWebSocketService:addChannelCallback has not been delegated to AssetsController', + ); + }, + ); + + const controller = new BackendWebsocketDataSource({ + messenger: controllerMessenger as unknown as AssetsControllerMessenger, + queryApiClient: { + accounts: { + fetchV2SupportedNetworks: jest.fn().mockResolvedValue({ + fullSupport: [1], + }), + }, + } as unknown as ApiPlatformClient, + onActiveChainsUpdated: jest.fn(), + getAssetType: () => 'erc20', + state: { activeChains: [CHAIN_MAINNET] }, + }); + + await controller.subscribe({ + subscriptionId: 'sub-1', + request: createDataRequest(), + isUpdate: false, + onAssetsUpdate, + }); + + notificationCallback( + createMockNotification({ + channel, + data: { + address: MOCK_ADDRESS, + tx: { chain: CHAIN_MAINNET }, + updates: [ + { + asset: { + type: 'eip155:1/erc20:0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48', + decimals: 6, + }, + postBalance: { amount: '1000000' }, + }, + ], + }, + }), + ); + + await new Promise(process.nextTick); + + expect(onAssetsUpdate).toHaveBeenCalledWith( + expect.objectContaining({ + assetsBalance: expect.objectContaining({ + 'mock-account-id': expect.any(Object), + }), + }), + expect.objectContaining({ dataTypes: ['balance'] }), + ); controller.destroy(); }); @@ -636,6 +940,10 @@ describe('BackendWebsocketDataSource', () => { }), }), }), + expect.objectContaining({ + dataTypes: ['balance'], + accountsWithSupportedChains: expect.any(Array), + }), ); controller.destroy(); @@ -706,6 +1014,10 @@ describe('BackendWebsocketDataSource', () => { }), }), }), + expect.objectContaining({ + dataTypes: ['balance'], + accountsWithSupportedChains: expect.any(Array), + }), ); controller.destroy(); @@ -769,6 +1081,10 @@ describe('BackendWebsocketDataSource', () => { }), }), }), + expect.objectContaining({ + dataTypes: ['balance'], + accountsWithSupportedChains: expect.any(Array), + }), ); controller.destroy(); @@ -835,6 +1151,10 @@ describe('BackendWebsocketDataSource', () => { }), }), }), + expect.objectContaining({ + dataTypes: ['balance'], + accountsWithSupportedChains: expect.any(Array), + }), ); controller.destroy(); @@ -888,7 +1208,10 @@ describe('BackendWebsocketDataSource', () => { await new Promise(process.nextTick); // No valid updates → response has only updateMode, no assetsBalance - expect(assetsUpdateHandler).toHaveBeenCalledWith({ updateMode: 'merge' }); + expect(assetsUpdateHandler).toHaveBeenCalledWith( + { updateMode: 'merge' }, + expect.objectContaining({ dataTypes: ['balance'] }), + ); controller.destroy(); }); diff --git a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts index 7ea547fd31..b3a513a44a 100644 --- a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts +++ b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts @@ -159,6 +159,43 @@ function buildAccountActivityChannel( return `${CHANNEL_TYPE}.${namespace}:0:${formatted}`; } +/** + * Normalize addresses for stable comparison when detecting account changes. + * + * @param address - Account address (hex or base58). + * @returns Normalized address for comparison. + */ +function normalizeAddressForComparison(address: string): string { + return address.startsWith('0x') ? address.toLowerCase() : address; +} + +/** + * Check whether subscribed account addresses changed (case-insensitive for EVM). + * + * @param nextAddresses - Addresses from the incoming subscribe request. + * @param existingAddresses - Addresses from the active subscription. + * @returns True when the address sets differ. + */ +function haveAddressesChanged( + nextAddresses: string[], + existingAddresses: string[], +): boolean { + if (nextAddresses.length !== existingAddresses.length) { + return true; + } + + const normalizedNext = nextAddresses + .map(normalizeAddressForComparison) + .sort(); + const normalizedExisting = existingAddresses + .map(normalizeAddressForComparison) + .sort(); + + return normalizedNext.some( + (address, index) => address !== normalizedExisting[index], + ); +} + /** * Normalize API chain identifier to CAIP-2 ChainId. * Passes through strings already in CAIP-2 form (e.g. eip155:1, solana:5eykt...). @@ -209,6 +246,8 @@ function toChainId(chainIdOrDecimal: number | string): ChainId { * - BackendWebSocketService:subscribe * - BackendWebSocketService:getConnectionInfo * - BackendWebSocketService:findSubscriptionsByChannelPrefix + * - BackendWebSocketService:addChannelCallback + * - BackendWebSocketService:removeChannelCallback */ const DEFAULT_CHAINS_REFRESH_INTERVAL_MS = 20 * 60 * 1000; // 20 minutes @@ -248,6 +287,12 @@ export class BackendWebsocketDataSource extends AbstractDataSource< /** Store original subscription requests for reconnection */ readonly #subscriptionRequests: Map = new Map(); + /** Channels with registered BackendWebSocketService channel callbacks */ + readonly #registeredChannelCallbacks: Set = new Set(); + + /** Serializes subscribe/unsubscribe so account switches cannot interleave. */ + #subscribeLock: Promise = Promise.resolve(); + constructor(options: BackendWebsocketDataSourceOptions) { super(CONTROLLER_NAME, { ...defaultState, @@ -448,6 +493,23 @@ export class BackendWebsocketDataSource extends AbstractDataSource< // ============================================================================ async subscribe(subscriptionRequest: SubscriptionRequest): Promise { + const previousLock = this.#subscribeLock; + let releaseLock: () => void = () => undefined; + this.#subscribeLock = new Promise((resolve) => { + releaseLock = resolve; + }); + + await previousLock; + try { + await this.#subscribeInternal(subscriptionRequest); + } finally { + releaseLock(); + } + } + + async #subscribeInternal( + subscriptionRequest: SubscriptionRequest, + ): Promise { const { request, subscriptionId, isUpdate } = subscriptionRequest; // Filter to active chains only @@ -473,7 +535,7 @@ export class BackendWebsocketDataSource extends AbstractDataSource< this.#pendingSubscriptions.set(subscriptionId, subscriptionRequest); return; } - } catch { + } catch (error) { // Store anyway - will be processed when we can connect this.#pendingSubscriptions.set(subscriptionId, subscriptionRequest); return; @@ -488,21 +550,21 @@ export class BackendWebsocketDataSource extends AbstractDataSource< if (existing) { // Check if accounts changed - if so, we need to re-subscribe to different channels const existingAddresses = existing.addresses ?? []; - const addressesChanged = - addresses.length !== existingAddresses.length || - addresses.some((addr) => !existingAddresses.includes(addr)); + const addressesChanged = haveAddressesChanged(addresses, existingAddresses); if (!addressesChanged) { - // Only chains changed - just update chains and return + // Only chains changed - update chains, request, and callback existing.chains = chainsToSubscribe; + existing.onAssetsUpdate = subscriptionRequest.onAssetsUpdate; + this.#subscriptionRequests.set(subscriptionId, subscriptionRequest); return; } // Accounts changed - fall through to re-subscribe with new channels } } - // Clean up existing subscription if any - await this.unsubscribe(subscriptionId); + // Clean up existing subscription if any (inline teardown — subscribe holds the lock) + await this.#teardownSubscription(subscriptionId); // Always subscribe to eip155 and solana account activity, plus any namespaces from requested chains const namespaces = getNamespacesForAccountActivity(chainsToSubscribe); @@ -518,6 +580,10 @@ export class BackendWebsocketDataSource extends AbstractDataSource< } } + if (channels.length === 0) { + return; + } + try { // Create WebSocket subscription const wsSubscription = await this.#messenger.call( @@ -531,29 +597,29 @@ export class BackendWebsocketDataSource extends AbstractDataSource< }, ); - // Store WebSocket subscription + // Store WebSocket subscription and subscription state before optional + // channel callbacks — wsCallback routing works without them. this.#wsSubscriptions.set(subscriptionId, wsSubscription); - // Store in abstract class tracking this.activeSubscriptions.set(subscriptionId, { cleanup: () => { - const wsSub = this.#wsSubscriptions.get(subscriptionId); - if (wsSub) { - wsSub.unsubscribe().catch((unsubErr: unknown) => { - log('Error unsubscribing', { subscriptionId, error: unsubErr }); - }); - this.#wsSubscriptions.delete(subscriptionId); - } - // Also clean up the stored request - this.#subscriptionRequests.delete(subscriptionId); + this.#teardownSubscription(subscriptionId).catch(() => undefined); }, chains: chainsToSubscribe, addresses, onAssetsUpdate: subscriptionRequest.onAssetsUpdate, }); - // Store original request for reconnection this.#subscriptionRequests.set(subscriptionId, subscriptionRequest); + + try { + this.#registerChannelCallbacks(subscriptionId, channels); + } catch (channelCallbackError) { + log( + 'Channel callback registration failed; ws subscription still active', + { subscriptionId, error: channelCallbackError }, + ); + } } catch (error) { log('WebSocket subscription FAILED', { subscriptionId, @@ -606,7 +672,7 @@ export class BackendWebsocketDataSource extends AbstractDataSource< const response = this.#processBalanceUpdates(updates, chainId, accountId); if (Object.keys(response).length > 0 && subscription) { - Promise.resolve(subscription.onAssetsUpdate(response)).catch( + Promise.resolve(subscription.onAssetsUpdate(response, request)).catch( console.error, ); } @@ -682,6 +748,93 @@ export class BackendWebsocketDataSource extends AbstractDataSource< return response; } + // ============================================================================ + // UNSUBSCRIBE + // ============================================================================ + + /** + * Unsubscribe and await server-side teardown so a re-subscribe does not race + * with stale subscription IDs on incoming notifications. + * + * @param subscriptionId - The ID of the subscription to cancel. + */ + async unsubscribe(subscriptionId: string): Promise { + const previousLock = this.#subscribeLock; + let releaseLock: () => void = () => undefined; + this.#subscribeLock = new Promise((resolve) => { + releaseLock = resolve; + }); + + await previousLock; + try { + await this.#teardownSubscription(subscriptionId); + } finally { + releaseLock(); + } + } + + async #teardownSubscription(subscriptionId: string): Promise { + const wsSub = this.#wsSubscriptions.get(subscriptionId); + + if (wsSub) { + const channels = [...wsSub.channels]; + try { + await wsSub.unsubscribe(); + } catch (unsubErr: unknown) { + log('Error unsubscribing', { subscriptionId, error: unsubErr }); + } + this.#wsSubscriptions.delete(subscriptionId); + this.#removeChannelCallbacks(channels); + } + + this.#subscriptionRequests.delete(subscriptionId); + this.activeSubscriptions.delete(subscriptionId); + } + + #registerChannelCallbacks( + subscriptionId: string, + channels: string[], + ): void { + for (const channel of channels) { + this.#unregisterChannelCallback(channel); + + try { + this.#messenger.call('BackendWebSocketService:addChannelCallback', { + channelName: channel, + callback: (notification: ServerNotificationMessage) => { + this.#handleNotification(notification, subscriptionId); + }, + }); + this.#registeredChannelCallbacks.add(channel); + } catch { + // Channel callbacks are optional; ws subscription still works without them. + } + } + } + + #unregisterChannelCallback(channel: string): void { + if (!this.#registeredChannelCallbacks.has(channel)) { + return; + } + + try { + this.#messenger.call( + 'BackendWebSocketService:removeChannelCallback', + channel, + ); + } catch { + // Best-effort cleanup when the channel callback was never registered. + } + + this.#registeredChannelCallbacks.delete(channel); + } + + #removeChannelCallbacks(channels: string[]): void { + for (const channel of channels) { + this.#unregisterChannelCallback(channel); + } + } + // ============================================================================ // CLEANUP // ============================================================================ @@ -692,22 +845,17 @@ export class BackendWebsocketDataSource extends AbstractDataSource< this.#chainsRefreshTimer = null; } - // Clean up WebSocket subscriptions - // Convert to array first to avoid modifying map during iteration - const subscriptions = [...this.#wsSubscriptions.values()]; - for (const wsSub of subscriptions) { - try { - // Fire and forget - don't await in destroy - wsSub.unsubscribe().catch(() => { - // Ignore errors during cleanup - }); - } catch { - // Ignore errors during cleanup - } + const subscriptionIds = [ + ...new Set([ + ...this.#wsSubscriptions.keys(), + ...this.activeSubscriptions.keys(), + ]), + ]; + for (const subscriptionId of subscriptionIds) { + this.#teardownSubscription(subscriptionId).catch(() => undefined); } - this.#wsSubscriptions.clear(); - // Clean up base class subscriptions + // Clean up base class subscriptions (no-op if already torn down) super.destroy(); } } diff --git a/packages/assets-controller/src/types.ts b/packages/assets-controller/src/types.ts index 038fe74df9..d53207bd72 100644 --- a/packages/assets-controller/src/types.ts +++ b/packages/assets-controller/src/types.ts @@ -371,6 +371,11 @@ export type DataResponse = { * Defaults to `'merge'` if omitted. */ updateMode?: AssetsUpdateMode; + /** + * @internal Set by AssetsController when applying updates. Data sources must + * not populate this field. + */ + sourceId?: string; }; /** From 8b42ded93c79debdeb6bdefa7bdbc867ca2c7e7c Mon Sep 17 00:00:00 2001 From: salimtb Date: Thu, 25 Jun 2026 16:26:26 +0200 Subject: [PATCH 2/4] fix: add changelog --- packages/assets-controller/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/packages/assets-controller/CHANGELOG.md b/packages/assets-controller/CHANGELOG.md index bc08d2bcb0..3e9c4373da 100644 --- a/packages/assets-controller/CHANGELOG.md +++ b/packages/assets-controller/CHANGELOG.md @@ -11,6 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Bump `@metamask/transaction-controller` from `^68.1.1` to `^68.2.0` ([#9253](https://github.com/MetaMask/core/pull/9253)) +### Fixed + +- Fix stale token balances after transactions when switching accounts or when websocket subscriptions reconnect; `AssetsController` now fetches before re-subscribing on account switch, serializes overlapping refresh work, treats `getAssets({ forceUpdate: true })` as authoritative over recent websocket freshness guards, and prevents passive polling from overwriting websocket balances for 120 seconds ([#9265](https://github.com/MetaMask/core/pull/9265)) +- `AccountsApiDataSource` bypasses the TanStack Query balance cache when `forceUpdate` is true so forced refreshes return up-to-date balances instead of 60-second cached values ([#9265](https://github.com/MetaMask/core/pull/9265)) +- `BackendWebsocketDataSource` re-subscribes when subscribed accounts change (case-insensitive EVM address matching), serializes subscribe/unsubscribe to prevent races on account switch, and registers optional channel callbacks for more reliable notification delivery ([#9265](https://github.com/MetaMask/core/pull/9265)) + ## [9.1.0] ### Added From b363be1b81f56b1641e8cd9e4de576527b654dcc Mon Sep 17 00:00:00 2001 From: salimtb Date: Thu, 25 Jun 2026 16:38:05 +0200 Subject: [PATCH 3/4] fix: fix lint --- .../src/data-sources/BackendWebsocketDataSource.test.ts | 4 ++-- .../src/data-sources/BackendWebsocketDataSource.ts | 2 +- packages/assets-controller/src/types.ts | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts index 5a27a02b4e..df3264f0da 100644 --- a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts @@ -603,7 +603,7 @@ describe('BackendWebsocketDataSource', () => { await Promise.all([firstSubscribe, secondSubscribe]); expect(wsSubscribeMock).toHaveBeenCalledTimes(2); - expect(wsSubscribeMock.mock.calls[1][0].channels).toEqual([ + expect(wsSubscribeMock.mock.calls[1][0].channels).toStrictEqual([ `account-activity.v1.eip155:0:${addressB.toLowerCase()}`, ]); @@ -774,7 +774,7 @@ describe('BackendWebsocketDataSource', () => { }, } as unknown as ApiPlatformClient, onActiveChainsUpdated: jest.fn(), - getAssetType: () => 'erc20', + getAssetType: (): 'erc20' => 'erc20', state: { activeChains: [CHAIN_MAINNET] }, }); diff --git a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts index b3a513a44a..f000ea0db2 100644 --- a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts +++ b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts @@ -535,7 +535,7 @@ export class BackendWebsocketDataSource extends AbstractDataSource< this.#pendingSubscriptions.set(subscriptionId, subscriptionRequest); return; } - } catch (error) { + } catch { // Store anyway - will be processed when we can connect this.#pendingSubscriptions.set(subscriptionId, subscriptionRequest); return; diff --git a/packages/assets-controller/src/types.ts b/packages/assets-controller/src/types.ts index d53207bd72..90acb98f7a 100644 --- a/packages/assets-controller/src/types.ts +++ b/packages/assets-controller/src/types.ts @@ -372,8 +372,10 @@ export type DataResponse = { */ updateMode?: AssetsUpdateMode; /** - * @internal Set by AssetsController when applying updates. Data sources must + * Set by AssetsController when applying updates. Data sources must * not populate this field. + * + * @internal */ sourceId?: string; }; From 0fc4cba5d702896a12df48e5b3fc82be45538163 Mon Sep 17 00:00:00 2001 From: salimtb Date: Thu, 25 Jun 2026 16:45:39 +0200 Subject: [PATCH 4/4] fix: fix lint --- .../assets-controller/src/AssetsController.ts | 4 ++- .../BackendWebsocketDataSource.test.ts | 31 +++++++++---------- .../BackendWebsocketDataSource.ts | 10 +++--- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/packages/assets-controller/src/AssetsController.ts b/packages/assets-controller/src/AssetsController.ts index ada20a462b..d559705351 100644 --- a/packages/assets-controller/src/AssetsController.ts +++ b/packages/assets-controller/src/AssetsController.ts @@ -2177,7 +2177,9 @@ export class AssetsController extends BaseController< const filtered: Record = {}; for (const [assetId, balance] of Object.entries(accountBalances)) { - const freshUntil = this.#wsBalanceFreshUntil.get(`${accountId}:${assetId}`); + const freshUntil = this.#wsBalanceFreshUntil.get( + `${accountId}:${assetId}`, + ); if (freshUntil !== undefined && now < freshUntil) { continue; } diff --git a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts index df3264f0da..676bff2c8f 100644 --- a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts @@ -615,9 +615,9 @@ describe('BackendWebsocketDataSource', () => { const mockWsSubscription = createMockWsSubscription([channel]); const { controller, wsSubscribeMock, removeChannelCallbackMock } = setupController({ - initialActiveChains: [CHAIN_MAINNET], - connectionState: WebSocketState.CONNECTED, - }); + initialActiveChains: [CHAIN_MAINNET], + connectionState: WebSocketState.CONNECTED, + }); wsSubscribeMock.mockResolvedValueOnce(mockWsSubscription); @@ -640,14 +640,11 @@ describe('BackendWebsocketDataSource', () => { const channel = `account-activity.v1.eip155:0:${MOCK_ADDRESS.toLowerCase()}`; const mockWsSubscription = createMockWsSubscription([channel]); const onAssetsUpdate = jest.fn().mockResolvedValue(undefined); - const { - controller, - wsSubscribeMock, - addChannelCallbackMock, - } = setupController({ - initialActiveChains: [CHAIN_MAINNET], - connectionState: WebSocketState.CONNECTED, - }); + const { controller, wsSubscribeMock, addChannelCallbackMock } = + setupController({ + initialActiveChains: [CHAIN_MAINNET], + connectionState: WebSocketState.CONNECTED, + }); wsSubscribeMock.mockResolvedValueOnce(mockWsSubscription); @@ -711,9 +708,11 @@ describe('BackendWebsocketDataSource', () => { notification: ServerNotificationMessage, ) => void = () => undefined; - const rootMessenger = new Messenger( - { namespace: MOCK_ANY_NAMESPACE }, - ); + const rootMessenger = new Messenger< + MockAnyNamespace, + AllActions, + AllEvents + >({ namespace: MOCK_ANY_NAMESPACE }); const controllerMessenger = new Messenger< 'BackendWebsocketDataSource', AllActions, @@ -738,9 +737,7 @@ describe('BackendWebsocketDataSource', () => { 'BackendWebSocketService:subscribe', ({ callback }) => { notificationCallback = callback; - return Promise.resolve( - createMockWsSubscription([channel]), - ); + return Promise.resolve(createMockWsSubscription([channel])); }, ); rootMessenger.registerActionHandler( diff --git a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts index f000ea0db2..945272d8e0 100644 --- a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts +++ b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts @@ -550,7 +550,10 @@ export class BackendWebsocketDataSource extends AbstractDataSource< if (existing) { // Check if accounts changed - if so, we need to re-subscribe to different channels const existingAddresses = existing.addresses ?? []; - const addressesChanged = haveAddressesChanged(addresses, existingAddresses); + const addressesChanged = haveAddressesChanged( + addresses, + existingAddresses, + ); if (!addressesChanged) { // Only chains changed - update chains, request, and callback @@ -791,10 +794,7 @@ export class BackendWebsocketDataSource extends AbstractDataSource< this.activeSubscriptions.delete(subscriptionId); } - #registerChannelCallbacks( - subscriptionId: string, - channels: string[], - ): void { + #registerChannelCallbacks(subscriptionId: string, channels: string[]): void { for (const channel of channels) { this.#unregisterChannelCallback(channel);