Enable static quantization for Qwen3-0.6B decoder (transformer-only)#836
Enable static quantization for Qwen3-0.6B decoder (transformer-only)#836spalne wants to merge 4 commits into
Conversation
| from .qwen3_modeling import ( | ||
| WinMLQwen3Attention, | ||
| WinMLQwen3DecoderLayer, | ||
| WinMLQwen3MLP, | ||
| WinMLQwen3Model, | ||
| WinMLQwen3RMSNorm, | ||
| ) |
| from .qwen3_export_ops import ( | ||
| GroupQueryAttentionOnnxExport, | ||
| LpNormOnnxExport, | ||
| TransposeConv2d1x1Transpose, | ||
| ) |
|
|
||
| COMPOSITE_MODEL_REGISTRY[("qwen3", "text-generation")] = WinMLQwen3TransformerOnlyModel | ||
|
|
||
| _INSTALLED = True |
|
@spalne please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.
Contributor License AgreementContribution License AgreementThis Contribution License Agreement (“Agreement”) is agreed to by the party signing below (“You”),
|
DingmaomaoBJTU
left a comment
There was a problem hiding this comment.
Summary - structurally sound export, but registration/test/quant integration don't match repo conventions, and w8a16 accuracy regresses.
Nice work getting a fused GQA + LpNorm RMSNorm + 1x1-Conv transformer-only export running end-to-end on QNN, and the export itself is faithful - the FP optimized graph reproduces HF eager's next-token exactly. Three things to address before this is review-ready:
1. Registration is non-standard (highest priority). qwen_transformer_only.install() hot-patches the global registries at runtime and isn't imported by models/hf/__init__.py. Every other model registers declaratively at import time (@register_onnx_overwrite / @register_composite_model, merged in __init__.py). Please make this a first-class variant (distinct task/model_type or a build-config flag) instead of monkey-patching; it also removes the "must call install() before importing the composite machinery" ordering trap and the no-way-back override of the eager path.
2. Test & quant entry points violate repo layout. test_qwen.py and qwen3_transformer_only_quantize.py are standalone scripts at the repo root; test_qwen.py is a subprocess driver that judges success by artifact mtime and uses os._exit(0) to mask a native QNN/ORT teardown crash. Convention (tests/CLAUDE.md) is pytest under tests/. Move the runner to tests/e2e/ (or examples/), and wire the calibration reader into the config-driven quant flow (WinMLBuildConfig.quant) rather than a bespoke quantizer.
3. w8a16 accuracy is not yet acceptable. Measured against the FP graph on the same GSM8K-style input, the quantized model flips the top-1 next token on both prefill and decode (top-5 overlap 0-1/5, KL 0.66/2.75; hidden-state cosine 0.64-0.72), while present-KV stays ~0.999 - i.e. the residual stream is the casualty. Likely minmax + all-zero KV calibration + only 30 samples. Please try percentile/entropy calibration with a realistic non-zero KV feed and report an actual task metric, not just QDQ node count.
Naming and the custom-op export pattern look good and match the codebase.
| @@ -0,0 +1,235 @@ | |||
| """E2E test for the transformer-only Qwen3 export path. | |||
There was a problem hiding this comment.
This is a standalone runner at the repo root that drives the build via subprocess and judges success by "did a fresh artifact file appear". Repo convention (and tests/CLAUDE.md) is pytest under tests/ with code-generated expectations - there are no other root-level test_*.py scripts. Could this move under tests/e2e/ as a real pytest (marked e2e/npu/qnn), or under examples/ if it's really a demo rather than a test? As-is it'll get picked up by name but isn't a pytest, and it lives outside the tree the suite runs from.
| print(f"BUILD COMPLETE: task={task} seq_len={seq_len}", flush=True) | ||
| sys.stdout.flush() | ||
| sys.stderr.flush() | ||
| os._exit(0) |
There was a problem hiding this comment.
os._exit(0) to skip interpreter teardown (because QNN/ORT segfaults on shutdown), combined with judging success purely by artifact mtime, hides a real native crash and makes the subprocess exit code meaningless. If the teardown crash is reproducible it's worth a tracked issue / fix at the EP layer rather than papering over it in the test harness. At minimum this deserves a code comment pointing at a tracking issue, otherwise a genuine build failure that still happens to touch the file would read as success.
| @@ -0,0 +1,230 @@ | |||
| """Transformer-only w8a16 quantization for Qwen3. | |||
There was a problem hiding this comment.
Quantization in winml-cli is normally config-driven through WinMLBuildConfig.quant and runs as part of the build pipeline. This adds a parallel standalone quant entry point at the repo root that reaches into sub_models[*]._onnx_path directly and is "run via test_qwen.py". Could the transformer-only calibration reader be wired into the standard quant flow so it's reachable from winml build / the config instead of a bespoke script? Also minor: Qwen3TransformerOnlyCalibReader structurally satisfies winml.modelkit.quant.config.CalibrationDataReader but doesn't declare it - worth importing/typing against the protocol so it stays in sync.
| samples=num_samples, | ||
| weight_type=weight_type, # type: ignore[arg-type] | ||
| activation_type=activation_type, # type: ignore[arg-type] | ||
| calibration_method="minmax", |
There was a problem hiding this comment.
Accuracy concern worth resolving before this lands. I ran the produced w8a16 graphs against the FP optimized graphs on the same GSM8K-style input (ORT CPU EP): the FP export matches HF eager exactly (top-1 next token identical), but the w8a16 output flips the top-1 token on both prefill and decode - top-5 overlap 0-1/5, KL(FP||quant) 0.66 / 2.75, output_hidden_states cosine 0.64-0.72. The present-KV path is ~0.999, so the damage is concentrated in the residual stream.
Likely causes: minmax calibration over a residual stream with large outliers (+/-76), calibrating with an all-zero KV cache, and only 30 samples. Suggest trying calibration_method="percentile" (or entropy), feeding a realistic non-zero KV during calibration, and reporting an actual task metric (e.g. GSM8K logits/top-1 agreement) so we can see the quant is acceptable, not just that QDQ nodes were inserted.
DingmaomaoBJTU
left a comment
There was a problem hiding this comment.
Code Review — PR #836 (Draft)
Well-structured PR. The transformer-only export topology (fused GQA, LpNorm RMSNorm, 1x1 Conv), GSM8K calibration pipeline, and model_type override mechanism are solid. A few correctness bugs and infrastructure concerns should be resolved before marking ready for merge.
Not approving since this is a draft PR.
| if torch.any(self.weight.data != torch.ones_like(self.weight)).item(): | ||
| new_w = scale * self.weight | ||
| else: | ||
| new_w = scale |
There was a problem hiding this comment.
Bug: RMSNorm weight shape mismatch when weights are all 1.0
When self.weight is all ones (the default init), new_w = scale produces a [1]-shaped tensor, not [hidden_size]. The ONNX initializer exports with shape [1] instead of [hidden_size], which broadcasts silently in PyTorch but may cause shape errors in downstream ONNX tooling.
The branch is also logically redundant (scale * ones == scale). Simplify to:
self.weight = nn.Parameter(scale * self.weight)|
|
||
| Run:: | ||
|
|
||
| python test_qwen_transformer_only.py |
There was a problem hiding this comment.
Bug: Wrong filename in docstring — says python test_qwen_transformer_only.py but the file is test_qwen.py.
|
|
||
| @staticmethod | ||
| def forward(ctx, input, axis, p): # noqa: ARG004 | ||
| return input # placeholder — real compute happens in symbolic |
There was a problem hiding this comment.
Warning: Eager-mode forward returns incorrect (un-normalized) results
LpNormOnnxExport.forward returns input unchanged (identity). This is only correct during ONNX tracing where symbolic runs instead. Any eager execution (unit tests, calibration debug runs) silently gets un-normalized values. Consider computing the real norm for eager mode or raising NotImplementedError to make misuse obvious.
| kv_num_heads, | ||
| num_heads, | ||
| ): # noqa: ARG004 | ||
| return query, past_key, past_value # placeholder shapes |
There was a problem hiding this comment.
Warning: Stale KV cache in eager mode
GroupQueryAttentionOnnxExport.forward returns (query, past_key, past_value) — the present_keys/present_values are the old un-updated tensors. Eager execution silently produces a KV cache that never advances. A NotImplementedError here would be safer than a silently-wrong placeholder.
| # Identify Qwen3 submodules by their (stock HF) class name so we don't | ||
| # depend on importing ``transformers.models.qwen3`` here. | ||
| def _is(module: nn.Module, name: str) -> bool: | ||
| return type(module).__name__ == name |
There was a problem hiding this comment.
Warning: Fragile class-name string matching
type(module).__name__ == name breaks silently if HuggingFace renames a Qwen3 module in a future release — the forward won't be bound and the export will be silently broken. Consider adding a post-patch assertion that the expected number of attention/MLP/RMSNorm modules were patched.
| print("\n=== Loading HF embed_tokens for calibration ===") | ||
| hf_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32) | ||
| hf_model.eval() | ||
| embed_tokens = hf_model.get_input_embeddings() |
There was a problem hiding this comment.
Warning: HF model not freed after calibration data is pre-built
hf_model stays live for the entire quantization loop. Since Qwen3TransformerOnlyCalibReader.__init__ materializes all samples in self._samples, the model weights are no longer needed after reader construction. Add del hf_model; gc.collect() before the quantization loop to free the model memory.
|
|
||
| seq_len = seq_by_sub[sub_name] | ||
| quant_path = fused_path.with_name( | ||
| fused_path.stem + f"_w{weight_type[-1]}a{activation_type[-2:]}.quant.onnx" |
There was a problem hiding this comment.
Warning: Brittle string slicing for output filename
weight_type[-1] and activation_type[-2:] work for int8/uint16 but produce wrong suffixes for other valid types (e.g. activation_type='uint4' -> _t4). Use an explicit dict mapping or strip the numeric suffix with a regex/lstrip instead.
| print(f"\n########## BUILD {name} (task={task}, seq_len={seq_len}) ##########", flush=True) | ||
| before = _latest_ctx_mtime(prefix) | ||
| start = _time.time() | ||
| rc = subprocess.run( |
There was a problem hiding this comment.
Warning: subprocess.run() has no timeout
If the QNN/ORT build stalls, this blocks indefinitely. Add timeout=1800 (or similar) and catch subprocess.TimeoutExpired to surface a clear CI failure.
| print(f"BUILD COMPLETE: task={task} seq_len={seq_len}", flush=True) | ||
| sys.stdout.flush() | ||
| sys.stderr.flush() | ||
| os._exit(0) |
There was a problem hiding this comment.
Info: os._exit(0) as a crash workaround
Documented, but worth tracking as an upstream ORT/QNN issue. Consider adding a TODO comment with a link to a filed issue so this workaround isn't forgotten.
| @@ -0,0 +1,229 @@ | |||
| """E2E test for the transformer-only Qwen3 export path. | |||
There was a problem hiding this comment.
Info: Not a pytest test
Despite the test_ prefix, this file uses __main__, sys.path mutations at import time, subprocess orchestration, and os._exit. It lives in the repo root and won't be collected by uv run pytest tests/. Consider renaming to scripts/run_qwen3_quant.py to avoid accidental pytest collection, or convert to a proper pytest integration test with hardware skip markers.
Adds a transformer-only ONNX export path for Qwen3 that emits a fused (GQA) GroupQueryAttention op (with built-in rotary), LpNormalization RMSNorm, and 1×1 Conv projections, backed by an FP16 KV cache. The path is opt-in via install(), which hot-patches the build registries to produce two graphs (prefill seq=64, decode seq=1) without embeddings or lm_head. Quantization runs w8a16 static PTQ on these graphs using GSM8K calibration
Results
Produces two transformer-only ONNX files (prefill + decode) plus their w8a16-quantized variants.