Skip to content

Add more array API functions#3684

Open
katlun-lgtm wants to merge 1 commit into
ml-explore:mainfrom
katlun-lgtm:array-api-batch
Open

Add more array API functions#3684
katlun-lgtm wants to merge 1 commit into
ml-explore:mainfrom
katlun-lgtm:array-api-batch

Conversation

@katlun-lgtm

Copy link
Copy Markdown

Proposed changes

Adds a batch of array API functions to mlx.core, all built on existing primitives (no core/Metal changes):

Elementwise / utility

  • positive(a), logical_xor(a, b), trunc(a)
  • count_nonzero(a, /, *, axis=None, keepdims=False)
  • diff(a, /, n=1, axis=-1, *, prepend=None, append=None)

Creation

  • full_like(a, vals, dtype=None)
  • empty(shape, dtype=...), empty_like(a, dtype=None) — these return zeros, since MLX does not expose uninitialized memory

Free-function wrappers

  • astype(a, dtype) and matrix_transpose(a) mirror the existing array method/property
  • cumulative_sum / cumulative_prod wrap cumsum / cumprod with the array API axis (flatten when None), dtype, and include_initial semantics

Inspection

  • __array_namespace_info__() returning an object with capabilities(), default_device(), default_dtypes(), devices(), and dtypes(kind=...)

All added to the ops docs and tested in test_ops.py / test_array.py. Part of #3484.

Checklist

  • I have read the CONTRIBUTING document
  • clang-format and black (the formatters configured in .pre-commit-config.yaml) report no changes on the modified files
  • Added tests (test_ops.py: test_array_api_elementwise, test_diff, test_array_api_creation, test_astype_and_matrix_transpose, test_cumulative_sum_prod; test_array.py: test_array_namespace_info)
  • Built from source and ran the tests locally: the new tests pass, and the full test_ops.py (144) and test_array.py (75) suites pass with no regressions

Adds array API functions toward ml-explore#3484, all built on existing primitives
(no core changes):

- Elementwise / utility: positive, logical_xor, trunc, count_nonzero, diff
- Creation: full_like, empty, empty_like (empty / empty_like return zeros
  since MLX does not expose uninitialized memory)
- Free functions: astype, matrix_transpose, cumulative_sum, cumulative_prod
- Inspection: __array_namespace_info__ (capabilities, default_device,
  default_dtypes, devices, dtypes)

Adds them to the ops docs and tests in test_ops.py / test_array.py.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant