Source code for droplets.tools.misc

"""Miscellaneous functions.

.. autosummary::
   :nosignatures:

   enable_scalar_args

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from __future__ import annotations

import functools
from collections.abc import Callable
from typing import Any, TypeVar

from pde.tools.misc import number_array

TFunc = TypeVar("TFunc", bound=Callable[..., Any])


[docs] def enable_scalar_args(method: TFunc) -> TFunc: """Decorator that makes vectorized methods work with scalars. This decorator allows to call functions that are written to work on numpy arrays to also accept python scalars, like `int` and `float`. Essentially, this wrapper turns them into an array and unboxes the result. Note that the dtype of the returned value will always be double or cdouble even if the function is called with an integer. Args: method: The method being decorated Returns: The decorated method """ @functools.wraps(method) def wrapper(self, *args): args = [number_array(arg, copy=None) for arg in args] if args[0].ndim == 0: args = [arg[None] for arg in args] return method(self, *args)[0] return method(self, *args) return wrapper # type: ignore