Mathematical functions¶
Pychop provides two ways to implement mathematical functions in reduced precision.
The first approach requires specifying a backend, whereas the second approach does not require explicitly specifying one.
Mathematical functions I¶
The chop class provides a suite of mathematical functions that operate on floating-point numbers with custom precision chopping. These functions apply the chopping mechanism (via chop_wrapper) to inputs and outputs, ensuring results adhere to the specified precision (e.g., fp16, fp32). Implementations are available for NumPy, PyTorch, and JAX, with slight variations noted below. Functions are categorized for clarity. However, this method requires user to specify the backend first.
Note
All functions use the chop_wrapper method to apply precision chopping before and after computation.
NumPy: Uses numpy (np) operations, operates on np.ndarray, and assumes a stateless implementation.
PyTorch: Uses torch operations, operates on torch.Tensor, and supports GPU acceleration.
JAX: Uses jax.numpy (jnp) operations, operates on jax.Array, requires a random key for chopping, and is JIT-compatible.
Examples assume a chop instance with half-precision (prec=’h’) unless stated otherwise.
Trigonometric functions¶
- sin(x)¶
Compute the sine of x with chopping.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).
- Returns:
Chopped sine of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
Example (NumPy):
import numpy as np chopper = chop(prec='h') x = np.array([0.0, 1.5708]) # ~[0, pi/2] result = chopper.sin(x) print(result) # Expected: ~[0.0, 1.0] with chopping
Example (PyTorch):
import torch chopper = chop(prec='h') x = torch.tensor([0.0, 1.5708]) result = chopper.sin(x) print(result) # Expected: ~[0.0, 1.0] with chopping
Example (JAX):
import jax.numpy as jnp chopper = chop(prec='h') x = jnp.array([0.0, 1.5708]) result = chopper.sin(x) print(result) # Expected: ~[0.0, 1.0] with chopping
- cos(x)¶
Compute the cosine of x with chopping.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).
- Returns:
Chopped cosine of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- tan(x)¶
Compute the tangent of x with chopping.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).
- Returns:
Chopped tangent of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- arcsin(x)¶
Compute the arcsine of x with chopping. Input must be in [-1, 1].
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor in [-1, 1].
- Returns:
Chopped arcsine of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- Raises:
ValueError – If any element of x is not in [-1, 1].
- arccos(x)¶
Compute the arccosine of x with chopping. Input must be in [-1, 1].
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor in [-1, 1].
- Returns:
Chopped arccosine of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- Raises:
ValueError – If any element of x is not in [-1, 1].
- arctan(x)¶
Compute the arctangent of x with chopping.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).
- Returns:
Chopped arctangent of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
Hyperbolic functions¶
- sinh(x)¶
Compute the hyperbolic sine of x with chopping.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).
- Returns:
Chopped hyperbolic sine of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- cosh(x)¶
Compute the hyperbolic cosine of x with chopping.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).
- Returns:
Chopped hyperbolic cosine of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- tanh(x)¶
Compute the hyperbolic tangent of x with chopping.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).
- Returns:
Chopped hyperbolic tangent of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- arcsinh(x)¶
Compute the inverse hyperbolic sine of x with chopping.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).
- Returns:
Chopped inverse hyperbolic sine of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- arccosh(x)¶
Compute the inverse hyperbolic cosine of x with chopping. Input must be >= 1.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (>= 1).
- Returns:
Chopped inverse hyperbolic cosine of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- Raises:
ValueError – If any element of x is < 1.
- arctanh(x)¶
Compute the inverse hyperbolic tangent of x with chopping. Input must be in (-1, 1).
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor in (-1, 1).
- Returns:
Chopped inverse hyperbolic tangent of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- Raises:
ValueError – If any element of x is not in (-1, 1).
Exponential and logarithmic functions¶
- exp(x)¶
Compute the exponential of x with chopping.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).
- Returns:
Chopped exponential of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- expm1(x)¶
Compute exp(x) - 1 with chopping, optimized for small x.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).
- Returns:
Chopped exp(x) - 1.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- log(x)¶
Compute the natural logarithm of x with chopping. Input must be positive.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (> 0).
- Returns:
Chopped natural logarithm of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- Raises:
ValueError – If any element of x is <= 0.
- log10(x)¶
Compute the base-10 logarithm of x with chopping. Input must be positive.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (> 0).
- Returns:
Chopped base-10 logarithm of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- Raises:
ValueError – If any element of x is <= 0.
- log2(x)¶
Compute the base-2 logarithm of x with chopping. Input must be positive.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (> 0).
- Returns:
Chopped base-2 logarithm of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- Raises:
ValueError – If any element of x is <= 0.
- log1p(x)¶
Compute log(1 + x) with chopping, optimized for small x. Input must be > -1.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (> -1).
- Returns:
Chopped log(1 + x).
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- Raises:
ValueError – If any element of x is <= -1.
Power and root functions¶
- sqrt(x)¶
Compute the square root of x with chopping. Input must be non-negative.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (>= 0).
- Returns:
Chopped square root of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- Raises:
ValueError – If any element of x is < 0.
- cbrt(x)¶
Compute the cube root of x with chopping.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).
- Returns:
Chopped cube root of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
Miscellaneous functions¶
- abs(x)¶
Compute the absolute value of x with chopping.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real or complex).
- Returns:
Chopped absolute value of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- reciprocal(x)¶
Compute the reciprocal (1/x) of x with chopping. Input must not be zero.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (!= 0).
- Returns:
Chopped reciprocal of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- Raises:
ValueError – If any element of x is 0.
- square(x)¶
Compute the square of x with chopping.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).
- Returns:
Chopped square of x.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
Additional mathematical functions¶
- frexp(x)¶
Decompose x into mantissa and exponent with chopping applied to mantissa.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).
- Returns:
Tuple of (chopped mantissa, exponent).
- Return type:
tuple (np.ndarray, np.ndarray) (NumPy), (torch.Tensor, torch.Tensor) (PyTorch), or (jax.Array, jax.Array) (JAX)
- hypot(x, y)¶
Compute the Euclidean norm sqrt(x^2 + y^2) with chopping.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – First input array/tensor (real-valued).
y (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Second input array/tensor (real-valued).
- Returns:
Chopped Euclidean norm.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- diff(x, n=1)¶
Compute the n-th order difference of x with chopping.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).
n (int) – Order of difference (default: 1).
- Returns:
Chopped n-th order difference.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
- power(x, y)¶
Compute x raised to the power y with chopping.
- Parameters:
x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Base array/tensor (real-valued).
y (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Exponent array/tensor (real-valued).
- Returns:
Chopped x^y.
- Return type:
np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)
Mathematical functions II¶
The pychop.math_func module provides a suite of backend-aware mathematical functions that operate on floating-point numbers or arrays with custom precision chopping. These functions apply a provided chop callable to inputs and outputs, ensuring results adhere to the specified precision (e.g., fp16, fp32).
Supported backends: NumPy, PyTorch, and JAX. Backend is inferred from the type of the input array/tensor. Functions are categorized for clarity.
Note
All functions apply the chop callable before and after computation.
Backend detection is automatic: - NumPy: np.ndarray - PyTorch: torch.Tensor - JAX: jax.Array
Inputs must satisfy domain constraints (e.g., positive for log, non-zero for reciprocal).
matmul requires inputs to be at least 1-dimensional; scalars are not allowed.
Example (NumPy):
import numpy as np
import pychop.math_func as mf
from pychop import Chop
chopper = Chop(exp_bits=5, sig_bits=10, rmode=3)
x = np.array([0.0, 1.5708]) # ~ [0, pi/2]
result = mf.sin(x, chopper)
print(result) # Expected: ~ [0.0, 1.0] with chopping
Example (PyTorch):
import torch
import pychop.math_func as mf
from pychop import Chop
chopper = Chop(exp_bits=5, sig_bits=10, rmode=3)
x = torch.tensor([0.0, 1.5708])
result = mf.sin(x, chopper)
print(result) # Expected: ~ [0.0, 1.0] with chopping
Example (JAX):
import jax.numpy as jnp
import pychop.math_func as mf
from pychop import Chop
chopper = Chop(exp_bits=5, sig_bits=10, rmode=3)
x = jnp.array([0.0, 1.5708])
result = mf.sin(x, chopper)
print(result) # Expected: ~ [0.0, 1.0] with chopping
Trigonometric functions¶
- sin(x, chop)¶
Compute sine of x with chopping.
- Parameters:
x – Real-valued input.
chop – Callable that applies precision chopping.
- Returns:
Chopped sine of x.
- Return type:
Same type as x (NumPy, PyTorch, or JAX)
- cos(x, chop)¶
- tan(x, chop)¶
- arcsin(x, chop)¶
- Input must be in [-1, 1].
- arccos(x, chop)¶
- Input must be in [-1, 1].
- arctan(x, chop)¶
Hyperbolic functions¶
- sinh(x, chop)¶
- cosh(x, chop)¶
- tanh(x, chop)¶
- arcsinh(x, chop)¶
- arccosh(x, chop)¶
- Input must be >= 1.
- arctanh(x, chop)¶
- Input must be in (-1, 1).
Exponential and logarithmic functions¶
- exp(x, chop)¶
- expm1(x, chop)¶
- log(x, chop)¶
- Input must be positive.
- log10(x, chop)¶
- Input must be positive.
- log2(x, chop)¶
- Input must be positive.
- log1p(x, chop)¶
- Input must be > -1.
Power and root functions¶
- sqrt(x, chop)¶
- Input must be non-negative.
- cbrt(x, chop)¶
- square(x, chop)¶
- power(x, y, chop)¶
- Compute x raised to the power y.
Arithmetic functions¶
- add(x, y, chop)¶
- subtract(x, y, chop)¶
- multiply(x, y, chop)¶
- divide(x, y, chop)¶
- Divisor must be non-zero.
- floor_divide(x, y, chop)¶
- Divisor must be non-zero.
- mod(x, y, chop)¶
- Divisor must be non-zero.
Linear algebra functions¶
- dot(x, y, chop)¶
- matmul(x, y, chop)¶
- Inputs must be at least 1-dimensional; scalars are not allowed.
Reduction and aggregation functions¶
- sum(x, chop, axis=None)¶
- prod(x, chop, axis=None)¶
- mean(x, chop, axis=None)¶
- std(x, chop, axis=None)¶
- var(x, chop, axis=None)¶
- cumsum(x, chop, axis=None)¶
- cumprod(x, chop, axis=None)¶
Rounding functions¶
- floor(x, chop)¶
- ceil(x, chop)¶
- round(x, chop)¶
- sign(x, chop)¶
Comparison functions¶
- maximum(x, y, chop)¶
- minimum(x, y, chop)¶
Miscellaneous functions¶
- frexp(x, chop)¶
- Returns tuple of (chopped mantissa, exponent).
- hypot(x, y, chop)¶
- diff(x, n=1, chop)¶
- Compute n-th order difference.
- reciprocal(x, chop)¶
- Input must be non-zero.