Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ a specific version, or vendor the library inside your own.

```{note}
This library depends on `array-api-compat`. We aim for compatibility with
the latest released version of array-api-compat, and your mileage may vary
with older or dev versions.
the latest released versions of the standard and array-api-compat,
and your mileage may vary with older or dev/draft versions.
```

(vendoring)=
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ reportUnusedParameter = false
reportImportCycles = false
# PyRight can't trace types in lambdas
reportUnknownLambdaType = false
# conflicts with https://docs.astral.sh/ruff/rules/explicit-string-concatenation/
reportImplicitStringConcatenation = false

executionEnvironments = [
{ root = "tests", reportPrivateUsage = false, reportUnknownArgumentType = false },
Expand Down
33 changes: 22 additions & 11 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
is_torch_namespace,
)
from ._lib._utils._compat import device as get_device
from ._lib._utils._helpers import asarrays, eager_shape
from ._lib._utils._helpers import asarrays, deprecated, eager_shape
from ._lib._utils._typing import Array, DType

__all__ = [
Expand Down Expand Up @@ -83,19 +83,23 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
return _funcs.atleast_nd(x, ndim=ndim, xp=xp)


@deprecated(
"`xpx.broadcast_shapes` is deprecated and will be removed in v1.0.0. "
"`xp.broadcast_shapes` exists in the standard as of v2025.12."
)
def broadcast_shapes(
*shapes: tuple[float | None, ...], xp: ModuleType | None = None
) -> tuple[int | None, ...]:
"""
Compute the shape of the broadcasted arrays.

.. deprecated:: 0.11.0
:func:`broadcast_shapes` is deprecated and will be removed in v1.0.0.
:func:`array_api.broadcast_shapes` exists in the standard as of v2025.12.

Duplicates :func:`numpy.broadcast_shapes`, with additional support for
None and NaN sizes.

This is equivalent to ``xp.broadcast_arrays(arr1, arr2, ...)[0].shape``
without needing to worry about the backend potentially deep copying
the arrays.

Parameters
----------
*shapes : tuple[int | None, ...]
Expand Down Expand Up @@ -300,18 +304,25 @@ def create_diagonal(
return _funcs.create_diagonal(x, offset=offset, xp=xp)


@deprecated(
"`xpx.expand_dims` is deprecated and will be removed in v1.0.0. "
"`xp.expand_dims` with support for a tuple of ints in `axis` "
"exists in the standard as of v2025.12."
)
def expand_dims(
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
) -> Array:
"""
Expand the shape of an array.

.. deprecated:: 0.11.0
:func:`expand_dims` is deprecated and will be removed in v1.0.0.
:func:`array_api.expand_dims` with support for a tuple of ints in `axis`
exists in the standard as of v2025.12.

Insert (a) new axis/axes that will appear at the position(s) specified by
`axis` in the expanded array shape.

This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*.
Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays.

Parameters
----------
a : array
Expand Down Expand Up @@ -804,7 +815,7 @@ def searchsorted(
Find the indices into a sorted array ``x1`` such that if the elements in ``x2``
were inserted before the indices, the resulting array would remain sorted.

The behavior of this function is similar to that of `array_api.searchsorted`,
The behavior of this function is similar to that of :func:`array_api.searchsorted`,
but it relaxes the requirement that `x1` must be one-dimensional.
This function is vectorized, treating slices along the last axis
as elements and preceding axes as batch (or "loop") dimensions.
Expand Down Expand Up @@ -1220,8 +1231,8 @@ def isin(
"""
Determine whether each element in `a` is present in `b`.

Return a boolean array of the same shape as `a` that is True for elements
that are in `b` and False otherwise.
This is :func:`array_api.isin`, with additional `assume_unique`
and `kind` parameters.

Parameters
----------
Expand Down
23 changes: 23 additions & 0 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from __future__ import annotations

import functools
import io
import math
import pickle
import types
import warnings
from collections.abc import Callable, Generator, Iterable, Iterator
from functools import wraps
from types import ModuleType
Expand Down Expand Up @@ -48,6 +50,7 @@ def override(func):
__all__ = [
"asarrays",
"capabilities",
"deprecated",
"eager_shape",
"in1d",
"is_python_scalar",
Expand All @@ -58,6 +61,26 @@ def override(func):
]


def deprecated(
msg: str, stacklevel: int = 2
) -> Callable[[Callable[P, T]], Callable[P, T]]: # numpydoc ignore=PR01,RT01
"""Deprecate a function by emitting a warning on use."""

def decorate(func: Callable[P, T]) -> Callable[P, T]: # numpydoc ignore=GL08
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
warnings.warn(
msg,
category=DeprecationWarning,
stacklevel=stacklevel,
)
return func(*args, **kwargs)

return wrapper

return decorate


def in1d(
x1: Array,
x2: Array,
Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def xp(
# Possibly wrap module with array_api_compat
xp = array_namespace(xp.empty(0))

if library.like(Backend.ARRAY_API_STRICT):
xp.set_array_api_strict_flags(api_version="2025.12")

if library == Backend.ARRAY_API_STRICTEST:
with xp.ArrayAPIStrictFlags(
boolean_indexing=False,
Expand Down
1 change: 1 addition & 0 deletions tests/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ py.install_sources([
'__init__.py',
'conftest.py',
'test_at.py',
'test_deprecation.py',
'test_funcs.py',
'test_helpers.py',
'test_lazy.py',
Expand Down
15 changes: 15 additions & 0 deletions tests/test_deprecation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from types import ModuleType

import pytest

from array_api_extra import broadcast_shapes, expand_dims


class TestDeprecatedFunctions:
def test_broadcast_shapes(self, xp: ModuleType):
with pytest.raises(DeprecationWarning, match=r"removed in v1.0.0"):
_ = broadcast_shapes((2, 3), (2, 1), xp=xp)

def test_expand_dims(self, xp: ModuleType):
with pytest.raises(DeprecationWarning, match=r"removed in v1.0.0"):
_ = expand_dims(xp.ones(2), axis=0, xp=xp)
2 changes: 2 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def test_5D_values(self, xp: ModuleType):
assert_equal(y, xp.asarray([[[[[[[[[3.0]], [[2.0]]]]]]]]]))


@pytest.mark.filterwarnings("ignore:.*removed in v1.0.0.*:DeprecationWarning")
class TestBroadcastShapes:
def test_delegates_known_integer_shapes(self, monkeypatch: pytest.MonkeyPatch):
calls = []
Expand Down Expand Up @@ -828,6 +829,7 @@ def test_torch(self, torch: ModuleType):
assert default_dtype(xp, "complex floating") == xp.complex64


@pytest.mark.filterwarnings(r"ignore:.*removed in v1.0.0.*:DeprecationWarning")
class TestExpandDims:
def test_single_axis(self, xp: ModuleType):
"""Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims"""
Expand Down