Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "3rdparty/nccl"]
path = 3rdparty/nccl
url = https://github.com/NVIDIA/nccl.git
1 change: 1 addition & 0 deletions 3rdparty/nccl
Submodule nccl added at 808d24
46 changes: 41 additions & 5 deletions qa/L1_cpp_distributed/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,52 @@
#
# See LICENSE for license information.

set -e
function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}

RET=0
FAILED_CASES=""

# Find TE
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"

TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}')
export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH

if [[ $(nvidia-smi --list-gpus | wc -l) -ge 4 ]]; then
cd $TE_PATH/tests/cpp_distributed
cmake -GNinja -S. -Bbuild
cmake --build build
mpirun --allow-run-as-root --np 4 --oversubscribe ./build/test_comm_gemm
cd $TE_PATH/tests/cpp_distributed
configure_ok=1
cmake -GNinja -S. -Bbuild || { test_fail "configure"; configure_ok=0; }

# Build each suite independently so one suite's build failure does not
# mask the other's results. Skip mpirun when the binary is missing.
if [[ $configure_ok -eq 1 ]]; then
comm_gemm_ok=1
ep_ok=1
cmake --build build --target test_comm_gemm || { test_fail "test_comm_gemm_build"; comm_gemm_ok=0; }
cmake --build build --target test_ep || { test_fail "test_ep_build"; ep_ok=0; }

if [[ $comm_gemm_ok -eq 1 ]]; then
# Per-rank XML to avoid a write race on a shared path.
mpirun --allow-run-as-root --np 4 --oversubscribe bash -c \
"exec ./build/test_comm_gemm --gtest_output=xml:$XML_LOG_DIR/cpp_distributed_test_comm_gemm.rank\${OMPI_COMM_WORLD_RANK}.xml" \
|| test_fail "test_comm_gemm"
fi

if [[ $ep_ok -eq 1 ]]; then
# EP suites; runner self-skips on pre-Hopper GPUs.
GTEST_XML_PREFIX="$XML_LOG_DIR/cpp_distributed_test_ep" \
bash ./run_test_ep.sh 4 ./build || test_fail "test_ep"
fi
fi
fi

if [ "$RET" -ne 0 ]; then
echo "FAILED sub-tests:$FAILED_CASES"
fi
exit $RET
157 changes: 157 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,31 @@ def setup_common_extension() -> CMakeExtension:
cusolvermp_dir = os.getenv("CUSOLVERMP_HOME", "/usr")
cmake_flags.append(f"-DCUSOLVERMP_DIR={cusolvermp_dir}")

# NCCL EP (Hopper+): on by default; auto-skipped when no arch >= 90 is
# targeted. Set NVTE_WITH_NCCL_EP=0 to force off.
nccl_ep_env = os.getenv("NVTE_WITH_NCCL_EP")
nccl_ep_explicit = nccl_ep_env is not None
build_with_nccl_ep = bool(int(nccl_ep_env if nccl_ep_explicit else "1"))
if build_with_nccl_ep:
arch_tokens = [a.strip() for a in str(archs or "").split(";") if a.strip()]
has_hopper_or_newer = any(
t.lower() == "native" or (t.rstrip("af").isdigit() and int(t.rstrip("af")) >= 90)
for t in arch_tokens
)
if not has_hopper_or_newer:
if nccl_ep_explicit:
raise RuntimeError(
f"NVTE_WITH_NCCL_EP=1 was set but NVTE_CUDA_ARCHS ('{archs}') "
"contains no arch >= 90. NCCL EP requires Hopper or newer."
)
print(f"[NCCL EP] No arch >= 90 in NVTE_CUDA_ARCHS ('{archs}'); skipping build.")
build_with_nccl_ep = False
if build_with_nccl_ep:
nccl_home = build_nccl_ep_submodule()
cmake_flags.append(f"-DNCCL_INCLUDE_DIR={nccl_home}/include")
else:
cmake_flags.append("-DNVTE_WITH_NCCL_EP=OFF")

# Add custom CMake arguments from environment variable
nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
if nvte_cmake_extra_args:
Expand Down Expand Up @@ -130,6 +155,138 @@ def setup_requirements() -> Tuple[List[str], List[str]]:
return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]]


def _discover_nccl_home() -> str:
"""Resolve NCCL_HOME: honor env var, else probe well-known prefixes, else ldconfig."""
env_home = os.environ.get("NCCL_HOME")
if env_home:
if (Path(env_home) / "include" / "nccl.h").exists():
return env_home
print(
f"[NCCL EP] WARNING: NCCL_HOME='{env_home}' is set but "
f"'{env_home}/include/nccl.h' was not found; falling back to system probes."
)

lib_names = ("libnccl.so", "libnccl.so.2")
# Include Debian/Ubuntu multiarch subdirs (e.g. lib/aarch64-linux-gnu).
lib_subdirs = ("lib", "lib64", "lib/aarch64-linux-gnu", "lib/x86_64-linux-gnu")

# pip-installed NCCL (nvidia-nccl-cu* wheel) lives under nvidia/nccl in
# site-packages and has no top-level include/lib layout.
try:
import importlib.util

spec = importlib.util.find_spec("nvidia.nccl")
if spec is not None and spec.submodule_search_locations:
pip_root = Path(next(iter(spec.submodule_search_locations)))
if (pip_root / "include" / "nccl.h").exists() and any(
(pip_root / sub / name).exists() for sub in lib_subdirs for name in lib_names
):
return str(pip_root)
except (ImportError, ValueError):
pass

for cand in ("/opt/nvidia/nccl", "/usr/local/nccl", "/usr"):
p = Path(cand)
if (p / "include" / "nccl.h").exists() and any(
(p / sub / name).exists() for sub in lib_subdirs for name in lib_names
):
return str(p)

try:
out = subprocess.check_output(["ldconfig", "-p"], stderr=subprocess.DEVNULL).decode()
for line in out.splitlines():
if "libnccl.so" in line and "=>" in line:
lib_path = Path(line.split("=>")[-1].strip())
# Walk upward so multiarch layouts (.../lib/<triplet>/libnccl.so)
# resolve to the prefix that contains include/nccl.h.
for root in (lib_path.parent.parent, lib_path.parent.parent.parent):
if (root / "include" / "nccl.h").exists():
return str(root)
except (subprocess.CalledProcessError, FileNotFoundError):
pass

raise RuntimeError(
"Could not locate NCCL core (nccl.h + libnccl.so). Set NCCL_HOME to the install prefix."
)


def build_nccl_ep_submodule() -> str:
"""Build libnccl_ep.a from the 3rdparty/nccl submodule and return NCCL_HOME."""
nccl_root = current_file_path / "3rdparty" / "nccl"
if not (nccl_root / "Makefile").exists():
raise RuntimeError(
f"NCCL submodule not found at {nccl_root}. "
"Run `git submodule update --init --recursive`."
)

build_dir = nccl_root / "build"
nccl_ep_lib = build_dir / "lib" / "libnccl_ep.a"
gencode_stamp = build_dir / "lib" / "libnccl_ep.gencode"

# Caller gates on arch >= 90 or "native"; expand "native" to the host's
# actual sm_XX so the build stamp distinguishes machines.
arch_tokens = [a.strip() for a in str(cuda_archs() or "").split(";") if a.strip()]
arch_list: list[str] = []
for t in arch_tokens:
if t.lower() == "native":
try:
out = subprocess.check_output(
["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"],
stderr=subprocess.DEVNULL,
).decode()
except (subprocess.CalledProcessError, FileNotFoundError) as e:
raise RuntimeError(
"NVTE_CUDA_ARCHS=native requires nvidia-smi to resolve the host arch."
) from e
for line in out.splitlines():
cap = line.strip().replace(".", "")
if cap.isdigit() and int(cap) >= 90 and cap not in arch_list:
arch_list.append(cap)
else:
bare = t.rstrip("af")
if bare.isdigit() and int(bare) >= 90 and bare not in arch_list:
arch_list.append(bare)
Comment on lines +229 to +248

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Stamp-file increment is not atomic: parallel builds can race

prev_gencode is checked, the make build is launched, and the stamp is written in three distinct steps with no locking between them. In a distributed training environment where each rank's setup script (or a pip install in a parallel job) invokes build_nccl_ep_submodule() simultaneously, two processes can both pass the stale-stamp check, both launch make -j, and one can overwrite the stamp while the other is still building, leaving a mismatched stamp if the builds produce different outputs. A file lock (e.g., fcntl.flock around the whole check-build-stamp block) or a build-level lock would make this robust.

if not arch_list:
raise RuntimeError(
"NCCL EP requires Hopper or newer (SM >= 90); none found in"
f" NVTE_CUDA_ARCHS={cuda_archs()!r}. Re-run with NVTE_WITH_NCCL_EP=0 to skip the NCCL"
" EP build (the rest of TE still builds)."
)
gencode = " ".join(f"-gencode=arch=compute_{a},code=sm_{a}" for a in arch_list)

nproc = os.cpu_count() or 8
env = os.environ.copy()
env["NVCC_GENCODE"] = gencode
# NCCL EP needs the core NCCL headers + libnccl.so; write NCCL EP build
# outputs to the submodule's local build/ tree.
nccl_home = _discover_nccl_home()
env["NCCL_HOME"] = nccl_home
env["NCCL_EP_BUILDDIR"] = str(build_dir)

prev_gencode = gencode_stamp.read_text().strip() if gencode_stamp.exists() else None
if not nccl_ep_lib.exists() or prev_gencode != gencode:
if nccl_ep_lib.exists() and prev_gencode != gencode:
print(
f"[NCCL EP] gencode changed ('{prev_gencode}' -> '{gencode}'); "
"rebuilding libnccl_ep.a"
)
subprocess.check_call(
["make", "-C", "contrib/nccl_ep", "clean"],
cwd=str(nccl_root),
env=env,
)
print(f"[NCCL EP] Building libnccl_ep.a (gencode='{gencode}')")
subprocess.check_call(
["make", "-j", str(nproc), "-C", "contrib/nccl_ep", "lib"],
cwd=str(nccl_root),
env=env,
)
gencode_stamp.parent.mkdir(parents=True, exist_ok=True)
gencode_stamp.write_text(gencode)

return nccl_home


def git_check_submodules() -> None:
"""
Attempt to checkout git submodules automatically during setup.
Expand Down
58 changes: 57 additions & 1 deletion tests/cpp_distributed/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,32 @@ target_include_directories(test_comm_gemm PRIVATE ${test_comm_gemm_INCLUDES})
find_package(CUDAToolkit REQUIRED)
find_package(OpenMP REQUIRED)
find_package(MPI REQUIRED)

# -- NCCL core ----------------------------------------------------------------
# Anchor on libnccl and derive nccl.h from the same install prefix so the
# header and library can't drift across installs.
find_library(NCCL_LIB
NAMES nccl libnccl
PATH_SUFFIXES lib
HINTS /opt/nvidia/nccl/lib /opt/nvidia/nccl/lib64
/usr/local/nccl/lib /usr/local/nccl/lib64
PATH_SUFFIXES lib lib64
REQUIRED)
get_filename_component(_nccl_lib_dir "${NCCL_LIB}" DIRECTORY)
set(NCCL_PREFIX "${_nccl_lib_dir}")
while(NCCL_PREFIX AND NOT EXISTS "${NCCL_PREFIX}/include/nccl.h")
get_filename_component(_nccl_parent "${NCCL_PREFIX}" DIRECTORY)
if(_nccl_parent STREQUAL NCCL_PREFIX)
break()
endif()
set(NCCL_PREFIX "${_nccl_parent}")
endwhile()
find_path(NCCL_INCLUDE_DIR nccl.h
HINTS "${NCCL_PREFIX}/include"
NO_DEFAULT_PATH)
if(NOT NCCL_INCLUDE_DIR)
message(FATAL_ERROR
"nccl.h not found under the prefix of ${NCCL_LIB}.")
endif()
list(APPEND test_comm_gemm_LINKER_LIBS
CUDA::cuda_driver
CUDA::cudart
Expand All @@ -74,3 +96,37 @@ target_compile_options(test_comm_gemm PRIVATE -O2 -fopenmp)

include(GoogleTest)
gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600)

# -- EP distributed tests ------------------------------------------------------
# Launched via mpirun; ncclUniqueId exchange uses MPI_Bcast (see test_ep_common.h).
# The test binary only uses NCCL core symbols (ncclMemAlloc, ncclCommWindow*);
# all ncclEp* calls live behind TE's public <transformer_engine/ep.h>, which is
# statically linked into libtransformer_engine.so.
message(STATUS "EP test: NCCL headers: ${NCCL_INCLUDE_DIR}")
set(EP_TEST_COMMON_INCLUDES
${NCCL_INCLUDE_DIR}
${MPI_CXX_INCLUDE_PATH}
../../transformer_engine/common/include
../../transformer_engine/common
${CMAKE_CURRENT_SOURCE_DIR})

# nvrtc must follow TE_LIB so symbols referenced from libtransformer_engine.so
# (loaded via dlopen in Python; not in its DT_NEEDED) resolve through nvrtc.
set(EP_TEST_COMMON_LIBS
CUDA::cuda_driver
CUDA::cudart
GTest::gtest
${TE_LIB}
CUDA::nvrtc
${NCCL_LIB}
MPI::MPI_CXX
OpenMP::OpenMP_CXX)

# -- EP distributed tests (per-op + full pipeline + zero-copy symm) -----------
add_executable(test_ep test_ep.cu ../cpp/test_common.cu)
target_include_directories(test_ep PRIVATE ${EP_TEST_COMMON_INCLUDES})
target_link_libraries(test_ep PUBLIC ${EP_TEST_COMMON_LIBS})

# Do NOT use gtest_discover_tests - these binaries require multi-process
# launch via run_test_ep.sh, not direct single-process execution.
message(STATUS "EP distributed tests enabled (NCCL EP statically linked into libtransformer_engine.so)")
63 changes: 63 additions & 0 deletions tests/cpp_distributed/run_test_ep.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env bash
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
#
# Run TE EP distributed unit tests via mpirun. Each MPI rank pins to one GPU
# (rank % device_count) and exchanges ncclUniqueId through MPI_Bcast.
#
# Usage:
# bash run_test_ep.sh [num_gpus] [build_dir]
#
# Defaults:
# num_gpus = number of GPUs visible to nvidia-smi
# build_dir = <script_dir>/build
#
# Environment variables:
# GTEST_FILTER - forwarded to all processes (e.g., "EPPipelineTest.*")
# GTEST_XML_PREFIX - if set, each rank writes JUnit XML to
# ${GTEST_XML_PREFIX}.rank<N>.xml
# MPIRUN - override the mpirun binary (default: mpirun)
# MPIRUN_EXTRA - extra flags forwarded to mpirun

set -euo pipefail

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
BUILD_DIR="${2:-${SCRIPT_DIR}/build}"
NUM_GPUS="${1:-$(nvidia-smi -L 2>/dev/null | wc -l)}"
MPIRUN="${MPIRUN:-mpirun}"

# Skip cleanly on pre-Hopper: NCCL EP requires SM>=90.
MIN_SM=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \
| awk -F. 'NR==1 || ($1*10+$2)<min { min=$1*10+$2 } END { print min+0 }')
if (( MIN_SM > 0 && MIN_SM < 90 )); then
echo "NCCL EP requires SM>=90 (lowest visible GPU is SM${MIN_SM}); SKIPPING."
exit 0
fi

TEST_BIN="${BUILD_DIR}/test_ep"
if [[ ! -x "${TEST_BIN}" ]]; then
echo "ERROR: binary not found: ${TEST_BIN}"
echo "Build: cd ${SCRIPT_DIR} && mkdir -p build && cd build && cmake .. && make"
exit 1
fi

if (( NUM_GPUS < 2 )); then
echo "EP Tests: requires at least 2 GPUs, found ${NUM_GPUS}. Skipping."
exit 0
fi

GTEST_ARGS="${GTEST_FILTER:+--gtest_filter=${GTEST_FILTER}}"

echo "=== EP Tests ==="
echo " GPUs: ${NUM_GPUS} Binary: ${TEST_BIN}"
echo

if [[ -n "${GTEST_XML_PREFIX:-}" ]]; then
# bash -c so OMPI_COMM_WORLD_RANK expands per-rank, avoiding a write race
# on a single shared output path.
"${MPIRUN}" --allow-run-as-root --oversubscribe -n "${NUM_GPUS}" ${MPIRUN_EXTRA:-} bash -c \
"exec '${TEST_BIN}' ${GTEST_ARGS} --gtest_output=xml:${GTEST_XML_PREFIX}.rank\${OMPI_COMM_WORLD_RANK}.xml"
else
"${MPIRUN}" --allow-run-as-root --oversubscribe -n "${NUM_GPUS}" ${MPIRUN_EXTRA:-} "${TEST_BIN}" ${GTEST_ARGS}
fi
Loading
Loading