Mathematical functions

Pychop provides two ways to implement mathematical functions in reduced precision. The first approach requires specifying a backend, whereas the second approach does not require explicitly specifying one.

Mathematical functions I

The chop class provides a suite of mathematical functions that operate on floating-point numbers with custom precision chopping. These functions apply the chopping mechanism (via chop_wrapper) to inputs and outputs, ensuring results adhere to the specified precision (e.g., fp16, fp32). Implementations are available for NumPy, PyTorch, and JAX, with slight variations noted below. Functions are categorized for clarity. However, this method requires user to specify the backend first.

Note

  • All functions use the chop_wrapper method to apply precision chopping before and after computation.

  • NumPy: Uses numpy (np) operations, operates on np.ndarray, and assumes a stateless implementation.

  • PyTorch: Uses torch operations, operates on torch.Tensor, and supports GPU acceleration.

  • JAX: Uses jax.numpy (jnp) operations, operates on jax.Array, requires a random key for chopping, and is JIT-compatible.

  • Examples assume a chop instance with half-precision (prec=’h’) unless stated otherwise.

Trigonometric functions

sin(x)

Compute the sine of x with chopping.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).

Returns:

Chopped sine of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

Example (NumPy):

import numpy as np
chopper = chop(prec='h')
x = np.array([0.0, 1.5708])  # ~[0, pi/2]
result = chopper.sin(x)
print(result)  # Expected: ~[0.0, 1.0] with chopping

Example (PyTorch):

import torch
chopper = chop(prec='h')
x = torch.tensor([0.0, 1.5708])
result = chopper.sin(x)
print(result)  # Expected: ~[0.0, 1.0] with chopping

Example (JAX):

import jax.numpy as jnp
chopper = chop(prec='h')
x = jnp.array([0.0, 1.5708])
result = chopper.sin(x)
print(result)  # Expected: ~[0.0, 1.0] with chopping
cos(x)

Compute the cosine of x with chopping.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).

Returns:

Chopped cosine of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

tan(x)

Compute the tangent of x with chopping.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).

Returns:

Chopped tangent of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

arcsin(x)

Compute the arcsine of x with chopping. Input must be in [-1, 1].

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor in [-1, 1].

Returns:

Chopped arcsine of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

Raises:

ValueError – If any element of x is not in [-1, 1].

arccos(x)

Compute the arccosine of x with chopping. Input must be in [-1, 1].

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor in [-1, 1].

Returns:

Chopped arccosine of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

Raises:

ValueError – If any element of x is not in [-1, 1].

arctan(x)

Compute the arctangent of x with chopping.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).

Returns:

Chopped arctangent of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

Hyperbolic functions

sinh(x)

Compute the hyperbolic sine of x with chopping.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).

Returns:

Chopped hyperbolic sine of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

cosh(x)

Compute the hyperbolic cosine of x with chopping.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).

Returns:

Chopped hyperbolic cosine of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

tanh(x)

Compute the hyperbolic tangent of x with chopping.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).

Returns:

Chopped hyperbolic tangent of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

arcsinh(x)

Compute the inverse hyperbolic sine of x with chopping.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).

Returns:

Chopped inverse hyperbolic sine of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

arccosh(x)

Compute the inverse hyperbolic cosine of x with chopping. Input must be >= 1.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (>= 1).

Returns:

Chopped inverse hyperbolic cosine of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

Raises:

ValueError – If any element of x is < 1.

arctanh(x)

Compute the inverse hyperbolic tangent of x with chopping. Input must be in (-1, 1).

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor in (-1, 1).

Returns:

Chopped inverse hyperbolic tangent of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

Raises:

ValueError – If any element of x is not in (-1, 1).

Exponential and logarithmic functions

exp(x)

Compute the exponential of x with chopping.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).

Returns:

Chopped exponential of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

expm1(x)

Compute exp(x) - 1 with chopping, optimized for small x.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).

Returns:

Chopped exp(x) - 1.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

log(x)

Compute the natural logarithm of x with chopping. Input must be positive.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (> 0).

Returns:

Chopped natural logarithm of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

Raises:

ValueError – If any element of x is <= 0.

log10(x)

Compute the base-10 logarithm of x with chopping. Input must be positive.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (> 0).

Returns:

Chopped base-10 logarithm of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

Raises:

ValueError – If any element of x is <= 0.

log2(x)

Compute the base-2 logarithm of x with chopping. Input must be positive.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (> 0).

Returns:

Chopped base-2 logarithm of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

Raises:

ValueError – If any element of x is <= 0.

log1p(x)

Compute log(1 + x) with chopping, optimized for small x. Input must be > -1.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (> -1).

Returns:

Chopped log(1 + x).

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

Raises:

ValueError – If any element of x is <= -1.

Power and root functions

sqrt(x)

Compute the square root of x with chopping. Input must be non-negative.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (>= 0).

Returns:

Chopped square root of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

Raises:

ValueError – If any element of x is < 0.

cbrt(x)

Compute the cube root of x with chopping.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).

Returns:

Chopped cube root of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

Miscellaneous functions

abs(x)

Compute the absolute value of x with chopping.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real or complex).

Returns:

Chopped absolute value of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

reciprocal(x)

Compute the reciprocal (1/x) of x with chopping. Input must not be zero.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (!= 0).

Returns:

Chopped reciprocal of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

Raises:

ValueError – If any element of x is 0.

square(x)

Compute the square of x with chopping.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).

Returns:

Chopped square of x.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

Additional mathematical functions

frexp(x)

Decompose x into mantissa and exponent with chopping applied to mantissa.

Parameters:

x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).

Returns:

Tuple of (chopped mantissa, exponent).

Return type:

tuple (np.ndarray, np.ndarray) (NumPy), (torch.Tensor, torch.Tensor) (PyTorch), or (jax.Array, jax.Array) (JAX)

hypot(x, y)

Compute the Euclidean norm sqrt(x^2 + y^2) with chopping.

Parameters:
  • x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – First input array/tensor (real-valued).

  • y (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Second input array/tensor (real-valued).

Returns:

Chopped Euclidean norm.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

diff(x, n=1)

Compute the n-th order difference of x with chopping.

Parameters:
  • x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Input array/tensor (real-valued).

  • n (int) – Order of difference (default: 1).

Returns:

Chopped n-th order difference.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

power(x, y)

Compute x raised to the power y with chopping.

Parameters:
  • x (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Base array/tensor (real-valued).

  • y (np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)) – Exponent array/tensor (real-valued).

Returns:

Chopped x^y.

Return type:

np.ndarray (NumPy), torch.Tensor (PyTorch), or jax.Array (JAX)

Mathematical functions II

The pychop.math_func module provides a suite of backend-aware mathematical functions that operate on floating-point numbers or arrays with custom precision chopping. These functions apply a provided chop callable to inputs and outputs, ensuring results adhere to the specified precision (e.g., fp16, fp32).

Supported backends: NumPy, PyTorch, and JAX. Backend is inferred from the type of the input array/tensor. Functions are categorized for clarity.

Note

  • All functions apply the chop callable before and after computation.

  • Backend detection is automatic: - NumPy: np.ndarray - PyTorch: torch.Tensor - JAX: jax.Array

  • Inputs must satisfy domain constraints (e.g., positive for log, non-zero for reciprocal).

  • matmul requires inputs to be at least 1-dimensional; scalars are not allowed.

Example (NumPy):

import numpy as np
import pychop.math_func as mf
from pychop import Chop

chopper = Chop(exp_bits=5, sig_bits=10, rmode=3)
x = np.array([0.0, 1.5708])  # ~ [0, pi/2]
result = mf.sin(x, chopper)
print(result)  # Expected: ~ [0.0, 1.0] with chopping

Example (PyTorch):

import torch
import pychop.math_func as mf
from pychop import Chop

chopper = Chop(exp_bits=5, sig_bits=10, rmode=3)
x = torch.tensor([0.0, 1.5708])
result = mf.sin(x, chopper)
print(result)  # Expected: ~ [0.0, 1.0] with chopping

Example (JAX):

import jax.numpy as jnp
import pychop.math_func as mf
from pychop import Chop

chopper = Chop(exp_bits=5, sig_bits=10, rmode=3)
x = jnp.array([0.0, 1.5708])
result = mf.sin(x, chopper)
print(result)  # Expected: ~ [0.0, 1.0] with chopping

Trigonometric functions

sin(x, chop)

Compute sine of x with chopping.

Parameters:
  • x – Real-valued input.

  • chop – Callable that applies precision chopping.

Returns:

Chopped sine of x.

Return type:

Same type as x (NumPy, PyTorch, or JAX)

cos(x, chop)
tan(x, chop)
arcsin(x, chop)
Input must be in [-1, 1].
arccos(x, chop)
Input must be in [-1, 1].
arctan(x, chop)

Hyperbolic functions

sinh(x, chop)
cosh(x, chop)
tanh(x, chop)
arcsinh(x, chop)
arccosh(x, chop)
Input must be >= 1.
arctanh(x, chop)
Input must be in (-1, 1).

Exponential and logarithmic functions

exp(x, chop)
expm1(x, chop)
log(x, chop)
Input must be positive.
log10(x, chop)
Input must be positive.
log2(x, chop)
Input must be positive.
log1p(x, chop)
Input must be > -1.

Power and root functions

sqrt(x, chop)
Input must be non-negative.
cbrt(x, chop)
square(x, chop)
power(x, y, chop)
Compute x raised to the power y.

Arithmetic functions

add(x, y, chop)
subtract(x, y, chop)
multiply(x, y, chop)
divide(x, y, chop)
Divisor must be non-zero.
floor_divide(x, y, chop)
Divisor must be non-zero.
mod(x, y, chop)
Divisor must be non-zero.

Linear algebra functions

dot(x, y, chop)
matmul(x, y, chop)
Inputs must be at least 1-dimensional; scalars are not allowed.

Reduction and aggregation functions

sum(x, chop, axis=None)
prod(x, chop, axis=None)
mean(x, chop, axis=None)
std(x, chop, axis=None)
var(x, chop, axis=None)
cumsum(x, chop, axis=None)
cumprod(x, chop, axis=None)

Rounding functions

floor(x, chop)
ceil(x, chop)
round(x, chop)
sign(x, chop)

Comparison functions

maximum(x, y, chop)
minimum(x, y, chop)

Miscellaneous functions

frexp(x, chop)
Returns tuple of (chopped mantissa, exponent).
hypot(x, y, chop)
diff(x, n=1, chop)
Compute n-th order difference.
reciprocal(x, chop)
Input must be non-zero.