Source code for larray.core.ufuncs

# numpy ufuncs
# http://docs.scipy.org/doc/numpy/reference/routines.math.html

import numpy as np

from larray.core.array import LArray, make_numpy_broadcastable

__all__ = [
    # Trigonometric functions
    'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan', 'hypot', 'arctan2', 'degrees', 'radians', 'unwrap',
    # 'deg2rad', 'rad2deg',

    # Hyperbolic functions
    'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh', 'arctanh',

    # Rounding
    'round', 'around', 'round_', 'rint', 'fix', 'floor', 'ceil', 'trunc',

    # Sums, products, differences
    # 'prod', 'sum', 'nansum', 'cumprod', 'cumsum',

    # cannot use a simple wrapped ufunc because those ufuncs do not preserve shape or dimensions so labels are wrong
    # 'diff', 'ediff1d', 'gradient', 'cross', 'trapz',

    # Exponents and logarithms
    'exp', 'expm1', 'exp2', 'log', 'log10', 'log2', 'log1p', 'logaddexp', 'logaddexp2',

    # Other special functions
    'i0', 'sinc',

    # Floating point routines
    'signbit', 'copysign', 'frexp', 'ldexp',

    # Arithmetic operations
    # 'add', 'reciprocal', 'negative', 'multiply', 'divide', 'power', 'subtract', 'true_divide', 'floor_divide',
    # 'fmod', 'mod', 'modf', 'remainder',

    # Handling complex numbers
    'angle', 'real', 'imag', 'conj',

    # Miscellaneous
    'convolve', 'clip', 'sqrt',
    # 'square',
    'absolute', 'fabs', 'sign', 'maximum', 'minimum', 'fmax', 'fmin', 'nan_to_num', 'real_if_close',
    'interp', 'where', 'isnan', 'isinf',
    'inverse',
]


def broadcastify(func):
    # intentionally not using functools.wraps, because it does not work for wrapping a function from another module
    def wrapper(*args, **kwargs):
        # TODO: normalize args/kwargs like in LIAM2 so that we can also broadcast if args are given via kwargs
        #       (eg out=)
        args, combined_axes = make_numpy_broadcastable(args)

        # We pass only raw numpy arrays to the ufuncs even though numpy is normally meant to handle those case itself
        # via __array_wrap__

        # There is a problem with np.clip though (and possibly other ufuncs): np.clip is roughly equivalent to
        # np.maximum(np.minimum(np.asarray(la), high), low)
        # the np.asarray(la) is problematic because it lose original labels
        # and then tries to get them back from high, where they are possibly
        # incomplete if broadcasting happened

        # It fails on "np.minimum(ndarray, LArray)" because it calls __array_wrap__(high, result) which cannot work if
        # there was broadcasting involved (high has potentially less labels than result).
        # it does this because numpy calls __array_wrap__ on the argument with the highest __array_priority__
        raw_args = [np.asarray(a) if isinstance(a, LArray) else a
                    for a in args]
        res_data = func(*raw_args, **kwargs)
        if combined_axes:
            return LArray(res_data, combined_axes)
        else:
            return res_data
    # copy meaningful attributes (numpy ufuncs do not have __annotations__ nor __qualname__)
    wrapper.__name__ = func.__name__
    wrapper.__doc__ = func.__doc__
    # set __qualname__ explicitly (all these functions are supposed to be top-level function in the ufuncs module)
    wrapper.__qualname__ = func.__name__
    # we should not copy __module__
    return wrapper


# Trigonometric functions

sin = broadcastify(np.sin)
cos = broadcastify(np.cos)
tan = broadcastify(np.tan)
arcsin = broadcastify(np.arcsin)
arccos = broadcastify(np.arccos)
arctan = broadcastify(np.arctan)
hypot = broadcastify(np.hypot)
arctan2 = broadcastify(np.arctan2)
degrees = broadcastify(np.degrees)
radians = broadcastify(np.radians)
unwrap = broadcastify(np.unwrap)
# deg2rad = broadcastify(np.deg2rad)
# rad2deg = broadcastify(np.rad2deg)

# Hyperbolic functions

sinh = broadcastify(np.sinh)
cosh = broadcastify(np.cosh)
tanh = broadcastify(np.tanh)
arcsinh = broadcastify(np.arcsinh)
arccosh = broadcastify(np.arccosh)
arctanh = broadcastify(np.arctanh)

# Rounding

# all 3 are equivalent, I am unsure I should support around and round_
round = broadcastify(np.round)
around = broadcastify(np.around)
round_ = broadcastify(np.round_)
rint = broadcastify(np.rint)
fix = broadcastify(np.fix)
floor = broadcastify(np.floor)
ceil = broadcastify(np.ceil)
trunc = broadcastify(np.trunc)

# Sums, products, differences

# prod = broadcastify(np.prod)
# sum = broadcastify(np.sum)
# nansum = broadcastify(np.nansum)
# cumprod = broadcastify(np.cumprod)
# cumsum = broadcastify(np.cumsum)

# cannot use a simple wrapped ufunc because those ufuncs do not preserve
# shape or dimensions so labels are wrong
# diff = broadcastify(np.diff)
# ediff1d = broadcastify(np.ediff1d)
# gradient = broadcastify(np.gradient)
# cross = broadcastify(np.cross)
# trapz = broadcastify(np.trapz)

# Exponents and logarithms

exp = broadcastify(np.exp)
expm1 = broadcastify(np.expm1)
exp2 = broadcastify(np.exp2)
log = broadcastify(np.log)
log10 = broadcastify(np.log10)
log2 = broadcastify(np.log2)
log1p = broadcastify(np.log1p)
logaddexp = broadcastify(np.logaddexp)
logaddexp2 = broadcastify(np.logaddexp2)

# Other special functions

i0 = broadcastify(np.i0)
sinc = broadcastify(np.sinc)

# Floating point routines

signbit = broadcastify(np.signbit)
copysign = broadcastify(np.copysign)
frexp = broadcastify(np.frexp)
ldexp = broadcastify(np.ldexp)

# Arithmetic operations

# add = broadcastify(np.add)
# reciprocal = broadcastify(np.reciprocal)
# negative = broadcastify(np.negative)
# multiply = broadcastify(np.multiply)
# divide = broadcastify(np.divide)
# power = broadcastify(np.power)
# subtract = broadcastify(np.subtract)
# true_divide = broadcastify(np.true_divide)
# floor_divide = broadcastify(np.floor_divide)
# fmod = broadcastify(np.fmod)
# mod = broadcastify(np.mod)
# modf = broadcastify(np.modf)
# remainder = broadcastify(np.remainder)

# Handling complex numbers

angle = broadcastify(np.angle)
real = broadcastify(np.real)
imag = broadcastify(np.imag)
conj = broadcastify(np.conj)

# Miscellaneous

convolve = broadcastify(np.convolve)
clip = broadcastify(np.clip)
sqrt = broadcastify(np.sqrt)
# square = broadcastify(np.square)
absolute = broadcastify(np.absolute)
fabs = broadcastify(np.fabs)
sign = broadcastify(np.sign)
maximum = broadcastify(np.maximum)
minimum = broadcastify(np.minimum)
fmax = broadcastify(np.fmax)
fmin = broadcastify(np.fmin)
nan_to_num = broadcastify(np.nan_to_num)
real_if_close = broadcastify(np.real_if_close)
interp = broadcastify(np.interp)
where = broadcastify(np.where)
isnan = broadcastify(np.isnan)
isinf = broadcastify(np.isinf)

inverse = broadcastify(np.linalg.inv)