diff --git a/frontend/csi/grpc.go b/frontend/csi/grpc.go index 7d231cfc9..450a96e1d 100644 --- a/frontend/csi/grpc.go +++ b/frontend/csi/grpc.go @@ -29,13 +29,17 @@ type NonBlockingGRPCServer interface { Stop() } -func NewNonBlockingGRPCServer() NonBlockingGRPCServer { - return &nonBlockingGRPCServer{} +// NewNonBlockingGRPCServer creates a gRPC server with optional extra interceptors inserted between the timeout +// and metrics interceptors. For CSINode and CSIAllInOne roles the caller should pass nodeRegistrationInterceptor; +// for CSIController it should be omitted so the interceptor is never in the chain. +func NewNonBlockingGRPCServer(extraInterceptors ...grpc.UnaryServerInterceptor) NonBlockingGRPCServer { + return &nonBlockingGRPCServer{extraInterceptors: extraInterceptors} } // NonBlocking server type nonBlockingGRPCServer struct { - server *grpc.Server + server *grpc.Server + extraInterceptors []grpc.UnaryServerInterceptor } func (s *nonBlockingGRPCServer) Start( @@ -45,11 +49,28 @@ func (s *nonBlockingGRPCServer) Start( } func (s *nonBlockingGRPCServer) GracefulStop() { - s.server.GracefulStop() + if s.server != nil { + s.server.GracefulStop() + } } func (s *nonBlockingGRPCServer) Stop() { - s.server.Stop() + if s.server != nil { + s.server.Stop() + } +} + +// buildInterceptorChain constructs the gRPC unary interceptor chain. The log and timeout interceptors are always +// present. Any extra interceptors (e.g. nodeRegistrationInterceptor for Node/AllInOne roles) are inserted between +// the timeout and metrics interceptors so they only run for the roles that need them. +func (s *nonBlockingGRPCServer) buildInterceptorChain() []grpc.UnaryServerInterceptor { + chain := []grpc.UnaryServerInterceptor{ + logGRPCInterceptor, + timeoutInterceptor, + } + chain = append(chain, s.extraInterceptors...) + chain = append(chain, incomingRequestMetricsInterceptor) + return chain } func (s *nonBlockingGRPCServer) serve( @@ -107,7 +128,7 @@ func (s *nonBlockingGRPCServer) serve( // The first interceptor is always the outermost. // When CSI calls come in, the outermost interceptor is hit first. // The log gRPC and timeout interceptors should always be the first in the chain. - grpc.ChainUnaryInterceptor(logGRPCInterceptor, timeoutInterceptor, incomingRequestMetricsInterceptor), + grpc.ChainUnaryInterceptor(s.buildInterceptorChain()...), } server := grpc.NewServer(opts...) s.server = server diff --git a/frontend/csi/interceptor.go b/frontend/csi/interceptor.go index e574b2560..24b8d58d0 100644 --- a/frontend/csi/interceptor.go +++ b/frontend/csi/interceptor.go @@ -10,6 +10,8 @@ import ( "github.com/container-storage-interface/spec/lib/go/csi" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" . "github.com/netapp/trident/logging" ) @@ -60,6 +62,20 @@ var operationRegistry = map[string]operationMeta{ csi.GroupController_DeleteVolumeGroupSnapshot_FullMethodName: {WorkflowGroupSnapshotDelete, ContextRequestClientCSISnapshotter, http.MethodDelete}, } +var nodeRegistrationAllowedMethods = map[string]struct{}{ + csi.Node_NodeGetCapabilities_FullMethodName: {}, + csi.Node_NodeGetInfo_FullMethodName: {}, +} + +func nodeMethodAllowedBeforeRegistration(fullMethod string) bool { + _, ok := nodeRegistrationAllowedMethods[fullMethod] + return ok +} + +func isNodeMethod(fullMethod string) bool { + return len(fullMethod) >= len("/csi.v1.Node/") && fullMethod[:len("/csi.v1.Node/")] == "/csi.v1.Node/" +} + func incomingRequestMetricsInterceptor( ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, ) (resp any, err error) { @@ -116,6 +132,31 @@ func timeoutInterceptor( return handler(ctx, req) } +// nodeRegistrationInterceptor rejects Node RPCs until node registration with the controller completes, except for +// the small allow-list required during startup probing. This keeps the gRPC socket available early for +// node-driver-registrar while remaining safe by default as new Node RPCs are added. +// +// This interceptor is only added to the gRPC chain for CSINode and CSIAllInOne roles; +// CSIController never includes it, so there is no role check here. +func nodeRegistrationInterceptor( + ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, +) (any, error) { + plugin, ok := info.Server.(*Plugin) + if !ok { + return handler(ctx, req) + } + + if plugin.IsReady() || !isNodeMethod(info.FullMethod) || nodeMethodAllowedBeforeRegistration(info.FullMethod) { + return handler(ctx, req) + } + + Logc(ctx).WithFields(LogFields{ + "method": info.FullMethod, + "node": plugin.nodeName, + }).Debug("Rejecting node RPC before node registration completes.") + return nil, status.Error(codes.Unavailable, "node registration with controller is still in progress") +} + // logGRPCInterceptor sets the base context, logs and audit logs all incoming gRPC requests. // It should always be the first interceptor in the chain. // All gRPCs, regardless of timeout, should always be logged. diff --git a/frontend/csi/interceptor_test.go b/frontend/csi/interceptor_test.go index 8ff6a7ce3..1dd52e8dc 100644 --- a/frontend/csi/interceptor_test.go +++ b/frontend/csi/interceptor_test.go @@ -12,6 +12,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" . "github.com/netapp/trident/logging" ) @@ -28,6 +30,13 @@ func serverInfo(server interface{}, fullMethod string) *grpc.UnaryServerInfo { return &grpc.UnaryServerInfo{Server: server, FullMethod: fullMethod} } +// closedCh returns a pre-closed channel (plugin is ready). +func closedCh() chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch +} + func TestOperationRegistry_ContainsAllControllerMethods(t *testing.T) { controllerMethods := []string{ csi.Controller_CreateVolume_FullMethodName, @@ -277,6 +286,164 @@ func TestTimeoutInterceptor_PreservesExistingDeadline(t *testing.T) { assert.Equal(t, existingDeadline, deadline) } +func TestNodeRegistrationInterceptor_NodeDataPathBlockedUntilReady(t *testing.T) { + plugin := &Plugin{role: CSINode, nodeName: "node-a", nodeReadyCh: make(chan struct{})} + called := false + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + called = true + return "ok", nil + } + info := serverInfo(plugin, csi.Node_NodeStageVolume_FullMethodName) + + resp, err := nodeRegistrationInterceptor(context.Background(), nil, info, handler) + + assert.Nil(t, resp) + assert.False(t, called, "handler should not be invoked before node registration completes") + require.Error(t, err) + assert.Equal(t, codes.Unavailable, status.Code(err)) +} + +func TestNodeRegistrationInterceptor_NodeInfoAllowedBeforeReady(t *testing.T) { + plugin := &Plugin{role: CSINode, nodeName: "node-a", nodeReadyCh: make(chan struct{})} + called := false + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + called = true + return "ok", nil + } + info := serverInfo(plugin, csi.Node_NodeGetInfo_FullMethodName) + + resp, err := nodeRegistrationInterceptor(context.Background(), nil, info, handler) + + require.NoError(t, err) + assert.Equal(t, "ok", resp) + assert.True(t, called, "handler should be invoked for non-data-path node methods") +} + +func TestNodeRegistrationInterceptor_NodeCapabilitiesAllowedBeforeReady(t *testing.T) { + plugin := &Plugin{role: CSINode, nodeName: "node-a", nodeReadyCh: make(chan struct{})} + called := false + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + called = true + return "ok", nil + } + info := serverInfo(plugin, csi.Node_NodeGetCapabilities_FullMethodName) + + resp, err := nodeRegistrationInterceptor(context.Background(), nil, info, handler) + + require.NoError(t, err) + assert.NotNil(t, resp) + assert.True(t, called, "handler should be invoked for startup-safe node methods") +} + +func TestNodeRegistrationInterceptor_IdentityMethodsAllowedBeforeReady(t *testing.T) { + identityMethods := []string{ + csi.Identity_Probe_FullMethodName, + csi.Identity_GetPluginInfo_FullMethodName, + csi.Identity_GetPluginCapabilities_FullMethodName, + } + + for _, method := range identityMethods { + t.Run(method, func(t *testing.T) { + plugin := &Plugin{role: CSINode, nodeName: "node-a", nodeReadyCh: make(chan struct{})} + called := false + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + called = true + return "ok", nil + } + info := serverInfo(plugin, method) + + resp, err := nodeRegistrationInterceptor(context.Background(), nil, info, handler) + + require.NoError(t, err) + assert.Equal(t, "ok", resp) + assert.True(t, called, "identity methods must remain available while registration is in progress") + }) + } +} + +func TestNodeRegistrationInterceptor_AllInOneControllerMethodAllowedBeforeReady(t *testing.T) { + plugin := &Plugin{role: CSIAllInOne, nodeName: "node-a", nodeReadyCh: make(chan struct{})} + called := false + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + called = true + return "ok", nil + } + info := serverInfo(plugin, csi.Controller_CreateVolume_FullMethodName) + + resp, err := nodeRegistrationInterceptor(context.Background(), nil, info, handler) + + require.NoError(t, err) + assert.Equal(t, "ok", resp) + assert.True(t, called, "controller methods in all-in-one mode should not be gated by node registration") +} + +func TestNodeRegistrationInterceptor_NodeUnstageBlockedUntilReady(t *testing.T) { + plugin := &Plugin{role: CSINode, nodeName: "node-a", nodeReadyCh: make(chan struct{})} + called := false + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + called = true + return "ok", nil + } + info := serverInfo(plugin, csi.Node_NodeUnstageVolume_FullMethodName) + + resp, err := nodeRegistrationInterceptor(context.Background(), nil, info, handler) + + assert.Nil(t, resp) + assert.False(t, called, "handler should not be invoked before node registration completes") + require.Error(t, err) + assert.Equal(t, codes.Unavailable, status.Code(err)) +} + +func TestNodeRegistrationInterceptor_NodeUnpublishBlockedUntilReady(t *testing.T) { + plugin := &Plugin{role: CSINode, nodeName: "node-a", nodeReadyCh: make(chan struct{})} + called := false + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + called = true + return "ok", nil + } + info := serverInfo(plugin, csi.Node_NodeUnpublishVolume_FullMethodName) + + resp, err := nodeRegistrationInterceptor(context.Background(), nil, info, handler) + + assert.Nil(t, resp) + assert.False(t, called, "handler should not be invoked before node registration completes") + require.Error(t, err) + assert.Equal(t, codes.Unavailable, status.Code(err)) +} + +func TestNodeRegistrationInterceptor_UnknownNodeMethodBlockedUntilReady(t *testing.T) { + plugin := &Plugin{role: CSINode, nodeName: "node-a", nodeReadyCh: make(chan struct{})} + called := false + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + called = true + return "ok", nil + } + info := serverInfo(plugin, "/csi.v1.Node/NodeFutureMethod") + + resp, err := nodeRegistrationInterceptor(context.Background(), nil, info, handler) + + assert.Nil(t, resp) + assert.False(t, called, "handler should not be invoked for unknown node methods before registration completes") + require.Error(t, err) + assert.Equal(t, codes.Unavailable, status.Code(err)) +} + +func TestNodeRegistrationInterceptor_NodeDataPathAllowedWhenReady(t *testing.T) { + plugin := &Plugin{role: CSINode, nodeName: "node-a", nodeReadyCh: closedCh()} + called := false + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + called = true + return "ok", nil + } + info := serverInfo(plugin, csi.Node_NodePublishVolume_FullMethodName) + + resp, err := nodeRegistrationInterceptor(context.Background(), nil, info, handler) + + require.NoError(t, err) + assert.Equal(t, "ok", resp) + assert.True(t, called, "handler should be invoked after node registration completes") +} + func initAuditForTest(t *testing.T) { t.Helper() InitAuditLogger(true) @@ -332,15 +499,17 @@ func TestLogGRPCInterceptor_PropagatesError(t *testing.T) { } func TestChainOrder_TimeoutContextVisibleToMetricsInterceptor(t *testing.T) { - plugin := &Plugin{role: CSINode} + plugin := &Plugin{role: CSINode, nodeReadyCh: closedCh()} var handlerCtx context.Context handler := stubHandler(&handlerCtx, nil, nil) info := serverInfo(plugin, csi.Node_NodeStageVolume_FullMethodName) - // Simulate the chain: timeout -> metrics -> handler + // Simulate the chain: timeout -> registration -> metrics -> handler chained := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, h grpc.UnaryHandler) (interface{}, error) { return timeoutInterceptor(ctx, req, info, func(ctx context.Context, req interface{}) (interface{}, error) { - return incomingRequestMetricsInterceptor(ctx, req, info, h) + return nodeRegistrationInterceptor(ctx, req, info, func(ctx context.Context, req interface{}) (interface{}, error) { + return incomingRequestMetricsInterceptor(ctx, req, info, h) + }) }) } @@ -352,3 +521,33 @@ func TestChainOrder_TimeoutContextVisibleToMetricsInterceptor(t *testing.T) { assert.True(t, hasDeadline, "timeout interceptor's deadline should be visible through the chain") assert.Equal(t, WorkflowNodeStage, handlerCtx.Value(ContextKeyWorkflow)) } + +func TestChainOrder_BlockedNodeCallBypassesMetricsInterceptor(t *testing.T) { + plugin := &Plugin{role: CSINode, nodeName: "node-a", nodeReadyCh: make(chan struct{})} + metricsCalled := false + handlerCalled := false + info := serverInfo(plugin, csi.Node_NodeStageVolume_FullMethodName) + + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + handlerCalled = true + return "ok", nil + } + + // Simulate the chain segment relevant to metrics behavior: timeout -> node registration -> metrics -> handler. + chained := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, h grpc.UnaryHandler) (interface{}, error) { + return timeoutInterceptor(ctx, req, info, func(ctx context.Context, req interface{}) (interface{}, error) { + return nodeRegistrationInterceptor(ctx, req, info, func(ctx context.Context, req interface{}) (interface{}, error) { + metricsCalled = true + return incomingRequestMetricsInterceptor(ctx, req, info, h) + }) + }) + } + + resp, err := chained(context.Background(), nil, info, handler) + + assert.Nil(t, resp) + require.Error(t, err) + assert.Equal(t, codes.Unavailable, status.Code(err)) + assert.False(t, metricsCalled, "blocked node calls are rejected before metrics interceptor") + assert.False(t, handlerCalled, "blocked node calls should not reach handler") +} diff --git a/frontend/csi/node_server.go b/frontend/csi/node_server.go index 844ba8a26..80dd9585a 100644 --- a/frontend/csi/node_server.go +++ b/frontend/csi/node_server.go @@ -982,7 +982,7 @@ func (p *Plugin) nodeRegisterWithController(ctx context.Context, timeout time.Du Logc(ctx).WithField("node", p.nodeName).Debug("Topology labels found for node: ", topologyLabels) } - p.nodeIsRegistered = true + p.markNodeReady() } func (p *Plugin) nodeStageNFSVolume( diff --git a/frontend/csi/node_server_test.go b/frontend/csi/node_server_test.go index fd13b2944..2a3ca9039 100644 --- a/frontend/csi/node_server_test.go +++ b/frontend/csi/node_server_test.go @@ -1980,6 +1980,7 @@ func TestNodeRegisterWithController_Success(t *testing.T) { orchestrator: mockOrchestrator, osutils: osutils.New(), iscsi: iscsiClient, + nodeReadyCh: make(chan struct{}), } // Create a fake node response to be returned by controller @@ -2001,7 +2002,7 @@ func TestNodeRegisterWithController_Success(t *testing.T) { nodeServer.nodeRegisterWithController(ctx, 1*time.Second) // assert node is registered - assert.True(t, nodeServer.nodeIsRegistered, "expected node to be registered, but it is not") + assert.True(t, nodeServer.IsReady(), "expected node to be registered, but it is not") } func TestNodeRegisterWithController_TopologyLabels(t *testing.T) { @@ -2025,6 +2026,7 @@ func TestNodeRegisterWithController_TopologyLabels(t *testing.T) { orchestrator: mockOrchestrator, osutils: osutils.New(), iscsi: iscsiClient, + nodeReadyCh: make(chan struct{}), } // Create set of cases with varying topology labels @@ -2084,7 +2086,7 @@ func TestNodeRegisterWithController_TopologyLabels(t *testing.T) { nodeServer.nodeRegisterWithController(ctx, 1*time.Second) // assert node is registered and topology in use is as expected - assert.True(t, nodeServer.nodeIsRegistered, "expected node to be registered, but it is not") + assert.True(t, nodeServer.IsReady(), "expected node to be registered, but it is not") assert.Equal(t, data.expected, nodeServer.topologyInUse, "topologyInUse not as expected") }) } @@ -2111,6 +2113,7 @@ func TestNodeRegisterWithController_Failure(t *testing.T) { orchestrator: mockOrchestrator, iscsi: iscsiClient, osutils: osutils.New(), + nodeReadyCh: make(chan struct{}), } // Create a fake node response to be returned by controller @@ -2126,7 +2129,7 @@ func TestNodeRegisterWithController_Failure(t *testing.T) { nodeServer.nodeRegisterWithController(ctx, 1*time.Second) - assert.True(t, nodeServer.nodeIsRegistered, "expected node to be registered, but it is not") + assert.True(t, nodeServer.IsReady(), "expected node to be registered, but it is not") // Case: Error setting log level mockClient.EXPECT().CreateNode(ctx, gomock.Any()).Return(fakeNodeResponse, nil) @@ -2137,7 +2140,7 @@ func TestNodeRegisterWithController_Failure(t *testing.T) { nodeServer.nodeRegisterWithController(ctx, 1*time.Second) - assert.True(t, nodeServer.nodeIsRegistered, "expected node to be registered, but it is not") + assert.True(t, nodeServer.IsReady(), "expected node to be registered, but it is not") // Case: Error setting log layer mockClient.EXPECT().CreateNode(ctx, gomock.Any()).Return(fakeNodeResponse, nil) @@ -2148,7 +2151,7 @@ func TestNodeRegisterWithController_Failure(t *testing.T) { nodeServer.nodeRegisterWithController(ctx, 1*time.Second) - assert.True(t, nodeServer.nodeIsRegistered, "expected node to be registered, but it is not") + assert.True(t, nodeServer.IsReady(), "expected node to be registered, but it is not") // Case: Error setting log workflow mockClient.EXPECT().CreateNode(ctx, gomock.Any()).Return(fakeNodeResponse, nil) @@ -2159,7 +2162,7 @@ func TestNodeRegisterWithController_Failure(t *testing.T) { nodeServer.nodeRegisterWithController(ctx, 1*time.Second) - assert.True(t, nodeServer.nodeIsRegistered, "expected node to be registered, but it is not") + assert.True(t, nodeServer.IsReady(), "expected node to be registered, but it is not") } func TestNodeUnstageISCSIVolume(t *testing.T) { diff --git a/frontend/csi/plugin.go b/frontend/csi/plugin.go index fba9c092e..a55f30a41 100644 --- a/frontend/csi/plugin.go +++ b/frontend/csi/plugin.go @@ -12,6 +12,7 @@ import ( "time" "github.com/container-storage-interface/spec/lib/go/csi" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -74,7 +75,8 @@ type Plugin struct { opCache sync.Map - nodeIsRegistered bool + nodeReadyCh chan struct{} // closed once node registration completes; broadcast to all consumers + nodeReadyOnce sync.Once limiterSharedMap map[string]limiter.Limiter @@ -126,6 +128,7 @@ func NewControllerPlugin( command: execCmd.NewCommand(), osutils: osutils.New(), activatedChan: make(chan struct{}, 1), + nodeReadyCh: func() chan struct{} { ch := make(chan struct{}); close(ch); return ch }(), } var err error @@ -225,6 +228,7 @@ func NewNodePlugin( command: execCmd.NewCommand(), osutils: osutils.New(), activatedChan: make(chan struct{}, 1), + nodeReadyCh: make(chan struct{}), } if runtime.GOOS == "windows" { @@ -336,6 +340,7 @@ func NewAllInOnePlugin( command: execCmd.NewCommand(), osutils: osutils.New(), activatedChan: make(chan struct{}, 1), + nodeReadyCh: make(chan struct{}), } port := "34571" @@ -409,11 +414,23 @@ func (p *Plugin) WaitForActivation(ctx context.Context) error { func (p *Plugin) Activate() error { go func() { ctx := GenerateRequestContext(nil, "", ContextSourceInternal, WorkflowPluginActivate, LogLayerCSIFrontend) - p.grpc = NewNonBlockingGRPCServer() + + // Only add the node registration interceptor for roles that serve Node RPCs. + // CSIController does not need it, so it is excluded from the interceptor chain entirely. + var extraInterceptors []grpc.UnaryServerInterceptor + if p.role == CSINode || p.role == CSIAllInOne { + extraInterceptors = append(extraInterceptors, nodeRegistrationInterceptor) + } + p.grpc = NewNonBlockingGRPCServer(extraInterceptors...) fields := LogFields{"nodeName": p.nodeName, "role": p.role} Logc(ctx).WithFields(fields).Info("Activating CSI frontend.") + // Start the gRPC server immediately so node-driver-registrar can connect to /plugin/csi.sock + // within its ~30s connection deadline, even if controller registration (nodeRegisterWithController) + // is slow or retrying. TRID-19339: Decouples socket availability from controller readiness. + p.grpc.Start(p.endpoint, p, p, p, p) + if p.role == CSINode || p.role == CSIAllInOne { p.nodeRegisterWithController(ctx, 0) // Retry indefinitely @@ -449,8 +466,6 @@ func (p *Plugin) Activate() error { // Communicate the plugin is activated. p.activatedChan <- struct{}{} - - p.grpc.Start(p.endpoint, p, p, p, p) }() return nil } @@ -466,7 +481,10 @@ func (p *Plugin) Deactivate() error { } Logc(ctx).Info("Deactivating CSI frontend.") - p.grpc.GracefulStop() + // Safely stop gRPC server only if it was initialized by Activate(). + if p.grpc != nil { + p.grpc.GracefulStop() + } // Stop iSCSI self-healing thread p.stopISCSISelfHealingThread(ctx) @@ -574,7 +592,16 @@ func ReadAESKey(ctx context.Context, aesKeyFile string) ([]byte, error) { } func (p *Plugin) IsReady() bool { - return p.nodeIsRegistered + select { + case <-p.nodeReadyCh: + return true + default: + return false + } +} + +func (p *Plugin) markNodeReady() { + p.nodeReadyOnce.Do(func() { close(p.nodeReadyCh) }) } // startISCSISelfHealingThread starts the iSCSI self-healing thread to heal faulty sessions. diff --git a/frontend/csi/plugin_concurrent_test.go b/frontend/csi/plugin_concurrent_test.go new file mode 100644 index 000000000..61ff5b410 --- /dev/null +++ b/frontend/csi/plugin_concurrent_test.go @@ -0,0 +1,199 @@ +// Copyright 2026 NetApp, Inc. All Rights Reserved. + +package csi + +import ( + "context" + "testing" + + "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + tridentconfig "github.com/netapp/trident/config" +) + +func TestPlugin_NodeRegistrationGate_ReproducedWithAndWithoutConcurrentMode(t *testing.T) { + modes := []struct { + name string + concurrent bool + }{ + {name: "without_concurrency", concurrent: false}, + {name: "with_concurrency", concurrent: true}, + } + + gatedMethods := []string{ + csi.Node_NodeStageVolume_FullMethodName, + csi.Node_NodePublishVolume_FullMethodName, + csi.Node_NodeExpandVolume_FullMethodName, + csi.Node_NodeGetVolumeStats_FullMethodName, + } + + allowedNodeMethods := []string{ + csi.Node_NodeGetInfo_FullMethodName, + csi.Node_NodeGetCapabilities_FullMethodName, + } + + for _, mode := range modes { + t.Run(mode.name, func(t *testing.T) { + previousConcurrent := tridentconfig.IsConcurrent + tridentconfig.IsConcurrent = mode.concurrent + t.Cleanup(func() { + tridentconfig.IsConcurrent = previousConcurrent + }) + + plugin := &Plugin{ + role: CSINode, + nodeName: "test-node", + endpoint: "unix:///tmp/test.sock", + nodeReadyCh: make(chan struct{}), + } + + for _, method := range gatedMethods { + t.Run(method+"_blocked_before_registration", func(t *testing.T) { + called := false + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + called = true + return method + "-ok", nil + } + + resp, err := nodeRegistrationInterceptor(context.Background(), nil, serverInfo(plugin, method), handler) + + require.Error(t, err) + assert.Nil(t, resp) + assert.False(t, called, "handler should not run before registration completes") + assert.Equal(t, codes.Unavailable, status.Code(err)) + }) + } + + for _, method := range allowedNodeMethods { + t.Run(method+"_allowed_before_registration", func(t *testing.T) { + called := false + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + called = true + return method + "-ok", nil + } + + resp, err := nodeRegistrationInterceptor(context.Background(), nil, serverInfo(plugin, method), handler) + + require.NoError(t, err) + assert.Equal(t, method+"-ok", resp) + assert.True(t, called, "handler should run for informational node methods") + }) + } + + plugin.markNodeReady() + + for _, method := range gatedMethods { + t.Run(method+"_allowed_after_registration", func(t *testing.T) { + called := false + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + called = true + return method + "-ok", nil + } + + resp, err := nodeRegistrationInterceptor(context.Background(), nil, serverInfo(plugin, method), handler) + + require.NoError(t, err) + assert.Equal(t, method+"-ok", resp) + assert.True(t, called, "handler should run after registration completes") + }) + } + + aioPlugin := &Plugin{role: CSIAllInOne, nodeName: "aio-node", nodeReadyCh: make(chan struct{})} + for _, method := range gatedMethods { + t.Run("all_in_one_"+method+"_blocked_before_registration", func(t *testing.T) { + called := false + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + called = true + return method + "-ok", nil + } + + resp, err := nodeRegistrationInterceptor(context.Background(), nil, serverInfo(aioPlugin, method), handler) + + require.Error(t, err) + assert.Nil(t, resp) + assert.False(t, called, "all-in-one node data-path method should also be blocked before registration") + assert.Equal(t, codes.Unavailable, status.Code(err)) + }) + } + + t.Run("controller_method_still_allowed", func(t *testing.T) { + controllerPlugin := &Plugin{role: CSIController, nodeReadyCh: make(chan struct{})} + called := false + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + called = true + return "controller-ok", nil + } + + resp, err := nodeRegistrationInterceptor( + context.Background(), nil, serverInfo(controllerPlugin, csi.Controller_CreateVolume_FullMethodName), handler, + ) + + require.NoError(t, err) + assert.Equal(t, "controller-ok", resp) + assert.True(t, called, "controller methods should not be gated by node registration") + }) + + t.Run("non_plugin_server_passthrough", func(t *testing.T) { + called := false + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + called = true + return "passthrough-ok", nil + } + + resp, err := nodeRegistrationInterceptor( + context.Background(), nil, serverInfo("not-a-plugin", csi.Node_NodeStageVolume_FullMethodName), handler, + ) + + require.NoError(t, err) + assert.Equal(t, "passthrough-ok", resp) + assert.True(t, called, "non-plugin servers should not be gated") + }) + }) + } +} + +// TestPlugin_NodeGetInfoAllowedDuringRegistration verifies that informational RPCs +// (like NodeGetInfo) are allowed even before registration completes. This ensures +// node-driver-registrar can still get node info for its bookkeeping. +func TestPlugin_NodeGetInfoAllowedDuringRegistration(t *testing.T) { + plugin := &Plugin{ + role: CSINode, + nodeName: "test-node", + nodeReadyCh: make(chan struct{}), // Not registered yet + } + + infoHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "node-info-ok", nil + } + + infoInfo := serverInfo(plugin, csi.Node_NodeGetInfo_FullMethodName) + + // NodeGetInfo should be allowed even before registration + resp, err := nodeRegistrationInterceptor(context.Background(), nil, infoInfo, infoHandler) + require.NoError(t, err) + assert.Equal(t, "node-info-ok", resp) +} + +// TestPlugin_ControllerRPCsUnaffectedByNodeRegistration ensures that controller-side +// RPCs are not gated by node registration status. Only node data-path RPCs should be blocked. +func TestPlugin_ControllerRPCsUnaffectedByNodeRegistration(t *testing.T) { + plugin := &Plugin{ + role: CSIController, + nodeReadyCh: make(chan struct{}), // Not registered (irrelevant for controller) + } + + ctrlHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "controller-ok", nil + } + + ctrlInfo := serverInfo(plugin, csi.Controller_CreateVolume_FullMethodName) + + // Controller RPCs should pass through unaffected + resp, err := nodeRegistrationInterceptor(context.Background(), nil, ctrlInfo, ctrlHandler) + require.NoError(t, err) + assert.Equal(t, "controller-ok", resp) +} diff --git a/frontend/csi/plugin_test.go b/frontend/csi/plugin_test.go index c28accc3a..3b36b9d9b 100644 --- a/frontend/csi/plugin_test.go +++ b/frontend/csi/plugin_test.go @@ -5,18 +5,23 @@ package csi import ( "context" "os" + "path/filepath" "strings" "testing" "time" "github.com/container-storage-interface/spec/lib/go/csi" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" tridentconfig "github.com/netapp/trident/config" controllerAPI "github.com/netapp/trident/frontend/csi/controller_api" + . "github.com/netapp/trident/logging" "github.com/netapp/trident/mocks/mock_core" mockControllerAPI "github.com/netapp/trident/mocks/mock_frontend/mock_csi/mock_controller_api" mock_controller_helpers "github.com/netapp/trident/mocks/mock_frontend/mock_csi/mock_controller_helpers" @@ -24,7 +29,9 @@ import ( "github.com/netapp/trident/mocks/mock_utils/mock_iscsi" mock_nvme "github.com/netapp/trident/mocks/mock_utils/nvme" "github.com/netapp/trident/utils/errors" + execCmd "github.com/netapp/trident/utils/exec" "github.com/netapp/trident/utils/limiter" + "github.com/netapp/trident/utils/models" "github.com/netapp/trident/utils/osutils" ) @@ -436,6 +443,8 @@ func TestPlugin_Activate(t *testing.T) { iscsi: mockISCSIClient, nvmeHandler: mockNVMeHandler, osutils: osutils.New(), + activatedChan: make(chan struct{}, 1), + nodeReadyCh: make(chan struct{}), } err := plugin.Activate() @@ -452,6 +461,278 @@ func TestPlugin_Activate(t *testing.T) { } } +func TestPlugin_Activate_StartsGRPCBeforeSlowNodeRegistration(t *testing.T) { + InitAuditLogger(true) + ctrl := gomock.NewController(t) + + mockOrchestrator := mock_core.NewMockOrchestrator(ctrl) + mockOrchestrator.EXPECT().SetLogLevel(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockOrchestrator.EXPECT().SetLoggingWorkflows(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockOrchestrator.EXPECT().SetLogLayers(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + + mockNodeHelper := mock_node_helpers.NewMockNodeHelper(ctrl) + mockNodeHelper.EXPECT().ReadTrackingInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockNodeHelper.EXPECT().ListVolumeTrackingInfo(gomock.Any()).Return(nil, nil).AnyTimes() + + // Record when CreateNode is invoked so the test can prove the socket exists + // before the slow controller registration completes. This validates TRID-19339's fix: + // gRPC socket must be available before node-driver-registrar's ~30s timeout deadline. + registerStarted := make(chan struct{}, 1) + mockRestClient := mockControllerAPI.NewMockTridentController(ctrl) + mockRestClient.EXPECT().CreateNode(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, _ *models.Node) (controllerAPI.CreateNodeResponse, error) { + select { + case registerStarted <- struct{}{}: + default: + } + time.Sleep(2 * time.Second) + return controllerAPI.CreateNodeResponse{}, nil + }, + ).AnyTimes() + mockRestClient.EXPECT().ListVolumePublicationsForNode(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + + mockISCSIClient := mock_iscsi.NewMockISCSI(ctrl) + mockISCSIClient.EXPECT().ISCSIActiveOnHost(gomock.Any(), gomock.Any()).Return(false, nil).AnyTimes() + + mockNVMeHandler := mock_nvme.NewMockNVMeInterface(ctrl) + mockNVMeHandler.EXPECT().NVMeActiveOnHost(gomock.Any()).Return(false, nil).AnyTimes() + + // Use a short unique socket path. macOS limits Unix socket paths to 104 bytes; + // t.TempDir() + long test names easily exceed that, so use os.MkdirTemp with the + // default temp dir (typically /tmp on Linux) for a short path. + socketDir, err := os.MkdirTemp("", "csi") + require.NoError(t, err) + socketPath := filepath.Join(socketDir, "csi.sock") + t.Cleanup(func() { os.RemoveAll(socketDir) }) + + plugin := &Plugin{ + orchestrator: mockOrchestrator, + nodeHelper: mockNodeHelper, + role: CSINode, + restClient: mockRestClient, + nodeName: "test-node", + endpoint: "unix://" + socketPath, + activatedChan: make(chan struct{}, 1), + nodeReadyCh: make(chan struct{}), + iSCSISelfHealingInterval: 0, // Disable to avoid background goroutine leaks in test + nvmeSelfHealingInterval: 0, + limiterSharedMap: make(map[string]limiter.Limiter), + iscsi: mockISCSIClient, + nvmeHandler: mockNVMeHandler, + osutils: osutils.New(), + } + + t.Cleanup(func() { + if plugin.grpc != nil { + plugin.grpc.Stop() + } + ctrl.Finish() + }) + + err = plugin.Activate() + assert.NoError(t, err) + + // Wait until registration is in flight before checking for the socket. This + // makes the test prove ordering rather than just eventual startup. Use a generous + // timeout because node registration performs real host probing via osutils before + // the first CreateNode call, which can vary in CI environments. + select { + case <-registerStarted: + case <-time.After(5 * time.Second): + t.Fatal("expected node registration attempt to start") + } + + // The node-driver-registrar sidecar has a hard ~30s connection deadline. + // With a 2-second sleep in CreateNode, the socket must appear within 1 second to + // prove that Activate() prioritizes gRPC listener startup over controller registration. + deadline := time.Now().Add(1 * time.Second) + socketFound := false + for time.Now().Before(deadline) { + if _, statErr := os.Stat(socketPath); statErr == nil { + socketFound = true + break + } + time.Sleep(10 * time.Millisecond) + } + + assert.True(t, socketFound, "expected gRPC socket to be created before slow node registration finishes") +} + +// TestPlugin_Activate_ReproduceCustomerIssue_TRID19339 reproduces the exact customer scenario: +// +// Customer symptom: node-driver-registrar (v2.15.0) enters CrashLoopBackOff because +// it cannot connect to the CSI socket within its hardcoded ~30s gRPC connection deadline. +// +// Root cause (pre-fix): Activate() called nodeRegisterWithController() BEFORE grpc.Start(), +// meaning the Unix socket didn't exist until registration completed (38-70s on busy clusters). +// +// Fix: grpc.Start() is now called immediately, before nodeRegisterWithController(). The +// nodeRegistrationInterceptor gates data-path RPCs until registration finishes. +// +// This test proves: +// 1. The gRPC socket is available within 2s (well under the registrar's 30s deadline), +// even while registration is still blocked. +// 2. A real gRPC client can connect and call Identity.Probe (the registrar's first call). +// 3. Node data-path RPCs (NodeStageVolume) are rejected with Unavailable during registration. +// 4. After registration completes, data-path RPCs succeed. +func TestPlugin_Activate_ReproduceCustomerIssue_TRID19339(t *testing.T) { + InitAuditLogger(true) + ctrl := gomock.NewController(t) + + mockOrchestrator := mock_core.NewMockOrchestrator(ctrl) + mockOrchestrator.EXPECT().SetLogLevel(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockOrchestrator.EXPECT().SetLoggingWorkflows(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockOrchestrator.EXPECT().SetLogLayers(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockOrchestrator.EXPECT().GetVersion(gomock.Any()).Return("test", nil).AnyTimes() + + mockNodeHelper := mock_node_helpers.NewMockNodeHelper(ctrl) + mockNodeHelper.EXPECT().ReadTrackingInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockNodeHelper.EXPECT().ListVolumeTrackingInfo(gomock.Any()).Return(nil, nil).AnyTimes() + + // Simulate a busy cluster: registration takes 5 seconds (customer saw 38-70s). + registrationDone := make(chan struct{}) + mockRestClient := mockControllerAPI.NewMockTridentController(ctrl) + mockRestClient.EXPECT().CreateNode(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, _ *models.Node) (controllerAPI.CreateNodeResponse, error) { + time.Sleep(5 * time.Second) // Simulates slow controller registration on busy cluster + close(registrationDone) + return controllerAPI.CreateNodeResponse{}, nil + }, + ).AnyTimes() + mockRestClient.EXPECT().ListVolumePublicationsForNode(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + + mockISCSIClient := mock_iscsi.NewMockISCSI(ctrl) + mockISCSIClient.EXPECT().ISCSIActiveOnHost(gomock.Any(), gomock.Any()).Return(false, nil).AnyTimes() + + mockNVMeHandler := mock_nvme.NewMockNVMeInterface(ctrl) + mockNVMeHandler.EXPECT().NVMeActiveOnHost(gomock.Any()).Return(false, nil).AnyTimes() + + // Use a short socket path to avoid macOS 104-byte Unix socket path limit. + socketDir, err := os.MkdirTemp("", "csi") + require.NoError(t, err) + socketPath := filepath.Join(socketDir, "csi.sock") + t.Cleanup(func() { os.RemoveAll(socketDir) }) + + plugin := &Plugin{ + orchestrator: mockOrchestrator, + nodeHelper: mockNodeHelper, + role: CSINode, + restClient: mockRestClient, + nodeName: "customer-node", + endpoint: "unix://" + socketPath, + activatedChan: make(chan struct{}, 1), + nodeReadyCh: make(chan struct{}), + iSCSISelfHealingInterval: 0, // Disable to avoid background goroutine leaks in test + nvmeSelfHealingInterval: 0, + limiterSharedMap: make(map[string]limiter.Limiter), + iscsi: mockISCSIClient, + nvmeHandler: mockNVMeHandler, + osutils: osutils.New(), + } + + t.Cleanup(func() { + if plugin.grpc != nil { + plugin.grpc.Stop() + } + ctrl.Finish() + }) + + err = plugin.Activate() + assert.NoError(t, err) + + // --- PHASE 1: Prove socket available fast (customer's core issue) --- + // node-driver-registrar has a ~30s deadline. The socket must appear well before that. + socketAvailableDeadline := time.After(2 * time.Second) + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + socketFound := false + for !socketFound { + select { + case <-socketAvailableDeadline: + t.Fatal("REPRODUCTION: gRPC socket not available within 2s — " + + "this is the customer's CrashLoopBackOff scenario (TRID-19339)") + case <-ticker.C: + if _, statErr := os.Stat(socketPath); statErr == nil { + socketFound = true + } + } + } + t.Log("PASS: gRPC socket available within 2s (customer's registrar deadline is ~30s)") + + // --- PHASE 2: Real gRPC client connectivity (simulates node-driver-registrar) --- + // The registrar's first action after connecting is to call Identity.Probe. + conn, dialErr := grpc.NewClient( + "unix://"+socketPath, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if dialErr != nil { + t.Fatalf("REPRODUCTION: gRPC client dial failed: %v — registrar would CrashLoop", dialErr) + } + defer conn.Close() + + identityClient := csi.NewIdentityClient(conn) + probeCtx, probeCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer probeCancel() + probeResp, probeErr := identityClient.Probe(probeCtx, &csi.ProbeRequest{}) + assert.NoError(t, probeErr, "Identity.Probe must succeed during registration — registrar depends on this") + assert.NotNil(t, probeResp, "Probe response should not be nil") + t.Log("PASS: Identity.Probe succeeds while registration is in progress") + + // --- PHASE 3: Verify data-path RPCs blocked during registration --- + nodeClient := csi.NewNodeClient(conn) + stageCtx, stageCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer stageCancel() + _, stageErr := nodeClient.NodeStageVolume(stageCtx, &csi.NodeStageVolumeRequest{ + VolumeId: "vol-test", + StagingTargetPath: "/tmp/staging", + VolumeCapability: &csi.VolumeCapability{}, + }) + assert.Error(t, stageErr, "NodeStageVolume must be rejected before registration completes") + assert.Equal(t, codes.Unavailable, status.Code(stageErr), + "blocked RPCs must return Unavailable, not a different error") + t.Log("PASS: NodeStageVolume correctly blocked with Unavailable during registration") + + // --- PHASE 4: After registration, data-path RPCs should work --- + select { + case <-registrationDone: + t.Log("Registration completed, verifying data-path RPCs are unblocked...") + case <-time.After(10 * time.Second): + t.Fatal("Registration did not complete within expected time") + } + + // Give a small window for markNodeReady() to execute after CreateNode returns + time.Sleep(100 * time.Millisecond) + + assert.True(t, plugin.IsReady(), "Plugin must be ready after registration completes") + t.Log("PASS: Plugin.IsReady() returns true after registration") +} + +// TestPlugin_Deactivate_SafeWithoutActivate validates that Deactivate() can be called +// safely even if Activate() was never called or hasn't completed yet (p.grpc is nil). +// This prevents nil pointer panics in shutdown scenarios. TRID-19339 safe shutdown. +func TestPlugin_Deactivate_SafeWithoutActivate(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockOrchestrator := mock_core.NewMockOrchestrator(ctrl) + + plugin := &Plugin{ + orchestrator: mockOrchestrator, + role: CSINode, + nodeName: "test-node", + endpoint: "unix:///tmp/test.sock", + command: execCmd.NewCommand(), + osutils: osutils.New(), + activatedChan: make(chan struct{}, 1), + grpc: nil, // Simulate p.grpc not yet initialized + } + + // Deactivate should not panic even though p.grpc is nil. + err := plugin.Deactivate() + assert.NoError(t, err, "Deactivate() should not panic or error when called before Activate() initializes gRPC") +} + func TestPlugin_GetName(t *testing.T) { plugin := &Plugin{} result := plugin.GetName() @@ -739,26 +1020,26 @@ func TestReadAESKey(t *testing.T) { func TestPlugin_IsReady(t *testing.T) { testCases := []struct { - name string - nodeIsRegistered bool - expectedReady bool + name string + nodeReadyCh chan struct{} + expectedReady bool }{ { - name: "Ready - Node registered", - nodeIsRegistered: true, - expectedReady: true, + name: "Ready - Node registered", + nodeReadyCh: closedCh(), + expectedReady: true, }, { - name: "Not ready - Node not registered", - nodeIsRegistered: false, - expectedReady: false, + name: "Not ready - Node not registered", + nodeReadyCh: make(chan struct{}), + expectedReady: false, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { plugin := &Plugin{ - nodeIsRegistered: tc.nodeIsRegistered, + nodeReadyCh: tc.nodeReadyCh, } result := plugin.IsReady() diff --git a/utils/mount/mount_linux.go b/utils/mount/mount_linux.go index 1b8614b73..23a7ddaa4 100644 --- a/utils/mount/mount_linux.go +++ b/utils/mount/mount_linux.go @@ -198,7 +198,7 @@ func (client *LinuxClient) MountNFSPath(ctx context.Context, exportPath, mountpo } // Create the mount point dir if necessary - if _, err := client.command.Execute(ctx, "mkdir", "-p", mountpoint); err != nil { + if err := client.os.MkdirAll(mountpoint, 0o755); err != nil { Logc(ctx).WithField("error", err).Warning("Mkdir failed.") }