Skip to content

Generalized Tensor Parallelism (GTP) #3005

Open
fanshiqing wants to merge 5 commits into
NVIDIA:mainfrom
fanshiqing:gtp_release
Open

Generalized Tensor Parallelism (GTP) #3005
fanshiqing wants to merge 5 commits into
NVIDIA:mainfrom
fanshiqing:gtp_release

Conversation

@fanshiqing

@fanshiqing fanshiqing commented May 18, 2026

Copy link
Copy Markdown
Member

Deisgn doc: GTP.docx

Description

Core-idea: add Generalized Tensor Parallelism (GTP), which is a flexible fine-grained sharding/just-in time materialization of both activations and parameters with efficient computation-communication overlap.

Mission: improve LLM pretraining efficiency through generalized tensor parallelism, enabling high performance, memory efficiency, ease of use, and strong scalability.

Summary of features

  1. Fine-grained materialization & gradient reduction
  • Weight, gradient, and optimizer states are sharded along the GTP group.
  • Weights are temporarily materialized through prefetching in both the forward and backward passes.
  1. Composability with TP / SP / EP / DDP with efficient overlapping of computation and communication
  • GEMM + TP/EP communication + GTP communication + DDP communication.
  1. GTP + partial Cudagraphs with fine-grained synchronization across graphs
  • GTP reduce-scatter overlapping across graphs.
  1. Low-Precision quantize-then-gather
  • MXFP8 / NVFP4
  • Auto-padding/stripping to satisfy low-precision alignment requirements.
  1. Parallel folding for MoE layer
  • Support configuring the GTP size for dense layers and MoE layers separately.
  1. Distributed checkpointing

How Mcore interacts with TE

① Mcore registers callbacks into TE at import time.

② TE calls back into Mcore runtime during te.Linear(gtp_group=…) init AND during fwd/bwd (weight.all_gather_and_prefetch / wgrad_reduce_scatter).

③ Mcore extensions forward gtp_group= at module init.

④ TE provides FP8 / MXFP8 / NVFP4 tensor types AND the quantize-then-AG / RS collectives (gather_along_first_dim, reduce_scatter_along_first_dim) — imported by Mcore runtime; GTP wraps them with its own schedule, buffer cache, and stream choreography.

image

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • transformer_engine/pytorch/module/base.py (+76 / −2)
    • GTP hook registry: register_gtp_hooks(), maybe_wrap_gtp()
  • transformer_engine/pytorch/module/linear.py (+72 / −2)
    • Linear(gtp_group=…) kwarg
    • fwd: optional all_gather_and_prefetch rebind and skip workspace save;
    • bwd: re-gather + wgrad_reduce_scatter + main_grad write-back guard + sharded
      wgrad_shape.
  • transformer_engine/pytorch/module/layernorm_linear.py (+60 / −5)
    • same pattern mirrored for the fused LN+Linear path
  • transformer_engine/pytorch/module/grouped_linear.py (+115 / −16)
    • GroupedLinear(gtp_group=…) + maybe_wrap_gtp(..., is_grouped=True); dual saved-tensor
      carving (with/without GTP);
    • batched_all_gather_and_prefetch + batched_all_gather_and_prefetch_bwd + batched_wgrad_reduce_scatter
  • transformer_engine/pytorch/distributed.py (+142 / −53)
    • in-place .copy_() for amax/scale_inv/data so storage addresses stay stable across CUDA-graph replay.
    • GTP runtime depends on this for prefetch overlap.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • [] I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps

greptile-apps Bot commented May 18, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces Generalized Tensor Parallelism (GTP) to TransformerEngine, enabling fine-grained weight sharding with quantize-then-gather collectives, computation-communication overlap, and CUDA-graph-safe in-place buffer management across Linear, LayerNormLinear, GroupedLinear, and the fused GroupedMLP op.

  • Core hook architecture (base.py): A decoupled callback registry (register_gtp_hooks) lets Megatron register slice_fn/finalize_fn/wrap_fn at import time with no circular TE↔Megatron dependency; maybe_wrap_gtp wires the finalize step at module init.
  • Module integration (linear.py, layernorm_linear.py, grouped_linear.py): Each module gains a gtp_group kwarg; forward all-gathers and materializes the weight, backward re-gathers for dgrad and reduce-scatters the wgrad, with guards preventing double-accumulation into main_grad.
  • Fused GroupedMLP (grouped_mlp.py): FC2 weight preparation is deferred to after the FC1 kernel to allow GTP all-gather overlap; fine-grained per-op activation offload markers added; _compute_grad_params gains GTP-aware wgrad buffer allocation and batched RS dispatch.

Confidence Score: 3/5

The PR should not be merged until the NVFP4+EGTP FC1 dgrad regression in grouped_mlp.py is fixed; the fused GroupedMLP backward passes sharded weight parameters directly into dgrad GEMMs when NVFP4 is active, breaking any run combining NVFP4 quantization with the new EGTP path.

The fused GroupedMLP backward dispatches on use_nvfp4 before issuing the FC1 GTP all-gather: the gather call sits in the MXFP8 else branch but is missing from the if use_nvfp4: branch. Any training run with NVFP4 + EGTP feeds sharded parameter objects to general_gemm/general_grouped_gemm_for_grouped_tensor, which will either crash or silently produce wrong input gradients. The PR description explicitly lists NVFP4 as a first-class supported quantization format for GTP, so this affects an advertised configuration rather than an edge case. Multiple other issues identified in prior review rounds also remain open, compounding the risk.

transformer_engine/pytorch/ops/fused/grouped_mlp.py — the FC1 dgrad backward section around lines 2050–2084 needs a GTP all-gather call before the if use_nvfp4: branch, mirroring the FC2 pattern at line 1826.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/grouped_mlp.py Large EGTP integration for fused GroupedMLP: FC2 weight prep moved after FC1 kernel to enable overlap; fine-grained per-op CPU offload markers added; _compute_grad_params guards wgrad buffer and RS for GTP. FC1 dgrad GTP all-gather is missing from the NVFP4 branch (P1 bug).
transformer_engine/pytorch/distributed.py Adds output_tensor/grouped params to all-gather functions and switches _post_process_nvfp4_gather to in-place .copy_() for CUDA-graph stability; adds post_process_nvfp4_gather() public method and null-safe async_handle check in wait(); adds output kwarg to reduce_scatter_along_first_dim.
transformer_engine/pytorch/module/base.py Adds GTP hook registry (_gtp_slice_fn, _gtp_finalize_fn, _gtp_wrap_fn) and maybe_wrap_gtp(); wires slice/finalize hooks into reset_parameters() using the enumeration index as expert_idx.
transformer_engine/pytorch/module/linear.py Adds gtp_group init kwarg; gtp_size field on LinearFwdArgs/LinearBwdArgs; all-gather in fwd, re-gather + wgrad RS in bwd; GTP guard on wgrad GEMM accumulation path.
transformer_engine/pytorch/module/layernorm_linear.py Mirrors linear.py GTP wiring (all-gather in fwd, re-gather + wgrad RS in bwd, gtp_group kwarg); correctly guards main_grad write-back and wgrad accumulation with gtp_size == 1 checks.
transformer_engine/pytorch/module/grouped_linear.py Adds gtp_group kwarg, weight_names assignment before gtp_group availability check, batched AG/RS hooks in fwd/bwd; dual saved-tensor carving for GTP vs non-GTP paths.

Sequence Diagram

sequenceDiagram
    participant Mcore as Megatron (Mcore)
    participant TE_Base as TE base.py
    participant TE_Linear as TE Linear/LNLinear/GroupedLinear
    participant GTP as GTPShardedParam (Mcore runtime)
    participant NCCL as NCCL Collective

    Note over Mcore,TE_Base: Import time — hook registration
    Mcore->>TE_Base: register_gtp_hooks(slice_fn, finalize_fn, wrap_fn)

    Note over TE_Linear,GTP: Module __init__ (gtp_group != None)
    TE_Linear->>TE_Base: reset_parameters() → _gtp_slice_fn(param, expert_idx)
    TE_Base->>GTP: Create GTPShardedParam (sharded weight)
    TE_Linear->>TE_Base: maybe_wrap_gtp() → _gtp_wrap_fn(module, weight_names, gtp_group)

    Note over TE_Linear,NCCL: Forward pass
    TE_Linear->>GTP: "weight.all_gather_and_prefetch(fwd=True)"
    GTP->>NCCL: All-Gather (quantize-then-gather)
    NCCL-->>GTP: full weight materialized
    GTP-->>TE_Linear: gathered weight tensor
    TE_Linear->>TE_Linear: GEMM with gathered weight

    Note over TE_Linear,NCCL: Backward pass
    TE_Linear->>GTP: saved_weight.all_gather_and_prefetch_bwd()
    GTP->>NCCL: All-Gather (columnwise for dgrad)
    NCCL-->>GTP: full weight (columnwise)
    TE_Linear->>TE_Linear: dgrad GEMM
    TE_Linear->>GTP: saved_weight.wgrad_reduce_scatter(wgrad)
    GTP->>NCCL: Reduce-Scatter wgrad to sharded main_grad
Loading

Comments Outside Diff (1)

  1. transformer_engine/pytorch/ops/fused/grouped_mlp.py, line 2053-2084 (link)

    P1 EGTP FC1 dgrad uses sharded weights when use_nvfp4=True

    The GTP all-gather for FC1 backward (fc1_op.weight0.batched_all_gather_and_prefetch_bwd()) is placed inside the else: branch at line 2107, which only runs for non-NVFP4 (MXFP8) quantization. When use_nvfp4=True, this if use_nvfp4: block runs and grouped_fc1_weight is still the raw list of sharded params recovered from fc1_ctx.saved_tensors (line 1627). Those sharded params are then passed directly to general_gemm (line 2062) or general_grouped_gemm_for_grouped_tensor (line 2079), producing wrong input gradients or a hard crash.

    The FC2 backward correctly calls batched_all_gather_and_prefetch_bwd() before the if fc2_op.single_grouped_weight: dispatch (line 1826), so it works regardless of precision. The same pattern needs to be applied to FC1: move the GTP all-gather for FC1 to before the if use_nvfp4: check (around line 2052), identical to the FC2 fix.

Reviews (12): Last reviewed commit: "[fix] Respect per-op activation-offload ..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py Outdated
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py Outdated
Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py Outdated
Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py Outdated
@fanshiqing

Copy link
Copy Markdown
Member Author

/te-ci L1 pytorch

Comment thread transformer_engine/pytorch/distributed.py
Comment thread transformer_engine/pytorch/csrc/extensions/cast.cpp Outdated
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Co-authored-by: Jieming Zhang <jiemingz@nvidia.com>
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Comment on lines 1287 to 1295
# Fix the interleaved transposed data from gathering along first dim.
out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size)
out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size)
# In-place .copy_() (not `=` rebind) to keep the storage address stable
# for CUDA graph capture — replays see the same pointer they captured.
out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size))
out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size))

# Optionally pad the scaling inverse if needed.
out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv)
# Optionally pad the scaling inverse if needed (same in-place pattern).
out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Shape mismatch in _post_process_nvfp4_gather breaks any K not a multiple of 128

out._columnwise_scale_inv is allocated by NVFP4Quantizer.make_empty with shape (round_up(K, 128), round_up(ceil(M_total/16), 4)) — the fully-padded shape. The intermediate result from _swap_first_dims(columnwise_scale_inv_interleaved, world_size) has the unpadded shape (K_stripped, world_size * unpadded_dim1), because the gather side strips padding before the NCCL collect. When K is not a multiple of 128 (e.g. K=64 → padded to 128), the dimensions diverge and out._columnwise_scale_inv.copy_(...) raises a RuntimeError at the first all-gather call.

The pre-PR code used = rebinding, which handled arbitrary shapes. Replacing it with .copy_() is only safe when the caller pre-allocates buffers with the correct unpadded intermediate shape — which make_empty does not do. The GTP-prefetched output_tensor path has the same problem on the step-1 copy before the pad_columnwise_scale_inv call can correct things.

Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Comment on lines 1660 to +1680
@@ -1627,10 +1677,23 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
with get_rng_state_tracker().fork():
init_fn(param)

# GTP slice: shard the freshly-init weight into a GTPShardedParam;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Wrong expert_idx for LayerNormLinear (and GroupedLinear with bias) silently disables GTP weight slicing

expert_idx=idx uses the position of the parameter in named_parameters(recurse=False), which includes non-linear-weight parameters. For LayerNormLinear the iteration order is layer_norm_weight (idx=0), layer_norm_bias (idx=1 for non-RMSNorm), weight (idx=2 or 1). The linear weight therefore arrives at _gtp_slice_fn with expert_idx=2 (or 1 for RMSNorm) instead of expert_idx=0. A Mcore hook that maps expert_idx to a pre-registered shard slot would find no entry for idx=2 and return None, silently leaving the weight un-sharded while gtp_group is set — defeating GTP for the entire LayerNormLinear path this PR explicitly adds.

Similarly, for GroupedLinear with biases enabled, weight1 receives expert_idx=2 (interleaved with bias0), so every expert beyond the first is mis-indexed.

A correct counter only advances when gtp_sharded is not None, keeping it aligned with the weight-only registration slots.

Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Comment on lines +1316 to 1336
def post_process_nvfp4_gather(self) -> None:
"""Fix interleaved transposed data + pad scale_inv after the async AG completes.

Idempotent: gated by ``_synchronized`` in :meth:`wait`.
"""
_post_process_nvfp4_gather(
self.output,
self.columnwise_data_interleaved,
self.columnwise_scale_inv_interleaved,
self.world_size,
)

def wait(self) -> None:
"""Wait for the async operation to complete and post-process the tensor."""
if self._synchronized:
return
if self.async_handle is not None:
self.async_handle.wait()
self.post_process_nvfp4_gather()
self._synchronized = True

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 post_process_nvfp4_gather doesn't set _synchronized, enabling double-processing via wait()

post_process_nvfp4_gather() is a newly public method intended for callers using an outer coalescing manager (the grouped=True path). In that flow the GTP runtime is expected to call this method once the outer manager flushes, then later may call wait() for finalization. Because post_process_nvfp4_gather never sets self._synchronized = True, the _synchronized guard in wait() does not fire, and wait() calls post_process_nvfp4_gather() a second time. A double _swap_first_dims(..., world_size) reverts the data back to the interleaved format, silently producing corrupt weights for any forward pass that follows.

Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
@greptile-apps

greptile-apps Bot commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

Want your agent to iterate on Greptile's feedback? Try greploops.

Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants