diff --git a/problems/linalg/qr_py/reference.py b/problems/linalg/qr_py/reference.py index fc8ace77..cc70f83f 100644 --- a/problems/linalg/qr_py/reference.py +++ b/problems/linalg/qr_py/reference.py @@ -20,18 +20,21 @@ def _band_mask(n: int, bandwidth: int, device: torch.device) -> torch.Tensor: return (idx[:, None] - idx[None, :]).abs() <= bandwidth -def generate_input(batch: int, n: int, cond: int, seed: int, case: str = "dense") -> input_t: - assert batch > 0, "batch must be positive" - assert n > 0, "n must be positive" - assert cond >= 0, "cond must be non-negative" - - device = "cuda" if torch.cuda.is_available() else "cpu" - gen = torch.Generator(device=device) - gen.manual_seed(seed) +_MIXED_PROFILES = ( + "dense", + "rankdef", + "nearrank", + "clustered", + "band", + "rowscale", + "nearcollinear", +) +_MIXED_WEIGHTS = (6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0) - case = case.lower() - a = torch.randn((batch, n, n), device=device, dtype=torch.float32, generator=gen) +def _apply_case(a: torch.Tensor, case: str, cond: int, gen: torch.Generator) -> torch.Tensor: + batch, n = a.shape[0], a.shape[-1] + device = a.device if case == "dense": a = _apply_column_scaling(a, cond) elif case == "upper": @@ -83,6 +86,48 @@ def generate_input(batch: int, n: int, cond: int, seed: int, case: str = "dense" a = scales.reshape(1, n, 1) * a else: raise ValueError(f"unknown QR test case: {case}") + return a + + +def _generate_mixed(a: torch.Tensor, cond: int, gen: torch.Generator) -> torch.Tensor: + batch = a.shape[0] + device = a.device + weights = torch.tensor(_MIXED_WEIGHTS, dtype=torch.float32, device=device) + labels = torch.multinomial(weights, batch, replacement=True, generator=gen) + + if batch >= 2: + is_dense = labels == 0 + if not bool(is_dense.any()): + labels[int(torch.randint(0, batch, (1,), device=device, generator=gen))] = 0 + elif bool(is_dense.all()): + pos = int(torch.randint(0, batch, (1,), device=device, generator=gen)) + labels[pos] = int( + torch.randint(1, len(_MIXED_PROFILES), (1,), device=device, generator=gen) + ) + + for idx, profile in enumerate(_MIXED_PROFILES): + mask = labels == idx + if bool(mask.any()): + a[mask] = _apply_case(a[mask], profile, cond, gen) + return a + + +def generate_input(batch: int, n: int, cond: int, seed: int, case: str = "dense") -> input_t: + assert batch > 0, "batch must be positive" + assert n > 0, "n must be positive" + assert cond >= 0, "cond must be non-negative" + + device = "cuda" if torch.cuda.is_available() else "cpu" + gen = torch.Generator(device=device) + gen.manual_seed(seed) + + case = case.lower() + a = torch.randn((batch, n, n), device=device, dtype=torch.float32, generator=gen) + + if case == "mixed": + a = _generate_mixed(a, cond, gen) + else: + a = _apply_case(a, case, cond, gen) return a.contiguous() @@ -143,27 +188,44 @@ def check_implementation(data: input_t, output: output_t) -> tuple[bool, str]: q = torch.linalg.householder_product(h, tau) r = torch.triu(h) + if not torch.isfinite(q).all().item(): + return False, "Q materialized from `(H, tau)` contains NaN or Inf" + if not torch.isfinite(r).all().item(): + return False, "R extracted from `triu(H)` contains NaN or Inf" + a_check = a.double() q_check = q.double() r_check = r.double() projected = q_check.transpose(-1, -2) @ a_check - factor_residual = _matrix_l1_norm(r_check - projected).amax() - factor_scale = _matrix_l1_norm(a_check).amax() + if not torch.isfinite(projected).all().item(): + return False, "Q.T @ A contains NaN or Inf" + + factor_residual = _matrix_l1_norm(r_check - projected) + factor_scale = _matrix_l1_norm(a_check) factor_allowed = factor_rtol * factor_scale factor_scaled = _scaled_residual(factor_residual, factor_scale, n) - if factor_residual.item() > factor_allowed.item(): + if not torch.isfinite(factor_scaled).all().item(): + return False, "R - Q.T @ A residual produced NaN or Inf" + factor_failed = factor_residual > factor_allowed + if bool(factor_failed.any().item()): + worst = int(factor_scaled.argmax().item()) return False, ( "R - Q.T @ A is too large: " - f"residual={factor_residual.item():.3g}, allowed={factor_allowed.item():.3g}, " - f"scaled={factor_scaled.item():.3g}" + f"matrix={worst}, residual={factor_residual[worst].item():.3g}, " + f"allowed={factor_allowed[worst].item():.3g}, " + f"scaled={factor_scaled[worst].item():.3g}" ) eye = torch.eye(n, device=a.device, dtype=torch.float64).expand(batch, n, n) qtq = q_check.transpose(-1, -2) @ q_check + if not torch.isfinite(qtq).all().item(): + return False, "Q.T @ Q contains NaN or Inf" orth_residual = _matrix_l1_norm(qtq - eye).amax() orth_scale = _matrix_l1_norm(eye).amax() orth_allowed = orth_rtol * orth_scale orth_scaled = _scaled_residual(orth_residual, orth_scale, n) + if not torch.isfinite(orth_scaled).all().item(): + return False, "Q.T @ Q residual produced NaN or Inf" if orth_residual.item() > orth_allowed.item(): return False, ( "Q is not orthogonal enough: " @@ -177,6 +239,8 @@ def check_implementation(data: input_t, output: output_t) -> tuple[bool, str]: tri_scaled = _scaled_residual(tri_residual, tri_scale, n) recon = q_check @ r_check + if not torch.isfinite(recon).all().item(): + return False, "Q @ R contains NaN or Inf" recon_residual = _matrix_l1_norm(recon - a_check).amax() recon_scale = _matrix_l1_norm(a_check).amax() recon_scaled = _scaled_residual(recon_residual, recon_scale, n) @@ -184,7 +248,7 @@ def check_implementation(data: input_t, output: output_t) -> tuple[bool, str]: return True, ( f"factor_rtol={factor_rtol:.3g}; " f"orth_rtol={orth_rtol:.3g}; " - f"scaled_factor_residual={factor_scaled.item():.3g}; " + f"scaled_factor_residual={factor_scaled.amax().item():.3g}; " f"scaled_reconstruction_residual={recon_scaled.item():.3g}; " f"scaled_triangular_residual={tri_scaled.item():.3g}; " f"scaled_orthogonality_residual={orth_scaled.item():.3g}; " diff --git a/problems/linalg/qr_py/task.yml b/problems/linalg/qr_py/task.yml index 8e935eba..c1868dfe 100644 --- a/problems/linalg/qr_py/task.yml +++ b/problems/linalg/qr_py/task.yml @@ -39,6 +39,14 @@ description: | structure, such as rank-deficient, near-rank-deficient, banded, row-scaled, near-collinear, upper-triangular, or clustered-scale inputs. + The `mixed` case builds a heterogeneous batch: each matrix is independently + assigned a conditioning profile at a random seeded position in the batch. This + mirrors the optimizer-statistics regime, where factors batched into one call + can have very different conditioning, rather than all sharing one structure. + Benchmarks include both `mixed` batches and fully ill-conditioned homogeneous + batches, so conditioning robustness is ranked, not only gated. Each matrix in + the batch must be factored correctly on its own merits. + Correctness is a hard gate against the original FP32 input and the FP32 `torch.geqrf` compact-factor contract. Low-bit FP16, FP8, or NVFP4 work is allowed only as an internal implementation strategy: returned factors must @@ -89,6 +97,9 @@ tests: - {"batch": 2, "n": 2048, "cond": 2, "seed": 224466, "case": "dense"} - {"batch": 2, "n": 2048, "cond": 0, "seed": 224467, "case": "rankdef"} - {"batch": 1, "n": 4096, "cond": 0, "seed": 75343, "case": "upper"} + - {"batch": 16, "n": 512, "cond": 2, "seed": 32530, "case": "mixed"} + - {"batch": 4, "n": 1024, "cond": 2, "seed": 4332, "case": "mixed"} + - {"batch": 2, "n": 2048, "cond": 2, "seed": 224468, "case": "mixed"} benchmarks: - {"batch": 20, "n": 32, "cond": 1, "seed": 43214} @@ -98,3 +109,8 @@ benchmarks: - {"batch": 60, "n": 1024, "cond": 2, "seed": 75342} - {"batch": 8, "n": 2048, "cond": 1, "seed": 224466} - {"batch": 2, "n": 4096, "cond": 1, "seed": 32412} + - {"batch": 640, "n": 512, "cond": 2, "seed": 770001, "case": "mixed"} + - {"batch": 60, "n": 1024, "cond": 2, "seed": 770002, "case": "mixed"} + - {"batch": 640, "n": 512, "cond": 0, "seed": 770003, "case": "rankdef"} + - {"batch": 640, "n": 512, "cond": 0, "seed": 770004, "case": "clustered"} + - {"batch": 60, "n": 1024, "cond": 0, "seed": 770005, "case": "nearrank"}