Integer quantization¶
Integer quantization is another important feature of Pychop. It intention is to convert the floating point number into
low bit-width integer, which speedup the computations in certain computing hardware. It performs quantization with
user-defined bitwidths. The following example illustrates the usage of the method.
Basic usage¶
The integer arithmetic emulation of Pychop is implemented by the interface Chopi. It can be used in many circumstances, and offer flexible choices for users to choose, such as symmetric quantization or not, number of bitwidth to use, the usage is illustrated as below:
- class Chopi(num_bits=8, symmetric=False, per_channel=False, channel_dim=0)¶
A class for quantizing and dequantizing arrays to and from integer representations.
This class supports both symmetric and asymmetric quantization, with optional per-channel quantization along a specified axis. It is designed for inference-style quantization in JAX, PyTorch, and NumPy frameworks, with framework-specific array types (
jnp.ndarray,torch.Tensor,np.ndarray).- Parameters:
num_bits (int) – Bit-width for quantization (e.g., 8 for INT8). Default is 8.
symmetric (bool) – If True, use symmetric quantization (zero_point = 0). If False, use asymmetric quantization. Default is False.
per_channel (bool) – If True, quantize per channel along the specified
channel_dim. If False, quantize the entire array. Default is False.channel_dim (int) – Dimension to treat as the channel axis for per-channel quantization. Default is 0.
- Variables:
qmin (int) – Minimum quantized value (e.g., -128 for symmetric INT8, 0 for asymmetric INT8).
qmax (int) – Maximum quantized value (e.g., 127 for INT8).
scale – Scaling factor(s) for quantization, computed during calibration. Shape depends on
per_channel(scalar or array).zero_point – Zero-point offset(s) for quantization, computed during calibration. None if symmetric, else matches
scaleshape.
- calibrate(x)¶
Calibrate the Chopi by computing the scale and zero-point based on the input array.
- Parameters:
x – Input array to calibrate from (
jnp.ndarrayfor JAX,torch.Tensorfor PyTorch,np.ndarrayfor NumPy).- Raises:
TypeError – If the input is not of the expected array type for the framework.
- quantize(x)¶
Quantize the input array to integers.
If the Chopi has not been calibrated, it will automatically calibrate using the input array.
- Parameters:
x – Input array to quantize (
jnp.ndarrayfor JAX,torch.Tensorfor PyTorch,np.ndarrayfor NumPy).- Returns:
Quantized integer array (
jnp.ndarraywith dtypeint8for JAX,torch.Tensorwith dtypetorch.int8for PyTorch,np.ndarraywith dtypeint8for NumPy).- Raises:
TypeError – If the input is not of the expected array type for the framework.
- dequantize(q)¶
Dequantize the integer array back to floating-point.
- Parameters:
q – Quantized integer array (
jnp.ndarrayfor JAX,torch.Tensorfor PyTorch,np.ndarrayfor NumPy).- Returns:
Dequantized floating-point array (
jnp.ndarrayfor JAX,torch.Tensorfor PyTorch,np.ndarrayfor NumPy).- Raises:
TypeError – If the input is not of the expected array type for the framework.
ValueError – If the Chopi has not been calibrated (i.e.,
scaleis None).
Principle
Quantization reduces the precision of floating-point values to integers to save memory and accelerate computation, especially on hardware with integer arithmetic support. The process involves:
Calibration: Determine the range of the input array (min and max values) to compute a scaling factor (
scale) and offset (zero_point).Quantization: Map floats to integers using
q = round(x / scale + zero_point), clipped to[qmin, qmax].Dequantization: Recover approximate floats using
x = (q - zero_point) * scale.
Symmetric: Assumes
zero_point = 0(e.g., range[-128, 127]for INT8), suitable for weights with zero-centered distributions.Asymmetric: Allows
zero_pointto shift the range (e.g.,[0, 255]for INT8), better for activations with non-zero minima.Per-channel: Applies separate
scaleandzero_pointper channel, improving accuracy for multi-channel data (e.g., CNN weights).
Examples
JAX Example:
import jax.numpy as jnp from pychop import Chopi pychop.backend('jax') x = jnp.array([[0.1, -0.2], [0.3, 0.4]]) Chopi = Chopi(num_bits=8, symmetric=False) q = Chopi.quantize(x) dq = Chopi.dequantize(q) print(q) # e.g., [[ 85 42] [106 127]], dtype=int8 print(dq) # e.g., [[ 0.098 -0.196] [ 0.294 0.392]]
PyTorch Example:
import torch from pychop import Chopi pychop.backend('torch') x = torch.tensor([[0.1, -0.2], [0.3, 0.4]]) ch = Chopi(num_bits=8, symmetric=False) q = ch.quantize(x) # Inference mode dq = ch.dequantize(q) print(q) # e.g., tensor([[ 85, 42], [106, 127]], dtype=torch.int8) print(dq) # e.g., tensor([[ 0.098, -0.196], [ 0.294, 0.392]])
NumPy Example:
import numpy as np from pychop import Chopi pychop.backend('numpy') x = np.array([[0.1, -0.2], [0.3, 0.4]]) ch = NumpyChopi(num_bits=8, symmetric=False) q = ch.quantize(x) dq = ch.dequantize(q) print(q) # e.g., [[ 85 42] [106 127]], dtype=int8 print(dq) # e.g., [[ 0.098 -0.196] [ 0.294 0.392]]
Note
The PyTorch version supports training mode via
forward(x, training=True)for fake quantization, which isn’t shown here but is useful for quantization-aware training.Exact integer values may vary slightly due to rounding and range differences.
import numpy as np
import torch
import pychop
from numpy import linalg
import jax
X_np = np.random.randn(500, 500) # NumPy array
X_th = torch.Tensor(X_np) # Torch array
X_jx = jax.numpy.asarray(X_np) # JAX array
print(X_np)
pychop.backend('numpy')
pyq_f = pychop.Chopi(num_bits=8) # The larger the ``bits`` are, the more accurate of the reconstruction is
X_q = pyq_f.quantize(X_np) # quant array -> integer
X_inv = pyq_f.dequantize(X_q) # dequant array -> floating point values
linalg.norm(X_inv - X_np)
pychop.backend('torch')
pyq_f = pychop.Chopi(num_bits=8)
X_q = pyq_f.quantize(X_th) # quant array -> integer
X_inv = pyq_f.dequantize(X_q) # dequant array -> floating point values
linalg.norm(X_inv - X_np)
pychop.backend('jax')
pyq_f = pychop.Chopi(num_bits=8)
X_q = pyq_f.quantize(X_jx) # quant array -> integer
X_inv = pyq_f.dequantize(X_q) # dequant array -> floating point values
linalg.norm(X_inv - X_jx)
Quantization aware training¶
Pychop provides easy-to-use API IntQuantizedLayer for quantization aware training.
Simply load the module via:
from pychop import IntQuantizedLayer
The quantization-aware training simply perform by plugging the IntQuantizedLayer into neural network building. We illustrate its usage in fully connected layer training: