Skip to content

Add solver option and use SciPy solver by default#10

Open
KenyaOtsuka wants to merge 8 commits into
KamitaniLab:devfrom
KenyaOtsuka:refactor/use-cholesky
Open

Add solver option and use SciPy solver by default#10
KenyaOtsuka wants to merge 8 commits into
KamitaniLab:devfrom
KenyaOtsuka:refactor/use-cholesky

Conversation

@KenyaOtsuka

@KenyaOtsuka KenyaOtsuka commented May 21, 2026

Copy link
Copy Markdown
Contributor

This PR adds a solver option to FastL2LiR and uses the SciPy-based solver by default.

Previously, FastL2LiR used np.linalg.solve to solve the normal equations. This treats the coefficient matrix as a general dense matrix. The new default solver uses scipy.linalg.solve with assume_a="pos", which allows SciPy to use a Cholesky-based solver.

The speedup seems particularly noticeable in environments where NumPy/SciPy use the OpenBLAS backend. In my local benchmarks, the SciPy-based solver achieved up to about a 10x speedup over the NumPy-based solver.

Changes

  1. Added a solver option.
  2. Set the SciPy-based solver as the default.
  3. Added tests for the NumPy solver to ensure backward-compatible behavior.

Checks

Following CONTRIBUTING.md, I confirmed that pytest passes.

I also ran ruff check . and ruff format ., and all checks pass.

Note

The assume_a="pos" solver may fail when the coefficient matrix has a large condition number. In such cases, the implementation falls back to assume_a="sym".

However, even with assume_a="sym" or the previous np.linalg.solve-based implementation, there is no guarantee that a good numerical solution is obtained in such cases.

A large condition number can occur when alpha=0 or when alpha is very small. In particular, when alpha=0, returning an OLS estimator would be another possible implementation, but this PR does not address it because it would be a behavior change.

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a configurable linear-system solver for FastL2LiR and switches the default from NumPy’s generic solver to SciPy’s solver with assume_a="pos" to enable Cholesky-based solving for the normal equations.

Changes:

  • Added a solver option to FastL2LiR ('scipy' default, 'numpy' alternative) and routed all linear solves through it.
  • Implemented a SciPy assume_a="pos" solve with fallback to assume_a="sym" on LinAlgError.
  • Added tests to validate solver selection and to compare NumPy vs SciPy solver outputs.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
src/fastl2lir/fastl2lir.py Adds solver selection in __init__ and replaces direct np.linalg.solve calls with a chosen solver function.
tests/test_fastl2lir.py Adds coverage for solver selection and a cross-solver equivalence check.
pyproject.toml Adds SciPy as a runtime dependency.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/fastl2lir/fastl2lir.py Outdated
Comment thread src/fastl2lir/fastl2lir.py Outdated
Comment thread pyproject.toml Outdated
Comment thread tests/test_fastl2lir.py
@KenyaOtsuka KenyaOtsuka marked this pull request as draft May 21, 2026 09:17
@KenyaOtsuka KenyaOtsuka changed the title Refactor/use cholesky Add solver option and use SciPy solver by default May 21, 2026
@KenyaOtsuka KenyaOtsuka force-pushed the refactor/use-cholesky branch from 515f98f to e2fd04b Compare June 8, 2026 11:58
@KenyaOtsuka KenyaOtsuka force-pushed the refactor/use-cholesky branch from 8aec090 to 439c73e Compare June 8, 2026 12:10
@KenyaOtsuka KenyaOtsuka marked this pull request as ready for review June 9, 2026 05:03
@KenyaOtsuka KenyaOtsuka requested a review from ganow June 9, 2026 05:03
@KenyaOtsuka

Copy link
Copy Markdown
Contributor Author

Benchmark: scipy.linalg.solve(assume_a='pos') vs numpy.linalg.solve

Solve step only (A = 1000×1000 positive definite, B = 1000×100), 1 BLAS thread, median of 10 runs.

Environment Python NumPy SciPy BLAS Architecture Speedup (scipy / numpy)
RISC-V 3.12.1 1.26.4 1.12.0 OpenBLAS 0.3.26 (system) RISCV64_GENERIC (no SIMD) 3.06x
x86-64 3.11.15 1.26.4 1.11.4 OpenBLAS 0.3.23 (bundled) Prescott (SSE2) 1.36x
x86-64 3.11.15 2.4.6 1.17.1 scipy-openblas 0.3.31 (shared) SkylakeX (AVX-512) 1.03x

The speedup is largest in environments without SIMD optimization (RISC-V), where the algorithmic advantage of Cholesky decomposition (~half the FLOPs of LU) translates directly to wall-clock time. On modern x86 with AVX-512, SIMD optimization narrows the gap significantly.

Note: FastL2LiR's feature selection path runs with 1 BLAS thread (threadpool_limits(limits=1)), so the 1-thread figures above reflect actual runtime behavior for that code path.

@ganow ganow left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the solver option and the speedup. The overall direction looks good, the target matrices are symmetric positive definite so assume_a="pos" is appropriate, and the pos to sym fallback is sound.

I am requesting changes mainly for one blocking issue. Storing the solver as an inline closure on self.__solve makes fitted models unpicklable, which is a backward-compatibility regression for the common workflow of fitting a model, pickling it, and running prediction in another process. Moving the solve functions to top-level scope, as noted in the inline comment, resolves this while preserving the early validation and the one-time branch selection.

The remaining inline comments are smaller. Removing the unused self.__solver, extending the solver tests to cover the feature-selection and dual-form paths as well as the pos to sym fallback, and adding a pickle round-trip test.

Comment thread src/fastl2lir/fastl2lir.py
Comment thread src/fastl2lir/fastl2lir.py Outdated
Comment thread tests/test_fastl2lir.py
…ck tests

- Extract _solve_scipy and _solve_numpy as module-level functions so that
  FastL2LiR instances are picklable and solver logic is statically analysable
- Remove unused self.__solver instance variable
- Add test_solver_pickle: verifies pickle round-trip preserves W and b
- Add test_solver_scipy_fallback: verifies pos->sym fallback completes
  without error and satisfies the linear system

https://claude.ai/code/session_016xWp5Yz2EQz2F4TnWuaymM

@ganow ganow left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for carefully addressing all the previous review comments. The blocking pickling issue is resolved by moving the solver functions to module level, and the added tests (the pos-to-sym fallback, the pickle round-trip, and the solver-level equivalence check) land exactly where our discussion pointed. I confirmed locally that the full test suite passes.

That said, I have to apologize: while re-reviewing the updated code, I noticed a few points that I had missed in my initial review pass, including one that stems from the design I myself suggested. I am sorry for not catching these the first time and for extending the review cycle as a result.

None of the new comments are blocking. Two are small suggestions (the SciPy version floor and the attribute access in the new test), and the pickling one is an open question where I would value your opinion on how (or whether) to handle it, rather than a requested change.

Comment on lines +13 to +21
def _solve_scipy(a, b):
try:
return sp_linalg.solve(a, b, assume_a="pos", check_finite=False)
except sp_linalg.LinAlgError:
return sp_linalg.solve(a, b, assume_a="sym", check_finite=False)


def _solve_numpy(a, b):
return np.linalg.solve(a, b)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for bringing this up at this stage. While verifying the fix, I noticed a remaining subtlety in the design I myself suggested, so this is on me as much as anything. Moving the solver to module-level functions fixes pickling within a single environment, but pickle stores functions by reference (module path + qualified name), which leaves an asymmetric constraint across package versions:

  • A model pickled with the old version and loaded with this code works for predict(), but calling fit() again raises AttributeError because __solve is missing from the unpickled __dict__.
  • A model pickled with this version cannot be unpickled at all under the old package, since _solve_scipy does not exist there.

One can reasonably argue that pickles should never cross package versions in the first place, so I don't consider this blocking. On the other hand, the fit-on-one-machine / predict-on-another workflow makes version skew plausible in practice. If we wanted to harden this, one option is __getstate__/__setstate__ that stores the solver name and re-resolves the function on load. What do you think? Would this be worth addressing here, in a follow-up, or just documenting?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing this out.
Since this issue was introduced by this PR, I thought it would be better to address it here rather than only document it or leave it for a follow-up.

In 7fc620c, the model stores the solver name as state, excludes the resolved solver function from the pickle state, and re-resolves it in __setstate__.
For legacy pickles without a stored solver name, it falls back to the previous NumPy behavior and emits a warning. The new pickle state also avoids storing references to the solver helper functions, so it should be more robust to version skew.

Comment thread pyproject.toml Outdated
Comment thread tests/test_fastl2lir.py Outdated
@ganow

ganow commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Follow-up: solver benchmark on a typical compute server

Since switching the default solver is a behavior change, I benchmarked the two solvers on hardware representative of where we run feature decoding.

What I measured

I modeled the path FastL2LiR actually takes for feature decoding (feature-decoding / bdpy / brain-decoding-cookbook): feature selection is ON, so each output unit solves one (n_feat+1) x (n_feat+1) SPD normal-equation system with a single RHS column, under threadpool_limits(limits=1). Parameters match the configs: n_samples=6000, alpha=100, sweeping n_feat, for both float32 (the dtype the pipelines fit with) and float64. The matrix is built as newX.T @ newX + alpha*I, the same as __sub_fit, so the conditioning is realistic. Ratio is scipy(pos) / numpy; a value > 1 means scipy is slower.

Script: https://gist.github.com/ganow/4f89c91006689a7955fac34881b3af2d

Run on a 32-core OpenBLAS server with 1 BLAS thread actually enforced (the real decoding setting), on two dependency stacks:

# current resolve: numpy 2.x, scipy 1.15.3, OpenBLAS 0.3.31
uv run python bench_solver.py

# the stack many users get under bdpy[all] (numpy<1.24): numpy 1.23.5, scipy 1.10.1, OpenBLAS 0.3.18
uv run --with 'numpy<1.24' --with 'scipy<1.11' python bench_solver.py

Results (ratio = scipy / numpy; > 1 means scipy is slower)

float32 (the dtype the pipelines actually use)

n_feat matrix current stack bdpy[all] stack (numpy<1.24)
50 51² 1.82 1.57
100 101² 1.46 1.47
200 201² 1.29 1.40
500 501² 0.88 0.97
1000 1001² 0.70 0.83
2000 2001² 0.53 0.56

float64

n_feat matrix current stack bdpy[all] stack (numpy<1.24)
50 51² 1.95 1.66
100 101² 1.52 1.60
200 201² 1.31 1.50
500 501² 1.11 1.32
1000 1001² 1.09 1.33
2000 2001² 1.20 1.17

Takeaways:

  • At the realistic operating point (float32, n_feat=500, the VC num:500 configs), scipy is about 12% faster on the current stack and essentially a tie (0.97) on the bdpy[all]-constrained stack many users run.
  • In float64, scipy is slower everywhere on both stacks.
  • For small systems (n_feat ≲ 200) scipy is 1.3 to 1.95x slower in both dtypes.
  • scipy clearly wins only for large float32 systems (n_feat ≥ 1000).

Why scipy only helps in float32 (and why that is mostly a NumPy quirk)

The flip is not Cholesky vs LU. numpy.linalg always computes in double precision regardless of input dtype, then casts the result back, so a float32 input gets no single-precision speedup, while scipy.linalg.solve runs genuine single-precision LAPACK.

Source (numpy/linalg/_linalg.py, numpy 2.2.6) [1, 2]: _commonType always returns double as the computation type (inline comment: "in lite version, use higher precision (always double or cdouble)"), and solve casts the result to the input dtype afterward:

def _commonType(*arrays):
    # in lite version, use higher precision (always double or cdouble)
    ...
    return double, result_type   # first element (compute type) is always double

# in solve():
t, result_t = _commonType(a, b)  # t == double, result_t == single for float32 input
...
return wrap(r.astype(result_t, copy=False))   # compute in double, cast back to float32

This is observable in the result accuracy, independent of platform and threading:

import numpy as np
from scipy import linalg as sp
n = 1200; rng = np.random.default_rng(0)
X = rng.standard_normal((6000, n)).astype(np.float32)
A = np.ascontiguousarray(X.T @ X + 100*np.eye(n, dtype=np.float32))
b = X.T @ rng.standard_normal(6000).astype(np.float32)
A64, b64 = A.astype(np.float64), b.astype(np.float64)
x_ref = np.linalg.solve(A64, b64)
err = lambda x: np.linalg.norm(x.astype(np.float64) - x_ref) / np.linalg.norm(x_ref)
print("numpy fp32 rel.err", err(np.linalg.solve(A, b)))                               # ~2.4e-8 (below float32 eps ~1.2e-7) -> computed in double
print("scipy fp32 rel.err", err(sp.solve(A, b, assume_a='pos', check_finite=False)))  # ~2.6e-7 (genuine fp32)

numpy's float32 result is accurate to below float32 machine epsilon, which is only possible if the computation ran in double. This reproduces on both a laptop and the OpenBLAS server.

I am deliberately not quoting absolute per-routine solve times here: they depend heavily on the BLAS build and thread settings and are not portable (on the server, numpy's multi-threaded LAPACK path was far slower than scipy's, which is a separate effect from the dtype upcast). The ratio tables above are the relevant comparison, since they were measured under threadpool_limits(limits=1), the setting the feature-selection path actually runs in.

So the float32 advantage for scipy is partly that numpy.linalg.solve computes in double regardless of dtype, not that Cholesky is fundamentally faster. In float64, where both compute in double, scipy is comparable or slower.

Suggestion

The Cholesky path is a genuine win for large float32 problems, so the solver option is valuable. But which solver is faster depends on dtype, matrix size, BLAS build, and CPU, and the crossover moves around (near n_feat≈500 here, but that is hardware/stack specific), so I don't think a reliable automatic dispatcher is feasible.

Given that, defaulting to scipy looks a bit opinionated: on the stack many users run, it is a tie at the typical size in float32 and a regression for small or float64 cases. Would it make sense to keep numpy as the default (preserving the existing behavior and avoiding any silent slowdown) and expose the SciPy solver as an opt-in through the solver option you added? Pipelines running large float32 decoders can switch to it explicitly.

@ganow

ganow commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Correction to the comment above: I removed a per-routine timing snippet that quoted absolute solve times measured on my laptop. Re-running that exact snippet on the OpenBLAS server gave very different absolute numbers (numpy's multi-threaded LAPACK path came out roughly 16x slower than scipy's there), so those laptop numbers were not portable and could mislead.

What does not change:

  • The mechanism (numpy computes in double regardless of input dtype, then casts back) is established by the numpy source and by the accuracy check, both platform-independent. The accuracy numbers reproduce on the server (numpy fp32 rel.err ~2.4e-8, below float32 eps; scipy ~2.6e-7).
  • The recommendation is unchanged. It rests on the ratio tables, which were measured under threadpool_limits(limits=1), the setting the feature-selection path actually runs in. The laptop snippet was run without that limit, so it was not representative of the decoding hot loop anyway.

If anything, the server result reinforces the point that absolute performance is highly sensitive to BLAS build and threading, which is part of why I don't think an automatic solver dispatcher would be reliable.

@KenyaOtsuka

Copy link
Copy Markdown
Contributor Author

@ganow
Thank you for running this benchmark. The float64 slowdown with SciPy is an important observation.

Before deciding the default, I wanted to check one point about the cost-benefit assessment. I agree that the relative ratios are important, but the absolute impact may also matter. My understanding is that most users of the actual feature decoding pipelines use float32, so the impact on the main float32 use case may deserve more weight than the float64 case.

For float32, SciPy is indeed slower for small n_feat, but the absolute overhead seems relatively small when accumulated over a typical number of output units. On the other hand, for realistic or larger decoding workloads such as n_feat >= 500, the potential savings from SciPy can amount to minutes or even hours.

So I wanted to add that the trade-off may not be simply “sometimes faster, sometimes slower.” For the main float32 use case, the cost when SciPy is worse and the benefit when SciPy is better may be quite asymmetric.

Even taking this into account, I am slightly leaning toward reverting the default to NumPy to avoid changing the existing behavior. Still, I wanted to check whether this asymmetric absolute impact changes your view on what would be most beneficial for the main users of this model.

- Add _get_solver_func() to centralize solver selection logic
- Reintroduce self.__solver for serializable solver name
- Add __getstate__/__setstate__ for cross-version pickle compatibility
- Legacy pickles without solver info fall back to numpy with UserWarning
- Raise scipy lower bound from >=1.2.3 to >=1.8
- Update test_solver_helpers_agree to use module-level helpers directly
- Strengthen test_solver_pickle with predict comparison
- Add test_solver_getstate and test_solver_legacy_setstate

https://claude.ai/code/session_016xWp5Yz2EQz2F4TnWuaymM
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants