Generalized Tensor Parallelism (GTP) #3005
Conversation
Greptile SummaryThis PR introduces Generalized Tensor Parallelism (GTP) to TransformerEngine, enabling fine-grained weight sharding with quantize-then-gather collectives, computation-communication overlap, and CUDA-graph-safe in-place buffer management across
Confidence Score: 3/5The PR should not be merged until the NVFP4+EGTP FC1 dgrad regression in grouped_mlp.py is fixed; the fused GroupedMLP backward passes sharded weight parameters directly into dgrad GEMMs when NVFP4 is active, breaking any run combining NVFP4 quantization with the new EGTP path. The fused GroupedMLP backward dispatches on transformer_engine/pytorch/ops/fused/grouped_mlp.py — the FC1 dgrad backward section around lines 2050–2084 needs a GTP all-gather call before the Important Files Changed
Sequence DiagramsequenceDiagram
participant Mcore as Megatron (Mcore)
participant TE_Base as TE base.py
participant TE_Linear as TE Linear/LNLinear/GroupedLinear
participant GTP as GTPShardedParam (Mcore runtime)
participant NCCL as NCCL Collective
Note over Mcore,TE_Base: Import time — hook registration
Mcore->>TE_Base: register_gtp_hooks(slice_fn, finalize_fn, wrap_fn)
Note over TE_Linear,GTP: Module __init__ (gtp_group != None)
TE_Linear->>TE_Base: reset_parameters() → _gtp_slice_fn(param, expert_idx)
TE_Base->>GTP: Create GTPShardedParam (sharded weight)
TE_Linear->>TE_Base: maybe_wrap_gtp() → _gtp_wrap_fn(module, weight_names, gtp_group)
Note over TE_Linear,NCCL: Forward pass
TE_Linear->>GTP: "weight.all_gather_and_prefetch(fwd=True)"
GTP->>NCCL: All-Gather (quantize-then-gather)
NCCL-->>GTP: full weight materialized
GTP-->>TE_Linear: gathered weight tensor
TE_Linear->>TE_Linear: GEMM with gathered weight
Note over TE_Linear,NCCL: Backward pass
TE_Linear->>GTP: saved_weight.all_gather_and_prefetch_bwd()
GTP->>NCCL: All-Gather (columnwise for dgrad)
NCCL-->>GTP: full weight (columnwise)
TE_Linear->>TE_Linear: dgrad GEMM
TE_Linear->>GTP: saved_weight.wgrad_reduce_scatter(wgrad)
GTP->>NCCL: Reduce-Scatter wgrad to sharded main_grad
|
|
/te-ci L1 pytorch |
3e70bdf to
ed9ce68
Compare
Co-authored-by: Jieming Zhang <jiemingz@nvidia.com> Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
| # Fix the interleaved transposed data from gathering along first dim. | ||
| out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) | ||
| out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) | ||
| # In-place .copy_() (not `=` rebind) to keep the storage address stable | ||
| # for CUDA graph capture — replays see the same pointer they captured. | ||
| out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size)) | ||
| out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size)) | ||
|
|
||
| # Optionally pad the scaling inverse if needed. | ||
| out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) | ||
| # Optionally pad the scaling inverse if needed (same in-place pattern). | ||
| out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv)) | ||
|
|
There was a problem hiding this comment.
Shape mismatch in
_post_process_nvfp4_gather breaks any K not a multiple of 128
out._columnwise_scale_inv is allocated by NVFP4Quantizer.make_empty with shape (round_up(K, 128), round_up(ceil(M_total/16), 4)) — the fully-padded shape. The intermediate result from _swap_first_dims(columnwise_scale_inv_interleaved, world_size) has the unpadded shape (K_stripped, world_size * unpadded_dim1), because the gather side strips padding before the NCCL collect. When K is not a multiple of 128 (e.g. K=64 → padded to 128), the dimensions diverge and out._columnwise_scale_inv.copy_(...) raises a RuntimeError at the first all-gather call.
The pre-PR code used = rebinding, which handled arbitrary shapes. Replacing it with .copy_() is only safe when the caller pre-allocates buffers with the correct unpadded intermediate shape — which make_empty does not do. The GTP-prefetched output_tensor path has the same problem on the step-1 copy before the pad_columnwise_scale_inv call can correct things.
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
| @@ -1627,10 +1677,23 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: | |||
| with get_rng_state_tracker().fork(): | |||
| init_fn(param) | |||
|
|
|||
| # GTP slice: shard the freshly-init weight into a GTPShardedParam; | |||
There was a problem hiding this comment.
Wrong
expert_idx for LayerNormLinear (and GroupedLinear with bias) silently disables GTP weight slicing
expert_idx=idx uses the position of the parameter in named_parameters(recurse=False), which includes non-linear-weight parameters. For LayerNormLinear the iteration order is layer_norm_weight (idx=0), layer_norm_bias (idx=1 for non-RMSNorm), weight (idx=2 or 1). The linear weight therefore arrives at _gtp_slice_fn with expert_idx=2 (or 1 for RMSNorm) instead of expert_idx=0. A Mcore hook that maps expert_idx to a pre-registered shard slot would find no entry for idx=2 and return None, silently leaving the weight un-sharded while gtp_group is set — defeating GTP for the entire LayerNormLinear path this PR explicitly adds.
Similarly, for GroupedLinear with biases enabled, weight1 receives expert_idx=2 (interleaved with bias0), so every expert beyond the first is mis-indexed.
A correct counter only advances when gtp_sharded is not None, keeping it aligned with the weight-only registration slots.
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
| def post_process_nvfp4_gather(self) -> None: | ||
| """Fix interleaved transposed data + pad scale_inv after the async AG completes. | ||
|
|
||
| Idempotent: gated by ``_synchronized`` in :meth:`wait`. | ||
| """ | ||
| _post_process_nvfp4_gather( | ||
| self.output, | ||
| self.columnwise_data_interleaved, | ||
| self.columnwise_scale_inv_interleaved, | ||
| self.world_size, | ||
| ) | ||
|
|
||
| def wait(self) -> None: | ||
| """Wait for the async operation to complete and post-process the tensor.""" | ||
| if self._synchronized: | ||
| return | ||
| if self.async_handle is not None: | ||
| self.async_handle.wait() | ||
| self.post_process_nvfp4_gather() | ||
| self._synchronized = True | ||
|
|
There was a problem hiding this comment.
post_process_nvfp4_gather doesn't set _synchronized, enabling double-processing via wait()
post_process_nvfp4_gather() is a newly public method intended for callers using an outer coalescing manager (the grouped=True path). In that flow the GTP runtime is expected to call this method once the outer manager flushes, then later may call wait() for finalization. Because post_process_nvfp4_gather never sets self._synchronized = True, the _synchronized guard in wait() does not fire, and wait() calls post_process_nvfp4_gather() a second time. A double _swap_first_dims(..., world_size) reverts the data back to the interleaved format, silently producing corrupt weights for any forward pass that follows.
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
|
Want your agent to iterate on Greptile's feedback? Try greploops. |
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Deisgn doc: GTP.docx
Description
Core-idea: add Generalized Tensor Parallelism (GTP), which is a flexible fine-grained sharding/just-in time materialization of both activations and parameters with efficient computation-communication overlap.
Mission: improve LLM pretraining efficiency through generalized tensor parallelism, enabling high performance, memory efficiency, ease of use, and strong scalability.
Summary of features
How Mcore interacts with TE
① Mcore registers callbacks into TE at import time.
② TE calls back into Mcore runtime during te.Linear(gtp_group=…) init AND during fwd/bwd (weight.all_gather_and_prefetch / wgrad_reduce_scatter).
③ Mcore extensions forward gtp_group= at module init.
④ TE provides FP8 / MXFP8 / NVFP4 tensor types AND the quantize-then-AG / RS collectives (gather_along_first_dim, reduce_scatter_along_first_dim) — imported by Mcore runtime; GTP wraps them with its own schedule, buffer cache, and stream choreography.
Type of change
Changes
Please list the changes introduced in this PR:
wgrad_shape.
carving (with/without GTP);
Checklist: