.. _ptq_guide:
========================================
Post-Training Quantization
========================================
Overview
--------------------------
PyChop provides comprehensive **Post-Training Quantization (PTQ)** methods to quantize pre-trained models
without retraining. PTQ is ideal for quick deployment and model compression with minimal accuracy loss.
**Key Features:**
- ✅ **4 PTQ Methods**: Basic, Static, Dynamic, Mixed-Precision
- ✅ **4 Calibration Algorithms**: MinMax, Percentile, KL-Divergence, MSE
- ✅ **Dual Backend Support**: PyTorch and JAX/Flax
- ✅ **TensorFlow Backend**: Full PTQ support with calibration
- ✅ **Flexible Quantization**: FP16, INT8, INT4, Custom Precision
- ✅ **Easy API**: Unified interface across backends
.. _ptq_comparison:
PTQ Methods Comparison
--------------------------
**Legend:**
- ✅ Quantized = Converted to specified precision (INT8, FP16, custom)
- ⚫ Original precision = Not quantized, keeps model's current precision
- ⚫ Preserved (FP32) = Always kept as FP32 for numerical stability
Quantization Components
--------------------------
Understanding what gets quantized in each method:
.. code-block:: text
┌─────────────────────────────────────────────────────────────────┐
│ Model Components │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ Weights │ │ Activations │ │ Batch Stats │ │
│ │ (Conv/FC) │ │ (ReLU) │ │ (BN mean/var)│ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
│ ▲ ▲ ▲ │
│ │ │ │ │
│ ┌────┴─────┬───────────┴──────┬──────────┴────┐ │
│ │ │ │ │ │
│ Basic Static Dynamic Preserved │
│ PTQ PTQ PTQ (all PTQ) │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ W-only W + A W + A (dynamic) │
└─────────────────────────────────────────────────────────────────┘
**Component Quantization Details:**
1. **Weights (Conv/Linear)**: Always quantized in all PTQ methods
2. **Biases**: Always quantized in all PTQ methods
3. **Activations (ReLU/GELU)**: Quantized in Static/Dynamic/Mixed PTQ
4. **BatchNorm Stats**: Never quantized (preserved as FP32)
5. **LayerNorm**: Quantized in Static/Dynamic PTQ
.. _ptq_api:
API Reference
=============
.. _post_quantization:
Basic PTQ: ``post_quantization``
----------------------------------------------------
**Weight-only quantization** (fastest, simplest).
.. code-block:: python
pychop.ptq.post_quantization(
model,
chop,
eval_mode=True,
verbose=False
)
**Parameters:**
.. list-table::
:widths: 20 80
:header-rows: 1
* - Parameter
- Description
* - ``model``
- **PyTorch**: ``torch.nn.Module`` | **JAX**: Flax variables dict | **TensorFlow**: ``tf.keras.Model``
* - ``chop``
- Quantizer instance (``Chop``, ``Chopf``, or ``Chopi``)
* - ``eval_mode``
- Set model to eval mode (PyTorch only)
* - ``verbose``
- Print quantization details
**Returns:**
- **PyTorch**: Quantized ``nn.Module``
- **JAX**: Quantized params dict
- **TensorFlow**: Quantized ``tf.keras.Model``
**Example:**
.. code-block:: python
import pychop
from pychop import Chopi
from pychop.ptq import post_quantization
# PyTorch
pychop.backend('torch')
chop = Chopi(bits=8, symmetric=True)
model_q = post_quantization(model, chop, verbose=True)
# JAX
pychop.backend('jax')
from pychop.jx.layers import ChopiSTE
chop = ChopiSTE(bits=8, symmetric=True)
quantized_params = post_quantization(variables, chop, verbose=True)
# TensorFlow
pychop.backend('tensorflow')
chop = Chopi(bits=8, symmetric=True)
model_q = post_quantization(tf_model, chop, verbose=True)
.. _static_post_quantization:
Static PTQ: ``static_post_quantization``
----------------------------------------------------
**Weights + Activations quantization with calibration** (best accuracy).
.. code-block:: python
pychop.ptq.static_post_quantization(
model,
chop,
calibration_data,
calibration_method='minmax',
percentile=99.99,
fuse_bn=True, # PyTorch only
eval_mode=True, # PyTorch only
verbose=False,
model_apply_fn=None # JAX only
)
**Parameters:**
.. list-table::
:widths: 20 80
:header-rows: 1
* - Parameter
- Description
* - ``model``
- Model to quantize
* - ``chop``
- Quantizer instance
* - ``calibration_data``
- Iterable of input batches (50-1000 batches recommended)
* - ``calibration_method``
- ``'minmax'`` | ``'percentile'`` | ``'kl_divergence'`` | ``'mse'``
* - ``percentile``
- Percentile for ``'percentile'`` method (e.g., 99.99)
* - ``fuse_bn``
- Fuse Conv+BN layers (PyTorch only, improves accuracy ~1-2%)
* - ``eval_mode``
- Set model to eval mode (PyTorch only)
* - ``verbose``
- Print quantization details
* - ``model_apply_fn``
- Model's apply function (JAX only, required for activation stats)
**Returns:**
- **PyTorch**: Quantized ``nn.Module`` with static activation hooks
- **JAX**: Dict with ``params``, ``batch_stats``, ``quant_config``
**Calibration Methods:**
.. list-table::
:widths: 20 40 40
:header-rows: 1
* - Method
- Description
- When to Use
* - ``minmax``
- Simple min/max clipping
- Quick tests, simple models
* - ``percentile``
- Percentile-based clipping (e.g., 99.99%)
- Data with outliers
* - ``kl_divergence``
- TensorRT-style KL-divergence optimization
- Production, best accuracy
* - ``mse``
- MSE-based threshold search
- Balance between accuracy and speed
**Example (PyTorch):**
.. code-block:: python
import pychop
from pychop import Chopi
from pychop.ptq import static_post_quantization
pychop.backend('torch')
chop = Chopi(bits=8, symmetric=True)
# Prepare calibration data
calibration_data = [
torch.randn(4, 3, 224, 224) for _ in range(100)
]
# Option 1: MinMax calibration (fastest)
model_q = static_post_quantization(
model, chop,
calibration_data=calibration_data,
calibration_method='minmax',
fuse_bn=True,
verbose=True
)
# Option 2: Percentile calibration (better)
model_q = static_post_quantization(
model, chop,
calibration_data=calibration_data,
calibration_method='percentile',
percentile=99.9, # Clip 0.1% outliers
verbose=True
)
# Option 3: KL-Divergence calibration (best)
model_q = static_post_quantization(
model, chop,
calibration_data=calibration_data,
calibration_method='kl_divergence',
fuse_bn=True,
verbose=True
)
# Use quantized model
output = model_q(input)
**Example (JAX):**
.. code-block:: python
import pychop
import jax
import jax.numpy as jnp
from pychop.jx.layers import ChopiSTE
from pychop.ptq import static_post_quantization
pychop.backend('jax')
chop = ChopiSTE(bits=8, symmetric=True)
# Prepare calibration data
calibration_data = [
jax.random.normal(jax.random.PRNGKey(i), (4, 224, 224, 3))
for i in range(100)
]
# Define apply function
def apply_fn(params, x):
return model.apply(params, x, train=False)
# Static PTQ with KL-divergence
result = static_post_quantization(
variables, chop,
calibration_data=calibration_data,
calibration_method='kl_divergence',
model_apply_fn=apply_fn,
verbose=True
)
# Use quantized model
output = model.apply(result, input, train=False)
.. _dynamic_post_quantization:
Dynamic PTQ: ``dynamic_post_quantization``
-------------------------------------------
**Weights + Activations quantization without calibration** (no calibration needed).
.. code-block:: python
pychop.ptq.dynamic_post_quantization(
model,
chop,
eval_mode=True,
verbose=False
)
**Parameters:**
.. list-table::
:widths: 20 80
:header-rows: 1
* - Parameter
- Description
* - ``model``
- Model to quantize
* - ``chop``
- Quantizer instance
* - ``eval_mode``
- Set model to eval mode (PyTorch only)
* - ``verbose``
- Print quantization details
**Returns:**
- **PyTorch**: Quantized ``nn.Module`` with dynamic activation hooks
- **JAX**: Dict with ``params``, ``batch_stats``, ``quant_config``
**Key Differences from Static PTQ:**
.. list-table::
:widths: 30 35 35
:header-rows: 1
* - Aspect
- Static PTQ
- Dynamic PTQ
* - **Calibration**
- Required (50-1000 batches)
- Not needed
* - **Inference Speed**
- ★★★★★ Fast
- ★★★★☆ ~5-10% slower
* - **Accuracy**
- ★★★★★ Best
- ★★★★☆ 0.5-1% lower
* - **Use Case**
- Vision models (fixed input)
- NLP models (variable input)
**Example:**
.. code-block:: python
import pychop
from pychop import Chopi
from pychop.ptq import dynamic_post_quantization
pychop.backend('torch')
chop = Chopi(bits=8, symmetric=True)
# Dynamic PTQ (no calibration needed)
model_q = dynamic_post_quantization(model, chop, verbose=True)
# Activations are quantized dynamically per batch
output = model_q(input)
.. _mixed_post_quantization:
Mixed-Precision PTQ: ``mixed_post_quantization``
-------------------------------------------------
**Separate quantizers for weights and activations** (W8A16, W4A8, etc.).
.. code-block:: python
pychop.ptq.mixed_post_quantization(
model,
weight_chop,
activation_chop,
calibration_data=None,
calibration_method='minmax',
percentile=99.99,
dynamic=True, # PyTorch only
eval_mode=True, # PyTorch only
verbose=False,
model_apply_fn=None # JAX only
)
**Parameters:**
.. list-table::
:widths: 20 80
:header-rows: 1
* - Parameter
- Description
* - ``model``
- Model to quantize
* - ``weight_chop``
- Quantizer for weights (e.g., ``Chopi(bits=8)``)
* - ``activation_chop``
- Quantizer for activations (e.g., ``Chop(exp_bits=5, sig_bits=10)`` for FP16)
* - ``calibration_data``
- Calibration data (optional for dynamic mode)
* - ``calibration_method``
- Calibration algorithm (if ``calibration_data`` provided)
* - ``percentile``
- Percentile for ``'percentile'`` calibration
* - ``dynamic``
- Use dynamic activation quantization (PyTorch only)
* - ``eval_mode``
- Set model to eval mode (PyTorch only)
* - ``verbose``
- Print quantization details
* - ``model_apply_fn``
- Model's apply function (JAX only)
**Popular Mixed-Precision Configurations:**
.. list-table::
:widths: 15 30 25 30
:header-rows: 1
* - Config
- Weight Quantizer
- Activation Quantizer
- Use Case
* - **W8A16**
- ``Chopi(bits=8)``
- ``Chop(exp_bits=5, sig_bits=10)``
- LLM quantization (minimal accuracy loss)
* - **W4A8**
- ``Chopi(bits=4)``
- ``Chopi(bits=8)``
- Extreme compression (75% size reduction)
* - **W2A8**
- ``Chopi(bits=2)``
- ``Chopi(bits=8)``
- Experimental (87.5% size reduction)
* - **W8A8**
- ``Chopi(bits=8)``
- ``Chopi(bits=8)``
- Standard INT8 quantization
**Example (W8A16 - LLM Quantization):**
.. code-block:: python
import pychop
from pychop import Chopi, Chop
from pychop.ptq import mixed_post_quantization
pychop.backend('torch')
# W8A16 configuration
weight_chop = Chopi(bits=8, symmetric=True) # 8-bit weights
activation_chop = Chop(exp_bits=5, sig_bits=10) # FP16 activations
# Option 1: Dynamic (no calibration)
model_q = mixed_post_quantization(
model, weight_chop, activation_chop,
dynamic=True,
verbose=True
)
# Option 2: Static (with calibration)
calibration_data = [torch.randn(4, 3, 224, 224) for _ in range(50)]
model_q = mixed_post_quantization(
model, weight_chop, activation_chop,
calibration_data=calibration_data,
calibration_method='percentile',
dynamic=False,
verbose=True
)
**Example (W4A8 - Extreme Compression):**
.. code-block:: python
# W4A8 configuration
weight_chop = Chopi(bits=4, symmetric=True) # 4-bit weights (75% size reduction)
activation_chop = Chopi(bits=8, symmetric=True) # 8-bit activations
model_q = mixed_post_quantization(
model, weight_chop, activation_chop,
dynamic=True,
verbose=True
)
.. _ptq_examples:
Complete Examples
=================
Example 1: ResNet-18 INT8 PTQ (PyTorch)
----------------------------------------
.. code-block:: python
import torch
import torchvision.models as models
import pychop
from pychop import Chopi
from pychop.ptq import static_post_quantization
# Load pre-trained ResNet-18
pychop.backend('torch')
model = models.resnet18(pretrained=True)
model.eval()
# Prepare calibration data (100 batches from ImageNet)
calibration_data = []
for images, _ in train_loader:
calibration_data.append(images)
if len(calibration_data) >= 100:
break
# INT8 quantization with percentile calibration
chop = Chopi(bits=8, symmetric=True)
model_q = static_post_quantization(
model, chop,
calibration_data=calibration_data,
calibration_method='percentile',
percentile=99.9,
fuse_bn=True,
verbose=True
)
# Evaluate accuracy
def evaluate(model, test_loader):
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return 100 * correct / total
fp32_acc = evaluate(model, test_loader)
int8_acc = evaluate(model_q, test_loader)
print(f"FP32 Accuracy: {fp32_acc:.2f}%")
print(f"INT8 Accuracy: {int8_acc:.2f}%")
print(f"Accuracy Drop: {fp32_acc - int8_acc:.2f}%")
# Expected output:
# FP32 Accuracy: 69.76%
# INT8 Accuracy: 69.34%
# Accuracy Drop: 0.42%
Example 2: BERT-Base W8A16 PTQ (PyTorch)
-----------------------------------------
.. code-block:: python
import torch
from transformers import BertForSequenceClassification
import pychop
from pychop import Chopi, Chop
from pychop.ptq import mixed_post_quantization
# Load pre-trained BERT
pychop.backend('torch')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
model.eval()
# Prepare calibration data
calibration_data = [
torch.randint(0, 30522, (4, 128)) for _ in range(50) # 50 batches
]
# W8A16 configuration
weight_chop = Chopi(bits=8, symmetric=True) # 8-bit weights
activation_chop = Chop(exp_bits=5, sig_bits=10) # FP16 activations
# Mixed-precision PTQ
model_q = mixed_post_quantization(
model, weight_chop, activation_chop,
calibration_data=calibration_data,
calibration_method='percentile',
percentile=99.99,
dynamic=False,
verbose=True
)
# Model size comparison
import os
torch.save(model.state_dict(), 'bert_fp32.pth')
torch.save(model_q.state_dict(), 'bert_w8a16.pth')
fp32_size = os.path.getsize('bert_fp32.pth') / (1024**2)
w8a16_size = os.path.getsize('bert_w8a16.pth') / (1024**2)
print(f"FP32 Model Size: {fp32_size:.2f} MB")
print(f"W8A16 Model Size: {w8a16_size:.2f} MB")
print(f"Size Reduction: {(1 - w8a16_size/fp32_size)*100:.1f}%")
# Expected output:
# FP32 Model Size: 438.00 MB
# W8A16 Model Size: 219.00 MB
# Size Reduction: 50.0%
Example 3: Vision Transformer (ViT) PTQ (JAX)
----------------------------------------------
.. code-block:: python
import jax
import jax.numpy as jnp
from flax import linen as nn
import pychop
from pychop.jx.layers import ChopiSTE
from pychop.ptq import static_post_quantization
# Define a simple ViT model
class SimpleViT(nn.Module):
num_classes: int = 1000
@nn.compact
def __call__(self, x, train=False):
# Patch embedding
x = nn.Conv(features=768, kernel_size=(16, 16), strides=16)(x)
# Transformer blocks (simplified)
for _ in range(12):
# Self-attention
attn = nn.MultiHeadDotProductAttention(num_heads=12)(x, x)
x = x + attn
x = nn.LayerNorm()(x)
# MLP
mlp = nn.Dense(features=3072)(x)
mlp = nn.gelu(mlp)
mlp = nn.Dense(features=768)(mlp)
x = x + mlp
x = nn.LayerNorm()(x)
# Classification head
x = jnp.mean(x, axis=(1, 2))
x = nn.Dense(features=self.num_classes)(x)
return x
# Initialize model
pychop.backend('jax')
model = SimpleViT()
rng = jax.random.PRNGKey(0)
variables = model.init(rng, jnp.ones((1, 224, 224, 3)), train=False)
# Prepare calibration data
calibration_data = [
jax.random.normal(jax.random.PRNGKey(i), (4, 224, 224, 3))
for i in range(100)
]
# Define apply function
def apply_fn(params, x):
return model.apply(params, x, train=False)
# INT8 quantization with KL-divergence
chop = ChopiSTE(bits=8, symmetric=True)
result = static_post_quantization(
variables, chop,
calibration_data=calibration_data,
calibration_method='kl_divergence',
model_apply_fn=apply_fn,
verbose=True
)
# Use quantized model
test_input = jax.random.normal(rng, (1, 224, 224, 3))
output = model.apply(result, test_input, train=False)
print(f"Output shape: {output.shape}")
print(f"Quantization config: {result['quant_config']}")
Example 4: Comparing Calibration Methods
-----------------------------------------
.. code-block:: python
import torch
import torchvision.models as models
import pychop
from pychop import Chopi
from pychop.ptq import static_post_quantization
# Load model
pychop.backend('torch')
model = models.mobilenet_v2(pretrained=True)
model.eval()
# Prepare calibration data
calibration_data = [
torch.randn(4, 3, 224, 224) for _ in range(100)
]
chop = Chopi(bits=8, symmetric=True)
# Test all calibration methods
methods = ['minmax', 'percentile', 'kl_divergence', 'mse']
results = {}
for method in methods:
print(f"\n{'='*60}")
print(f"Testing {method.upper()} calibration")
print('='*60)
model_q = static_post_quantization(
model, chop,
calibration_data=calibration_data,
calibration_method=method,
percentile=99.9 if method == 'percentile' else 99.99,
fuse_bn=True,
verbose=True
)
# Evaluate
acc = evaluate(model_q, test_loader)
results[method] = acc
print(f"{method} Accuracy: {acc:.2f}%")
# Summary
print(f"\n{'='*60}")
print("Summary: Calibration Method Comparison")
print('='*60)
for method, acc in sorted(results.items(), key=lambda x: x[1], reverse=True):
print(f"{method:20s}: {acc:.2f}%")
# Expected output:
# ============================================================
# Summary: Calibration Method Comparison
# ============================================================
# kl_divergence : 71.85%
# mse : 71.72%
# percentile : 71.58%
# minmax : 71.34%
.. _ptq_best_practices:
Best Practices
--------------------------
1. Choosing PTQ Method
-----------------------
.. code-block:: text
Decision Tree:
Need calibration data?
├─ No → Use Dynamic PTQ
└─ Yes →
└─ High accuracy required?
├─ Yes → Use Static PTQ (KL-divergence or MSE)
└─ No → Use Static PTQ (MinMax or Percentile)
Need mixed precision?
└─ Use Mixed PTQ (W8A16 for LLMs, W4A8 for extreme compression)
2. Calibration Data Guidelines
-------------------------------
**Size:**
- Vision models: 50-200 batches (200-800 images)
- NLP models: 100-500 batches
- Small models (<10M params): 50 batches
- Large models (>100M params): 200-1000 batches
**Diversity:**
.. code-block:: python
# Good: Diverse calibration data
calibration_data = sample_diverse_batches(train_loader, n=100)
# Bad: Only one class
calibration_data = [cat_images for _ in range(100)] # Only cats!
**Preprocessing:**
.. code-block:: python
# Apply same preprocessing as training
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
3. Conv+BN Fusion (PyTorch)
----------------------------
**Always enable for better accuracy:**
.. code-block:: python
# Good
model_q = static_post_quantization(
model, chop,
calibration_data=data,
fuse_bn=True, # ✅ Improves accuracy ~1-2%
verbose=True
)
# Bad
model_q = static_post_quantization(
model, chop,
calibration_data=data,
fuse_bn=False, # ❌ Lower accuracy
verbose=True
)
4. Percentile Selection
------------------------
**Guidelines:**
- **99.99%**: Default, works for most cases
- **99.9%**: More aggressive clipping, better for outlier-heavy data
- **99.0%**: Very aggressive, use with caution
.. code-block:: python
# Data with many outliers (e.g., sensor data)
model_q = static_post_quantization(
model, chop,
calibration_data=data,
calibration_method='percentile',
percentile=99.9, # Clip 0.1% outliers
verbose=True
)
5. JAX Backend Tips
-------------------
**Always provide ``model_apply_fn`` for activation quantization:**
.. code-block:: python
# Good
def apply_fn(params, x):
return model.apply(params, x, train=False)
result = static_post_quantization(
variables, chop,
calibration_data=data,
model_apply_fn=apply_fn, # ✅ Required for activation stats
verbose=True
)
# Bad
result = static_post_quantization(
variables, chop,
calibration_data=data,
# ❌ No model_apply_fn = no activation quantization
verbose=True
)
.. _ptq_troubleshooting:
Troubleshooting
===============
Common Issues
-------------
**Issue 1: Large Accuracy Drop (>5%)**
.. code-block:: python
# Solution 1: Use better calibration method
model_q = static_post_quantization(
model, chop,
calibration_data=data,
calibration_method='kl_divergence', # Try KL instead of minmax
fuse_bn=True,
verbose=True
)
# Solution 2: Increase calibration data
calibration_data = [batch for batch in train_loader[:200]] # More data
# Solution 3: Use mixed-precision
weight_chop = Chopi(bits=8)
activation_chop = Chop(exp_bits=5, sig_bits=10) # FP16 activations
model_q = mixed_post_quantization(
model, weight_chop, activation_chop,
calibration_data=data,
verbose=True
)
**Issue 2: Slow Calibration**
.. code-block:: python
# Solution 1: Use faster calibration method
model_q = static_post_quantization(
model, chop,
calibration_data=data,
calibration_method='percentile', # Faster than KL/MSE
verbose=True
)
# Solution 2: Reduce calibration data
calibration_data = calibration_data[:50] # Use fewer batches
**Issue 3: Out of Memory (OOM)**
.. code-block:: python
# Solution 1: Reduce batch size in calibration data
calibration_data = [
torch.randn(2, 3, 224, 224) # Smaller batch size
for _ in range(100)
]
# Solution 2: Use dynamic PTQ (no calibration)
model_q = dynamic_post_quantization(model, chop, verbose=True)
**Issue 4: JAX "model_apply_fn required" Error**
.. code-block:: python
# Solution: Always provide apply function
def apply_fn(params, x):
return model.apply(params, x, train=False)
result = static_post_quantization(
variables, chop,
calibration_data=data,
model_apply_fn=apply_fn, # ✅ Add this
verbose=True
)
References
--------------------------
- `TensorRT Documentation `_
- `PyTorch Quantization `_
- `ZeroQuant: Efficient INT8 Quantization `_
- `GPTQ: Accurate Quantization for GPT Models `_