diff --git a/.gitmodules b/.gitmodules index 4b188d6bb1..495d8e3fe7 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git +[submodule "3rdparty/nccl"] + path = 3rdparty/nccl + url = https://github.com/NVIDIA/nccl.git diff --git a/3rdparty/nccl b/3rdparty/nccl new file mode 160000 index 0000000000..808d2433dd --- /dev/null +++ b/3rdparty/nccl @@ -0,0 +1 @@ +Subproject commit 808d2433dda3cccc80f8172a94a6b117359e7102 diff --git a/qa/L1_cpp_distributed/test.sh b/qa/L1_cpp_distributed/test.sh index 8d767a4efb..4122163f69 100755 --- a/qa/L1_cpp_distributed/test.sh +++ b/qa/L1_cpp_distributed/test.sh @@ -2,16 +2,52 @@ # # See LICENSE for license information. -set -e +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" # Find TE : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" + TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}') export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH if [[ $(nvidia-smi --list-gpus | wc -l) -ge 4 ]]; then - cd $TE_PATH/tests/cpp_distributed - cmake -GNinja -S. -Bbuild - cmake --build build - mpirun --allow-run-as-root --np 4 --oversubscribe ./build/test_comm_gemm + cd $TE_PATH/tests/cpp_distributed + configure_ok=1 + cmake -GNinja -S. -Bbuild || { test_fail "configure"; configure_ok=0; } + + # Build each suite independently so one suite's build failure does not + # mask the other's results. Skip mpirun when the binary is missing. + if [[ $configure_ok -eq 1 ]]; then + comm_gemm_ok=1 + ep_ok=1 + cmake --build build --target test_comm_gemm || { test_fail "test_comm_gemm_build"; comm_gemm_ok=0; } + cmake --build build --target test_ep || { test_fail "test_ep_build"; ep_ok=0; } + + if [[ $comm_gemm_ok -eq 1 ]]; then + # Per-rank XML to avoid a write race on a shared path. + mpirun --allow-run-as-root --np 4 --oversubscribe bash -c \ + "exec ./build/test_comm_gemm --gtest_output=xml:$XML_LOG_DIR/cpp_distributed_test_comm_gemm.rank\${OMPI_COMM_WORLD_RANK}.xml" \ + || test_fail "test_comm_gemm" + fi + + if [[ $ep_ok -eq 1 ]]; then + # EP suites; runner self-skips on pre-Hopper GPUs. + GTEST_XML_PREFIX="$XML_LOG_DIR/cpp_distributed_test_ep" \ + bash ./run_test_ep.sh 4 ./build || test_fail "test_ep" + fi + fi +fi + +if [ "$RET" -ne 0 ]; then + echo "FAILED sub-tests:$FAILED_CASES" fi +exit $RET diff --git a/setup.py b/setup.py index 7f6b51c148..64ed120268 100644 --- a/setup.py +++ b/setup.py @@ -84,6 +84,31 @@ def setup_common_extension() -> CMakeExtension: cusolvermp_dir = os.getenv("CUSOLVERMP_HOME", "/usr") cmake_flags.append(f"-DCUSOLVERMP_DIR={cusolvermp_dir}") + # NCCL EP (Hopper+): on by default; auto-skipped when no arch >= 90 is + # targeted. Set NVTE_WITH_NCCL_EP=0 to force off. + nccl_ep_env = os.getenv("NVTE_WITH_NCCL_EP") + nccl_ep_explicit = nccl_ep_env is not None + build_with_nccl_ep = bool(int(nccl_ep_env if nccl_ep_explicit else "1")) + if build_with_nccl_ep: + arch_tokens = [a.strip() for a in str(archs or "").split(";") if a.strip()] + has_hopper_or_newer = any( + t.lower() == "native" or (t.rstrip("af").isdigit() and int(t.rstrip("af")) >= 90) + for t in arch_tokens + ) + if not has_hopper_or_newer: + if nccl_ep_explicit: + raise RuntimeError( + f"NVTE_WITH_NCCL_EP=1 was set but NVTE_CUDA_ARCHS ('{archs}') " + "contains no arch >= 90. NCCL EP requires Hopper or newer." + ) + print(f"[NCCL EP] No arch >= 90 in NVTE_CUDA_ARCHS ('{archs}'); skipping build.") + build_with_nccl_ep = False + if build_with_nccl_ep: + nccl_home = build_nccl_ep_submodule() + cmake_flags.append(f"-DNCCL_INCLUDE_DIR={nccl_home}/include") + else: + cmake_flags.append("-DNVTE_WITH_NCCL_EP=OFF") + # Add custom CMake arguments from environment variable nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") if nvte_cmake_extra_args: @@ -130,6 +155,138 @@ def setup_requirements() -> Tuple[List[str], List[str]]: return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]] +def _discover_nccl_home() -> str: + """Resolve NCCL_HOME: honor env var, else probe well-known prefixes, else ldconfig.""" + env_home = os.environ.get("NCCL_HOME") + if env_home: + if (Path(env_home) / "include" / "nccl.h").exists(): + return env_home + print( + f"[NCCL EP] WARNING: NCCL_HOME='{env_home}' is set but " + f"'{env_home}/include/nccl.h' was not found; falling back to system probes." + ) + + lib_names = ("libnccl.so", "libnccl.so.2") + # Include Debian/Ubuntu multiarch subdirs (e.g. lib/aarch64-linux-gnu). + lib_subdirs = ("lib", "lib64", "lib/aarch64-linux-gnu", "lib/x86_64-linux-gnu") + + # pip-installed NCCL (nvidia-nccl-cu* wheel) lives under nvidia/nccl in + # site-packages and has no top-level include/lib layout. + try: + import importlib.util + + spec = importlib.util.find_spec("nvidia.nccl") + if spec is not None and spec.submodule_search_locations: + pip_root = Path(next(iter(spec.submodule_search_locations))) + if (pip_root / "include" / "nccl.h").exists() and any( + (pip_root / sub / name).exists() for sub in lib_subdirs for name in lib_names + ): + return str(pip_root) + except (ImportError, ValueError): + pass + + for cand in ("/opt/nvidia/nccl", "/usr/local/nccl", "/usr"): + p = Path(cand) + if (p / "include" / "nccl.h").exists() and any( + (p / sub / name).exists() for sub in lib_subdirs for name in lib_names + ): + return str(p) + + try: + out = subprocess.check_output(["ldconfig", "-p"], stderr=subprocess.DEVNULL).decode() + for line in out.splitlines(): + if "libnccl.so" in line and "=>" in line: + lib_path = Path(line.split("=>")[-1].strip()) + # Walk upward so multiarch layouts (.../lib//libnccl.so) + # resolve to the prefix that contains include/nccl.h. + for root in (lib_path.parent.parent, lib_path.parent.parent.parent): + if (root / "include" / "nccl.h").exists(): + return str(root) + except (subprocess.CalledProcessError, FileNotFoundError): + pass + + raise RuntimeError( + "Could not locate NCCL core (nccl.h + libnccl.so). Set NCCL_HOME to the install prefix." + ) + + +def build_nccl_ep_submodule() -> str: + """Build libnccl_ep.a from the 3rdparty/nccl submodule and return NCCL_HOME.""" + nccl_root = current_file_path / "3rdparty" / "nccl" + if not (nccl_root / "Makefile").exists(): + raise RuntimeError( + f"NCCL submodule not found at {nccl_root}. " + "Run `git submodule update --init --recursive`." + ) + + build_dir = nccl_root / "build" + nccl_ep_lib = build_dir / "lib" / "libnccl_ep.a" + gencode_stamp = build_dir / "lib" / "libnccl_ep.gencode" + + # Caller gates on arch >= 90 or "native"; expand "native" to the host's + # actual sm_XX so the build stamp distinguishes machines. + arch_tokens = [a.strip() for a in str(cuda_archs() or "").split(";") if a.strip()] + arch_list: list[str] = [] + for t in arch_tokens: + if t.lower() == "native": + try: + out = subprocess.check_output( + ["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"], + stderr=subprocess.DEVNULL, + ).decode() + except (subprocess.CalledProcessError, FileNotFoundError) as e: + raise RuntimeError( + "NVTE_CUDA_ARCHS=native requires nvidia-smi to resolve the host arch." + ) from e + for line in out.splitlines(): + cap = line.strip().replace(".", "") + if cap.isdigit() and int(cap) >= 90 and cap not in arch_list: + arch_list.append(cap) + else: + bare = t.rstrip("af") + if bare.isdigit() and int(bare) >= 90 and bare not in arch_list: + arch_list.append(bare) + if not arch_list: + raise RuntimeError( + "NCCL EP requires Hopper or newer (SM >= 90); none found in" + f" NVTE_CUDA_ARCHS={cuda_archs()!r}. Re-run with NVTE_WITH_NCCL_EP=0 to skip the NCCL" + " EP build (the rest of TE still builds)." + ) + gencode = " ".join(f"-gencode=arch=compute_{a},code=sm_{a}" for a in arch_list) + + nproc = os.cpu_count() or 8 + env = os.environ.copy() + env["NVCC_GENCODE"] = gencode + # NCCL EP needs the core NCCL headers + libnccl.so; write NCCL EP build + # outputs to the submodule's local build/ tree. + nccl_home = _discover_nccl_home() + env["NCCL_HOME"] = nccl_home + env["NCCL_EP_BUILDDIR"] = str(build_dir) + + prev_gencode = gencode_stamp.read_text().strip() if gencode_stamp.exists() else None + if not nccl_ep_lib.exists() or prev_gencode != gencode: + if nccl_ep_lib.exists() and prev_gencode != gencode: + print( + f"[NCCL EP] gencode changed ('{prev_gencode}' -> '{gencode}'); " + "rebuilding libnccl_ep.a" + ) + subprocess.check_call( + ["make", "-C", "contrib/nccl_ep", "clean"], + cwd=str(nccl_root), + env=env, + ) + print(f"[NCCL EP] Building libnccl_ep.a (gencode='{gencode}')") + subprocess.check_call( + ["make", "-j", str(nproc), "-C", "contrib/nccl_ep", "lib"], + cwd=str(nccl_root), + env=env, + ) + gencode_stamp.parent.mkdir(parents=True, exist_ok=True) + gencode_stamp.write_text(gencode) + + return nccl_home + + def git_check_submodules() -> None: """ Attempt to checkout git submodules automatically during setup. diff --git a/tests/cpp_distributed/CMakeLists.txt b/tests/cpp_distributed/CMakeLists.txt index 44ad7c7384..13b6242816 100644 --- a/tests/cpp_distributed/CMakeLists.txt +++ b/tests/cpp_distributed/CMakeLists.txt @@ -55,10 +55,32 @@ target_include_directories(test_comm_gemm PRIVATE ${test_comm_gemm_INCLUDES}) find_package(CUDAToolkit REQUIRED) find_package(OpenMP REQUIRED) find_package(MPI REQUIRED) + +# -- NCCL core ---------------------------------------------------------------- +# Anchor on libnccl and derive nccl.h from the same install prefix so the +# header and library can't drift across installs. find_library(NCCL_LIB NAMES nccl libnccl - PATH_SUFFIXES lib + HINTS /opt/nvidia/nccl/lib /opt/nvidia/nccl/lib64 + /usr/local/nccl/lib /usr/local/nccl/lib64 + PATH_SUFFIXES lib lib64 REQUIRED) +get_filename_component(_nccl_lib_dir "${NCCL_LIB}" DIRECTORY) +set(NCCL_PREFIX "${_nccl_lib_dir}") +while(NCCL_PREFIX AND NOT EXISTS "${NCCL_PREFIX}/include/nccl.h") + get_filename_component(_nccl_parent "${NCCL_PREFIX}" DIRECTORY) + if(_nccl_parent STREQUAL NCCL_PREFIX) + break() + endif() + set(NCCL_PREFIX "${_nccl_parent}") +endwhile() +find_path(NCCL_INCLUDE_DIR nccl.h + HINTS "${NCCL_PREFIX}/include" + NO_DEFAULT_PATH) +if(NOT NCCL_INCLUDE_DIR) + message(FATAL_ERROR + "nccl.h not found under the prefix of ${NCCL_LIB}.") +endif() list(APPEND test_comm_gemm_LINKER_LIBS CUDA::cuda_driver CUDA::cudart @@ -74,3 +96,37 @@ target_compile_options(test_comm_gemm PRIVATE -O2 -fopenmp) include(GoogleTest) gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) + +# -- EP distributed tests ------------------------------------------------------ +# Launched via mpirun; ncclUniqueId exchange uses MPI_Bcast (see test_ep_common.h). +# The test binary only uses NCCL core symbols (ncclMemAlloc, ncclCommWindow*); +# all ncclEp* calls live behind TE's public , which is +# statically linked into libtransformer_engine.so. +message(STATUS "EP test: NCCL headers: ${NCCL_INCLUDE_DIR}") +set(EP_TEST_COMMON_INCLUDES + ${NCCL_INCLUDE_DIR} + ${MPI_CXX_INCLUDE_PATH} + ../../transformer_engine/common/include + ../../transformer_engine/common + ${CMAKE_CURRENT_SOURCE_DIR}) + +# nvrtc must follow TE_LIB so symbols referenced from libtransformer_engine.so +# (loaded via dlopen in Python; not in its DT_NEEDED) resolve through nvrtc. +set(EP_TEST_COMMON_LIBS + CUDA::cuda_driver + CUDA::cudart + GTest::gtest + ${TE_LIB} + CUDA::nvrtc + ${NCCL_LIB} + MPI::MPI_CXX + OpenMP::OpenMP_CXX) + +# -- EP distributed tests (per-op + full pipeline + zero-copy symm) ----------- +add_executable(test_ep test_ep.cu ../cpp/test_common.cu) +target_include_directories(test_ep PRIVATE ${EP_TEST_COMMON_INCLUDES}) +target_link_libraries(test_ep PUBLIC ${EP_TEST_COMMON_LIBS}) + +# Do NOT use gtest_discover_tests - these binaries require multi-process +# launch via run_test_ep.sh, not direct single-process execution. +message(STATUS "EP distributed tests enabled (NCCL EP statically linked into libtransformer_engine.so)") diff --git a/tests/cpp_distributed/run_test_ep.sh b/tests/cpp_distributed/run_test_ep.sh new file mode 100755 index 0000000000..d486d45f8a --- /dev/null +++ b/tests/cpp_distributed/run_test_ep.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Run TE EP distributed unit tests via mpirun. Each MPI rank pins to one GPU +# (rank % device_count) and exchanges ncclUniqueId through MPI_Bcast. +# +# Usage: +# bash run_test_ep.sh [num_gpus] [build_dir] +# +# Defaults: +# num_gpus = number of GPUs visible to nvidia-smi +# build_dir = /build +# +# Environment variables: +# GTEST_FILTER - forwarded to all processes (e.g., "EPPipelineTest.*") +# GTEST_XML_PREFIX - if set, each rank writes JUnit XML to +# ${GTEST_XML_PREFIX}.rank.xml +# MPIRUN - override the mpirun binary (default: mpirun) +# MPIRUN_EXTRA - extra flags forwarded to mpirun + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BUILD_DIR="${2:-${SCRIPT_DIR}/build}" +NUM_GPUS="${1:-$(nvidia-smi -L 2>/dev/null | wc -l)}" +MPIRUN="${MPIRUN:-mpirun}" + +# Skip cleanly on pre-Hopper: NCCL EP requires SM>=90. +MIN_SM=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \ + | awk -F. 'NR==1 || ($1*10+$2) 0 && MIN_SM < 90 )); then + echo "NCCL EP requires SM>=90 (lowest visible GPU is SM${MIN_SM}); SKIPPING." + exit 0 +fi + +TEST_BIN="${BUILD_DIR}/test_ep" +if [[ ! -x "${TEST_BIN}" ]]; then + echo "ERROR: binary not found: ${TEST_BIN}" + echo "Build: cd ${SCRIPT_DIR} && mkdir -p build && cd build && cmake .. && make" + exit 1 +fi + +if (( NUM_GPUS < 2 )); then + echo "EP Tests: requires at least 2 GPUs, found ${NUM_GPUS}. Skipping." + exit 0 +fi + +GTEST_ARGS="${GTEST_FILTER:+--gtest_filter=${GTEST_FILTER}}" + +echo "=== EP Tests ===" +echo " GPUs: ${NUM_GPUS} Binary: ${TEST_BIN}" +echo + +if [[ -n "${GTEST_XML_PREFIX:-}" ]]; then + # bash -c so OMPI_COMM_WORLD_RANK expands per-rank, avoiding a write race + # on a single shared output path. + "${MPIRUN}" --allow-run-as-root --oversubscribe -n "${NUM_GPUS}" ${MPIRUN_EXTRA:-} bash -c \ + "exec '${TEST_BIN}' ${GTEST_ARGS} --gtest_output=xml:${GTEST_XML_PREFIX}.rank\${OMPI_COMM_WORLD_RANK}.xml" +else + "${MPIRUN}" --allow-run-as-root --oversubscribe -n "${NUM_GPUS}" ${MPIRUN_EXTRA:-} "${TEST_BIN}" ${GTEST_ARGS} +fi diff --git a/tests/cpp_distributed/test_ep.cu b/tests/cpp_distributed/test_ep.cu new file mode 100644 index 0000000000..c7fee7720c --- /dev/null +++ b/tests/cpp_distributed/test_ep.cu @@ -0,0 +1,843 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * EP pipeline tests: smallest-scope first. + * + * EPDispatchTest/PrepareAndDispatch : exact recv values + per-expert counts + * EPCombineTest/Combine : round-trip: out == top_k * tokens + * EPCombineBwdTest/CombineBwdCheck : exact grad_expert values + * EPDispatchBwdTest/DispatchBwdCheck : exact grad_tokens + * EPDispatchBwdGradWeightsTest/RoundTrip : exact per-(t, k) grad_topk_weights + * EPPipelineTest/FullForwardBackward : fwd + bwd NaN/Inf check + * + * Routing: token t on rank r -> expert (r * num_tokens * top_k + t * top_k + k) % num_experts + * Token values: rank r, token t -> all hidden dims = (r+1)*0.01 + t*0.001 + * + * Closed-form expected values: + * dispatch recv: multiset of source-token values routed to this rank's experts + * combine: result[t] == top_k * tokens[t] + * combine_bwd: grad_expert[slot] == d_result[t] (no weighting) + * dispatch_bwd: grad_tokens[t] == top_k * d_result[t] + */ + +#include "test_ep_common.h" + +#include +#include +#include +#include + +// -- Deterministic routing helpers --------------------------------------------- + +// Token value for (rank, t): (rank * num_tokens + t + 1) / 256. Step 1/256 is +// bf16-exact and unique across (rank, t) when rank * num_tokens + t < 256. +static inline float token_value(int rank, int t, int num_tokens) { + return static_cast(rank * num_tokens + t + 1) * (1.0f / 256.0f); +} + +// Per-element host-side conversion helpers used by templated test code. +inline float tok_to_float(nv_bfloat16 v) { return __bfloat162float(v); } +inline float tok_to_float(__half v) { return __half2float(v); } +inline float tok_to_float(float v) { return v; } + +template T tok_from_float(float v); +template <> inline nv_bfloat16 tok_from_float(float v) { return __float2bfloat16(v); } +template <> inline __half tok_from_float<__half> (float v) { return __float2half(v); } +template <> inline float tok_from_float (float v) { return v; } + +template +static std::vector generate_tokens(int rank, int num_tokens, int hidden_dim) { + std::vector v(num_tokens * hidden_dim); + for (int t = 0; t < num_tokens; ++t) { + T val = tok_from_float(token_value(rank, t, num_tokens)); + for (int h = 0; h < hidden_dim; ++h) + v[t * hidden_dim + h] = val; + } + return v; +} + +static std::vector expected_token_counts( + int recv_rank, int num_processes, int num_tokens, int top_k, + int num_experts, int num_local_experts) { + int base = recv_rank * num_local_experts; + std::vector cnt(num_local_experts, 0); + for (int src = 0; src < num_processes; ++src) { + auto idx = routing_balanced(src, num_tokens, top_k, num_experts, num_local_experts); + for (int t = 0; t < num_tokens; ++t) + for (int k = 0; k < top_k; ++k) { + int64_t e = idx[t * top_k + k]; + if (e >= base && e < base + num_local_experts) ++cnt[e - base]; + } + } + return cnt; +} + +template +static std::vector expected_recv_values_sorted( + int recv_rank, int num_processes, int num_tokens, int top_k, + int num_experts, int num_local_experts) { + int base = recv_rank * num_local_experts; + std::vector vals; + for (int src = 0; src < num_processes; ++src) { + auto idx = routing_balanced(src, num_tokens, top_k, num_experts, num_local_experts); + for (int t = 0; t < num_tokens; ++t) + for (int k = 0; k < top_k; ++k) { + int64_t e = idx[t * top_k + k]; + if (e >= base && e < base + num_local_experts) { + float raw = token_value(src, t, num_tokens); + vals.push_back(tok_to_float(tok_from_float(raw))); + } + } + } + std::sort(vals.begin(), vals.end()); + return vals; +} + +// 2^-5 relative tolerance for BF16 (matches mantissa precision with margin), +// plus a small atol floor for near-zero expected values. +static constexpr float kBf16Rtol = 1.0f / 32.0f; +static constexpr float kBf16Atol = 1e-3f; +static float bf16_tol(float magnitude) { + return kBf16Atol + kBf16Rtol * std::fabs(magnitude); +} + +template +static bool check_no_nan_inf(const T* dev, int count, const char* name) { + std::vector h(count); + cudaMemcpy(h.data(), dev, count * sizeof(T), cudaMemcpyDeviceToHost); + for (int i = 0; i < count; ++i) { + float v = tok_to_float(h[i]); + if (std::isnan(v) || std::isinf(v)) { + fprintf(stderr, "Rank %d: %s in %s[%d]\n", + g_process_id, std::isnan(v) ? "NaN" : "Inf", name, i); + return false; + } + } + return true; +} + +// -- Forward buffer set with RAII ---------------------------------------------- + +template +struct EPBuffers { + // Forward + DevBuf topk_idx; + DevBuf topk_weights; + DevBuf tokens; + DevBuf token_counts; + DevBuf handle_mem; + DevBuf recv_tokens; + DevBuf recv_topk_weights; + DevBuf result; + // Backward + DevBuf grad_result; + DevBuf grad_expert; + DevBuf grad_tokens; + DevBuf g_recv_topk_weights; + DevBuf grad_topk_weights; + + size_t handle_mem_size = 0; + size_t recv_capacity = 0; + int top_k_ = 0; + size_t alignment_ = 0; + + void alloc(int num_tokens, int top_k, int hidden_dim, int num_local_experts, + int ep_size, int max_tokens_per_rank, size_t alignment = 0) { + top_k_ = top_k; + alignment_ = alignment; + recv_capacity = static_cast(ep_size) * max_tokens_per_rank * 2; + + topk_idx.alloc(num_tokens * top_k); + topk_weights.alloc(num_tokens * top_k); + tokens.alloc(num_tokens * hidden_dim); + token_counts.alloc(num_local_experts); + recv_tokens.alloc(recv_capacity * hidden_dim); + recv_topk_weights.alloc(recv_capacity); + result.alloc(num_tokens * hidden_dim); + + handle_mem_size = nvte_ep_handle_mem_size(NVTEEpLayerConfig{top_k, alignment}); + handle_mem.alloc(handle_mem_size); + + grad_result.alloc(num_tokens * hidden_dim); + grad_expert.alloc(recv_capacity * hidden_dim); + grad_tokens.alloc(num_tokens * hidden_dim); + g_recv_topk_weights.alloc(recv_capacity); + grad_topk_weights.alloc(num_tokens * top_k); + } +}; + +// Bundled NVTETensor views over an EPBuffers, with the shapes the EP C API +// expects. +template +struct EPTensors { + TensorWrapper topk_idx, topk_weights, token_counts, handle_mem, tokens; + TensorWrapper recv_tokens, recv_topk_weights, result; + TensorWrapper grad_result, grad_expert, grad_tokens; + TensorWrapper g_recv_topk_weights, grad_topk_weights; + + int top_k_ = 0; + size_t alignment_ = 0; + + EPTensors(EPBuffers& b, int num_tokens, int top_k, int hidden_dim, + int num_local_experts) { + top_k_ = top_k; + alignment_ = b.alignment_; + constexpr DType kTokDType = test::TypeInfo::dtype; + using Shape = std::vector; + topk_idx = TensorWrapper(b.topk_idx.get(), + Shape{(size_t)num_tokens, (size_t)top_k}, DType::kInt64); + topk_weights = TensorWrapper(b.topk_weights.get(), + Shape{(size_t)num_tokens, (size_t)top_k}, DType::kFloat32); + token_counts = TensorWrapper(b.token_counts.get(), + Shape{(size_t)num_local_experts}, DType::kInt32); + handle_mem = TensorWrapper(b.handle_mem.get(), + Shape{b.handle_mem_size}, DType::kByte); + tokens = TensorWrapper(b.tokens.get(), + Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType); + recv_tokens = TensorWrapper(b.recv_tokens.get(), + Shape{b.recv_capacity, (size_t)hidden_dim}, kTokDType); + recv_topk_weights = TensorWrapper(b.recv_topk_weights.get(), + Shape{b.recv_capacity}, DType::kFloat32); + result = TensorWrapper(b.result.get(), + Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType); + grad_result = TensorWrapper(b.grad_result.get(), + Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType); + grad_expert = TensorWrapper(b.grad_expert.get(), + Shape{b.recv_capacity, (size_t)hidden_dim}, kTokDType); + grad_tokens = TensorWrapper(b.grad_tokens.get(), + Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType); + g_recv_topk_weights = TensorWrapper(b.g_recv_topk_weights.get(), + Shape{b.recv_capacity}, DType::kFloat32); + grad_topk_weights = TensorWrapper(b.grad_topk_weights.get(), + Shape{(size_t)num_tokens, (size_t)top_k}, DType::kFloat32); + } +}; + +// -- Shared fixture base ------------------------------------------------------- + +class EpOpTestBase : public ::testing::Test { + protected: + int ep_size_, num_experts_, num_local_experts_, hidden_dim_; + int max_tokens_per_rank_, top_k_, num_tokens_; + + void SetUp() override { + if (g_sm_major < 9) + GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)"; + ASSERT_GE(g_num_processes, 2); + ASSERT_TRUE(g_ep_initialized); + + ep_size_ = g_ep_size; + num_experts_ = g_num_experts; + num_local_experts_ = num_experts_ / ep_size_; + hidden_dim_ = g_hidden_dim; + max_tokens_per_rank_ = g_max_tokens_per_rank; + top_k_ = 2; + num_tokens_ = 32; + } + + template + void upload_inputs(EPBuffers& buf, int rank = -1) { + if (rank < 0) rank = g_process_id; + auto h_idx = routing_balanced(rank, num_tokens_, top_k_, + num_experts_, num_local_experts_); + std::vector h_w(num_tokens_ * top_k_, 1.0f / top_k_); + auto h_tok = generate_tokens(rank, num_tokens_, hidden_dim_); + + NVTE_CHECK_CUDA(cudaMemcpy(buf.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(buf.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(T), cudaMemcpyHostToDevice)); + } + + // NVTE_CHECK_CUDA (fprintf+exit) so this non-void helper stays legal. + template + int read_total_recv(const EPBuffers& buf) const { + std::vector cnt(num_local_experts_); + NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(), + num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); + int total = 0; + for (int c : cnt) total += c; + return total; + } +}; + +// Pull non-dependent base members into the typed-test scope as local consts so +// the bodies can reference them unqualified. +#define EP_PULL_FIXTURE() \ + const int ep_size_ = this->ep_size_; \ + const int num_experts_ = this->num_experts_; \ + const int num_local_experts_ = this->num_local_experts_; \ + const int hidden_dim_ = this->hidden_dim_; \ + const int max_tokens_per_rank_ = this->max_tokens_per_rank_; \ + const int top_k_ = this->top_k_; \ + const int num_tokens_ = this->num_tokens_ + +// ============================================================================= +// EPDispatchTest: exact recv values and per-expert counts. +// ============================================================================= + +template class EPDispatchTest : public EpOpTestBase {}; +using EPBf16Only = ::testing::Types; +TYPED_TEST_SUITE(EPDispatchTest, EPBf16Only); + +TYPED_TEST(EPDispatchTest, PrepareAndDispatch) { + using Tok = TypeParam; + EP_PULL_FIXTURE(); + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + this->template upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + NVTE_CHECK_CUDA(cudaMemset(buf.recv_tokens.get(), 0, buf.recv_tokens.bytes())); + + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + // 1. Per-expert counts. + std::vector got_counts(num_local_experts_); + NVTE_CHECK_CUDA(cudaMemcpy(got_counts.data(), buf.token_counts.get(), + num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); + auto exp_counts = expected_token_counts(g_process_id, g_num_processes, num_tokens_, top_k_, + num_experts_, num_local_experts_); + int total_recv = 0; + for (int i = 0; i < num_local_experts_; ++i) { + EXPECT_EQ(got_counts[i], exp_counts[i]) << "local expert " << i; + total_recv += exp_counts[i]; + } + ASSERT_LE(total_recv, static_cast(buf.recv_capacity)) + << "total_recv exceeded recv_capacity; overflow would corrupt downstream memory"; + + // 2. Recv values: read only the filled prefix per local-expert zone, not the + // whole recv buffer; avoids false positives from legitimate-zero token values. + std::vector h_recv(buf.recv_capacity * hidden_dim_); + NVTE_CHECK_CUDA(cudaMemcpy(h_recv.data(), buf.recv_tokens.get(), + h_recv.size() * sizeof(Tok), cudaMemcpyDeviceToHost)); + + std::vector got_vals; + got_vals.reserve(total_recv); + size_t slot = 0; + for (int e = 0; e < num_local_experts_; ++e) { + for (int i = 0; i < got_counts[e]; ++i) { + got_vals.push_back(tok_to_float(h_recv[slot * hidden_dim_])); + ++slot; + } + } + std::sort(got_vals.begin(), got_vals.end()); + + auto exp_vals = expected_recv_values_sorted(g_process_id, g_num_processes, num_tokens_, + top_k_, num_experts_, num_local_experts_); + + ASSERT_EQ(got_vals.size(), exp_vals.size()); + for (size_t i = 0; i < exp_vals.size(); ++i) + EXPECT_EQ(got_vals[i], exp_vals[i]) + << "recv value mismatch at sorted index " << i; + + // 3. recv_topk_weights: every filled slot must equal the per-token weight (1/top_k). + std::vector h_w(buf.recv_capacity); + NVTE_CHECK_CUDA(cudaMemcpy(h_w.data(), buf.recv_topk_weights.get(), + h_w.size() * sizeof(float), cudaMemcpyDeviceToHost)); + const float exp_w = 1.0f / static_cast(top_k_); + for (int i = 0; i < total_recv; ++i) + EXPECT_NEAR(h_w[i], exp_w, 1e-6f) << "recv_topk_weights[" << i << "]"; + + if (g_process_id == 0) + printf(" PrepareAndDispatch: passed (recv=%d, values + weights exact)\n", total_recv); + + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPCombineTest: round-trip identity expert -> result == top_k * tokens. +// ============================================================================= + +template class EPCombineTest : public EpOpTestBase {}; +TYPED_TEST_SUITE(EPCombineTest, EPBf16Only); + +TYPED_TEST(EPCombineTest, Combine) { + using Tok = TypeParam; + EP_PULL_FIXTURE(); + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + this->template upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(t.handle_mem.data(), t.recv_tokens.data(), NVTECommWindow{}, + t.result.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_result(num_tokens_ * hidden_dim_); + NVTE_CHECK_CUDA(cudaMemcpy(h_result.data(), buf.result.get(), + h_result.size() * sizeof(Tok), cudaMemcpyDeviceToHost)); + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + for (int tok = 0; tok < num_tokens_; ++tok) { + float exp = tok_to_float(h_tok[tok * hidden_dim_]) * static_cast(top_k_); + for (int p = 0; p < hidden_dim_; ++p) { + float got = tok_to_float(h_result[tok * hidden_dim_ + p]); + EXPECT_NEAR(got, exp, bf16_tol(exp)) + << "token " << tok << " rank " << g_process_id << " hidden " << p; + } + } + + if (g_process_id == 0) + printf(" Combine: passed (result == top_k * tokens)\n"); + + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPCombineBwdTest: filled slots in grad_expert == d_result (unweighted). +// ============================================================================= + +template class EPCombineBwdTest : public EpOpTestBase {}; +TYPED_TEST_SUITE(EPCombineBwdTest, EPBf16Only); + +TYPED_TEST(EPCombineBwdTest, CombineBwdCheck) { + using Tok = TypeParam; + EP_PULL_FIXTURE(); + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + this->template upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(t.handle_mem.data(), t.recv_tokens.data(), NVTECommWindow{}, + t.result.data(), stream)); + + std::vector h_grad_r(num_tokens_ * hidden_dim_, tok_from_float(0.1f)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad_r.data(), + h_grad_r.size() * sizeof(Tok), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(t.handle_mem.data(), t.grad_result.data(), NVTECommWindow{}, + t.grad_expert.data(), NVTECommWindow{}, stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + int total_recv = this->template read_total_recv(buf); + + std::vector cnt(num_local_experts_); + NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(), + num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); + std::vector h_ge(buf.recv_capacity * hidden_dim_); + NVTE_CHECK_CUDA(cudaMemcpy(h_ge.data(), buf.grad_expert.get(), + h_ge.size() * sizeof(Tok), cudaMemcpyDeviceToHost)); + + // Walk filled slots by per-expert zone (no v != 0 heuristic). + const float kExpGrad = tok_to_float(tok_from_float(0.1f)); + size_t slot = 0; + int filled = 0; + for (int e = 0; e < num_local_experts_; ++e) { + for (int i = 0; i < cnt[e]; ++i) { + for (int p = 0; p < hidden_dim_; ++p) { + float v = tok_to_float(h_ge[slot * hidden_dim_ + p]); + EXPECT_NEAR(v, kExpGrad, bf16_tol(kExpGrad)) + << "grad_expert expert " << e << " slot " << i + << " (linear " << slot << ") hidden " << p; + } + ++filled; ++slot; + } + } + EXPECT_EQ(filled, total_recv); + + if (g_process_id == 0) + printf(" CombineBwdCheck: passed (filled=%d)\n", filled); + + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPDispatchBwdTest: grad_tokens == top_k * d_result. +// ============================================================================= + +template class EPDispatchBwdTest : public EpOpTestBase {}; +TYPED_TEST_SUITE(EPDispatchBwdTest, EPBf16Only); + +TYPED_TEST(EPDispatchBwdTest, DispatchBwdCheck) { + using Tok = TypeParam; + EP_PULL_FIXTURE(); + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + this->template upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(t.handle_mem.data(), t.recv_tokens.data(), NVTECommWindow{}, + t.result.data(), stream)); + + std::vector h_grad(num_tokens_ * hidden_dim_, tok_from_float(0.1f)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), + h_grad.size() * sizeof(Tok), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(t.handle_mem.data(), t.grad_result.data(), NVTECommWindow{}, + t.grad_expert.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(t.handle_mem.data(), t.grad_expert.data(), NVTECommWindow{}, + t.g_recv_topk_weights.data(), NVTECommWindow{}, + t.grad_tokens.data(), t.grad_topk_weights.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_gt(num_tokens_ * hidden_dim_); + NVTE_CHECK_CUDA(cudaMemcpy(h_gt.data(), buf.grad_tokens.get(), + h_gt.size() * sizeof(Tok), cudaMemcpyDeviceToHost)); + const float kExpGrad = static_cast(top_k_) * tok_to_float(tok_from_float(0.1f)); + for (int tok = 0; tok < num_tokens_; ++tok) + for (int p = 0; p < hidden_dim_; ++p) + EXPECT_NEAR(tok_to_float(h_gt[tok * hidden_dim_ + p]), kExpGrad, + bf16_tol(kExpGrad)) + << "grad_tokens token " << tok << " hidden " << p; + + if (g_process_id == 0) + printf(" DispatchBwdCheck: passed (grad_tokens == %.2f)\n", kExpGrad); + + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPDispatchBwdGradWeightsTest: round-trip per-(t, k) weights. +// ============================================================================= + +template class EPDispatchBwdGradWeightsTest : public EpOpTestBase {}; +TYPED_TEST_SUITE(EPDispatchBwdGradWeightsTest, EPBf16Only); + +TYPED_TEST(EPDispatchBwdGradWeightsTest, RoundTrip) { + using Tok = TypeParam; + EP_PULL_FIXTURE(); + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + this->template upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + // Distinct per-(rank, t, k) weights so each slot carries a unique value. + // Global integer counter over (rank, tok, k) keeps every slot unique. + std::vector h_w(num_tokens_ * top_k_); + for (int tok = 0; tok < num_tokens_; ++tok) + for (int k = 0; k < top_k_; ++k) + h_w[tok * top_k_ + k] = static_cast( + (g_process_id * num_tokens_ + tok) * top_k_ + k + 1); + NVTE_CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.recv_topk_weights.get(), 0, + buf.recv_topk_weights.bytes(), stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + + // Sentinel: NaN so any (t, k) the bwd kernel fails to write is immediately visible. + std::vector h_nan(num_tokens_ * top_k_, + std::numeric_limits::quiet_NaN()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_topk_weights.get(), h_nan.data(), + h_nan.size() * sizeof(float), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + + // g_recv_topk_weights := recv_topk_weights (the round-trip input). + auto g_recv_t = TensorWrapper(buf.recv_topk_weights.get(), + std::vector{buf.recv_capacity}, DType::kFloat32); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(t.handle_mem.data(), t.grad_expert.data(), + NVTECommWindow{}, g_recv_t.data(), NVTECommWindow{}, + t.grad_tokens.data(), t.grad_topk_weights.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_grad_w(num_tokens_ * top_k_); + NVTE_CHECK_CUDA(cudaMemcpy(h_grad_w.data(), buf.grad_topk_weights.get(), + h_grad_w.size() * sizeof(float), cudaMemcpyDeviceToHost)); + + const float kTol = 1e-5f; + int errs = 0, k0_eq_k1 = 0; + for (int tok = 0; tok < num_tokens_; ++tok) { + for (int k = 0; k < top_k_; ++k) { + float got = h_grad_w[tok * top_k_ + k]; + float exp = h_w[tok * top_k_ + k]; + if (std::isnan(got) || std::fabs(got - exp) > kTol) { + if (errs < 8) + fprintf(stderr, "Rank %d: grad_topk_weights[%d, %d]: got %.6f, expected %.6f\n", + g_process_id, tok, k, got, exp); + ++errs; + } + } + if (top_k_ >= 2 && + std::fabs(h_grad_w[tok * top_k_ + 0] - h_grad_w[tok * top_k_ + 1]) < 1e-7f) + ++k0_eq_k1; + } + EXPECT_EQ(errs, 0); + EXPECT_EQ(k0_eq_k1, 0) << "per-token-average regression: grad[t, 0] == grad[t, 1]"; + + if (g_process_id == 0 && errs == 0 && k0_eq_k1 == 0) + printf(" RoundTrip: passed (%d (t, k) gradients)\n", num_tokens_ * top_k_); + + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// Integrated FwdBwd: NaN/Inf check end-to-end. +// ============================================================================= + +class EPPipelineTest : public EpOpTestBase, public ::testing::WithParamInterface { + protected: + template + void run_full_forward_backward() { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(t.handle_mem.data(), t.recv_tokens.data(), NVTECommWindow{}, + t.result.data(), stream)); + + std::vector h_grad(num_tokens_ * hidden_dim_, tok_from_float(0.1f)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), + h_grad.size() * sizeof(Tok), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(t.handle_mem.data(), t.grad_result.data(), NVTECommWindow{}, + t.grad_expert.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(t.handle_mem.data(), t.grad_expert.data(), NVTECommWindow{}, + t.g_recv_topk_weights.data(), NVTECommWindow{}, + t.grad_tokens.data(), t.grad_topk_weights.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + ASSERT_TRUE(check_no_nan_inf(buf.result.get(), num_tokens_ * hidden_dim_, "result")); + ASSERT_TRUE(check_no_nan_inf(buf.grad_tokens.get(), num_tokens_ * hidden_dim_, "grad_tokens")); + + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); + } +}; + +TEST_P(EPPipelineTest, FullForwardBackward) { + const DType dtype = GetParam(); + // NCCL EP backend currently asserts ncclBfloat16 in ncclEpDispatch + // (contrib/nccl_ep/nccl_ep.cc); skip FP16/FP32 until the backend supports them. + if (dtype != DType::kBFloat16) { + GTEST_SKIP() << test::typeName(dtype) << " not yet supported by NCCL EP backend"; + } + switch (dtype) { + case DType::kBFloat16: run_full_forward_backward(); break; + case DType::kFloat16: run_full_forward_backward<__half> (); break; + case DType::kFloat32: run_full_forward_backward (); break; + default: FAIL() << "unsupported token dtype " << static_cast(dtype); + } + if (g_process_id == 0) + printf(" FullForwardBackward[%s]: passed\n", test::typeName(dtype).c_str()); +} + +INSTANTIATE_TEST_SUITE_P( + Dtypes, EPPipelineTest, + ::testing::Values(DType::kBFloat16, DType::kFloat16, DType::kFloat32), + [](const ::testing::TestParamInfo& info) { + return test::typeName(info.param); + }); + +// ============================================================================= +// EPZeroCopyTest: dispatch/combine with NCCL symmetric-memory windows attached +// to payload tensors (zero-copy fast path via ncclEpTensorCreateFromWindow). +// Symm-mem requirements per spec: input&output of Dispatch, input of Combine, +// input&output of Combine bwd, input of Dispatch bwd. +// ============================================================================= + +namespace { + +// Caller-owned ncclMemAlloc'd buffer with a registered symmetric window. +// Frees in destructor (deregister + ncclMemFree). Non-copyable, move-only. +struct SymmBuf { + void* ptr = nullptr; + size_t bytes = 0; + ncclWindow_t win = nullptr; + + SymmBuf() = default; + SymmBuf(const SymmBuf&) = delete; + SymmBuf& operator=(const SymmBuf&) = delete; + SymmBuf(SymmBuf&& o) noexcept : ptr(o.ptr), bytes(o.bytes), win(o.win) { + o.ptr = nullptr; o.win = nullptr; o.bytes = 0; + } + ~SymmBuf() { + if (win) ncclCommWindowDeregister(g_ep_comm, win); + if (ptr) ncclMemFree(ptr); + } + + void alloc(size_t n_bytes) { + bytes = n_bytes; + NVTE_CHECK_NCCL(ncclMemAlloc(&ptr, bytes)); + NVTE_CHECK_CUDA(cudaMemset(ptr, 0, bytes)); + NVTE_CHECK_NCCL(ncclCommWindowRegister(g_ep_comm, ptr, bytes, &win, + NCCL_WIN_COLL_SYMMETRIC)); + } +}; + +// Build an NVTECommWindow descriptor pointing at a SymmBuf's window (offset 0). +static inline NVTECommWindow symm_window(const SymmBuf& b) { + return NVTECommWindow{b.win, /*offset=*/0}; +} + +} // namespace + +// Tests rebootstrap the backend to zero_copy=ON for the symm phase via +// ep_reinitialize(); TearDown restores OFF for the rest of the suite. +template +class EPZeroCopyTest : public EpOpTestBase { + protected: + void TearDown() override { + if (g_ep_initialized) ep_reinitialize(/*zero_copy=*/0); + } +}; +TYPED_TEST_SUITE(EPZeroCopyTest, EPBf16Only); + +// Identity round-trip with symm-mem on dispatch i/o + combine input. Bit-exact +// vs HBM reference (same routing, same input). +TYPED_TEST(EPZeroCopyTest, IdentityAllSymm) { + using Tok = TypeParam; + EP_PULL_FIXTURE(); + constexpr DType kTokDType = test::TypeInfo::dtype; + + // HBM reference run. + EPBuffers ref_buf; + ref_buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + this->template upload_inputs(ref_buf); + EPTensors ref_t(ref_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + ASSERT_NO_THROW(nvte_ep_prepare(ref_t.handle_mem.data(), ref_t.topk_idx.data(), ref_t.token_counts.data(), NVTEEpLayerConfig{ref_t.top_k_, ref_t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(ref_t.handle_mem.data(), ref_t.topk_idx.data(), + ref_t.tokens.data(), NVTECommWindow{}, ref_t.topk_weights.data(), + NVTECommWindow{}, ref_t.recv_tokens.data(), NVTECommWindow{}, + ref_t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(ref_t.handle_mem.data(), ref_t.recv_tokens.data(), NVTECommWindow{}, + ref_t.result.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector ref_recv(ref_buf.recv_capacity * hidden_dim_); + std::vector ref_result(num_tokens_ * hidden_dim_); + NVTE_CHECK_CUDA(cudaMemcpy(ref_recv.data(), ref_buf.recv_tokens.get(), + ref_recv.size() * sizeof(Tok), cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(ref_result.data(), ref_buf.result.get(), + ref_result.size() * sizeof(Tok), cudaMemcpyDeviceToHost)); + + // Switch backend to zero_copy=ON for the symm phase. + ep_reinitialize(/*zero_copy=*/1); + + // Symm-mem run: tokens, recv_tokens, combine_input (== recv_tokens) all symm. + EPBuffers sym_buf; // alloc all buffers except the symm ones. + sym_buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + this->template upload_inputs(sym_buf); + + SymmBuf sym_tokens, sym_recv; + sym_tokens.alloc(num_tokens_ * hidden_dim_ * sizeof(Tok)); + sym_recv .alloc(sym_buf.recv_capacity * hidden_dim_ * sizeof(Tok)); + + // Stage same tokens into the symm-mem input. + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + NVTE_CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), + h_tok.size() * sizeof(Tok), cudaMemcpyHostToDevice)); + + EPTensors sym_t(sym_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + // Replace the tokens/recv_tokens views with ones pointing at the symm buffers. + sym_t.tokens = TensorWrapper(sym_tokens.ptr, + std::vector{(size_t)num_tokens_, (size_t)hidden_dim_}, kTokDType); + sym_t.recv_tokens = TensorWrapper(sym_recv.ptr, + std::vector{sym_buf.recv_capacity, (size_t)hidden_dim_}, kTokDType); + + ASSERT_NO_THROW(nvte_ep_prepare(sym_t.handle_mem.data(), sym_t.topk_idx.data(), sym_t.token_counts.data(), NVTEEpLayerConfig{sym_t.top_k_, sym_t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(sym_t.handle_mem.data(), sym_t.topk_idx.data(), + sym_t.tokens.data(), symm_window(sym_tokens), + sym_t.topk_weights.data(), NVTECommWindow{}, + sym_t.recv_tokens.data(), symm_window(sym_recv), + sym_t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(sym_t.handle_mem.data(), sym_t.recv_tokens.data(), + symm_window(sym_recv), sym_t.result.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector sym_recv_host(sym_buf.recv_capacity * hidden_dim_); + std::vector sym_result(num_tokens_ * hidden_dim_); + NVTE_CHECK_CUDA(cudaMemcpy(sym_recv_host.data(), sym_recv.ptr, + sym_recv_host.size() * sizeof(Tok), cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(sym_result.data(), sym_buf.result.get(), + sym_result.size() * sizeof(Tok), cudaMemcpyDeviceToHost)); + + // Compare per filled recv slot (HBM ref vs symm) and full result. + int total_recv = this->template read_total_recv(sym_buf); + for (int i = 0; i < total_recv * hidden_dim_; ++i) + ASSERT_EQ(tok_to_float(sym_recv_host[i]), tok_to_float(ref_recv[i])) + << "recv mismatch at " << i; + for (size_t i = 0; i < sym_result.size(); ++i) + ASSERT_EQ(tok_to_float(sym_result[i]), tok_to_float(ref_result[i])) + << "result mismatch at " << i; + + if (g_process_id == 0) + printf(" IdentityAllSymm: passed (recv_slots=%d, bit-exact vs HBM)\n", total_recv); + + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); +} + + +// -- main ---------------------------------------------------------------------- + +int main(int argc, char* argv[]) { + if (!ep_bootstrap(argc, argv)) return 0; + int ret = RUN_ALL_TESTS(); + ep_teardown(); + return ret; +} diff --git a/tests/cpp_distributed/test_ep_common.h b/tests/cpp_distributed/test_ep_common.h new file mode 100644 index 0000000000..d5e006cef6 --- /dev/null +++ b/tests/cpp_distributed/test_ep_common.h @@ -0,0 +1,201 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * Shared TE EP test infrastructure. Include once per TU; ep_bootstrap() in + * each test binary's main() populates process-level globals. + * Defaults: 4 experts/rank, hidden_dim=256, max_tokens_per_rank=64. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include "../cpp/test_common.h" +#include "util/logging.h" + +using transformer_engine::DType; +using transformer_engine::TensorWrapper; + +#define CHECK_MPI(expr) \ + do { \ + int _err_mpi = (expr); \ + NVTE_CHECK(_err_mpi == MPI_SUCCESS, "MPI error: ", _err_mpi); \ + } while (false) + +// -- Process-level state ------------------------------------------------------- + +static int g_process_id = -1; +static int g_num_processes = -1; + +static int g_sm_major = -1; // set by ep_bootstrap; -1 until then +static int g_ep_size = -1; +static int g_num_experts = -1; +static int g_hidden_dim = 256; +static int g_max_tokens_per_rank = 64; +static NVTEDType g_max_token_dtype = kNVTEFloat32; // staging-buffer sizing +static bool g_ep_initialized = false; +static ncclComm_t g_ep_comm = nullptr; // owned by harness, destroyed in ep_teardown + +// RAII owner for a cudaMalloc'd device buffer; element-count API on top of +// test::CudaPtr. +template +struct DevBuf { + test::CudaPtr ptr; + size_t count = 0; + + DevBuf() = default; + explicit DevBuf(size_t n) { alloc(n); } + + void alloc(size_t n) { + count = n; + ptr = (n > 0) ? test::cuda_alloc(n * sizeof(T)) : test::CudaPtr{}; + } + void reset() { + ptr.reset(); + count = 0; + } + + T* get() const { return ptr.get(); } + size_t bytes() const { return count * sizeof(T); } +}; + +// -- Shared routing helper ----------------------------------------------------- + +// Balanced round-robin routing: token t on rank r maps top_k experts to +// (r * num_tokens * top_k + t * top_k + k) % num_experts +// i.e. a single global counter over all (rank, t, k) triples mod num_experts. +static inline std::vector routing_balanced( + int rank, int num_tokens, int top_k, int num_experts, int /*num_local_experts*/) { + std::vector idx(num_tokens * top_k); + for (int t = 0; t < num_tokens; ++t) + for (int k = 0; k < top_k; ++k) + idx[t * top_k + k] = (rank * num_tokens * top_k + t * top_k + k) % num_experts; + return idx; +} + +// -- ncclUniqueId exchange via MPI --------------------------------------------- + +static void exchange_unique_id(ncclUniqueId* uid) { + if (g_process_id == 0) NVTE_CHECK_NCCL(ncclGetUniqueId(uid)); + CHECK_MPI(MPI_Bcast(uid, sizeof(*uid), MPI_BYTE, 0, MPI_COMM_WORLD)); +} + +// -- CLI parsing --------------------------------------------------------------- + +static void ep_parse_args(int argc, char* argv[]) { + for (int i = 1; i < argc; ++i) { + std::string a(argv[i]); + if (a.rfind("--max-token-dtype=", 0) == 0) + g_max_token_dtype = static_cast(std::stoi(a.substr(18))); + } +} + +// -- Bootstrap / teardown ------------------------------------------------------ + +// Returns false if the binary should exit without running tests (wrong SM, etc.). +static bool ep_bootstrap(int argc, char* argv[]) { + int mpi_initialized = 0; + MPI_Initialized(&mpi_initialized); + if (!mpi_initialized) CHECK_MPI(MPI_Init(&argc, &argv)); + CHECK_MPI(MPI_Comm_rank(MPI_COMM_WORLD, &g_process_id)); + CHECK_MPI(MPI_Comm_size(MPI_COMM_WORLD, &g_num_processes)); + + ep_parse_args(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + + int device_count; + cudaGetDeviceCount(&device_count); + cudaSetDevice(g_process_id % device_count); + + int device, major; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); + g_sm_major = major; + if (major < 9) { + if (g_process_id == 0) + printf("SKIP: EP requires SM_90+ (device is SM_%d0)\n", major); + return false; + } + if (g_num_processes < 2) { + if (g_process_id == 0) + printf("SKIP: at least 2 processes required\n"); + return false; + } + + g_ep_size = g_num_processes; + g_num_experts = g_ep_size * 4; // 4 experts per rank + + ncclUniqueId uid{}; + exchange_unique_id(&uid); + + NVTEEpGroupConfig group_config{}; + group_config.ep_size = g_ep_size; + group_config.num_experts = g_num_experts; + group_config.max_tokens_per_rank = g_max_tokens_per_rank; + // Worst-case for top_k fan-out: ep_size * max_tokens_per_rank * 2. + group_config.max_recv_tokens_per_rank = g_ep_size * g_max_tokens_per_rank * 2; + group_config.hidden_dim = g_hidden_dim; + group_config.max_token_dtype = g_max_token_dtype; + + NVTE_CHECK_NCCL(ncclCommInitRank(&g_ep_comm, g_num_processes, uid, g_process_id)); + nvte_ep_initialize(static_cast(g_ep_comm), group_config); + + if (g_process_id == 0) { + printf("EP initialized: ep_size=%d num_experts=%d " + "hidden_dim=%d max_tokens_per_rank=%d\n", + g_ep_size, g_num_experts, g_hidden_dim, g_max_tokens_per_rank); + } + + g_ep_initialized = true; + return true; +} + +// Re-bootstrap the EP backend on the existing g_ep_comm with a new zero_copy +// setting. +static void ep_reinitialize(int zero_copy) { + if (!g_ep_initialized) return; + nvte_ep_shutdown(); + NVTEEpGroupConfig group_config{}; + group_config.ep_size = g_ep_size; + group_config.num_experts = g_num_experts; + group_config.max_tokens_per_rank = g_max_tokens_per_rank; + group_config.max_recv_tokens_per_rank = g_ep_size * g_max_tokens_per_rank * 2; + group_config.hidden_dim = g_hidden_dim; + group_config.max_token_dtype = g_max_token_dtype; + group_config.zero_copy = zero_copy; + nvte_ep_initialize(static_cast(g_ep_comm), group_config); +} + +// Tear down in dependency order: backend's ep_group reads from ep_comm, +// so destroy the group first, then the comm. +static void ep_teardown() { + if (g_ep_initialized) { + nvte_ep_shutdown(); + if (g_ep_comm != nullptr) { + ncclCommDestroy(g_ep_comm); + g_ep_comm = nullptr; + } + g_ep_initialized = false; + } + int finalized = 0; + MPI_Finalized(&finalized); + if (!finalized) MPI_Finalize(); +} diff --git a/transformer_engine/__init__.py b/transformer_engine/__init__.py index 0175f04e2e..480a2e9a06 100644 --- a/transformer_engine/__init__.py +++ b/transformer_engine/__init__.py @@ -6,10 +6,54 @@ # pylint: disable=unused-import +import ctypes +import functools import os from importlib import metadata +from typing import Optional, Tuple import transformer_engine.common +# Minimum NCCL version for the statically-linked NCCL EP backend. +_NCCL_EP_MIN_VERSION = (2, 30, 4) + + +@functools.lru_cache(maxsize=1) +def _nccl_runtime_version() -> Optional[Tuple[int, int, int]]: + """Return runtime (major, minor, patch) from libnccl.so.2, or None if unavailable.""" + try: + libnccl = ctypes.CDLL("libnccl.so.2", mode=ctypes.RTLD_LOCAL) + ncclGetVersion = libnccl.ncclGetVersion + except (OSError, AttributeError): + return None + ver = ctypes.c_int(0) + if ncclGetVersion(ctypes.byref(ver)) != 0: + return None + v = ver.value + return (v // 10000, (v // 100) % 100, v % 100) + + +def is_nccl_ep_available() -> bool: + """Return True if the runtime libnccl.so meets the NCCL EP minimum.""" + cur = _nccl_runtime_version() + return cur is not None and cur >= _NCCL_EP_MIN_VERSION + + +def require_nccl_ep() -> None: + """Raise RuntimeError if NCCL EP cannot run on the current libnccl.""" + mn = ".".join(str(x) for x in _NCCL_EP_MIN_VERSION) + cur = _nccl_runtime_version() + if cur is None: + raise RuntimeError( + f"NCCL EP requires NCCL >= {mn}; could not load libnccl.so.2 or query its " + "version. Install NCCL or ensure libnccl.so.2 is on the loader path." + ) + if cur < _NCCL_EP_MIN_VERSION: + raise RuntimeError( + f"NCCL EP requires NCCL >= {mn} at runtime; found " + f"{'.'.join(str(x) for x in cur)}. Upgrade NCCL to a compatible version." + ) + + try: from . import pytorch except ImportError: diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 8f96432ed8..edb8c5e109 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -437,6 +437,79 @@ if (NVTE_WITH_CUSOLVERMP) message(STATUS "Using cuSolverMp at: ${CUSOLVERMP_DIR}") endif() +# -- NCCL EP (on by default, HT mode only) --------------------------------- +# Set -DNVTE_WITH_NCCL_EP=OFF (or NVTE_WITH_NCCL_EP=0 in setup.py) to +# skip NCCL EP entirely - useful on older images whose system NCCL is below +# the 2.30.4 EP minimum. +option(NVTE_WITH_NCCL_EP "Build NCCL EP into libtransformer_engine.so" ON) +if(NVTE_WITH_NCCL_EP) +# SM>=90 and NCCL>=2.30.4 are gated at runtime in EPBackend::initialize. +# -- NCCL EP headers -------------------------------------------------------- +# Headers + libs are produced by the in-tree 3rdparty/nccl submodule build +# (auto-built by setup.py via build_nccl_ep_submodule). +set(NCCL_EP_SUBMODULE_ROOT + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl") +set(NCCL_EP_INCLUDE_DIR "${NCCL_EP_SUBMODULE_ROOT}/build/include") +if(NOT EXISTS "${NCCL_EP_INCLUDE_DIR}/nccl_ep.h") + message(FATAL_ERROR + "NCCL EP header not found at ${NCCL_EP_INCLUDE_DIR}/nccl_ep.h. " + "Run `git submodule update --init --recursive` and rebuild TE.") +endif() +message(STATUS "NCCL EP headers: ${NCCL_EP_INCLUDE_DIR}") + +# -- libnccl_ep.a ----------------------------------------------------------- +# Statically linked into libtransformer_engine.so. EPBackend::initialize checks +# NCCL >= 2.30.4 before any nccl_ep call, so the newer NCCL symbols nccl_ep +# imports stay unresolved (and harmless) under default ELF lazy binding when +# the gate trips. LD_BIND_NOW environments lose this property. +set(NCCL_EP_LIB_DIR "${NCCL_EP_SUBMODULE_ROOT}/build/lib") +find_file(NCCL_EP_LIB + NAMES libnccl_ep.a + HINTS ${NCCL_EP_LIB_DIR} + NO_DEFAULT_PATH + REQUIRED) + +# -- NCCL core: nccl.h + libnccl.so ----------------------------------------- +# setup.py passes -DNCCL_INCLUDE_DIR; standalone CMake falls back to probing +# well-known NCCL install prefixes. +find_path(NCCL_INCLUDE_DIR nccl.h + HINTS /opt/nvidia/nccl/include /usr/local/nccl/include) +if(NOT NCCL_INCLUDE_DIR) + message(FATAL_ERROR + "nccl.h not found. Pass -DNCCL_INCLUDE_DIR=/include.") +endif() +if(NOT NCCL_LIB) + find_library(NCCL_LIB + NAMES nccl libnccl + PATH_SUFFIXES lib lib64 + REQUIRED) +endif() + +target_include_directories(transformer_engine PRIVATE + ${NCCL_EP_INCLUDE_DIR} + ${NCCL_INCLUDE_DIR}) + +# libnccl.so direct symbols (ncclGetVersion etc.) come from libnccl_ep.a's +# DT_NEEDED chain plus this TU's own references. CUDA::cuda_driver must follow +# the static archive on the link line so --as-needed records libcuda.so.1. +target_link_libraries(transformer_engine PUBLIC ${NCCL_LIB}) +target_link_libraries(transformer_engine PRIVATE + -Wl,--whole-archive ${NCCL_EP_LIB} -Wl,--no-whole-archive + CUDA::cuda_driver) + +target_sources(transformer_engine PRIVATE + ep/ep_backend.cpp + ep/ep_api.cpp) +target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_NCCL_EP) + +message(STATUS "NCCL EP enabled (static link): ${NCCL_EP_LIB}") +message(STATUS "NCCL EP include: ${NCCL_EP_INCLUDE_DIR}") +else() + # NCCL EP off: ep_api.cpp's #else branch exports throwing nvte_ep_* stubs. + target_sources(transformer_engine PRIVATE ep/ep_api.cpp) + message(STATUS "NCCL EP disabled (NVTE_WITH_NCCL_EP=OFF) - using nvte_ep_* stubs") +endif() + # Number of philox4x32 rounds for stochastic rounding (build-time constant). set(NVTE_BUILD_NUM_PHILOX_ROUNDS_STR $ENV{NVTE_BUILD_NUM_PHILOX_ROUNDS}) if (NOT NVTE_BUILD_NUM_PHILOX_ROUNDS_STR) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index fd2d146616..42b458bfc5 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -380,7 +380,7 @@ def _load_cuda_library(lib_name: str): @functools.lru_cache(maxsize=None) def _load_core_library(): """Load shared library with Transformer Engine C extensions""" - return ctypes.CDLL(_get_shared_object_file("core"), mode=ctypes.RTLD_GLOBAL) + return ctypes.CDLL(_get_shared_object_file("core"), mode=ctypes.RTLD_GLOBAL | os.RTLD_LAZY) if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): diff --git a/transformer_engine/common/ep/ep_api.cpp b/transformer_engine/common/ep/ep_api.cpp new file mode 100644 index 0000000000..66ee3dc8d9 --- /dev/null +++ b/transformer_engine/common/ep/ep_api.cpp @@ -0,0 +1,130 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_api.cpp + * \brief nvte_ep_* C API: thin delegations to the EPBackend singleton. + * + * When NVTE_WITH_NCCL_EP is undefined, the entry points become throwing + * stubs so framework bindings still link without NCCL EP support. + */ + +#include + +#include "../util/logging.h" + +#if defined(NVTE_WITH_NCCL_EP) + +#include + +#include "../common.h" +#include "ep_backend.h" + +using transformer_engine::ep::EPBackend; + +void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config) { + NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null"); + EPBackend::initialize(static_cast(ep_comm), group_config); +} + +void nvte_ep_shutdown(void) { EPBackend::shutdown(); } + +size_t nvte_ep_handle_mem_size(NVTEEpLayerConfig layer_cfg) { + return EPBackend::get().handle_mem_size(layer_cfg); +} + +namespace { +inline void* handle_mem_ptr(NVTETensor mem) { + void* p = nvte_tensor_data(mem); + NVTE_CHECK(p != nullptr, "handle_mem tensor data must not be null"); + return p; +} +} // namespace + +void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor token_counts, + NVTEEpLayerConfig layer_cfg, cudaStream_t stream) { + EPBackend::get().prepare(handle_mem_ptr(handle_mem), topk_idx, token_counts, layer_cfg, stream); +} + +void nvte_ep_dispatch(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor tokens, + NVTECommWindow tokens_win, NVTETensor topk_weights, + NVTECommWindow topk_weights_win, NVTETensor recv_tokens, + NVTECommWindow recv_tokens_win, NVTETensor recv_topk_weights, + NVTECommWindow recv_topk_weights_win, cudaStream_t stream) { + EPBackend::get().dispatch(handle_mem_ptr(handle_mem), topk_idx, tokens, tokens_win, topk_weights, + topk_weights_win, recv_tokens, recv_tokens_win, recv_topk_weights, + recv_topk_weights_win, stream); +} + +void nvte_ep_combine(NVTETensor handle_mem, NVTETensor expert_out, NVTECommWindow expert_out_win, + NVTETensor result, cudaStream_t stream) { + EPBackend::get().combine(handle_mem_ptr(handle_mem), expert_out, expert_out_win, result, stream); +} + +void nvte_ep_dispatch_bwd(NVTETensor handle_mem, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor g_recv_topk_weights, NVTECommWindow g_recv_topk_weights_win, + NVTETensor grad_tokens, NVTETensor grad_topk_weights, + cudaStream_t stream) { + EPBackend::get().dispatch_bwd(handle_mem_ptr(handle_mem), grad, grad_win, g_recv_topk_weights, + g_recv_topk_weights_win, grad_tokens, grad_topk_weights, stream); +} + +void nvte_ep_combine_bwd(NVTETensor handle_mem, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor grad_expert_out, NVTECommWindow grad_expert_out_win, + cudaStream_t stream) { + EPBackend::get().combine_bwd(handle_mem_ptr(handle_mem), grad, grad_win, grad_expert_out, + grad_expert_out_win, stream); +} + +#else // !NVTE_WITH_NCCL_EP - throwing stubs. + +namespace { +[[noreturn]] void ep_not_built() { + NVTE_ERROR( + "NCCL EP is not built into this TransformerEngine. Rebuild TE with " + "NVTE_WITH_NCCL_EP=1 and CUDA arch >= 90 (e.g. NVTE_CUDA_ARCHS=\"90\")."); +} +} // namespace + +void nvte_ep_initialize(void* /*ep_comm*/, NVTEEpGroupConfig /*group_config*/) { ep_not_built(); } + +void nvte_ep_shutdown(void) {} + +size_t nvte_ep_handle_mem_size(NVTEEpLayerConfig /*layer_cfg*/) { ep_not_built(); } + +void nvte_ep_prepare(NVTETensor /*handle_mem*/, NVTETensor /*topk_idx*/, + NVTETensor /*token_counts*/, NVTEEpLayerConfig /*layer_cfg*/, + cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_dispatch(NVTETensor /*handle_mem*/, NVTETensor /*topk_idx*/, NVTETensor /*tokens*/, + NVTECommWindow /*tokens_win*/, NVTETensor /*topk_weights*/, + NVTECommWindow /*topk_weights_win*/, NVTETensor /*recv_tokens*/, + NVTECommWindow /*recv_tokens_win*/, NVTETensor /*recv_topk_weights*/, + NVTECommWindow /*recv_topk_weights_win*/, cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_combine(NVTETensor /*handle_mem*/, NVTETensor /*expert_out*/, + NVTECommWindow /*expert_out_win*/, NVTETensor /*result*/, + cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_dispatch_bwd(NVTETensor /*handle_mem*/, NVTETensor /*grad*/, + NVTECommWindow /*grad_win*/, NVTETensor /*g_recv_topk_weights*/, + NVTECommWindow /*g_recv_topk_weights_win*/, NVTETensor /*grad_tokens*/, + NVTETensor /*grad_topk_weights*/, cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_combine_bwd(NVTETensor /*handle_mem*/, NVTETensor /*grad*/, + NVTECommWindow /*grad_win*/, NVTETensor /*grad_expert_out*/, + NVTECommWindow /*grad_expert_out_win*/, cudaStream_t /*stream*/) { + ep_not_built(); +} + +#endif // NVTE_WITH_NCCL_EP diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp new file mode 100644 index 0000000000..f1510693bb --- /dev/null +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -0,0 +1,482 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_backend.cpp + * \brief EPBackend implementation. See ep_backend.h for the op flow. + */ + +#include "ep_backend.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "../common.h" +#include "../util/cuda_runtime.h" +#include "../util/logging.h" + +namespace transformer_engine { +namespace ep { + +namespace { + +ncclDataType_t te_dtype_to_nccl_dtype(NVTEDType dtype) { + switch (dtype) { + case kNVTEFloat32: + return ncclFloat32; + case kNVTEFloat16: + return ncclFloat16; + case kNVTEBFloat16: + return ncclBfloat16; + case kNVTEInt32: + return ncclInt32; + case kNVTEInt64: + return ncclInt64; + case kNVTEByte: + return ncclUint8; + case kNVTEFloat8E4M3: + return ncclFloat8e4m3; + case kNVTEFloat8E5M2: + return ncclFloat8e5m2; + default: + NVTE_ERROR("Unsupported NVTEDType for NCCL dtype conversion: ", static_cast(dtype)); + } + return ncclFloat32; // unreachable +} + +// shape_out is caller-owned; desc.sizes aliases shape_out.data and must +// outlive the NCCL EP call. +inline ncclEpTensor_t make_nccl_ep_tensor(const NVTETensor t, NVTEShape& shape_out, + const NVTECommWindow& win = {}) { + shape_out = nvte_tensor_shape(t); + ncclEpTensor_t desc = NCCL_EP_TENSOR_INIT; + desc.ndim = shape_out.ndim; + desc.sizes = shape_out.data; + desc.datatype = te_dtype_to_nccl_dtype(nvte_tensor_type(t)); + if (win.window != nullptr) { + desc.win_hdl = win.window; + desc.win_offset = win.offset; + } else { + desc.data = nvte_tensor_data(t); + NVTE_CHECK(desc.data != nullptr, "tensor data must not be null"); + } + return desc; +} + +} // namespace + +// --------------------------------------------------------------------------- +// Singleton + bootstrap +// --------------------------------------------------------------------------- + +EPBackend& EPBackend::instance() { + static EPBackend inst; + return inst; +} + +EPBackend& EPBackend::get() { + EPBackend& inst = instance(); + NVTE_CHECK(inst.initialized_, "EPBackend not initialized. Call nvte_ep_initialize() first."); + return inst; +} + +void EPBackend::validate_config(const NVTEEpGroupConfig& config) { + NVTE_CHECK(config.ep_size > 0, "ep_size must be positive, got ", config.ep_size); + NVTE_CHECK(config.num_experts > 0, "num_experts must be positive, got ", config.num_experts); + NVTE_CHECK(config.max_tokens_per_rank > 0, "max_tokens_per_rank must be positive, got ", + config.max_tokens_per_rank); + NVTE_CHECK(config.max_recv_tokens_per_rank > 0, "max_recv_tokens_per_rank must be positive, got ", + config.max_recv_tokens_per_rank); + NVTE_CHECK(config.hidden_dim > 0, "hidden_dim must be positive, got ", config.hidden_dim); + NVTE_CHECK(config.max_token_dtype >= 0 && config.max_token_dtype < kNVTENumTypes, + "max_token_dtype out of range, got ", static_cast(config.max_token_dtype)); + const size_t elem_bytes = typeToSize(static_cast(config.max_token_dtype)); + const size_t row_bytes = static_cast(config.hidden_dim) * elem_bytes; + NVTE_CHECK(row_bytes >= 16, + "hidden_dim * sizeof(max_token_dtype) must be >= 16 (NCCL EP 16B row alignment); " + "got hidden_dim=", + config.hidden_dim, ", element_bytes=", elem_bytes); + // NCCL EP packs row size into ncclEpGroupConfig::max_token_bytes (unsigned int). + NVTE_CHECK(row_bytes <= static_cast(UINT_MAX), + "hidden_dim * sizeof(max_token_dtype) exceeds 4 GiB; got ", row_bytes, " bytes"); + NVTE_CHECK(config.num_experts % config.ep_size == 0, "num_experts (", config.num_experts, + ") must be divisible by ep_size (", config.ep_size, ")"); + NVTE_CHECK(config.max_num_sms >= 0, "max_num_sms must be >= 0 (0 = auto), got ", + config.max_num_sms); + + const int sm = cuda::sm_arch(); + NVTE_CHECK(sm >= 90, "NCCL EP requires SM_90+ (Hopper or later), but current device is SM_", sm); +} + +void EPBackend::initialize(ncclComm_t ep_comm, NVTEEpGroupConfig config) { + EPBackend& inst = instance(); + std::lock_guard lock(inst.mutex_); + NVTE_CHECK(!inst.initialized_, "EP already initialized. Call initialize only once per process."); + NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null"); + + // Runtime gate: NCCL >= 2.30.4 (matches the submodule pin). + constexpr int kMinNcclVersion = 23004; + int nccl_version = 0; + NVTE_CHECK_NCCL(ncclGetVersion(&nccl_version)); + NVTE_CHECK(nccl_version >= kMinNcclVersion, "NCCL EP requires NCCL >= 2.30.4, found ", + nccl_version / 10000, ".", (nccl_version / 100) % 100, ".", nccl_version % 100, + " at runtime."); + + validate_config(config); + + int comm_size = 0; + NVTE_CHECK_NCCL(ncclCommCount(ep_comm, &comm_size)); + NVTE_CHECK(comm_size == config.ep_size, "ep_comm size (", comm_size, ") must equal ep_size (", + config.ep_size, "). Pass the EP sub-communicator, not the world comm."); + + inst.init(ep_comm, config); +} + +void EPBackend::shutdown() { + EPBackend& inst = instance(); + std::lock_guard lock(inst.mutex_); + if (!inst.initialized_) return; + for (auto& e : inst.lru_) { + if (e.handle != nullptr) ncclEpHandleDestroy(e.handle); + } + inst.lru_.clear(); + inst.index_.clear(); + inst.fallback_layer_cfg_.reset(); + // ncclEpGroupDestroy reads from ep_comm_; destroy group while comm is still alive. + if (inst.ep_group_ != nullptr) { + ncclEpGroupDestroy(inst.ep_group_); + inst.ep_group_ = nullptr; + } + inst.ep_comm_ = nullptr; // borrowed; caller destroys + inst.initialized_ = false; +} + +ncclEpHandle_t EPBackend::open_handle(void* handle_mem, size_t handle_mem_size, int num_topk, + size_t dispatch_output_per_expert_alignment) { + size_t hm_sizes[1] = {handle_mem_size}; + ncclEpTensor_t routing_desc = NCCL_EP_TENSOR_INIT; + routing_desc.ndim = 1; + routing_desc.datatype = ncclUint8; + routing_desc.data = handle_mem; + routing_desc.sizes = hm_sizes; + ncclEpHandleConfig_t hcfg = NCCL_EP_HANDLE_CONFIG_INIT; + hcfg.dispatch_output_per_expert_alignment = dispatch_output_per_expert_alignment; + ncclEpHandle_t handle; + NVTE_CHECK_NCCL(ncclEpInitHandle(&handle, ep_group_, NCCL_EP_LAYOUT_EXPERT_MAJOR, &hcfg, num_topk, + &routing_desc)); + return handle; +} + +// --------------------------------------------------------------------------- +// Lifecycle +// --------------------------------------------------------------------------- + +// Static-dtor teardown: skip NCCL calls (CUDA context / borrowed ep_comm_ may +// already be gone) and release in-memory state only. +EPBackend::~EPBackend() { + std::lock_guard lock(mutex_); + if (!initialized_) return; + lru_.clear(); + index_.clear(); + fallback_layer_cfg_.reset(); + ep_group_ = nullptr; + ep_comm_ = nullptr; + initialized_ = false; +} + +void EPBackend::init(ncclComm_t ep_comm, NVTEEpGroupConfig group_config) { + NVTE_CHECK(!initialized_, "EPBackend already initialized"); + + group_config_ = group_config; + + ncclEpGroupConfig_t cfg = NCCL_EP_GROUP_CONFIG_INIT; + cfg.algorithm = NCCL_EP_ALGO_HIGH_THROUGHPUT; + cfg.num_experts = static_cast(group_config.num_experts); + cfg.max_dispatch_tokens_per_rank = static_cast(group_config.max_tokens_per_rank); + const size_t elem_bytes = typeToSize(static_cast(group_config.max_token_dtype)); + cfg.max_token_bytes = static_cast(group_config.hidden_dim * elem_bytes); + cfg.rdma_buffer_size = NCCL_EP_AUTO; + cfg.num_qp_per_rank = NCCL_EP_AUTO; + cfg.num_channels = NCCL_EP_AUTO; + cfg.max_num_sms = group_config.max_num_sms > 0 + ? static_cast(group_config.max_num_sms) + : NCCL_EP_AUTO; + // Must be > 0; NCCL EP errors out on 0. + cfg.max_recv_tokens_per_rank = static_cast(group_config.max_recv_tokens_per_rank); + cfg.zero_copy = group_config.zero_copy ? NCCL_EP_ZERO_COPY_ON : NCCL_EP_ZERO_COPY_OFF; + + NVTE_CHECK_NCCL(ncclEpCreateGroup(&ep_group_, ep_comm, &cfg)); + + ep_comm_ = ep_comm; + + initialized_ = true; +} + +// --------------------------------------------------------------------------- +// Pointer-keyed LRU cache +// --------------------------------------------------------------------------- + +size_t EPBackend::cache_cap_locked() { + if (handle_cache_cap_ == 0) { + const char* cap_env = std::getenv("NVTE_EP_HANDLE_CACHE_SIZE"); + if (cap_env != nullptr) { + const int64_t v = static_cast(std::atol(cap_env)); + if (v < 0) { + // Unlimited cache. WAR for JAX until XLA fixes handle_mem + // reloc between runs. + handle_cache_cap_ = SIZE_MAX; + } else { + NVTE_CHECK(v > 0, + "NVTE_EP_HANDLE_CACHE_SIZE=0 is invalid; use -1 for unlimited or a positive " + "cap."); + handle_cache_cap_ = static_cast(v); + } + } else { + handle_cache_cap_ = 4096; + } + } + return handle_cache_cap_; +} + +ncclEpHandle_t EPBackend::prepare_handle_locked(void* handle_mem, NVTEEpLayerConfig layer_cfg) { + // Update the program-wide fallback cfg so dispatch/combine/_bwd can + // reconstruct the handle on a pointer-cache miss (WAR for XLA buffer reloc + // between runs; one cfg per process). Remove this once XLA preserves the + // handle_mem device pointer across runs. + if (fallback_layer_cfg_.has_value()) { + NVTE_CHECK(fallback_layer_cfg_->top_k == layer_cfg.top_k, "EP prepare top_k=", layer_cfg.top_k, + " disagrees with process-wide cached top_k=", fallback_layer_cfg_->top_k); + NVTE_CHECK(fallback_layer_cfg_->dispatch_output_per_expert_alignment == + layer_cfg.dispatch_output_per_expert_alignment, + "EP prepare alignment=", layer_cfg.dispatch_output_per_expert_alignment, + " disagrees with process-wide cached alignment=", + fallback_layer_cfg_->dispatch_output_per_expert_alignment); + } else { + fallback_layer_cfg_ = layer_cfg; + } + + auto it = index_.find(handle_mem); + if (it != index_.end()) { + lru_.splice(lru_.begin(), lru_, it->second); + return it->second->handle; + } + ncclEpHandleConfig_t hcfg = NCCL_EP_HANDLE_CONFIG_INIT; + hcfg.dispatch_output_per_expert_alignment = layer_cfg.dispatch_output_per_expert_alignment; + size_t hm_size = 0; + NVTE_CHECK_NCCL(ncclEpHandleMemSize(ep_group_, NCCL_EP_LAYOUT_EXPERT_MAJOR, &hcfg, &hm_size, + layer_cfg.top_k)); + ncclEpHandle_t h = open_handle(handle_mem, hm_size, layer_cfg.top_k, + layer_cfg.dispatch_output_per_expert_alignment); + lru_.push_front(HandleEntry{handle_mem, h, layer_cfg, hm_size}); + index_.emplace(handle_mem, lru_.begin()); + while (lru_.size() > cache_cap_locked()) { + HandleEntry& victim = lru_.back(); + if (victim.handle != nullptr) ncclEpHandleDestroy(victim.handle); + index_.erase(victim.handle_mem); + lru_.pop_back(); + } + return h; +} + +ncclEpHandle_t EPBackend::lookup_handle_locked(void* handle_mem) { + auto it = index_.find(handle_mem); + if (it != index_.end()) { + lru_.splice(lru_.begin(), lru_, it->second); + return it->second->handle; + } + // Miss: reconstruct from the process-wide cached cfg. XLA may relocate + // handle_mem between runs, breaking the pointer key; the fallback cfg lets + // us open a fresh handle on the new buffer. Drop this branch once XLA + // preserves buffer pointers. + const uintptr_t hm_addr = reinterpret_cast(handle_mem); + NVTE_CHECK(fallback_layer_cfg_.has_value(), "ep op on handle_mem=0x", hm_addr, + " with no cached entry and no prior nvte_ep_prepare; call prepare first."); + return prepare_handle_locked(handle_mem, *fallback_layer_cfg_); +} + +// --------------------------------------------------------------------------- +// Per-step operations +// --------------------------------------------------------------------------- + +size_t EPBackend::handle_mem_size(NVTEEpLayerConfig layer_cfg) { + NVTE_CHECK(layer_cfg.top_k > 0, "top_k must be > 0, got ", layer_cfg.top_k); + std::lock_guard lock(mutex_); + NVTE_CHECK(initialized_, "EPBackend not initialized"); + ncclEpHandleConfig_t hcfg = NCCL_EP_HANDLE_CONFIG_INIT; + hcfg.dispatch_output_per_expert_alignment = layer_cfg.dispatch_output_per_expert_alignment; + size_t hm_size = 0; + NVTE_CHECK_NCCL(ncclEpHandleMemSize(ep_group_, NCCL_EP_LAYOUT_EXPERT_MAJOR, &hcfg, &hm_size, + layer_cfg.top_k)); + return hm_size; +} + +void EPBackend::prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor token_counts, + NVTEEpLayerConfig layer_cfg, cudaStream_t stream) { + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + NVTE_CHECK(layer_cfg.top_k > 0, "top_k must be > 0, got ", layer_cfg.top_k); + NVTE_CHECK(nvte_tensor_shape(topk_idx).ndim == 2, "topk_idx must be 2D [T, top_k]"); + + NVTEShape topk_idx_shape; + ncclEpTensor_t nccl_topk_idx = make_nccl_ep_tensor(topk_idx, topk_idx_shape); + + // ncclEpUpdateHandle writes per-expert counts via expert_counters. + NVTEShape token_counts_shape; + ncclEpTensor_t token_counts_desc; + if (token_counts != nullptr) { + token_counts_desc = make_nccl_ep_tensor(token_counts, token_counts_shape); + } + ncclEpLayoutInfo_t layout_info = NCCL_EP_LAYOUT_INFO_INIT; + layout_info.expert_counters = (token_counts != nullptr) ? &token_counts_desc : nullptr; + + std::lock_guard lock(mutex_); + NVTE_CHECK(initialized_, "EPBackend not initialized"); + ncclEpHandle_t h = prepare_handle_locked(handle_mem, layer_cfg); + NVTE_CHECK_NCCL(ncclEpUpdateHandle(h, &nccl_topk_idx, &layout_info, stream)); +} + +void EPBackend::dispatch(void* handle_mem, const NVTETensor topk_idx, const NVTETensor tokens, + const NVTECommWindow& tokens_win, const NVTETensor topk_weights, + const NVTECommWindow& topk_weights_win, NVTETensor recv_tokens, + const NVTECommWindow& recv_tokens_win, NVTETensor recv_topk_weights, + const NVTECommWindow& recv_topk_weights_win, cudaStream_t stream) { + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + NVTE_CHECK(nvte_tensor_shape(tokens).ndim == 2, "tokens must be 2D [T, hidden_dim]"); + NVTE_CHECK(nvte_tensor_shape(recv_tokens).ndim == 2, + "recv_tokens must be 2D [recv_T, hidden_dim]"); + + NVTEDType tok_dtype = nvte_tensor_type(tokens); + NVTE_CHECK(typeToSize(static_cast(tok_dtype)) <= + typeToSize(static_cast(group_config_.max_token_dtype)), + "tokens dtype (", static_cast(tok_dtype), ") wider than group max_token_dtype (", + static_cast(group_config_.max_token_dtype), ")"); + NVTEDType recv_dtype = nvte_tensor_type(recv_tokens); + NVTE_CHECK(typeToSize(static_cast(recv_dtype)) <= + typeToSize(static_cast(group_config_.max_token_dtype)), + "recv_tokens dtype (", static_cast(recv_dtype), + ") wider than group max_token_dtype (", + static_cast(group_config_.max_token_dtype), ")"); + + NVTEShape tokens_shape, recv_tokens_shape; + ncclEpTensor_t nccl_tokens_in = make_nccl_ep_tensor(tokens, tokens_shape, tokens_win); + ncclEpTensor_t nccl_tokens_out = + make_nccl_ep_tensor(recv_tokens, recv_tokens_shape, recv_tokens_win); + + // Routing is cached in handle_mem by ep_prepare; dispatch only needs + // topk_weights to reconstruct the sparse-to-dense prob map. + const bool is_forward = (topk_weights != nullptr); + NVTEShape topk_weights_shape, recv_topk_weights_shape; + ncclEpTensor_t nccl_topk_weights_in; + ncclEpTensor_t nccl_topk_weights_out; + if (is_forward) { + NVTE_CHECK(topk_idx != nullptr, "topk_idx required in forward dispatch"); + NVTE_CHECK(nvte_tensor_shape(topk_idx).ndim == 2, "topk_idx must be 2D [T, top_k]"); + NVTE_CHECK(nvte_tensor_shape(topk_weights).ndim == 2, "topk_weights must be 2D [T, top_k]"); + NVTE_CHECK(recv_topk_weights != nullptr, + "recv_topk_weights must not be null in forward dispatch"); + NVTE_CHECK(nvte_tensor_shape(recv_topk_weights).ndim == 1, + "recv_topk_weights must be 1D [recv_capacity]"); + nccl_topk_weights_in = make_nccl_ep_tensor(topk_weights, topk_weights_shape, topk_weights_win); + nccl_topk_weights_out = + make_nccl_ep_tensor(recv_topk_weights, recv_topk_weights_shape, recv_topk_weights_win); + } + + ncclEpDispatchInputs_t in_struct = NCCL_EP_DISPATCH_INPUTS_INIT; + in_struct.tokens = &nccl_tokens_in; + in_struct.topk_weights = is_forward ? &nccl_topk_weights_in : nullptr; + + ncclEpDispatchOutputs_t out_struct = NCCL_EP_DISPATCH_OUTPUTS_INIT; + out_struct.tokens = &nccl_tokens_out; + out_struct.topk_weights = is_forward ? &nccl_topk_weights_out : nullptr; + + ncclEpDispatchConfig_t dispatch_cfg = NCCL_EP_DISPATCH_CONFIG_INIT; + dispatch_cfg.pass_direction = is_forward ? NCCL_EP_FWD_PASS : NCCL_EP_BWD_PASS; + + std::lock_guard lock(mutex_); + NVTE_CHECK(initialized_, "EPBackend not initialized"); + ncclEpHandle_t h = lookup_handle_locked(handle_mem); + NVTE_CHECK_NCCL(ncclEpDispatch(h, &in_struct, &out_struct, + /*layout_info=*/nullptr, &dispatch_cfg, stream)); +} + +void EPBackend::combine(void* handle_mem, const NVTETensor expert_out, + const NVTECommWindow& expert_out_win, NVTETensor result, + cudaStream_t stream) { + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + NVTE_CHECK(nvte_tensor_shape(expert_out).ndim == 2, "expert_out must be 2D [recv_T, hidden_dim]"); + NVTE_CHECK(nvte_tensor_shape(result).ndim == 2, "result must be 2D [T, hidden_dim]"); + + NVTEShape expert_out_shape, result_shape; + ncclEpTensor_t nccl_expert_in = make_nccl_ep_tensor(expert_out, expert_out_shape, expert_out_win); + ncclEpTensor_t nccl_result_out = make_nccl_ep_tensor(result, result_shape); + + ncclEpCombineInputs_t in_struct = NCCL_EP_COMBINE_INPUTS_INIT; + in_struct.tokens = &nccl_expert_in; + + ncclEpCombineOutputs_t out_struct = NCCL_EP_COMBINE_OUTPUTS_INIT; + out_struct.tokens = &nccl_result_out; + + std::lock_guard lock(mutex_); + NVTE_CHECK(initialized_, "EPBackend not initialized"); + ncclEpHandle_t h = lookup_handle_locked(handle_mem); + NVTE_CHECK_NCCL(ncclEpCombine(h, &in_struct, &out_struct, /*config=*/nullptr, stream)); +} + +void EPBackend::dispatch_bwd(void* handle_mem, const NVTETensor grad, + const NVTECommWindow& grad_win, const NVTETensor g_recv_topk_weights, + const NVTECommWindow& g_recv_topk_weights_win, NVTETensor grad_tokens, + NVTETensor grad_topk_weights, cudaStream_t stream) { + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + NVTE_CHECK(nvte_tensor_shape(grad).ndim == 2, "grad must be 2D [recv_capacity, hidden_dim]"); + NVTE_CHECK(nvte_tensor_shape(grad_tokens).ndim == 2, "grad_tokens must be 2D [T, hidden_dim]"); + + // g_recv_topk_weights must be 1D [recv_capacity]; caller flattens. + NVTE_CHECK(nvte_tensor_shape(g_recv_topk_weights).ndim == 1, + "g_recv_topk_weights must be 1D [recv_capacity]; caller must flatten leading dims"); + NVTE_CHECK(nvte_tensor_shape(grad_topk_weights).ndim == 2, + "grad_topk_weights must be 2D [T, top_k]"); + + NVTEShape grad_shape, g_recv_w_shape, grad_tokens_shape, grad_w_shape; + ncclEpTensor_t nccl_tok_in = make_nccl_ep_tensor(grad, grad_shape, grad_win); + ncclEpTensor_t nccl_w_in = + make_nccl_ep_tensor(g_recv_topk_weights, g_recv_w_shape, g_recv_topk_weights_win); + ncclEpTensor_t nccl_tok_out = make_nccl_ep_tensor(grad_tokens, grad_tokens_shape); + ncclEpTensor_t nccl_w_out = make_nccl_ep_tensor(grad_topk_weights, grad_w_shape); + + ncclEpCombineInputs_t in_struct = NCCL_EP_COMBINE_INPUTS_INIT; + in_struct.tokens = &nccl_tok_in; + in_struct.topk_weights = &nccl_w_in; + + ncclEpCombineOutputs_t out_struct = NCCL_EP_COMBINE_OUTPUTS_INIT; + out_struct.tokens = &nccl_tok_out; + out_struct.topk_weights = &nccl_w_out; + + ncclEpCombineConfig_t cfg = NCCL_EP_COMBINE_CONFIG_INIT; + cfg.pass_direction = NCCL_EP_BWD_PASS; + + std::lock_guard lock(mutex_); + NVTE_CHECK(initialized_, "EPBackend not initialized"); + ncclEpHandle_t h = lookup_handle_locked(handle_mem); + NVTE_CHECK_NCCL(ncclEpCombine(h, &in_struct, &out_struct, &cfg, stream)); +} + +void EPBackend::combine_bwd(void* handle_mem, const NVTETensor grad, const NVTECommWindow& grad_win, + NVTETensor grad_expert_out, const NVTECommWindow& grad_expert_out_win, + cudaStream_t stream) { + // Backward of combine = reverse-direction dispatch. + dispatch(handle_mem, /*topk_idx=*/nullptr, grad, grad_win, + /*topk_weights=*/nullptr, /*topk_weights_win=*/NVTECommWindow{}, grad_expert_out, + grad_expert_out_win, + /*recv_topk_weights=*/nullptr, /*recv_topk_weights_win=*/NVTECommWindow{}, stream); +} + +} // namespace ep +} // namespace transformer_engine diff --git a/transformer_engine/common/ep/ep_backend.h b/transformer_engine/common/ep/ep_backend.h new file mode 100644 index 0000000000..2325baafca --- /dev/null +++ b/transformer_engine/common/ep/ep_backend.h @@ -0,0 +1,116 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_backend.h + * \brief Internal NCCL EP singleton; not part of the public API. See ep.h. + */ + +#ifndef TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_ +#define TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace transformer_engine { +namespace ep { + +/*! \brief EP backend singleton; owns the NCCL EP group, borrows the comm. */ +class EPBackend { + public: + /*! \brief Access the singleton. Aborts if not initialized. */ + static EPBackend& get(); + + /*! \brief Bootstrap from an existing EP sub-communicator. + * ep_comm is borrowed; the caller keeps it alive until shutdown() returns + * and must span exactly config.ep_size ranks. + */ + static void initialize(ncclComm_t ep_comm, NVTEEpGroupConfig config); + + /*! \brief Tear down the backend. Idempotent. Does not destroy ep_comm_. */ + static void shutdown(); + + // Host-only: report handle_mem byte size for layer_cfg. + size_t handle_mem_size(NVTEEpLayerConfig layer_cfg); + + // Seeds the cache for handle_mem with layer_cfg and runs the routing AllGather. + void prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor token_counts, + NVTEEpLayerConfig layer_cfg, cudaStream_t stream); + + // Per-step ops below require a prior prepare(). + void dispatch(void* handle_mem, const NVTETensor topk_idx, const NVTETensor tokens, + const NVTECommWindow& tokens_win, const NVTETensor topk_weights, + const NVTECommWindow& topk_weights_win, NVTETensor recv_tokens, + const NVTECommWindow& recv_tokens_win, NVTETensor recv_topk_weights, + const NVTECommWindow& recv_topk_weights_win, cudaStream_t stream); + + void combine(void* handle_mem, const NVTETensor expert_out, const NVTECommWindow& expert_out_win, + NVTETensor result, cudaStream_t stream); + + // g_recv_topk_weights: 1D [recv_capacity] f32; grad_topk_weights: 2D [T, top_k] f32. + void dispatch_bwd(void* handle_mem, const NVTETensor grad, const NVTECommWindow& grad_win, + const NVTETensor g_recv_topk_weights, + const NVTECommWindow& g_recv_topk_weights_win, NVTETensor grad_tokens, + NVTETensor grad_topk_weights, cudaStream_t stream); + + void combine_bwd(void* handle_mem, const NVTETensor grad, const NVTECommWindow& grad_win, + NVTETensor grad_expert_out, const NVTECommWindow& grad_expert_out_win, + cudaStream_t stream); + + private: + EPBackend() = default; + ~EPBackend(); + EPBackend(const EPBackend&) = delete; + EPBackend& operator=(const EPBackend&) = delete; + + // ep_comm is borrowed; caller retains ownership across the backend lifetime. + void init(ncclComm_t ep_comm, NVTEEpGroupConfig config); + + static EPBackend& instance(); // Meyers singleton accessor + static void validate_config(const NVTEEpGroupConfig& config); + + // Open a fresh ncclEpHandle over handle_mem. num_topk=-1 for paths + // that don't carry per-token weights. + ncclEpHandle_t open_handle(void* handle_mem, size_t handle_mem_size, int num_topk, + size_t dispatch_output_per_expert_alignment); + + // LRU cache: most-recently-used at the front of lru_, evict from the back. + struct HandleEntry { + void* handle_mem; + ncclEpHandle_t handle; + NVTEEpLayerConfig layer_cfg; + size_t handle_mem_size; + }; + + ncclEpGroup_t ep_group_{nullptr}; + ncclComm_t ep_comm_{nullptr}; + NVTEEpGroupConfig group_config_{}; + std::atomic initialized_{false}; + std::mutex mutex_; + std::list lru_; + std::unordered_map::iterator> index_; + size_t handle_cache_cap_{0}; // set lazily from NVTE_EP_HANDLE_CACHE_SIZE + std::optional fallback_layer_cfg_; + + // Caller must hold mutex_. + ncclEpHandle_t prepare_handle_locked(void* handle_mem, NVTEEpLayerConfig layer_cfg); + ncclEpHandle_t lookup_handle_locked(void* handle_mem); + size_t cache_cap_locked(); +}; + +} // namespace ep +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_ diff --git a/transformer_engine/common/include/transformer_engine/comm_window.h b/transformer_engine/common/include/transformer_engine/comm_window.h new file mode 100644 index 0000000000..424c350bbd --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/comm_window.h @@ -0,0 +1,35 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file comm_window.h + * \brief NCCL symmetric-memory window handle for zero-copy ops. Pass + * {NULL, 0} to use the raw-pointer path. + */ + +#ifndef TRANSFORMER_ENGINE_COMM_WINDOW_H_ +#define TRANSFORMER_ENGINE_COMM_WINDOW_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* Forward-declare NCCL's opaque window struct so this header does not pull in + * ; matches NCCL's typedef (struct ncclWindow_vidmem* ncclWindow_t). */ +struct ncclWindow_vidmem; + +/*! \brief NCCL window plus byte offset for a zero-copy payload tensor. */ +typedef struct { + struct ncclWindow_vidmem* window; /*!< NCCL window, or NULL to use the raw data pointer. */ + uint64_t offset; /*!< Byte offset of the payload within window. */ +} NVTECommWindow; + +#ifdef __cplusplus +} +#endif + +#endif // TRANSFORMER_ENGINE_COMM_WINDOW_H_ diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h new file mode 100644 index 0000000000..8928b92825 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/ep.h @@ -0,0 +1,203 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep.h + * \brief Public C API for Expert Parallelism. Per-step ops are + * allocation-free and CUDA graph-capturable. + * + * Per layer: call nvte_ep_handle_mem_size(layer_cfg) for the buffer size; + * allocate handle_mem as a kByte NVTETensor. Per step: nvte_ep_prepare seeds + * routing, then nvte_ep_dispatch / nvte_ep_combine / _bwd consume it. + * Cache cap: NVTE_EP_HANDLE_CACHE_SIZE (default 4096; -1 disables eviction). + */ + +#ifndef TRANSFORMER_ENGINE_EP_H_ +#define TRANSFORMER_ENGINE_EP_H_ + +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* -- Config structs ------------------------------------------------------- */ +/* TODO: add a struct_size/version field to these configs (and align with other + * TE public structs) once a TE-wide convention for ABI versioning lands. */ + +/*! \brief Group-level EP configuration (fixed for the EP group lifetime). */ +typedef struct { + /*! EP world size. */ + int ep_size; + /*! Total experts across all ranks. */ + int num_experts; + /*! Upper bound on tokens this rank sends per dispatch. */ + int max_tokens_per_rank; + /*! Upper bound on tokens this rank receives per dispatch (must be > 0). */ + int max_recv_tokens_per_rank; + /*! Token hidden dimension. */ + int hidden_dim; + /*! Max SMs for EP kernels. 0 = auto. */ + int max_num_sms; + /*! Widest token dtype the group will dispatch; sizes staging buffers. + * Per-dispatch tensors may use any dtype with element size <= this. */ + NVTEDType max_token_dtype; + /*! Zero-copy dispatch/combine. When nonzero, payload tensors must be backed + * by NVTECommWindow handles and transfer in place (no staging copies); + * 0 (default) = staged. */ + int zero_copy; +} NVTEEpGroupConfig; + +/*! \brief Per-layer configuration consumed by nvte_ep_handle_mem_size and + * nvte_ep_prepare. Reserved for future per-call options (fp8 scale, + * overflow policy, ...). + */ +typedef struct { + /*! Per-token expert fan-out (> 0). */ + int top_k; + /*! Per-expert recv-slab alignment in tokens (power of two; 0/1 disables). + * When > 1, each expert's slab in recv_tokens is zero-padded up to a + * multiple of this for downstream per-expert GEMM alignment. */ + size_t dispatch_output_per_expert_alignment; +} NVTEEpLayerConfig; + +/* -- Bootstrap ------------------------------------------------------------ */ + +/*! \brief Bootstrap the EP backend from an existing NCCL EP sub-communicator. + * Requires SM>=90. + * + * ep_comm is borrowed and must span exactly group_config.ep_size ranks. The + * caller retains ownership and must keep it alive until nvte_ep_shutdown() + * returns. Re-init after shutdown is allowed; double-init throws. One EP + * group per process, bound to the current CUDA device. + * + * \param[in] ep_comm Opaque ncclComm_t for the EP sub-group. + * \param[in] group_config Group-level EP configuration. + */ +void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config); + +/*! \brief Tear down the EP backend. Idempotent. Does not destroy ep_comm. */ +void nvte_ep_shutdown(void); + +/* -- Layer sizing (host-only) --------------------------------------------- */ + +/*! \brief Report the handle_mem byte size required for the given layer config. + * + * handle_mem is a per-layer kByte routing-state buffer; allocate once and + * thread the same pointer through every prepare/dispatch/combine/_bwd call + * for that layer (the backend keys its cache on the pointer). Host-only; + * size is stable for a given (group, layer) pair. + * + * \param[in] layer_cfg Per-call layer configuration. + * \return size in bytes for the handle_mem buffer. + */ +size_t nvte_ep_handle_mem_size(NVTEEpLayerConfig layer_cfg); + +/* -- Per-step ops (all allocation-free, CUDA graph-capturable) ------------ */ + +/*! \brief Seed handle_mem with this step's routing plan. + * + * AllGathers topk_idx across the EP group and stages per-expert offsets and + * counts into handle_mem so the matching dispatch/combine/_bwd can run with + * no further routing computation. Must precede every dispatch/combine/_bwd + * that uses this handle_mem. token_counts becomes host-valid after a stream + * sync. + * + * \param[in] handle_mem uint8 routing-state buffer. + * \param[in] topk_idx [T, top_k] int64 routing indices. + * \param[out] token_counts [num_local_experts] int32 counts. + * \param[in] layer_cfg Per-call layer configuration. + * \param[in] stream CUDA stream. + */ +void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor token_counts, + NVTEEpLayerConfig layer_cfg, cudaStream_t stream); + +/*! \brief Dispatch tokens (and routing weights) to expert ranks. + * + * Each local token is sent to its top_k destinations; recv_tokens is laid out + * expert-major (contiguous per-expert slabs, padded per layer_cfg). The + * *_win arguments enable zero-copy via symmem windows; pass NVTECommWindow{} + * when unused. Requires a prior nvte_ep_prepare on this handle_mem. + * + * \param[in] handle_mem uint8 routing-state buffer (from prepare). + * \param[in] topk_idx [T, top_k] int64 sparse routing indices. + * \param[in] tokens [T, hidden_dim] input tokens. + * \param[in] tokens_win Optional symmem window for tokens. + * \param[in] topk_weights [T, top_k] float32 weights, or null in backward. + * \param[in] topk_weights_win Optional symmem window for topk_weights. + * \param[out] recv_tokens [recv_T, hidden_dim] received tokens. + * \param[in] recv_tokens_win Optional symmem window for recv_tokens. + * \param[out] recv_topk_weights [recv_T] float32 per-slot weights, or null in backward. + * \param[in] recv_topk_weights_win Optional symmem window for recv_topk_weights. + * \param[in] stream CUDA stream. + */ +void nvte_ep_dispatch(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor tokens, + NVTECommWindow tokens_win, NVTETensor topk_weights, + NVTECommWindow topk_weights_win, NVTETensor recv_tokens, + NVTECommWindow recv_tokens_win, NVTETensor recv_topk_weights, + NVTECommWindow recv_topk_weights_win, cudaStream_t stream); + +/*! \brief Scatter-sum expert outputs back to originating ranks. + * + * Inverse of dispatch: the top_k destination slots for token t are summed + * into result[t]. Sums are unweighted; pre-scale expert_out by + * recv_topk_weights (and the valid-slot mask) before calling. Requires a + * prior nvte_ep_prepare on this handle_mem. + * + * \param[in] handle_mem uint8 routing-state buffer (from prepare). + * \param[in] expert_out [recv_T, hidden_dim] pre-weighted expert outputs. + * \param[in] expert_out_win Optional symmem window for expert_out. + * \param[out] result [T, hidden_dim] combined output. + * \param[in] stream CUDA stream. + */ +void nvte_ep_combine(NVTETensor handle_mem, NVTETensor expert_out, NVTECommWindow expert_out_win, + NVTETensor result, cudaStream_t stream); + +/*! \brief Backward of dispatch: route per-recv-slot grads back to source. + * + * Sums the top_k recv-slot grads into grad_tokens[t]; scatters per-slot + * recv-weight grads into grad_topk_weights[t, k]. Padded recv slots + * contribute nothing. Requires a prior nvte_ep_prepare on this handle_mem. + * + * \param[in] handle_mem uint8 routing-state buffer (from prepare). + * \param[in] grad [recv_capacity, hidden_dim] grad w.r.t. recv_tokens. + * \param[in] grad_win Optional symmem window for grad. + * \param[in] g_recv_topk_weights [recv_capacity] f32 grad w.r.t. recv_topk_weights. + * \param[in] g_recv_topk_weights_win Optional symmem window for g_recv_topk_weights. + * \param[out] grad_tokens [T, hidden_dim] grad w.r.t. tokens. + * \param[out] grad_topk_weights [T, top_k] f32 grad w.r.t. topk_weights. + * \param[in] stream CUDA stream. + */ +void nvte_ep_dispatch_bwd(NVTETensor handle_mem, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor g_recv_topk_weights, NVTECommWindow g_recv_topk_weights_win, + NVTETensor grad_tokens, NVTETensor grad_topk_weights, + cudaStream_t stream); + +/*! \brief Backward of combine: replicate each source-token grad to its recv + * slots from the forward. + * + * Padded recv slots in grad_expert_out are zeroed. Requires a prior + * nvte_ep_prepare on this handle_mem. + * + * \param[in] handle_mem uint8 routing-state buffer (from prepare). + * \param[in] grad [T, hidden_dim] grad w.r.t. result. + * \param[in] grad_win Optional symmem window for grad. + * \param[out] grad_expert_out [recv_capacity, hidden_dim] grad w.r.t. expert_out. + * \param[in] grad_expert_out_win Optional symmem window for grad_expert_out. + * \param[in] stream CUDA stream. + */ +void nvte_ep_combine_bwd(NVTETensor handle_mem, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor grad_expert_out, NVTECommWindow grad_expert_out_win, + cudaStream_t stream); + +#ifdef __cplusplus +} +#endif + +#endif // TRANSFORMER_ENGINE_EP_H_