Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion gsw/_utilities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import wraps
from functools import reduce, wraps
from itertools import chain

import numpy as np

Expand All @@ -16,6 +17,61 @@ def masked_to_nan(arg):
else:
return np.asarray(arg, dtype=float)

def masked_array_support(f):
"""Decorator which adds support for np.ma.masked_arrays to the _wrapped_ufuncs

When one or more masked arrays are encountered as arguments or keyword
arguments, the boolean masks are all logical ORed together then logical
NOT is applied to get the ufunc.where parameter.

If no masked arrays are found, the ufunc is immediately called without modification
of arguments.

If a where keyword argument is present, it will be used instead of the
masked derived value.

All args/kwargs are then passed directly to the wrapped function
"""

@wraps(f)
def wrapper(*args, **kwargs):
has_masked_args = any(
np.ma.isMaskedArray(arg) for arg in chain(args, kwargs.values())
)
if not has_masked_args:
return f(*args, **kwargs)

# The only thing done when a masked array is encountered is to figure out the correct value to set the where argument to.
# This logic inspired by how the np.ma wrapped ufuncs work.
# https://github.com/numpy/numpy/blob/cafec60a5e28af98fb8798049edd7942720d2d74/numpy/ma/core.py#L1016-L1025
# we want getmask rather than getmaskarray for performance reasons
mask = reduce(
np.logical_or,
(np.ma.getmask(arg) for arg in chain(args, kwargs.values())),
)
where = ~mask

new_kwargs = {"where": where}
new_kwargs.update(
**kwargs
) # Allow user override of the where kwarg if they passed it in.

ret = f(*args, **new_kwargs)

# I suspect based on __array_priority__ the returned values might
# not be masked arrays when mixed with other array subclasses with
# a higher priority.
#
# masked_invalid will retain the existing mask and mask
# any new invalid values (if e.g. the result of unmasked inputs
# was nan/inf)
if isinstance(ret, tuple):
return tuple(np.ma.masked_invalid(rv) for rv in ret)
return np.ma.masked_invalid(ret)

return wrapper


def match_args_return(f):
"""
Decorator for most functions that operate on profile data.
Expand Down
Loading
Loading