Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions frontend/csi/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,17 @@ type NonBlockingGRPCServer interface {
Stop()
}

func NewNonBlockingGRPCServer() NonBlockingGRPCServer {
return &nonBlockingGRPCServer{}
// NewNonBlockingGRPCServer creates a gRPC server with optional extra interceptors inserted between the timeout
// and metrics interceptors. For CSINode and CSIAllInOne roles the caller should pass nodeRegistrationInterceptor;
// for CSIController it should be omitted so the interceptor is never in the chain.
func NewNonBlockingGRPCServer(extraInterceptors ...grpc.UnaryServerInterceptor) NonBlockingGRPCServer {
return &nonBlockingGRPCServer{extraInterceptors: extraInterceptors}
}

// NonBlocking server
type nonBlockingGRPCServer struct {
server *grpc.Server
server *grpc.Server
extraInterceptors []grpc.UnaryServerInterceptor
}

func (s *nonBlockingGRPCServer) Start(
Expand All @@ -45,11 +49,28 @@ func (s *nonBlockingGRPCServer) Start(
}

func (s *nonBlockingGRPCServer) GracefulStop() {
s.server.GracefulStop()
if s.server != nil {
s.server.GracefulStop()
}
}

func (s *nonBlockingGRPCServer) Stop() {
s.server.Stop()
if s.server != nil {
s.server.Stop()
}
}

// buildInterceptorChain constructs the gRPC unary interceptor chain. The log and timeout interceptors are always
// present. Any extra interceptors (e.g. nodeRegistrationInterceptor for Node/AllInOne roles) are inserted between
// the timeout and metrics interceptors so they only run for the roles that need them.
func (s *nonBlockingGRPCServer) buildInterceptorChain() []grpc.UnaryServerInterceptor {
chain := []grpc.UnaryServerInterceptor{
logGRPCInterceptor,
timeoutInterceptor,
}
chain = append(chain, s.extraInterceptors...)
chain = append(chain, incomingRequestMetricsInterceptor)
return chain
}

func (s *nonBlockingGRPCServer) serve(
Expand Down Expand Up @@ -107,7 +128,7 @@ func (s *nonBlockingGRPCServer) serve(
// The first interceptor is always the outermost.
// When CSI calls come in, the outermost interceptor is hit first.
// The log gRPC and timeout interceptors should always be the first in the chain.
grpc.ChainUnaryInterceptor(logGRPCInterceptor, timeoutInterceptor, incomingRequestMetricsInterceptor),
grpc.ChainUnaryInterceptor(s.buildInterceptorChain()...),
}
server := grpc.NewServer(opts...)
s.server = server
Expand Down
41 changes: 41 additions & 0 deletions frontend/csi/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (

"github.com/container-storage-interface/spec/lib/go/csi"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

. "github.com/netapp/trident/logging"
)
Expand Down Expand Up @@ -60,6 +62,20 @@ var operationRegistry = map[string]operationMeta{
csi.GroupController_DeleteVolumeGroupSnapshot_FullMethodName: {WorkflowGroupSnapshotDelete, ContextRequestClientCSISnapshotter, http.MethodDelete},
}

var nodeRegistrationAllowedMethods = map[string]struct{}{
csi.Node_NodeGetCapabilities_FullMethodName: {},
csi.Node_NodeGetInfo_FullMethodName: {},
}

func nodeMethodAllowedBeforeRegistration(fullMethod string) bool {
_, ok := nodeRegistrationAllowedMethods[fullMethod]
return ok
}

func isNodeMethod(fullMethod string) bool {
return len(fullMethod) >= len("/csi.v1.Node/") && fullMethod[:len("/csi.v1.Node/")] == "/csi.v1.Node/"
}

func incomingRequestMetricsInterceptor(
ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler,
) (resp any, err error) {
Expand Down Expand Up @@ -116,6 +132,31 @@ func timeoutInterceptor(
return handler(ctx, req)
}

// nodeRegistrationInterceptor rejects Node RPCs until node registration with the controller completes, except for
// the small allow-list required during startup probing. This keeps the gRPC socket available early for
// node-driver-registrar while remaining safe by default as new Node RPCs are added.
//
// This interceptor is only added to the gRPC chain for CSINode and CSIAllInOne roles;
// CSIController never includes it, so there is no role check here.
func nodeRegistrationInterceptor(
ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler,
) (any, error) {
plugin, ok := info.Server.(*Plugin)
if !ok {
return handler(ctx, req)
}

if plugin.IsReady() || !isNodeMethod(info.FullMethod) || nodeMethodAllowedBeforeRegistration(info.FullMethod) {
return handler(ctx, req)
}

Logc(ctx).WithFields(LogFields{
"method": info.FullMethod,
"node": plugin.nodeName,
}).Debug("Rejecting node RPC before node registration completes.")
return nil, status.Error(codes.Unavailable, "node registration with controller is still in progress")
}

// logGRPCInterceptor sets the base context, logs and audit logs all incoming gRPC requests.
// It should always be the first interceptor in the chain.
// All gRPCs, regardless of timeout, should always be logged.
Expand Down
205 changes: 202 additions & 3 deletions frontend/csi/interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

. "github.com/netapp/trident/logging"
)
Expand All @@ -28,6 +30,13 @@ func serverInfo(server interface{}, fullMethod string) *grpc.UnaryServerInfo {
return &grpc.UnaryServerInfo{Server: server, FullMethod: fullMethod}
}

// closedCh returns a pre-closed channel (plugin is ready).
func closedCh() chan struct{} {
ch := make(chan struct{})
close(ch)
return ch
}

func TestOperationRegistry_ContainsAllControllerMethods(t *testing.T) {
controllerMethods := []string{
csi.Controller_CreateVolume_FullMethodName,
Expand Down Expand Up @@ -277,6 +286,164 @@ func TestTimeoutInterceptor_PreservesExistingDeadline(t *testing.T) {
assert.Equal(t, existingDeadline, deadline)
}

func TestNodeRegistrationInterceptor_NodeDataPathBlockedUntilReady(t *testing.T) {
plugin := &Plugin{role: CSINode, nodeName: "node-a", nodeReadyCh: make(chan struct{})}
called := false
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
called = true
return "ok", nil
}
info := serverInfo(plugin, csi.Node_NodeStageVolume_FullMethodName)

resp, err := nodeRegistrationInterceptor(context.Background(), nil, info, handler)

assert.Nil(t, resp)
assert.False(t, called, "handler should not be invoked before node registration completes")
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
}

func TestNodeRegistrationInterceptor_NodeInfoAllowedBeforeReady(t *testing.T) {
plugin := &Plugin{role: CSINode, nodeName: "node-a", nodeReadyCh: make(chan struct{})}
called := false
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
called = true
return "ok", nil
}
info := serverInfo(plugin, csi.Node_NodeGetInfo_FullMethodName)

resp, err := nodeRegistrationInterceptor(context.Background(), nil, info, handler)

require.NoError(t, err)
assert.Equal(t, "ok", resp)
assert.True(t, called, "handler should be invoked for non-data-path node methods")
}

func TestNodeRegistrationInterceptor_NodeCapabilitiesAllowedBeforeReady(t *testing.T) {
plugin := &Plugin{role: CSINode, nodeName: "node-a", nodeReadyCh: make(chan struct{})}
called := false
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
called = true
return "ok", nil
}
info := serverInfo(plugin, csi.Node_NodeGetCapabilities_FullMethodName)

resp, err := nodeRegistrationInterceptor(context.Background(), nil, info, handler)

require.NoError(t, err)
assert.NotNil(t, resp)
assert.True(t, called, "handler should be invoked for startup-safe node methods")
}

func TestNodeRegistrationInterceptor_IdentityMethodsAllowedBeforeReady(t *testing.T) {
identityMethods := []string{
csi.Identity_Probe_FullMethodName,
csi.Identity_GetPluginInfo_FullMethodName,
csi.Identity_GetPluginCapabilities_FullMethodName,
}

for _, method := range identityMethods {
t.Run(method, func(t *testing.T) {
plugin := &Plugin{role: CSINode, nodeName: "node-a", nodeReadyCh: make(chan struct{})}
called := false
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
called = true
return "ok", nil
}
info := serverInfo(plugin, method)

resp, err := nodeRegistrationInterceptor(context.Background(), nil, info, handler)

require.NoError(t, err)
assert.Equal(t, "ok", resp)
assert.True(t, called, "identity methods must remain available while registration is in progress")
})
}
}

func TestNodeRegistrationInterceptor_AllInOneControllerMethodAllowedBeforeReady(t *testing.T) {
plugin := &Plugin{role: CSIAllInOne, nodeName: "node-a", nodeReadyCh: make(chan struct{})}
called := false
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
called = true
return "ok", nil
}
info := serverInfo(plugin, csi.Controller_CreateVolume_FullMethodName)

resp, err := nodeRegistrationInterceptor(context.Background(), nil, info, handler)

require.NoError(t, err)
assert.Equal(t, "ok", resp)
assert.True(t, called, "controller methods in all-in-one mode should not be gated by node registration")
}

func TestNodeRegistrationInterceptor_NodeUnstageBlockedUntilReady(t *testing.T) {
plugin := &Plugin{role: CSINode, nodeName: "node-a", nodeReadyCh: make(chan struct{})}
called := false
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
called = true
return "ok", nil
}
info := serverInfo(plugin, csi.Node_NodeUnstageVolume_FullMethodName)

resp, err := nodeRegistrationInterceptor(context.Background(), nil, info, handler)

assert.Nil(t, resp)
assert.False(t, called, "handler should not be invoked before node registration completes")
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
}

func TestNodeRegistrationInterceptor_NodeUnpublishBlockedUntilReady(t *testing.T) {
plugin := &Plugin{role: CSINode, nodeName: "node-a", nodeReadyCh: make(chan struct{})}
called := false
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
called = true
return "ok", nil
}
info := serverInfo(plugin, csi.Node_NodeUnpublishVolume_FullMethodName)

resp, err := nodeRegistrationInterceptor(context.Background(), nil, info, handler)

assert.Nil(t, resp)
assert.False(t, called, "handler should not be invoked before node registration completes")
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
}

func TestNodeRegistrationInterceptor_UnknownNodeMethodBlockedUntilReady(t *testing.T) {
plugin := &Plugin{role: CSINode, nodeName: "node-a", nodeReadyCh: make(chan struct{})}
called := false
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
called = true
return "ok", nil
}
info := serverInfo(plugin, "/csi.v1.Node/NodeFutureMethod")

resp, err := nodeRegistrationInterceptor(context.Background(), nil, info, handler)

assert.Nil(t, resp)
assert.False(t, called, "handler should not be invoked for unknown node methods before registration completes")
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
}

func TestNodeRegistrationInterceptor_NodeDataPathAllowedWhenReady(t *testing.T) {
plugin := &Plugin{role: CSINode, nodeName: "node-a", nodeReadyCh: closedCh()}
called := false
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
called = true
return "ok", nil
}
info := serverInfo(plugin, csi.Node_NodePublishVolume_FullMethodName)

resp, err := nodeRegistrationInterceptor(context.Background(), nil, info, handler)

require.NoError(t, err)
assert.Equal(t, "ok", resp)
assert.True(t, called, "handler should be invoked after node registration completes")
}

func initAuditForTest(t *testing.T) {
t.Helper()
InitAuditLogger(true)
Expand Down Expand Up @@ -332,15 +499,17 @@ func TestLogGRPCInterceptor_PropagatesError(t *testing.T) {
}

func TestChainOrder_TimeoutContextVisibleToMetricsInterceptor(t *testing.T) {
plugin := &Plugin{role: CSINode}
plugin := &Plugin{role: CSINode, nodeReadyCh: closedCh()}
var handlerCtx context.Context
handler := stubHandler(&handlerCtx, nil, nil)
info := serverInfo(plugin, csi.Node_NodeStageVolume_FullMethodName)

// Simulate the chain: timeout -> metrics -> handler
// Simulate the chain: timeout -> registration -> metrics -> handler
chained := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, h grpc.UnaryHandler) (interface{}, error) {
return timeoutInterceptor(ctx, req, info, func(ctx context.Context, req interface{}) (interface{}, error) {
return incomingRequestMetricsInterceptor(ctx, req, info, h)
return nodeRegistrationInterceptor(ctx, req, info, func(ctx context.Context, req interface{}) (interface{}, error) {
return incomingRequestMetricsInterceptor(ctx, req, info, h)
})
})
}

Expand All @@ -352,3 +521,33 @@ func TestChainOrder_TimeoutContextVisibleToMetricsInterceptor(t *testing.T) {
assert.True(t, hasDeadline, "timeout interceptor's deadline should be visible through the chain")
assert.Equal(t, WorkflowNodeStage, handlerCtx.Value(ContextKeyWorkflow))
}

func TestChainOrder_BlockedNodeCallBypassesMetricsInterceptor(t *testing.T) {
plugin := &Plugin{role: CSINode, nodeName: "node-a", nodeReadyCh: make(chan struct{})}
metricsCalled := false
handlerCalled := false
info := serverInfo(plugin, csi.Node_NodeStageVolume_FullMethodName)

handler := func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return "ok", nil
}

// Simulate the chain segment relevant to metrics behavior: timeout -> node registration -> metrics -> handler.
chained := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, h grpc.UnaryHandler) (interface{}, error) {
return timeoutInterceptor(ctx, req, info, func(ctx context.Context, req interface{}) (interface{}, error) {
return nodeRegistrationInterceptor(ctx, req, info, func(ctx context.Context, req interface{}) (interface{}, error) {
metricsCalled = true
return incomingRequestMetricsInterceptor(ctx, req, info, h)
})
})
}

resp, err := chained(context.Background(), nil, info, handler)

assert.Nil(t, resp)
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.False(t, metricsCalled, "blocked node calls are rejected before metrics interceptor")
assert.False(t, handlerCalled, "blocked node calls should not reach handler")
}
2 changes: 1 addition & 1 deletion frontend/csi/node_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ func (p *Plugin) nodeRegisterWithController(ctx context.Context, timeout time.Du
Logc(ctx).WithField("node", p.nodeName).Debug("Topology labels found for node: ", topologyLabels)
}

p.nodeIsRegistered = true
p.markNodeReady()
}

func (p *Plugin) nodeStageNFSVolume(
Expand Down
Loading