Skip to content

Expert Parallelism: common C API + NCCL EP backend#3127

Open
timmoon10 wants to merge 1 commit into
NVIDIA:mainfrom
timmoon10:phuong/ep-2-commwindow
Open

Expert Parallelism: common C API + NCCL EP backend#3127
timmoon10 wants to merge 1 commit into
NVIDIA:mainfrom
timmoon10:phuong/ep-2-commwindow

Conversation

@timmoon10

@timmoon10 timmoon10 commented Jun 14, 2026

Copy link
Copy Markdown
Member

This is a resubmission of #3034, which was reverted in #3126 due to CI issues.

Implementation is from @phu0ngng and it is already approved by @ptrendx and @timmoon10.

Co-authored-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 requested a review from phu0ngng June 14, 2026 05:40
@timmoon10 timmoon10 requested a review from ptrendx as a code owner June 14, 2026 05:40
@timmoon10 timmoon10 added enhancement New feature or request MoE 2.17 labels Jun 14, 2026
@greptile-apps

greptile-apps Bot commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces the Expert Parallelism (EP) C API and its NCCL EP backend for Hopper+ GPUs (SM≥90), adding nvte_ep_prepare, nvte_ep_dispatch, nvte_ep_combine, and their backward counterparts as allocation-free, CUDA-graph-capturable operations backed by a statically-linked libnccl_ep.a.

  • Core backend (ep_backend.cpp/h): A process-wide singleton (EPBackend) wraps an ncclEpGroup_t, caches ncclEpHandle_t objects in a pointer-keyed LRU, and gates all NCCL EP calls behind a runtime NCCL ≥ 2.30.4 version check; a fallback_layer_cfg_ WAR for XLA buffer relocation is included but currently blocks models that mix top_k values across EP layers.
  • Build integration (setup.py, CMakeLists.txt): Adds the 3rdparty/nccl submodule, builds libnccl_ep.a for the required CUDA arch(s), and auto-skips the feature on pre-Hopper targets; RTLD_LAZY is added to the core library dlopen to preserve lazy symbol resolution under environments that set LD_BIND_NOW.
  • Tests (test_ep.cu, run_test_ep.sh): New MPI-distributed GTest suite covering dispatch/combine forward and backward passes with exact value checks and a zero-copy symmetric-memory path; CI script refactored to independent per-target builds with per-rank XML output.

Confidence Score: 3/5

The core communication path is functional for the common single-top_k case, but the global fallback_layer_cfg_ in the LRU cache will cause hard failures for any model that uses EP layers with differing top_k values — a realistic pattern in heterogeneous MoE designs.

The prepare_handle_locked() function stores a process-wide fallback_layer_cfg_ and asserts every subsequent cache-miss call must match the first layer's top_k and alignment. In a transformer with two MoE layers using different top_k values, the second layer's first nvte_ep_prepare will always throw, making the feature non-functional for heterogeneous models. The forward declaration of NCCL's internal struct in the public comm_window.h header is a latent ABI fragility, and the pre-lock read of group_config_ in dispatch() is a theoretical data race with concurrent shutdown.

transformer_engine/common/ep/ep_backend.cpp (fallback_layer_cfg_ process-wide constraint and pre-lock group_config_ read), transformer_engine/common/include/transformer_engine/comm_window.h (NCCL internal struct forward declaration)

Important Files Changed

Filename Overview
transformer_engine/common/ep/ep_backend.cpp Core NCCL EP backend implementation with a singleton LRU handle cache; contains a P1 design flaw where a single process-wide fallback_layer_cfg_ causes NVTE_CHECK failures for models with multiple EP layers using different top_k values, plus a pre-lock read of group_config_ that creates a data race with concurrent shutdown.
transformer_engine/common/ep/ep_backend.h Singleton EPBackend class definition; mixes std::atomic initialized_ (used lock-free in get()) with mutex-guarded state; the design is consistent for its stated access patterns.
transformer_engine/common/ep/ep_api.cpp Thin C API wrappers that delegate to EPBackend; includes clean no-op / throwing stubs for non-NCCL-EP builds; straightforward and correct.
transformer_engine/common/include/transformer_engine/ep.h Public C API header for Expert Parallelism; well-documented with Doxygen; the re-init-after-shutdown contract stated here is contradicted by the error message in ep_backend.cpp's initialize().
transformer_engine/common/include/transformer_engine/comm_window.h Introduces NVTECommWindow for zero-copy symmem paths; forward-declares NCCL's internal struct ncclWindow_vidmem, creating a fragile ABI coupling that could break if NCCL renames the struct in a future release.
setup.py Adds NCCL EP submodule build logic and arch-gating; the stamp-file mechanism for rebuild detection is not atomic against parallel invocations, though this is a minor concern in most single-build environments.
transformer_engine/common/init.py Adds RTLD_LAZY to the core library dlopen flags to prevent LD_BIND_NOW environments from resolving NCCL EP symbols before the version gate is checked; change is correct and intentional.
tests/cpp_distributed/test_ep.cu Comprehensive MPI-distributed GTest suite covering dispatch/combine forward and backward passes with exact value checks and a zero-copy symmetric-memory path; cudaMemcpy return values are not checked in a few helper functions but this is test-only code.
transformer_engine/common/CMakeLists.txt Adds NCCL EP static-link block; --whole-archive wrapping of libnccl_ep.a is correct; comment about LD_BIND_NOW behaviour is accurate but relies on the RTLD_LAZY change in init.py to be effective from Python.

Sequence Diagram

sequenceDiagram
    participant User as Framework (PyTorch/JAX)
    participant API as ep_api.cpp (C API)
    participant BE as EPBackend (singleton)
    participant Cache as LRU Handle Cache
    participant NCCL as ncclEp* (libnccl_ep.a)

    User->>API: nvte_ep_initialize(ep_comm, group_cfg)
    API->>BE: EPBackend::initialize()
    BE->>NCCL: ncclGetVersion() ≥ 2.30.4 check
    BE->>NCCL: ncclEpCreateGroup(ep_group_, ep_comm, cfg)
    BE-->>User: initialized

    loop Each training step
        User->>API: nvte_ep_prepare(handle_mem, topk_idx, ...)
        BE->>Cache: prepare_handle_locked()
        Cache->>NCCL: ncclEpInitHandle() [on miss]
        BE->>NCCL: ncclEpUpdateHandle() [AllGather routing]
        User->>API: nvte_ep_dispatch(...)
        BE->>NCCL: ncclEpDispatch()
        User->>API: nvte_ep_combine(...)
        BE->>NCCL: ncclEpCombine()
        User->>API: nvte_ep_combine_bwd + nvte_ep_dispatch_bwd
        BE->>NCCL: ncclEpDispatch(bwd) + ncclEpCombine(bwd)
    end

    User->>API: nvte_ep_shutdown()
    BE->>NCCL: ncclEpGroupDestroy(ep_group_)
Loading

Comments Outside Diff (2)

  1. transformer_engine/common/ep/ep_backend.cpp, line 2009-2019 (link)

    P1 Global fallback config blocks multi-layer heterogeneous models

    fallback_layer_cfg_ is a process-wide singleton that asserts every new prepare_handle_locked() call (i.e., every first-call-per-layer cache miss) uses the same top_k and alignment as the very first call. For a model with multiple EP transformer layers that happen to use different top_k values (or different per-expert alignments), the second layer's initial nvte_ep_prepare will trip the NVTE_CHECK — even when XLA buffer relocation is not involved and the pointers are stable. The fallback-cfg check fires before the pointer-cache lookup, so it runs on every new handle_mem pointer, not just on XLA relocation events.

    Concrete failure: a model that has top_k=2 on MoE layer 0 and top_k=4 on MoE layer 1 will succeed on layer 0's first step but throw on layer 1's first step with "EP prepare top_k=4 disagrees with process-wide cached top_k=2".

  2. transformer_engine/common/ep/ep_backend.cpp, line 2110-2120 (link)

    P2 group_config_ read without mutex before lock acquisition

    group_config_.max_token_dtype is read at lines 2111 and 2116 before the std::lock_guard is acquired at line 2157. If shutdown() is called concurrently on another thread (which zeroes out the backend state under the mutex), this is an unsynchronised read of a non-atomic struct field — technically a data race under the C++ memory model. The initialized_ atomic guards the post-lock code paths, but there is no corresponding guard for group_config_ in the pre-lock section.

Reviews (1): Last reviewed commit: "Expert Parallelism: common C API + NCCL ..." | Re-trigger Greptile

void EPBackend::initialize(ncclComm_t ep_comm, NVTEEpGroupConfig config) {
EPBackend& inst = instance();
std::lock_guard<std::mutex> lock(inst.mutex_);
NVTE_CHECK(!inst.initialized_, "EP already initialized. Call initialize only once per process.");

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The error message "Call initialize only once per process" contradicts the public API contract documented in ep.h ("Re-init after shutdown is allowed; double-init throws"). A user who calls shutdown() and then tries to re-init (which is explicitly documented as valid) would see a misleading error if they accidentally double-initialized without an intervening shutdown, potentially making the failure mode harder to diagnose.

Suggested change
NVTE_CHECK(!inst.initialized_, "EP already initialized. Call initialize only once per process.");
NVTE_CHECK(!inst.initialized_,
"EP already initialized. Call nvte_ep_shutdown() before re-initializing.");

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +23 to +27
struct ncclWindow_vidmem;

/*! \brief NCCL window plus byte offset for a zero-copy payload tensor. */
typedef struct {
struct ncclWindow_vidmem* window; /*!< NCCL window, or NULL to use the raw data pointer. */

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Forward-declaring an NCCL internal struct creates a fragile ABI dependency

struct ncclWindow_vidmem is not part of NCCL's stable public ABI — it is an implementation detail that happens to back ncclWindow_t. If NCCL ever renames the underlying struct (e.g., in a major refactor), this forward declaration will conflict with <nccl.h> in any translation unit that includes both headers, producing a type-mismatch error. A more stable approach is to declare an incomplete opaque struct that is explicitly TE-owned (e.g., struct NVTEWindowOpaque;) and cast to/from ncclWindow_t inside ep_backend.cpp where <nccl.h> is already included.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment thread setup.py
Comment on lines +229 to +248
arch_list: list[str] = []
for t in arch_tokens:
if t.lower() == "native":
try:
out = subprocess.check_output(
["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"],
stderr=subprocess.DEVNULL,
).decode()
except (subprocess.CalledProcessError, FileNotFoundError) as e:
raise RuntimeError(
"NVTE_CUDA_ARCHS=native requires nvidia-smi to resolve the host arch."
) from e
for line in out.splitlines():
cap = line.strip().replace(".", "")
if cap.isdigit() and int(cap) >= 90 and cap not in arch_list:
arch_list.append(cap)
else:
bare = t.rstrip("af")
if bare.isdigit() and int(bare) >= 90 and bare not in arch_list:
arch_list.append(bare)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Stamp-file increment is not atomic: parallel builds can race

prev_gencode is checked, the make build is launched, and the stamp is written in three distinct steps with no locking between them. In a distributed training environment where each rank's setup script (or a pip install in a parallel job) invokes build_nccl_ep_submodule() simultaneously, two processes can both pass the stale-stamp check, both launch make -j, and one can overwrite the stamp while the other is still building, leaving a mismatched stamp if the builds produce different outputs. A file lock (e.g., fcntl.flock around the whole check-build-stamp block) or a build-level lock would make this robust.

@timmoon10

timmoon10 commented Jun 14, 2026

Copy link
Copy Markdown
Member Author

Pipeline 54707935

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.17 enhancement New feature or request MoE

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant