feat: add NemotronH hybrid Mamba2-Transformer architecture adapter#1434
feat: add NemotronH hybrid Mamba2-Transformer architecture adapter#1434mukund1985 wants to merge 10 commits into
Conversation
…RoPE) Closes TransformerLensOrg#1400. DeepSeek-V2, V2-Lite, and Coder-V2 all use DeepseekV2ForCausalLM. This adds a bridge adapter covering three V2-specific differences from V3: 1. Complex-exponential RoPE: V2's rotary embedding returns freqs_cis (a complex tensor via torch.polar) rather than a (cos, sin) tuple. - RotaryEmbeddingBridge.forward() now passes complex tensors through without raising, leaving them for the attention bridge to consume. - MLAAttentionBridge.forward() detects complex position_embeddings and dispatches to a new _apply_rotary_complex() helper that mirrors DeepSeek-V2's apply_rotary_emb (view_as_complex, multiply, flatten). 2. Optional Q LoRA path: V2-Lite sets q_lora_rank=None, skipping q_a_proj/q_a_layernorm/q_b_proj and using q_proj directly instead. All three Q-path submodules are marked optional=True in the adapter; q_a_layernorm uses GeneralizedComponent (which already supports optional) rather than RMSNormalizationBridge. MLAAttentionBridge already branches on q_lora_rank at runtime. 3. Gate not hookable: DeepseekV2Moe.forward() routes via nn.functional.linear(..., self.gate.weight) rather than self.gate(hidden_states), so the gate module's forward() is never called and bridge hooks cannot fire. The gate is omitted from MoEBridge submodules; shared_experts uses __call__ and hooks fine. Files changed: - supported_architectures/deepseek_v2.py (new) - supported_architectures/__init__.py: register adapter - factories/architecture_adapter_factory.py: map DeepseekV2ForCausalLM - generalized_components/mla_attention.py: complex RoPE support - generalized_components/rotary_embedding.py: complex tensor pass-through - tests/integration/model_bridge/test_deepseek_v2_adapter.py (new, 17 tests)
Implements TransformerBridge support for NemotronHForCausalLM (nvidia/Nemotron-H-8B-Base, Nemotron-H-47B-A13B). Architecture overview: - Heterogeneous layers defined by config.layers_block_type: each element is one of mamba, attention, moe, or mlp (~8% attention, ~92% SSM/MLP/MoE) - Single pre-norm (block.norm) and single residual path per block; no ln2 - Single .mixer attribute per block whose type varies by layer - No model-level rotary embedding module; attention handles RoPE internally - Stateful generation via DynamicCache (transformers >= 5.12) Key adapter decisions: - SSMBlockBridge as block container: delegates full forward to HF block, avoids ln2 enforcement that BlockBridge would apply incorrectly here - SSM2MixerBridge(name=mixer) as passthrough wrapper: works for all four mixer types since forward calls original_component(*args, **kwargs) - Mamba-specific submodules (in_proj, conv1d, inner_norm, out_proj) marked optional so component_setup skips them gracefully on non-Mamba layers - GatedRMSNormBridge.optional set post-init (its __init__ does not accept the kwarg, unlike the GeneralizedComponent base class) - positional_embedding_type=none: no model-level rotary to wire - gated_mlp=False: MLP layers use relu2, not SwiGLU - applicable_phases=[]: verify_models is transformer-shaped; integration tests cover forward-pass correctness instead Registration: - architecture_adapter_factory.py: NemotronHForCausalLM key added - supported_architectures/__init__.py: export added - tools/model_registry/__init__.py: HF_SUPPORTED_ARCHITECTURES and CANONICAL_AUTHORS_BY_ARCH entries added (canonical author: nvidia) Tests (52 unit tests, all passing): - Config attribute propagation (normalization_type, positional_embedding_type, gated_mlp, is_stateful, final_rms, mamba_intermediate_size, conv_dim, layers_block_type, applicable_phases, weight_processing_conversions) - Top-level component mapping bridge types and HF path names - Block submodule bridge types (norm, mixer; no ln2) - Mixer submodule types, names, and optional flags for all four Mamba keys - create_stateful_cache returns DynamicCache; independent per call - Factory registration and model registry constants - Guard tests: SSMBlockBridge not BlockBridge, no weight conversions, mamba_intermediate_size and conv_dim formulas Closes TransformerLensOrg#1402
|
@jlarson4 PR is open, please review this. |
|
Solid adapter! Great reuse of the existing SSM bridge components, and the hybrid handling (single pre-norm, passthrough mixer for all four layer types, optional Mamba submodules driven by layers_block_type) is a solid solution. Setting applicable_phases=[] is the standard I set with Mamba/Mamba2, eventually we will update the verification system to verify SSM architectures, but that is a future consideration, not something you need to worry about here. I should probably open an issue for that. The one thing I'd like to see before merging: a forward-vs-HF parity test. Right now all the tests are structural, there's no check that the bridge matches HF numerically. Please add a forward-pass logit match against HF on a NemotronH checkpoint, and ideally a short multi-token generation match to exercise the state handling, if possible. If you run into issues, document them in a PR comment and we can consider next steps. |
|
Added the forward-pass and generation parity tests in |
Implements TransformerBridge support for
NemotronHForCausalLM(nvidia/Nemotron-H-8B-Base, Nemotron-H-47B-A13B).Architecture
NemotronH is a hybrid with ~8% attention layers and ~92% Mamba-2 SSM/MoE/MLP layers. Each block has a single pre-norm (
block.norm) and a single.mixerattribute whose type varies by layer (layers_block_typeconfig field). There is noln2or model-level rotary module.Adapter design
SSMBlockBridge is used as the block container — it delegates the full forward to the HF block without enforcing the
ln2thatBlockBridgerequires.SSM2MixerBridge(name="mixer") wraps
.mixerfor all four layer types. Itsforwardis a pure passthrough (original_component(*args, **kwargs)), so it works correctly for attention, MLP, and MoE mixers as well as Mamba ones.Mamba-specific submodules (
in_proj,conv1d,inner_norm,out_proj) are declaredoptional=Truesocomponent_setupskips them gracefully on non-Mamba layers. Note:GatedRMSNormBridge.__init__does not acceptoptional=(unlike theGeneralizedComponentbase), so the attribute is set directly post-construction.Stateful generation uses
DynamicCache(transformers >= 5.12), which carries both KV-cache entries and SSM conv/recurrent states in one object.applicable_phases = []—verify_modelsis transformer-shaped; forward-pass correctness lives in integration tests.Changes
transformer_lens/model_bridge/supported_architectures/nemotron_h.py— new adaptertransformer_lens/model_bridge/supported_architectures/__init__.py— exporttransformer_lens/factories/architecture_adapter_factory.py— registrationtransformer_lens/tools/model_registry/__init__.py— HF_SUPPORTED_ARCHITECTURES + CANONICAL_AUTHORS_BY_ARCHtests/unit/model_bridge/supported_architectures/test_nemotron_h_adapter.py— 52 unit tests, all passingCloses #1402