diff --git a/xds/src/main/java/io/grpc/xds/FaultFilter.java b/xds/src/main/java/io/grpc/xds/FaultFilter.java index ce764c7e943..79bcf7e42a0 100644 --- a/xds/src/main/java/io/grpc/xds/FaultFilter.java +++ b/xds/src/main/java/io/grpc/xds/FaultFilter.java @@ -197,7 +197,8 @@ private static FaultConfig.FractionalPercent parsePercent(FractionalPercent prot @Override public ClientInterceptor buildClientInterceptor( FilterConfig config, @Nullable FilterConfig overrideConfig, - final ScheduledExecutorService scheduler) { + final ScheduledExecutorService scheduler, + Filter.ResourceCleanupRegistry cleanupRegistry) { checkNotNull(config, "config"); if (overrideConfig != null) { config = overrideConfig; diff --git a/xds/src/main/java/io/grpc/xds/Filter.java b/xds/src/main/java/io/grpc/xds/Filter.java index 4fa56beb1de..f7fb11e8462 100644 --- a/xds/src/main/java/io/grpc/xds/Filter.java +++ b/xds/src/main/java/io/grpc/xds/Filter.java @@ -108,11 +108,21 @@ ConfigOrError parseFilterConfigOverride( Message rawProtoMessage, FilterConfigParseContext context); } - /** Uses the FilterConfigs produced above to produce an HTTP filter interceptor for clients. */ + /** + * Builds an HTTP filter interceptor for this route. + * + *

Filters that create stateful resources (e.g., shared channels) should register + * cleanup tasks via {@code cleanupRegistry}. These tasks execute in the xDS + * {@code SynchronizationContext} when the route's reference count reaches zero, + * meaning no in-flight RPCs reference the route and the control plane has released it. + * + * @param cleanupRegistry registry for cleanup tasks; never null + */ @Nullable default ClientInterceptor buildClientInterceptor( FilterConfig config, @Nullable FilterConfig overrideConfig, - ScheduledExecutorService scheduler) { + ScheduledExecutorService scheduler, + ResourceCleanupRegistry cleanupRegistry) { return null; } @@ -193,4 +203,15 @@ public String toString() { .toString(); } } + + /** + * Registry for cleanup tasks associated with a route's resource scope. + */ + @FunctionalInterface + interface ResourceCleanupRegistry { + /** + * Registers a task to run when the route is no longer in use. + */ + void addCleanupTask(Runnable task); + } } diff --git a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java index e87c402fcb0..9d58fc03e94 100644 --- a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java +++ b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java @@ -130,7 +130,8 @@ public ConfigOrError parseFilterConfigOverride( @Nullable @Override public ClientInterceptor buildClientInterceptor(FilterConfig config, - @Nullable FilterConfig overrideConfig, ScheduledExecutorService scheduler) { + @Nullable FilterConfig overrideConfig, ScheduledExecutorService scheduler, + Filter.ResourceCleanupRegistry cleanupRegistry) { ComputeEngineCredentials credentials = ComputeEngineCredentials.create(); synchronized (callCredentialsCache) { diff --git a/xds/src/main/java/io/grpc/xds/SharedResourceManager.java b/xds/src/main/java/io/grpc/xds/SharedResourceManager.java new file mode 100644 index 00000000000..939d558acb3 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/SharedResourceManager.java @@ -0,0 +1,205 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.common.base.Preconditions; +import io.grpc.Internal; +import io.grpc.ManagedChannel; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import javax.annotation.concurrent.ThreadSafe; + +/** + * Manages generic reference-counted shared resources for xDS filters. + * + *

Similar to {@code io.grpc.xds.internal.security.ReferenceCountingMap}, but provides + * additional lifecycle management ({@link #close()}) and a simpler key-only + * {@link #release(Object)} API designed for xDS filter state cleanup tasks. + */ +@Internal +public final class SharedResourceManager { + + /** + * An AutoCloseable resource that explicitly guarantees its close operation + * will not throw checked exceptions. + */ + public interface ResourceCloseable extends AutoCloseable { + @Override + void close(); + } + + /** + * Adapts {@link ManagedChannel} to {@link ResourceCloseable} for management by + * {@link SharedResourceManager}. + */ + public static final class ManagedChannelResource implements ResourceCloseable { + private final ManagedChannel channel; + + public ManagedChannelResource(ManagedChannel channel) { + this.channel = Preconditions.checkNotNull(channel, "channel"); + } + + @Override + public void close() { + channel.shutdown(); + } + + public ManagedChannel getChannel() { + return channel; + } + } + + /** + * An internal pure reference-counting container managing a stateful ResourceCloseable. + */ + @ThreadSafe + static final class SharedResource { + private final T resource; + private final AtomicInteger refCount = new AtomicInteger(1); + + SharedResource(T resource) { + this.resource = Preconditions.checkNotNull(resource, "resource"); + } + + /** + * Retains the resource. Returns false if the resource has hit 0 and is being closed. + */ + boolean retain() { + int count; + do { + count = refCount.get(); + if (count == 0) { + return false; + } + } while (!refCount.compareAndSet(count, count + 1)); + return true; + } + + /** + * Decrements reference count. Closes underlying resource if count hits 0. + * @return true if the count reached 0 and the resource was closed; false otherwise. + */ + boolean release() { + int count = refCount.decrementAndGet(); + if (count < 0) { + throw new AssertionError("SharedResourceManager reference count dropped below 0"); + } + if (count == 0) { + resource.close(); + return true; + } + return false; + } + + T get() { + return resource; + } + + int getRefCount() { + return refCount.get(); + } + } + + private final ConcurrentMap> resources = new ConcurrentHashMap<>(); + private final Function resourceCreator; + private final Object closeLock = new Object(); + private volatile boolean closed; + + public SharedResourceManager(Function resourceCreator) { + this.resourceCreator = resourceCreator; + } + + /** + * Acquires a resource for the given key, incrementing its reference count. + */ + public V acquire(K key) { + while (true) { + SharedResource shared = resources.get(key); + if (shared == null) { + SharedResource newShared = new SharedResource<>(resourceCreator.apply(key)); + synchronized (closeLock) { + if (closed) { + newShared.release(); + throw new IllegalStateException("SharedResourceManager is closed"); + } + shared = resources.putIfAbsent(key, newShared); + } + if (shared == null) { + return newShared.get(); + } + // Lost the race: close the resource we created to prevent leaks. + newShared.release(); + } + if (shared.retain()) { + return shared.get(); + } + // If retain failed, it's being evicted concurrently. Loop to compute a new one. + resources.remove(key, shared); + } + } + + /** + * Releases a resource for the given key, decrementing its reference count. + * Closes and evicts the resource if the reference count reaches 0. + * + *

In the xDS filter state lifecycle, this method is invoked from cleanup + * tasks that execute within the {@code SynchronizationContext}. + * + * @return true if the resource was closed; false otherwise. + */ + public boolean release(K key) { + SharedResource shared = resources.get(key); + if (shared == null) { + return false; + } + try { + if (shared.release()) { + return resources.remove(key, shared); + } + } catch (Throwable t) { + resources.remove(key, shared); + throw t; + } + return false; + } + + /** + * Removes all entries from the cache and releases the manager's reference for each. + * + *

This intentionally performs a single {@code release()} per entry, decrementing the + * manager's own reference count contribution. If in-flight RPCs still hold references, + * the underlying resource remains open until those references are released. This avoids + * pulling resources out from under active operations. + */ + public void close() { + synchronized (closeLock) { + closed = true; + } + for (K key : resources.keySet()) { + SharedResource shared = resources.remove(key); + if (shared != null) { + try { + shared.release(); + } catch (Throwable t) { + // Ignore exceptions during final close-all to ensure we try to close other resources + } + } + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index 69b0b824433..c7deae93de0 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -79,8 +79,11 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; +import java.util.logging.Level; +import java.util.logging.Logger; import javax.annotation.Nullable; /** @@ -411,7 +414,7 @@ private final class ConfigSelector extends InternalConfigSelector { @Override public Result selectConfig(PickSubchannelArgs args) { RoutingConfig routingCfg; - RouteData selectedRoute; + RefCountedRoute selectedRoute; String cluster; ClientInterceptor filters; Metadata headers = args.getHeaders(); @@ -422,8 +425,8 @@ public Result selectConfig(PickSubchannelArgs args) { return Result.forError(routingCfg.errorStatus); } selectedRoute = null; - for (RouteData route : routingCfg.routes) { - if (RoutingUtils.matchRoute(route.routeMatch, path, headers, random)) { + for (RefCountedRoute route : routingCfg.routes) { + if (RoutingUtils.matchRoute(route.getRouteData().routeMatch, path, headers, random)) { selectedRoute = route; break; } @@ -432,14 +435,14 @@ public Result selectConfig(PickSubchannelArgs args) { return Result.forError( Status.UNAVAILABLE.withDescription("Could not find xDS route matching RPC")); } - if (selectedRoute.routeAction == null) { + if (selectedRoute.getRouteData().routeAction == null) { return Result.forError(Status.UNAVAILABLE.withDescription( "Could not route RPC to Route with non-forwarding action")); } - RouteAction action = selectedRoute.routeAction; + RouteAction action = selectedRoute.getRouteData().routeAction; if (action.cluster() != null) { cluster = prefixedClusterName(action.cluster()); - filters = selectedRoute.filterChoices.get(0); + filters = selectedRoute.getRouteData().filterChoices.get(0); } else if (action.weightedClusters() != null) { // XdsRouteConfigureResource verifies the total weight will not be 0 or exceed uint32 long totalWeight = 0; @@ -453,21 +456,29 @@ public Result selectConfig(PickSubchannelArgs args) { accumulator += weightedCluster.weight(); if (select < accumulator) { cluster = prefixedClusterName(weightedCluster.name()); - filters = selectedRoute.filterChoices.get(i); + filters = selectedRoute.getRouteData().filterChoices.get(i); break; } } } else if (action.namedClusterSpecifierPluginConfig() != null) { cluster = prefixedClusterSpecifierPluginName(action.namedClusterSpecifierPluginConfig().name()); - filters = selectedRoute.filterChoices.get(0); + filters = selectedRoute.getRouteData().filterChoices.get(0); } else { // updateRoutes() discards routes with unknown actions throw new AssertionError(); } - } while (!retainCluster(cluster)); + if (!retainCluster(cluster)) { + continue; + } + if (!selectedRoute.retain()) { + releaseCluster(cluster); + continue; + } + break; + } while (true); - final RouteAction routeAction = selectedRoute.routeAction; + final RouteAction routeAction = selectedRoute.getRouteData().routeAction; Long timeoutNanos = null; if (enableTimeout) { timeoutNanos = routeAction.timeoutNano(); @@ -486,6 +497,7 @@ public Result selectConfig(PickSubchannelArgs args) { Object config = parsedServiceConfig.getConfig(); if (config == null) { releaseCluster(cluster); + selectedRoute.release(); return Result.forError( parsedServiceConfig.getError().augmentDescription( "Failed to parse service config (method config)")); @@ -533,12 +545,10 @@ public void onClose(Status status, Metadata trailers) { } } - return - Result.newBuilder() - .setConfig(config) - .setInterceptor(combineInterceptors( - ImmutableList.of(new ClusterSelectionInterceptor(), filters))) - .build(); + ClientInterceptor refCountedRouteInterceptor = new RefCountedRouteInterceptor(selectedRoute); + return Result.newBuilder().setConfig(config).setInterceptor(combineInterceptors( + ImmutableList.of(refCountedRouteInterceptor, new ClusterSelectionInterceptor(), filters))) + .build(); } private boolean retainCluster(String cluster) { @@ -695,6 +705,10 @@ private void shutdown() { stopped = true; xdsDependencyManager.shutdown(); updateActiveFilters(null); + RoutingConfig cfg = XdsNameResolver.this.routingConfig; + if (cfg != null) { + cfg.close(); + } } @Override @@ -761,7 +775,7 @@ private void updateRoutes( long httpMaxStreamDurationNano, @Nullable List filterConfigs) { List routes = virtualHost.routes(); - ImmutableList.Builder routesData = ImmutableList.builder(); + ImmutableList.Builder routesData = ImmutableList.builder(); // Populate all clusters to which requests can be routed to through the virtual host. Set clusters = new HashSet<>(); @@ -772,24 +786,32 @@ private void updateRoutes( for (Route route : routes) { RouteAction action = route.routeAction(); String prefixedName; + + CleanupCollector collector = new CleanupCollector(); + if (action == null) { - routesData.add(new RouteData(route.routeMatch(), null, ImmutableList.of())); + RouteData pureRouteData = new RouteData(route.routeMatch(), null, ImmutableList.of()); + routesData.add(new RefCountedRoute(pureRouteData, collector.tasks.build(), syncContext)); } else if (action.cluster() != null) { prefixedName = prefixedClusterName(action.cluster()); clusters.add(prefixedName); clusterNameMap.put(prefixedName, action.cluster()); - ClientInterceptor filters = createFilters(filterConfigs, virtualHost, route, null); - routesData.add(new RouteData(route.routeMatch(), route.routeAction(), filters)); + ClientInterceptor filters = + createFilters(filterConfigs, virtualHost, route, null, collector); + RouteData pureRouteData = new RouteData(route.routeMatch(), route.routeAction(), filters); + routesData.add(new RefCountedRoute(pureRouteData, collector.tasks.build(), syncContext)); } else if (action.weightedClusters() != null) { ImmutableList.Builder filterList = ImmutableList.builder(); for (ClusterWeight weightedCluster : action.weightedClusters()) { prefixedName = prefixedClusterName(weightedCluster.name()); clusters.add(prefixedName); clusterNameMap.put(prefixedName, weightedCluster.name()); - filterList.add(createFilters(filterConfigs, virtualHost, route, weightedCluster)); + filterList.add( + createFilters(filterConfigs, virtualHost, route, weightedCluster, collector)); } - routesData.add( - new RouteData(route.routeMatch(), route.routeAction(), filterList.build())); + RouteData pureRouteData = + new RouteData(route.routeMatch(), route.routeAction(), filterList.build()); + routesData.add(new RefCountedRoute(pureRouteData, collector.tasks.build(), syncContext)); } else if (action.namedClusterSpecifierPluginConfig() != null) { PluginConfig pluginConfig = action.namedClusterSpecifierPluginConfig().config(); if (pluginConfig instanceof RlsPluginConfig) { @@ -798,8 +820,10 @@ private void updateRoutes( clusters.add(prefixedName); rlsPluginConfigMap.put(prefixedName, (RlsPluginConfig) pluginConfig); } - ClientInterceptor filters = createFilters(filterConfigs, virtualHost, route, null); - routesData.add(new RouteData(route.routeMatch(), route.routeAction(), filters)); + ClientInterceptor filters = + createFilters(filterConfigs, virtualHost, route, null, collector); + RouteData pureRouteData = new RouteData(route.routeMatch(), route.routeAction(), filters); + routesData.add(new RefCountedRoute(pureRouteData, collector.tasks.build(), syncContext)); } else { // Discard route } @@ -854,7 +878,12 @@ private void updateRoutes( } // Make newly added clusters selectable by config selector and deleted clusters no longer // selectable. - routingConfig = new RoutingConfig(xdsConfig, httpMaxStreamDurationNano, routesData.build()); + RoutingConfig oldConfig = XdsNameResolver.this.routingConfig; + XdsNameResolver.this.routingConfig = + new RoutingConfig(xdsConfig, httpMaxStreamDurationNano, routesData.build()); + if (oldConfig != null) { + oldConfig.close(); + } for (String cluster : deletedClusters) { int count = clusterRefs.get(cluster).refCount.decrementAndGet(); if (count == 0) { @@ -871,7 +900,8 @@ private ClientInterceptor createFilters( @Nullable List filterConfigs, VirtualHost virtualHost, Route route, - @Nullable ClusterWeight weightedCluster) { + @Nullable ClusterWeight weightedCluster, + Filter.ResourceCleanupRegistry cleanupRegistry) { if (filterConfigs == null) { return new PassthroughClientInterceptor(); } @@ -893,7 +923,7 @@ private ClientInterceptor createFilters( Filter filter = activeFilters.get(filterKey); checkNotNull(filter, "activeFilters.get(%s)", filterKey); ClientInterceptor interceptor = - filter.buildClientInterceptor(config, overrideConfig, scheduler); + filter.buildClientInterceptor(config, overrideConfig, scheduler, cleanupRegistry); if (interceptor != null) { filterInterceptors.add(interceptor); @@ -906,7 +936,11 @@ private ClientInterceptor createFilters( } private void cleanUpRoutes(Status error) { - routingConfig = new RoutingConfig(error); + RoutingConfig oldConfig = XdsNameResolver.this.routingConfig; + XdsNameResolver.this.routingConfig = new RoutingConfig(error); + if (oldConfig != null) { + oldConfig.close(); + } if (existingClusters != null) { for (String cluster : existingClusters) { int count = clusterRefs.get(cluster).refCount.decrementAndGet(); @@ -934,11 +968,11 @@ private void cleanUpRoutes(Status error) { private static class RoutingConfig { final XdsConfig xdsConfig; final long fallbackTimeoutNano; - final ImmutableList routes; + final ImmutableList routes; final Status errorStatus; private RoutingConfig( - XdsConfig xdsConfig, long fallbackTimeoutNano, ImmutableList routes) { + XdsConfig xdsConfig, long fallbackTimeoutNano, ImmutableList routes) { this.xdsConfig = checkNotNull(xdsConfig, "xdsConfig"); this.fallbackTimeoutNano = fallbackTimeoutNano; this.routes = checkNotNull(routes, "routes"); @@ -952,6 +986,24 @@ private RoutingConfig(Status errorStatus) { this.errorStatus = checkNotNull(errorStatus, "errorStatus"); checkArgument(!errorStatus.isOk(), "errorStatus should not be okay"); } + + void close() { + if (routes != null) { + for (RefCountedRoute route : routes) { + route.release(); + } + } + } + } + + /** Collects cleanup tasks registered by filters during route construction. */ + private static final class CleanupCollector implements Filter.ResourceCleanupRegistry { + final ImmutableList.Builder tasks = ImmutableList.builder(); + + @Override + public void addCleanupTask(Runnable task) { + tasks.add(task); + } } static final class RouteData { @@ -1093,6 +1145,138 @@ public XdsClient returnObject(XdsClient xdsClient) { } } + /** + * Wraps a RouteData with reference counting. Held by the control plane (refCount=1). + * In-flight RPCs retain +1. When refCount hits 0, runs registered cleanup tasks in syncContext. + */ + @VisibleForTesting + static class RefCountedRoute { + private static final Logger routeLogger = + Logger.getLogger(RefCountedRoute.class.getName()); + private final RouteData routeData; + private final ImmutableList cleanupTasks; + private final SynchronizationContext syncContext; + // Starts at 1 representing Control Plane configuration ownership + private final AtomicInteger refCount = new AtomicInteger(1); + + public RefCountedRoute( + RouteData routeData, ImmutableList cleanupTasks, + SynchronizationContext syncContext) { + this.routeData = checkNotNull(routeData, "routeData"); + this.cleanupTasks = checkNotNull(cleanupTasks, "cleanupTasks"); + this.syncContext = checkNotNull(syncContext, "syncContext"); + } + + public RouteData getRouteData() { + return routeData; + } + + public boolean retain() { + int count; + do { + count = refCount.get(); + if (count == 0) { + routeLogger.log( + Level.WARNING, + "RefCountedRoute retain called on a dead route (refCount == 0), " + + "ignoring redundant retain"); + return false; + } + } while (!refCount.compareAndSet(count, count + 1)); + return true; + } + + public void release() { + int count = refCount.decrementAndGet(); + if (count < 0) { + throw new AssertionError(); + } + if (count == 0) { + syncContext.execute(new Runnable() { + @Override + public void run() { + if (refCount.get() != 0) { + throw new AssertionError(); + } + for (Runnable task : cleanupTasks) { + try { + task.run(); + } catch (Throwable t) { + routeLogger.log( + Level.SEVERE, "Exception running cleanup task", t); + } + } + } + }); + } + } + } + + /** + * Standalone interceptor guaranteeing reliable RefCountedRoute release + * on stream close or early cancellation. + */ + @VisibleForTesting + static final class RefCountedRouteInterceptor implements ClientInterceptor { + private final RefCountedRoute refCountedRoute; + + RefCountedRouteInterceptor(RefCountedRoute refCountedRoute) { + this.refCountedRoute = refCountedRoute; + } + + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + final AtomicBoolean released = new AtomicBoolean(false); + + try { + ClientCall call = next.newCall(method, callOptions); + return new SimpleForwardingClientCall(call) { + @Override + public void start(Listener responseListener, Metadata headers) { + try { + super.start( + new SimpleForwardingClientCallListener(responseListener) { + @Override + public void onClose(Status status, Metadata trailers) { + try { + super.onClose(status, trailers); + } finally { + if (released.compareAndSet(false, true)) { + refCountedRoute.release(); + } + } + } + }, + headers); + } catch (Throwable t) { + if (released.compareAndSet(false, true)) { + refCountedRoute.release(); + } + throw t; + } + } + + @Override + public void cancel(String message, Throwable cause) { + try { + super.cancel(message, cause); + } finally { + if (released.compareAndSet(false, true)) { + refCountedRoute.release(); + } + } + } + }; + } catch (Throwable t) { + if (released.compareAndSet(false, true)) { + refCountedRoute.release(); + } + throw t; + } + } + } + private static final class SupplierXdsClientPool implements XdsClientPool { private final Supplier xdsClientSupplier; diff --git a/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java index 11745c01fc2..6589c43c6c0 100644 --- a/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java +++ b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java @@ -161,7 +161,7 @@ public void testClientInterceptor_success() throws ResourceInvalidException { .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); - ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null, task -> { }); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); Channel mockChannel = Mockito.mock(Channel.class); ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class); @@ -190,7 +190,7 @@ public void testClientInterceptor_createsAndReusesCachedCredentials() .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); - ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null, task -> { }); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); Channel mockChannel = Mockito.mock(Channel.class); ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class); @@ -211,7 +211,7 @@ public void testClientInterceptor_createsAndReusesCachedCredentials() public void testClientInterceptor_withoutClusterSelectionKey() throws Exception { GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); - ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null, task -> { }); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); Channel mockChannel = mock(Channel.class); CallOptions callOptionsWithXds = CallOptions.DEFAULT; @@ -242,7 +242,7 @@ public void testClientInterceptor_clusterSelectionKeyWithoutPrefix() throws Exce GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); - ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null, task -> { }); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); @@ -253,7 +253,7 @@ public void testClientInterceptor_clusterSelectionKeyWithoutPrefix() throws Exce public void testClientInterceptor_xdsConfigDoesNotExist() throws Exception { GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); - ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null, task -> { }); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); Channel mockChannel = mock(Channel.class); CallOptions callOptionsWithXds = CallOptions.DEFAULT @@ -283,7 +283,7 @@ public void testClientInterceptor_incorrectClusterName() throws Exception { .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); - ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null, task -> { }); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); Channel mockChannel = mock(Channel.class); @@ -309,7 +309,7 @@ public void testClientInterceptor_statusOrError() throws Exception { .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); - ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null, task -> { }); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); Channel mockChannel = mock(Channel.class); @@ -337,7 +337,7 @@ public void testClientInterceptor_notAudienceWrapper() throws ResourceInvalidExc .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); - ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null, task -> { }); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); Channel mockChannel = Mockito.mock(Channel.class); @@ -363,7 +363,7 @@ public void testLruCacheAcrossInterceptors() throws ResourceInvalidException { .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 2); ClientInterceptor interceptor1 - = filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null); + = filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null, task -> { }); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); Channel mockChannel = Mockito.mock(Channel.class); ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class); @@ -373,7 +373,7 @@ public void testLruCacheAcrossInterceptors() throws ResourceInvalidException { CallOptions capturedOptions1 = callOptionsCaptor.getAllValues().get(0); assertNotNull(capturedOptions1.getCredentials()); ClientInterceptor interceptor2 - = filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null); + = filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null, task -> { }); interceptor2.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); verify(mockChannel, times(2)) .newCall(eq(methodDescriptor), callOptionsCaptor.capture()); @@ -399,7 +399,7 @@ public void testLruCacheEvictionOnResize() throws ResourceInvalidException { MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); ClientInterceptor interceptor1 = - filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null); + filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null, task -> { }); Channel mockChannel1 = Mockito.mock(Channel.class); ArgumentCaptor captor = ArgumentCaptor.forClass(CallOptions.class); interceptor1.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel1); @@ -407,7 +407,7 @@ public void testLruCacheEvictionOnResize() throws ResourceInvalidException { CallOptions options1 = captor.getValue(); // This will recreate the cache with max size of 1 and copy the credential for audience1. ClientInterceptor interceptor2 = - filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null); + filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null, task -> { }); Channel mockChannel2 = Mockito.mock(Channel.class); interceptor2.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel2); verify(mockChannel2).newCall(eq(methodDescriptor), captor.capture()); @@ -428,7 +428,7 @@ public void testLruCacheEvictionOnResize() throws ResourceInvalidException { // This will evict the credential for audience1 and add new credential for audience2 ClientInterceptor interceptor3 = - filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null); + filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null, task -> { }); Channel mockChannel3 = Mockito.mock(Channel.class); interceptor3.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel3); verify(mockChannel3).newCall(eq(methodDescriptor), captor.capture()); @@ -449,7 +449,7 @@ public void testLruCacheEvictionOnResize() throws ResourceInvalidException { // This will create new credential for audience1 because it has been evicted ClientInterceptor interceptor4 = - filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null); + filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null, task -> { }); Channel mockChannel4 = Mockito.mock(Channel.class); interceptor4.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel4); verify(mockChannel4).newCall(eq(methodDescriptor), captor.capture()); diff --git a/xds/src/test/java/io/grpc/xds/SharedResourceManagerTest.java b/xds/src/test/java/io/grpc/xds/SharedResourceManagerTest.java new file mode 100644 index 00000000000..72d09cae91a --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/SharedResourceManagerTest.java @@ -0,0 +1,389 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import io.grpc.ManagedChannel; +import io.grpc.xds.SharedResourceManager.ManagedChannelResource; +import io.grpc.xds.SharedResourceManager.ResourceCloseable; +import io.grpc.xds.SharedResourceManager.SharedResource; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** Tests for {@link SharedResourceManager}. */ +@RunWith(JUnit4.class) +public class SharedResourceManagerTest { + + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + + @Mock private ResourceCloseable mockResourceA; + @Mock private ResourceCloseable mockResourceB; + @Mock private ResourceCloseable mockResource; + @Mock private ManagedChannel mockChannel; + + private SharedResourceManager manager; + private final AtomicInteger createCount = new AtomicInteger(); + + @Before + public void setUp() { + manager = new SharedResourceManager<>(key -> { + createCount.incrementAndGet(); + if ("keyA".equals(key)) { + return mockResourceA; + } else if ("keyB".equals(key)) { + return mockResourceB; + } + throw new IllegalArgumentException("Unexpected key: " + key); + }); + } + + // ---- SharedResourceManager tests ---- + + @Test + public void acquire_createsNewResourceOnFirstCall() { + ResourceCloseable resource = manager.acquire("keyA"); + assertThat(resource).isSameInstanceAs(mockResourceA); + assertThat(createCount.get()).isEqualTo(1); + } + + @Test + public void acquire_returnsSameResourceOnSubsequentCalls() { + ResourceCloseable first = manager.acquire("keyA"); + ResourceCloseable second = manager.acquire("keyA"); + assertThat(second).isSameInstanceAs(first); + assertThat(createCount.get()).isEqualTo(1); + } + + @Test + public void acquire_differentKeysCreateIndependentResources() { + ResourceCloseable resourceA = manager.acquire("keyA"); + ResourceCloseable resourceB = manager.acquire("keyB"); + assertThat(resourceA).isSameInstanceAs(mockResourceA); + assertThat(resourceB).isSameInstanceAs(mockResourceB); + assertThat(createCount.get()).isEqualTo(2); + } + + @Test + public void release_decrementsRefCount() { + manager.acquire("keyA"); + manager.acquire("keyA"); // refCount = 2 + boolean closed = manager.release("keyA"); // refCount = 1 + assertThat(closed).isFalse(); + verify(mockResourceA, never()).close(); + } + + @Test + public void release_removesEntryWhenRefCountReachesZero() { + manager.acquire("keyA"); // refCount = 1 + boolean closed = manager.release("keyA"); // refCount = 0, should close and evict + assertThat(closed).isTrue(); + verify(mockResourceA).close(); + // Acquiring again should create a new resource. + createCount.set(0); + manager.acquire("keyA"); + assertThat(createCount.get()).isEqualTo(1); + } + + @Test + public void release_evictsAndRethrowsOnCloseException() { + RuntimeException boom = new RuntimeException("boom"); + doThrow(boom).when(mockResourceA).close(); + manager.acquire("keyA"); // refCount = 1 + RuntimeException thrown = assertThrows( + RuntimeException.class, + () -> manager.release("keyA")); + assertThat(thrown).isSameInstanceAs(boom); + // Entry should be evicted even on exception; re-acquire creates new. + createCount.set(0); + manager.acquire("keyA"); + assertThat(createCount.get()).isEqualTo(1); + } + + @Test + public void release_returnsFalseWhenKeyDoesNotExist() { + boolean closed = manager.release("nonExistentKey"); + assertThat(closed).isFalse(); + } + + @Test + public void close_releasesAllCachedResources() { + manager.acquire("keyA"); + manager.acquire("keyB"); + manager.close(); + verify(mockResourceA).close(); + verify(mockResourceB).close(); + } + + @Test + public void independentKeyLifecycle() { + manager.acquire("keyA"); + manager.acquire("keyB"); + // Release keyA fully. + manager.release("keyA"); + verify(mockResourceA).close(); + verify(mockResourceB, never()).close(); + // keyB should still be accessible. + ResourceCloseable stillB = manager.acquire("keyB"); + assertThat(stillB).isSameInstanceAs(mockResourceB); + } + + @Test + public void acquire_afterFullRelease_createsNewResource() { + // Acquire then fully release — exercises the eviction path where retain() fails. + manager.acquire("keyA"); // refCount = 1 + manager.release("keyA"); // refCount = 0, evicted + verify(mockResourceA).close(); + + // Re-acquire should create a new resource via the factory. + createCount.set(0); + ResourceCloseable reacquired = manager.acquire("keyA"); + assertThat(createCount.get()).isEqualTo(1); + // The new resource is a fresh mock created by the factory. + assertThat(reacquired).isSameInstanceAs(mockResourceA); + } + + @Test + public void close_onEmptyManager_doesNotThrow() { + // close() on a manager with no resources should be a no-op. + manager.close(); + // No exceptions thrown, no mocks invoked. + verify(mockResourceA, never()).close(); + verify(mockResourceB, never()).close(); + } + + @Test + public void close_afterPartialRelease_releasesRemaining() { + manager.acquire("keyA"); + manager.acquire("keyA"); // refCount = 2 + manager.acquire("keyB"); // refCount = 1 + + // Partially release keyA (refCount 2 -> 1). + manager.release("keyA"); + verify(mockResourceA, never()).close(); // Still alive. + + // close() should release everything remaining. + manager.close(); + verify(mockResourceA).close(); + verify(mockResourceB).close(); + } + + @Test + public void release_multipleReleasesOnSameKey() { + manager.acquire("keyA"); + manager.acquire("keyA"); // refCount = 2 + manager.acquire("keyA"); // refCount = 3 + + manager.release("keyA"); // 3 -> 2 + verify(mockResourceA, never()).close(); + + manager.release("keyA"); // 2 -> 1 + verify(mockResourceA, never()).close(); + + manager.release("keyA"); // 1 -> 0, close and evict + verify(mockResourceA).close(); + } + + @Test + public void acquire_retriesWhenRetainFailsOnStaleEntry() { + AtomicInteger callCount = new AtomicInteger(); + ResourceCloseable firstResource = new ResourceCloseable() { + @Override + public void close() {} + }; + ResourceCloseable secondResource = new ResourceCloseable() { + @Override + public void close() {} + }; + SharedResourceManager testManager = + new SharedResourceManager<>(key -> { + int c = callCount.incrementAndGet(); + return c == 1 ? firstResource : secondResource; + }); + + // First acquire creates and caches the resource. + ResourceCloseable got = testManager.acquire("keyA"); + assertThat(got).isSameInstanceAs(firstResource); + + // Fully release it: refCount -> 0 (resource is now dead). + testManager.release("keyA"); + + // Now acquire again — this should create a new resource. + ResourceCloseable got2 = testManager.acquire("keyA"); + assertThat(got2).isSameInstanceAs(secondResource); + assertThat(callCount.get()).isEqualTo(2); + } + + @Test + public void acquire_throwsIllegalStateExceptionAfterClose() { + manager.close(); + assertThrows(IllegalStateException.class, () -> manager.acquire("keyA")); + } + + @Test + public void acquire_losesPutIfAbsentRace_closesLoser() { + AtomicInteger callCount = new AtomicInteger(); + ResourceCloseable firstResource = mock(ResourceCloseable.class); + ResourceCloseable secondResource = mock(ResourceCloseable.class); + + final AtomicReference> racingManagerRef = + new AtomicReference<>(); + + racingManagerRef.set( + new SharedResourceManager<>(new java.util.function.Function() { + @Override + public ResourceCloseable apply(String key) { + int c = callCount.incrementAndGet(); + if (c == 1) { + racingManagerRef.get().acquire(key); + return secondResource; + } + return firstResource; + } + })); + + callCount.set(0); + ResourceCloseable got = racingManagerRef.get().acquire("keyA"); + + // We should get firstResource + assertThat(got).isSameInstanceAs(firstResource); + // secondResource should have been closed because it lost the putIfAbsent race + verify(secondResource).close(); + // firstResource should NOT have been closed + verify(firstResource, never()).close(); + } + + // ---- Tests for SharedResource (nested class) ---- + + @Test + public void sharedResource_initialRefCount_isOne() { + SharedResource shared = new SharedResource<>(mockResource); + assertThat(shared.getRefCount()).isEqualTo(1); + } + + @Test + public void sharedResource_get_returnsWrappedResource() { + SharedResource shared = new SharedResource<>(mockResource); + assertThat(shared.get()).isSameInstanceAs(mockResource); + } + + @Test + public void sharedResource_retain_incrementsRefCount() { + SharedResource shared = new SharedResource<>(mockResource); + boolean retained = shared.retain(); + assertThat(retained).isTrue(); + assertThat(shared.getRefCount()).isEqualTo(2); + } + + @Test + public void sharedResource_retain_multipleIncrements() { + SharedResource shared = new SharedResource<>(mockResource); + shared.retain(); + shared.retain(); + assertThat(shared.getRefCount()).isEqualTo(3); + } + + @Test + public void sharedResource_retain_returnsFalseWhenDead() { + SharedResource shared = new SharedResource<>(mockResource); + shared.release(); + boolean retained = shared.retain(); + assertThat(retained).isFalse(); + assertThat(shared.getRefCount()).isEqualTo(0); + } + + @Test + public void sharedResource_release_decrementsRefCount() { + SharedResource shared = new SharedResource<>(mockResource); + shared.retain(); // refCount = 2 + boolean closed = shared.release(); // refCount = 1 + assertThat(closed).isFalse(); + assertThat(shared.getRefCount()).isEqualTo(1); + verify(mockResource, never()).close(); + } + + @Test + public void sharedResource_release_closesResourceWhenRefCountReachesZero() { + SharedResource shared = new SharedResource<>(mockResource); + boolean closed = shared.release(); + assertThat(closed).isTrue(); + assertThat(shared.getRefCount()).isEqualTo(0); + verify(mockResource).close(); + } + + @Test + public void sharedResource_release_throwsAssertionErrorOnUnderflow() { + SharedResource shared = new SharedResource<>(mockResource); + shared.release(); // refCount = 0, closes resource + verify(mockResource).close(); + + assertThrows( + AssertionError.class, + () -> shared.release()); + + // Verify close was only called once (the first time) + verify(mockResource, times(1)).close(); + } + + // ---- Tests for ManagedChannelResource (nested class) ---- + + @Test + public void channelResource_constructor_rejectsNullChannel() { + assertThrows(NullPointerException.class, () -> new ManagedChannelResource(null)); + } + + @Test + public void channelResource_getChannel_returnsWrappedChannel() { + ManagedChannelResource resource = new ManagedChannelResource(mockChannel); + assertThat(resource.getChannel()).isSameInstanceAs(mockChannel); + } + + @Test + public void channelResource_close_callsChannelShutdown() { + ManagedChannelResource resource = new ManagedChannelResource(mockChannel); + resource.close(); + verify(mockChannel).shutdown(); + } + + @Test + public void channelResource_close_doesNotCallShutdownNow() { + ManagedChannelResource resource = new ManagedChannelResource(mockChannel); + resource.close(); + verify(mockChannel, never()).shutdownNow(); + } + + @Test + public void channelResource_implementsResourceCloseable() { + ManagedChannelResource resource = new ManagedChannelResource(mockChannel); + assertThat(resource).isInstanceOf(ResourceCloseable.class); + assertThat(resource).isInstanceOf(AutoCloseable.class); + } +} diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java index df3a0af5111..ea6c855f290 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java @@ -96,6 +96,9 @@ import io.grpc.xds.XdsClusterResource.CdsUpdate; import io.grpc.xds.XdsEndpointResource.EdsUpdate; import io.grpc.xds.XdsListenerResource.LdsUpdate; +import io.grpc.xds.XdsNameResolver.RefCountedRoute; +import io.grpc.xds.XdsNameResolver.RefCountedRouteInterceptor; +import io.grpc.xds.XdsNameResolver.RouteData; import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; import io.grpc.xds.client.Bootstrapper.AuthorityInfo; import io.grpc.xds.client.Bootstrapper.BootstrapInfo; @@ -115,6 +118,7 @@ import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.Nullable; import org.junit.After; @@ -2440,6 +2444,407 @@ public void resolved_faultConfigOverrideInLdsAndInRdsUpdate() { observer, Status.UNKNOWN.withDescription("RPC terminated due to fault injection")); } + @Test + public void updateRoutes_withInFlightStream_delaysCleanupUntilStreamClose() { + // 1. Create a test filter that registers a mocked cleanup Runnable + Runnable mockCleanup = mock(Runnable.class); + Filter mockFilter = new Filter() { + @Override + public ClientInterceptor buildClientInterceptor( + FilterConfig config, + FilterConfig overrideConfig, + ScheduledExecutorService scheduler, + Filter.ResourceCleanupRegistry cleanupRegistry) { + cleanupRegistry.addCleanupTask(mockCleanup); + return new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return next.newCall(method, callOptions); + } + }; + } + }; + Filter.Provider mockProvider = mock(Filter.Provider.class); + when(mockProvider.typeUrls()).thenReturn(new String[]{"type.googleapis.com/dummy"}); + when(mockProvider.newInstance(any(String.class))).thenReturn(mockFilter); + FilterRegistry customRegistry = FilterRegistry.newRegistry().register(mockProvider); + + XdsNameResolver customResolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, + serviceConfigParser, syncContext, scheduler, + xdsClientPoolFactory, mockRandom, customRegistry, rawBootstrap, metricRecorder, + nameResolverArgs); + customResolver.start(mockListener); + + FakeXdsClient fakeXdsClient = (FakeXdsClient) xdsClientPoolFactory.xdsClient; + + // Deliver initial LDS update + VirtualHost vhost = VirtualHost.create( + "virtual-host", Collections.singletonList(expectedLdsResourceName), + Collections.singletonList(Route.forAction( + RouteMatch.create(PathMatcher.fromPrefix("/", false), Collections.emptyList(), null), + RouteAction.forCluster("cluster0", Collections.emptyList(), null, null, false), + Collections.emptyMap())), + Collections.emptyMap()); + NamedFilterConfig namedConfig = + new NamedFilterConfig("dummy_filter", () -> "type.googleapis.com/dummy"); + fakeXdsClient.deliverLdsUpdateWithFilters(vhost, Collections.singletonList(namedConfig)); + createAndDeliverClusterUpdates(fakeXdsClient, "cluster0"); + + verify(mockListener).onResult2(resolutionResultCaptor.capture()); + ResolutionResult result = resolutionResultCaptor.getValue(); + InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); + + // 2. Start ClientCall (stream 1) via configSelector.selectConfig. + // RefCountedRoute refCount is now 2. + startNewCall(TestMethodDescriptors.voidMethod(), + configSelector, Collections.emptyMap(), CallOptions.DEFAULT); + verify(mockCleanup, never()).run(); + + // 3. Deliver LDS 2 with a new route configuration. XdsNameResolver calls oldConfig.close(), + // decrementing the old route's refCount from 2 to 1. Verify cleanup task has NOT run yet. + fakeXdsClient.deliverLdsUpdateWithFilters(vhost, Collections.emptyList()); + verify(mockCleanup, never()).run(); + + // 4. Terminate stream 1 (Listener#onClose). RefCountedRouteInterceptor calls release(), + // decrementing refCount from 1 to 0. Verify cleanup task is now executed exactly once. + testCall.deliverCompleted(); + verify(mockCleanup, times(1)).run(); + } + + @Test + public void updateRoutes_multipleConfigUpdates_retainsAndReleasesCorrectly() { + // 1. Create a test filter that registers a mocked cleanup Runnable + Runnable mockCleanup = mock(Runnable.class); + Filter mockFilter = new Filter() { + @Override + public ClientInterceptor buildClientInterceptor( + FilterConfig config, + FilterConfig overrideConfig, + ScheduledExecutorService scheduler, + Filter.ResourceCleanupRegistry cleanupRegistry) { + cleanupRegistry.addCleanupTask(mockCleanup); + return new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return next.newCall(method, callOptions); + } + }; + } + }; + Filter.Provider mockProvider = mock(Filter.Provider.class); + when(mockProvider.typeUrls()).thenReturn(new String[]{"type.googleapis.com/dummy"}); + when(mockProvider.newInstance(any(String.class))).thenReturn(mockFilter); + FilterRegistry customRegistry = FilterRegistry.newRegistry().register(mockProvider); + + XdsNameResolver customResolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, + serviceConfigParser, syncContext, scheduler, + xdsClientPoolFactory, mockRandom, customRegistry, rawBootstrap, metricRecorder, + nameResolverArgs); + customResolver.start(mockListener); + + FakeXdsClient fakeXdsClient = (FakeXdsClient) xdsClientPoolFactory.xdsClient; + + // Deliver initial LDS update (Config 1 with filter) + VirtualHost vhost = VirtualHost.create( + "virtual-host", Collections.singletonList(expectedLdsResourceName), + Collections.singletonList(Route.forAction( + RouteMatch.create(PathMatcher.fromPrefix("/", false), Collections.emptyList(), null), + RouteAction.forCluster("cluster0", Collections.emptyList(), null, null, false), + Collections.emptyMap())), + Collections.emptyMap()); + NamedFilterConfig namedConfig = + new NamedFilterConfig("dummy_filter", () -> "type.googleapis.com/dummy"); + fakeXdsClient.deliverLdsUpdateWithFilters(vhost, Collections.singletonList(namedConfig)); + createAndDeliverClusterUpdates(fakeXdsClient, "cluster0"); + + verify(mockListener).onResult2(resolutionResultCaptor.capture()); + InternalConfigSelector configSelector1 = + resolutionResultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY); + + // Start Call 1 on Config 1. + // RefCountedRoute 1 refCount becomes 2 (1 from control plane, 1 from stream) + startNewCall(TestMethodDescriptors.voidMethod(), + configSelector1, Collections.emptyMap(), CallOptions.DEFAULT); + TestCall call1 = testCall; // Capture Call 1 + verify(mockCleanup, never()).run(); + + // Deliver LDS 2 (Config 2 - no filters) + // This updates the resolution result, replacing the old config selector and closing it. + // Config 1 is closed. RefCountedRoute 1 refCount decrements 2 -> 1 (held by Call 1). + fakeXdsClient.deliverLdsUpdateWithFilters(vhost, Collections.emptyList()); + verify(mockCleanup, never()).run(); + + // Capture Config Selector 2 + verify(mockListener, times(2)).onResult2(resolutionResultCaptor.capture()); + InternalConfigSelector configSelector2 = + resolutionResultCaptor.getAllValues().get(1) + .getAttributes().get(InternalConfigSelector.KEY); + + // Start Call 2 on Config 2. + startNewCall(TestMethodDescriptors.voidMethod(), + configSelector2, Collections.emptyMap(), CallOptions.DEFAULT); + TestCall call2 = testCall; // Capture Call 2 + verify(mockCleanup, never()).run(); + + // Deliver LDS 3 (Config 3 - no filters) + // This closes Config 2. + fakeXdsClient.deliverLdsUpdateWithFilters(vhost, Collections.emptyList()); + + // At this point: + // - Config 1 is closed. RefCountedRoute 1 is at refCount=1 (held by active Call 1). + // - Config 2 is closed. RefCountedRoute 2 is at refCount=1 (held by active Call 2). + // - Config 3 is active (refCount=1). + // Cleanup has not run yet. + verify(mockCleanup, never()).run(); + + // Now terminate Call 2. This decrements RefCountedRoute 2 refCount from 1 to 0. + // Route 2 has no filters, so no cleanup should run. + call2.deliverCompleted(); + verify(mockCleanup, never()).run(); + + // Now terminate Call 1. This decrements RefCountedRoute 1 refCount from 1 to 0. + // This should trigger Route 1's cleanup task (mockCleanup) exactly once. + call1.deliverCompleted(); + verify(mockCleanup, times(1)).run(); + } + + @Test + public void + cleanUpRoutes_withInFlightStream_delaysCleanupUntilStreamClose() { + // 1. Create a test filter that registers a mocked cleanup Runnable + Runnable mockCleanup = mock(Runnable.class); + Filter mockFilter = new Filter() { + @Override + public ClientInterceptor buildClientInterceptor( + FilterConfig config, + FilterConfig overrideConfig, + ScheduledExecutorService scheduler, + Filter.ResourceCleanupRegistry cleanupRegistry) { + cleanupRegistry.addCleanupTask(mockCleanup); + return new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, + CallOptions callOptions, Channel next) { + return next.newCall(method, callOptions); + } + }; + } + }; + Filter.Provider mockProvider = mock(Filter.Provider.class); + when(mockProvider.typeUrls()) + .thenReturn(new String[]{"type.googleapis.com/dummy"}); + when(mockProvider.newInstance(any(String.class))) + .thenReturn(mockFilter); + FilterRegistry customRegistry = + FilterRegistry.newRegistry().register(mockProvider); + + XdsNameResolver customResolver = new XdsNameResolver( + targetUri, null, AUTHORITY, null, + serviceConfigParser, syncContext, scheduler, + xdsClientPoolFactory, mockRandom, customRegistry, + rawBootstrap, metricRecorder, nameResolverArgs); + customResolver.start(mockListener); + + FakeXdsClient fakeXdsClient = + (FakeXdsClient) xdsClientPoolFactory.xdsClient; + + // 2. Deliver initial LDS update with the filter + VirtualHost vhost = VirtualHost.create( + "virtual-host", + Collections.singletonList(expectedLdsResourceName), + Collections.singletonList(Route.forAction( + RouteMatch.create( + PathMatcher.fromPrefix("/", false), + Collections.emptyList(), null), + RouteAction.forCluster( + "cluster0", Collections.emptyList(), + null, null, false), + Collections.emptyMap())), + Collections.emptyMap()); + NamedFilterConfig namedConfig = new NamedFilterConfig( + "dummy_filter", () -> "type.googleapis.com/dummy"); + fakeXdsClient.deliverLdsUpdateWithFilters( + vhost, Collections.singletonList(namedConfig)); + createAndDeliverClusterUpdates(fakeXdsClient, "cluster0"); + + verify(mockListener).onResult2(resolutionResultCaptor.capture()); + InternalConfigSelector configSelector = + resolutionResultCaptor.getValue().getAttributes() + .get(InternalConfigSelector.KEY); + + // 3. Start a call. RefCountedRoute refCount becomes 2. + startNewCall(TestMethodDescriptors.voidMethod(), + configSelector, Collections.emptyMap(), + CallOptions.DEFAULT); + verify(mockCleanup, never()).run(); + + // 4. Deliver error — triggers cleanUpRoutes(), closing the + // old config. RefCountedRoute refCount goes from 2 to 1. + fakeXdsClient.deliverError( + Status.UNAVAILABLE.withDescription("server unreachable")); + verify(mockCleanup, never()).run(); + + // 5. Terminate the stream. RefCountedRoute refCount goes + // from 1 to 0. Cleanup task should run exactly once. + testCall.deliverCompleted(); + verify(mockCleanup, times(1)).run(); + } + + @Test + public void + shutdown_withInFlightStream_delaysCleanupUntilStreamClose() { + // 1. Create a test filter that registers a mocked cleanup Runnable + Runnable mockCleanup = mock(Runnable.class); + Filter mockFilter = new Filter() { + @Override + public ClientInterceptor buildClientInterceptor( + FilterConfig config, + FilterConfig overrideConfig, + ScheduledExecutorService scheduler, + Filter.ResourceCleanupRegistry cleanupRegistry) { + cleanupRegistry.addCleanupTask(mockCleanup); + return new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, + CallOptions callOptions, Channel next) { + return next.newCall(method, callOptions); + } + }; + } + }; + Filter.Provider mockProvider = mock(Filter.Provider.class); + when(mockProvider.typeUrls()) + .thenReturn(new String[]{"type.googleapis.com/dummy"}); + when(mockProvider.newInstance(any(String.class))) + .thenReturn(mockFilter); + FilterRegistry customRegistry = + FilterRegistry.newRegistry().register(mockProvider); + + XdsNameResolver customResolver = new XdsNameResolver( + targetUri, null, AUTHORITY, null, + serviceConfigParser, syncContext, scheduler, + xdsClientPoolFactory, mockRandom, customRegistry, + rawBootstrap, metricRecorder, nameResolverArgs); + customResolver.start(mockListener); + + FakeXdsClient fakeXdsClient = + (FakeXdsClient) xdsClientPoolFactory.xdsClient; + + // 2. Deliver initial LDS update with the filter + VirtualHost vhost = VirtualHost.create( + "virtual-host", + Collections.singletonList(expectedLdsResourceName), + Collections.singletonList(Route.forAction( + RouteMatch.create( + PathMatcher.fromPrefix("/", false), + Collections.emptyList(), null), + RouteAction.forCluster( + "cluster0", Collections.emptyList(), + null, null, false), + Collections.emptyMap())), + Collections.emptyMap()); + NamedFilterConfig namedConfig = new NamedFilterConfig( + "dummy_filter", () -> "type.googleapis.com/dummy"); + fakeXdsClient.deliverLdsUpdateWithFilters( + vhost, Collections.singletonList(namedConfig)); + createAndDeliverClusterUpdates(fakeXdsClient, "cluster0"); + + verify(mockListener).onResult2(resolutionResultCaptor.capture()); + InternalConfigSelector configSelector = + resolutionResultCaptor.getValue().getAttributes() + .get(InternalConfigSelector.KEY); + + // 3. Start a call. RefCountedRoute refCount becomes 2. + startNewCall(TestMethodDescriptors.voidMethod(), + configSelector, Collections.emptyMap(), + CallOptions.DEFAULT); + verify(mockCleanup, never()).run(); + + // 4. Shut down the resolver. This closes the current config. + // RefCountedRoute refCount goes from 2 to 1. + customResolver.shutdown(); + verify(mockCleanup, never()).run(); + + // 5. Terminate the stream. RefCountedRoute refCount goes + // from 1 to 0. Cleanup task should run exactly once. + testCall.deliverCompleted(); + verify(mockCleanup, times(1)).run(); + } + + @Test + public void shutdown_noInFlightStreams_runsCleanupImmediately() { + // 1. Create a test filter that registers a mocked cleanup Runnable + Runnable mockCleanup = mock(Runnable.class); + Filter mockFilter = new Filter() { + @Override + public ClientInterceptor buildClientInterceptor( + FilterConfig config, + FilterConfig overrideConfig, + ScheduledExecutorService scheduler, + Filter.ResourceCleanupRegistry cleanupRegistry) { + cleanupRegistry.addCleanupTask(mockCleanup); + return new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, + CallOptions callOptions, Channel next) { + return next.newCall(method, callOptions); + } + }; + } + }; + Filter.Provider mockProvider = mock(Filter.Provider.class); + when(mockProvider.typeUrls()) + .thenReturn(new String[]{"type.googleapis.com/dummy"}); + when(mockProvider.newInstance(any(String.class))) + .thenReturn(mockFilter); + FilterRegistry customRegistry = + FilterRegistry.newRegistry().register(mockProvider); + + XdsNameResolver customResolver = new XdsNameResolver( + targetUri, null, AUTHORITY, null, + serviceConfigParser, syncContext, scheduler, + xdsClientPoolFactory, mockRandom, customRegistry, + rawBootstrap, metricRecorder, nameResolverArgs); + customResolver.start(mockListener); + + FakeXdsClient fakeXdsClient = + (FakeXdsClient) xdsClientPoolFactory.xdsClient; + + // 2. Deliver initial LDS update with the filter. + // RefCountedRoute refCount is 1 (control-plane only). + VirtualHost vhost = VirtualHost.create( + "virtual-host", + Collections.singletonList(expectedLdsResourceName), + Collections.singletonList(Route.forAction( + RouteMatch.create( + PathMatcher.fromPrefix("/", false), + Collections.emptyList(), null), + RouteAction.forCluster( + "cluster0", Collections.emptyList(), + null, null, false), + Collections.emptyMap())), + Collections.emptyMap()); + NamedFilterConfig namedConfig = new NamedFilterConfig( + "dummy_filter", () -> "type.googleapis.com/dummy"); + fakeXdsClient.deliverLdsUpdateWithFilters( + vhost, Collections.singletonList(namedConfig)); + createAndDeliverClusterUpdates(fakeXdsClient, "cluster0"); + + verify(mockListener).onResult2(resolutionResultCaptor.capture()); + verify(mockCleanup, never()).run(); + + // 3. Shut down the resolver without starting any RPCs. + // RefCountedRoute refCount goes from 1 to 0. + // Cleanup task should run immediately. + customResolver.shutdown(); + verify(mockCleanup, times(1)).run(); + } + private ClientCall.Listener startNewCall( MethodDescriptor method, InternalConfigSelector selector, Map headers, CallOptions callOptions) { @@ -2957,4 +3362,384 @@ void deliverErrorStatus() { listener.onClose(Status.UNAVAILABLE, new Metadata()); } } + + // ---- RefCountedRoute unit tests ---- + + @Test + public void refCountedRoute_initialRefCountIsOne() { + RouteData routeData = new RouteData( + RouteMatch.withPathExactOnly("/svc/method"), + null, ImmutableList.of()); + RefCountedRoute route = + new RefCountedRoute(routeData, ImmutableList.of(), syncContext); + // Release from 1 -> 0; should not throw + route.release(); + // Retain on a dead route (refCount==0) should return false + assertThat(route.retain()).isFalse(); + } + + @Test + public void refCountedRoute_retainAndRelease() { + AtomicBoolean cleaned = new AtomicBoolean(false); + RouteData routeData = new RouteData( + RouteMatch.withPathExactOnly("/svc/method"), + null, ImmutableList.of()); + RefCountedRoute route = new RefCountedRoute( + routeData, + ImmutableList.of(() -> cleaned.set(true)), syncContext); + + // retain: 1 -> 2 + route.retain(); + // release: 2 -> 1 + route.release(); + assertThat(cleaned.get()).isFalse(); + + // release: 1 -> 0 => cleanup fires + route.release(); + assertThat(cleaned.get()).isTrue(); + } + + @Test + public void refCountedRoute_retainOnDeadRoute_returnsFalse() { + AtomicBoolean cleaned = new AtomicBoolean(false); + RouteData routeData = new RouteData( + RouteMatch.withPathExactOnly("/svc/method"), + null, ImmutableList.of()); + RefCountedRoute route = new RefCountedRoute( + routeData, + ImmutableList.of(() -> cleaned.set(true)), syncContext); + route.release(); // refCount 1 -> 0 + assertThat(cleaned.get()).isTrue(); + // retain on dead route returns false + assertThat(route.retain()).isFalse(); + // subsequent release throws AssertionError + try { + route.release(); + assertWithMessage("Expected AssertionError").fail(); + } catch (AssertionError e) { + assertThat(e).hasMessageThat().isNull(); + } + } + + @Test + public void refCountedRoute_releaseAtZero_runsAllCleanupTasks() { + AtomicBoolean task1Ran = new AtomicBoolean(false); + AtomicBoolean task2Ran = new AtomicBoolean(false); + AtomicBoolean task3Ran = new AtomicBoolean(false); + RouteData routeData = new RouteData( + RouteMatch.withPathExactOnly("/svc/method"), + null, ImmutableList.of()); + RefCountedRoute route = new RefCountedRoute( + routeData, + ImmutableList.of( + () -> task1Ran.set(true), + () -> task2Ran.set(true), + () -> task3Ran.set(true)), syncContext); + route.release(); + assertThat(task1Ran.get()).isTrue(); + assertThat(task2Ran.get()).isTrue(); + assertThat(task3Ran.get()).isTrue(); + } + + @Test + public void refCountedRoute_cleanupTaskException_doesNotAbortRemainingTasks() { + AtomicBoolean task1Ran = new AtomicBoolean(false); + AtomicBoolean task3Ran = new AtomicBoolean(false); + RouteData routeData = new RouteData( + RouteMatch.withPathExactOnly("/svc/method"), + null, ImmutableList.of()); + RefCountedRoute route = new RefCountedRoute( + routeData, + ImmutableList.of( + () -> task1Ran.set(true), + () -> { + throw new RuntimeException("boom"); + }, + () -> task3Ran.set(true)), syncContext); + route.release(); + assertThat(task1Ran.get()).isTrue(); + assertThat(task3Ran.get()).isTrue(); + } + + @Test + public void refCountedRoute_redundantRelease_throwsAssertionError() { + AtomicLong runCount = new AtomicLong(0); + RouteData routeData = new RouteData( + RouteMatch.withPathExactOnly("/svc/method"), + null, ImmutableList.of()); + RefCountedRoute route = new RefCountedRoute( + routeData, + ImmutableList.of(runCount::incrementAndGet), syncContext); + route.release(); // refCount 1 -> 0, cleanup fires + assertThat(runCount.get()).isEqualTo(1); + try { + route.release(); // underflow + assertWithMessage("Expected AssertionError").fail(); + } catch (AssertionError e) { + assertThat(e).hasMessageThat().isNull(); + } + } + + @Test + public void refCountedRoute_multipleRetainRelease_tasksFireOnlyOnFinalRelease() { + AtomicLong runCount = new AtomicLong(0); + RouteData routeData = new RouteData( + RouteMatch.withPathExactOnly("/svc/method"), + null, ImmutableList.of()); + RefCountedRoute route = new RefCountedRoute( + routeData, + ImmutableList.of(runCount::incrementAndGet), syncContext); + route.retain(); // 1 -> 2 + route.retain(); // 2 -> 3 + route.release(); // 3 -> 2 + assertThat(runCount.get()).isEqualTo(0); + route.release(); // 2 -> 1 + assertThat(runCount.get()).isEqualTo(0); + route.release(); // 1 -> 0, cleanup fires + assertThat(runCount.get()).isEqualTo(1); + } + + // ---- RefCountedRouteInterceptor tests ---- + + @Test + public void refCountedRouteInterceptor_onClose_releasesExactlyOnce() { + AtomicLong cleanupCount = new AtomicLong(0); + RouteData routeData = new RouteData( + RouteMatch.withPathExactOnly("/svc/method"), + null, ImmutableList.of()); + RefCountedRoute route = new RefCountedRoute( + routeData, + ImmutableList.of(cleanupCount::incrementAndGet), syncContext); + route.retain(); // explicitly retain for in-flight RPC + RefCountedRouteInterceptor interceptor = + new RefCountedRouteInterceptor(route); + Channel mockChannel = mock(Channel.class); + @SuppressWarnings("unchecked") + ClientCall mockCall = mock(ClientCall.class); + org.mockito.Mockito.doReturn(mockCall) + .when(mockChannel).newCall(any(), any()); + // interceptCall + ClientCall wrappedCall = interceptor.interceptCall( + TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, + mockChannel); + // Release control-plane ref: 2 -> 1 + route.release(); + assertThat(cleanupCount.get()).isEqualTo(0); + // Start the call to wire up the listener + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = + ArgumentCaptor.forClass(ClientCall.Listener.class); + wrappedCall.start( + new NoopClientCallListener<>(), new Metadata()); + verify(mockCall).start(listenerCaptor.capture(), any()); + // Trigger onClose: interceptor releases 1 -> 0 + listenerCaptor.getValue().onClose(Status.OK, new Metadata()); + assertThat(cleanupCount.get()).isEqualTo(1); + } + + @Test + public void refCountedRouteInterceptor_cancel_releasesExactlyOnce() { + AtomicLong cleanupCount = new AtomicLong(0); + RouteData routeData = new RouteData( + RouteMatch.withPathExactOnly("/svc/method"), + null, ImmutableList.of()); + RefCountedRoute route = new RefCountedRoute( + routeData, + ImmutableList.of(cleanupCount::incrementAndGet), syncContext); + route.retain(); // explicitly retain for in-flight RPC + RefCountedRouteInterceptor interceptor = + new RefCountedRouteInterceptor(route); + Channel mockChannel = mock(Channel.class); + @SuppressWarnings("unchecked") + ClientCall mockCall = mock(ClientCall.class); + org.mockito.Mockito.doReturn(mockCall) + .when(mockChannel).newCall(any(), any()); + // interceptCall + ClientCall wrappedCall = interceptor.interceptCall( + TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, + mockChannel); + // Release control-plane ref: 2 -> 1 + route.release(); + assertThat(cleanupCount.get()).isEqualTo(0); + // Cancel the call: interceptor releases 1 -> 0 + wrappedCall.cancel("test cancel", null); + assertThat(cleanupCount.get()).isEqualTo(1); + } + + @Test + public void + refCountedRouteInterceptor_cancelAfterOnClose_noDoubleRelease() { + AtomicLong cleanupCount = new AtomicLong(0); + RouteData routeData = new RouteData( + RouteMatch.withPathExactOnly("/svc/method"), + null, ImmutableList.of()); + RefCountedRoute route = new RefCountedRoute( + routeData, + ImmutableList.of(cleanupCount::incrementAndGet), syncContext); + route.retain(); // explicitly retain + RefCountedRouteInterceptor interceptor = + new RefCountedRouteInterceptor(route); + Channel mockChannel = mock(Channel.class); + @SuppressWarnings("unchecked") + ClientCall mockCall = mock(ClientCall.class); + org.mockito.Mockito.doReturn(mockCall) + .when(mockChannel).newCall(any(), any()); + ClientCall wrappedCall = interceptor.interceptCall( + TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, + mockChannel); + // Release control-plane ref + route.release(); + // Start the call + wrappedCall.start( + new NoopClientCallListener<>(), new Metadata()); + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = + ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(mockCall).start(listenerCaptor.capture(), any()); + // onClose releases once + listenerCaptor.getValue().onClose(Status.OK, new Metadata()); + assertThat(cleanupCount.get()).isEqualTo(1); + // cancel after onClose should not double-release + wrappedCall.cancel("late cancel", null); + assertThat(cleanupCount.get()).isEqualTo(1); + } + + @Test + public void + refCountedRouteInterceptor_onCloseAfterCancel_noDoubleRelease() { + AtomicLong cleanupCount = new AtomicLong(0); + RouteData routeData = new RouteData( + RouteMatch.withPathExactOnly("/svc/method"), + null, ImmutableList.of()); + RefCountedRoute route = new RefCountedRoute( + routeData, + ImmutableList.of(cleanupCount::incrementAndGet), syncContext); + route.retain(); // explicitly retain + RefCountedRouteInterceptor interceptor = + new RefCountedRouteInterceptor(route); + Channel mockChannel = mock(Channel.class); + @SuppressWarnings("unchecked") + ClientCall mockCall = mock(ClientCall.class); + org.mockito.Mockito.doReturn(mockCall) + .when(mockChannel).newCall(any(), any()); + ClientCall wrappedCall = interceptor.interceptCall( + TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, + mockChannel); + // Release control-plane ref + route.release(); + // Start the call + wrappedCall.start( + new NoopClientCallListener<>(), new Metadata()); + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = + ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(mockCall).start(listenerCaptor.capture(), any()); + // cancel releases once (1 -> 0) + wrappedCall.cancel("cancel first", null); + assertThat(cleanupCount.get()).isEqualTo(1); + // onClose after cancel should not double-release + listenerCaptor.getValue().onClose(Status.CANCELLED, new Metadata()); + assertThat(cleanupCount.get()).isEqualTo(1); + } + + @Test + public void + refCountedRouteInterceptor_newCallException_releasesReliably() { + AtomicLong cleanupCount = new AtomicLong(0); + RouteData routeData = new RouteData( + RouteMatch.withPathExactOnly("/svc/method"), + null, ImmutableList.of()); + RefCountedRoute route = new RefCountedRoute( + routeData, + ImmutableList.of(cleanupCount::incrementAndGet), syncContext); + route.retain(); // explicitly retain + Channel mockChannel = mock(Channel.class); + when(mockChannel.newCall(any(), any())) + .thenThrow(new RuntimeException("newCall failed")); + // Release control-plane ref: 2 -> 1 (held by interceptCall next) + route.release(); + assertThat(cleanupCount.get()).isEqualTo(0); + // But first interceptCall will retain (already done above: 1 -> 2), + // then newCall throws, so it releases (2 -> 1) + // Let's execute interceptCall directly to see it throw and release: + RefCountedRouteInterceptor interceptor = new RefCountedRouteInterceptor(route); + try { + interceptor.interceptCall( + TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, + mockChannel); + assertWithMessage("Expected RuntimeException").fail(); + } catch (RuntimeException e) { + assertThat(e).hasMessageThat().isEqualTo("newCall failed"); + } + // interceptor released on exception: 1 -> 0, cleanup fires! + assertThat(cleanupCount.get()).isEqualTo(1); + + // Reset for the second part + cleanupCount.set(0); + RouteData routeData2 = new RouteData( + RouteMatch.withPathExactOnly("/svc/method2"), + null, ImmutableList.of()); + RefCountedRoute route2 = new RefCountedRoute( + routeData2, + ImmutableList.of(cleanupCount::incrementAndGet), syncContext); + route2.retain(); // explicitly retain + RefCountedRouteInterceptor interceptor2 = + new RefCountedRouteInterceptor(route2); + try { + interceptor2.interceptCall( + TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, + mockChannel); + assertWithMessage("Expected RuntimeException").fail(); + } catch (RuntimeException e) { + assertThat(e).hasMessageThat().isEqualTo("newCall failed"); + } + // The interceptor should have released after the exception (2 -> 1). + // Control-plane ref is still held (refCount=1), so cleanup should NOT have fired yet. + assertThat(cleanupCount.get()).isEqualTo(0); + // Now release the control-plane ref (1 -> 0) + route2.release(); + assertThat(cleanupCount.get()).isEqualTo(1); + } + + @Test + public void + refCountedRouteInterceptor_startException_releasesReliably() { + AtomicLong cleanupCount = new AtomicLong(0); + RouteData routeData = new RouteData( + RouteMatch.withPathExactOnly("/svc/method"), + null, ImmutableList.of()); + RefCountedRoute route = new RefCountedRoute( + routeData, + ImmutableList.of(cleanupCount::incrementAndGet), syncContext); + route.retain(); // explicitly retain + RefCountedRouteInterceptor interceptor = + new RefCountedRouteInterceptor(route); + Channel mockChannel = mock(Channel.class); + @SuppressWarnings("unchecked") + ClientCall mockCall = mock(ClientCall.class); + org.mockito.Mockito.doReturn(mockCall) + .when(mockChannel).newCall(any(), any()); + // Make start() throw + org.mockito.stubbing.Answer throwOnStart = invocation -> { + throw new RuntimeException("start failed"); + }; + org.mockito.Mockito.doAnswer(throwOnStart) + .when(mockCall).start(any(), any()); + ClientCall wrappedCall = interceptor.interceptCall( + TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, + mockChannel); + // Release control-plane ref: 2 -> 1 + route.release(); + assertThat(cleanupCount.get()).isEqualTo(0); + try { + wrappedCall.start( + new NoopClientCallListener<>(), new Metadata()); + assertWithMessage("Expected RuntimeException").fail(); + } catch (RuntimeException e) { + assertThat(e).hasMessageThat().isEqualTo("start failed"); + } + // The interceptor should release on start exception: 1 -> 0 + assertThat(cleanupCount.get()).isEqualTo(1); + } } +