Post-Training Quantization¶
Overview¶
PyChop provides comprehensive Post-Training Quantization (PTQ) methods to quantize pre-trained models without retraining. PTQ is ideal for quick deployment and model compression with minimal accuracy loss.
Key Features:
✅ 4 PTQ Methods: Basic, Static, Dynamic, Mixed-Precision
✅ 4 Calibration Algorithms: MinMax, Percentile, KL-Divergence, MSE
✅ Dual Backend Support: PyTorch and JAX/Flax
✅ Flexible Quantization: FP16, INT8, INT4, Custom Precision
✅ Easy API: Unified interface across backends
PTQ Methods Comparison¶
Legend:
✅ Quantized = Converted to specified precision (INT8, FP16, custom)
⚫ Original precision = Not quantized, keeps model’s current precision
⚫ Preserved (FP32) = Always kept as FP32 for numerical stability
Quantization Components¶
Understanding what gets quantized in each method:
┌─────────────────────────────────────────────────────────────────┐
│ Model Components │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ Weights │ │ Activations │ │ Batch Stats │ │
│ │ (Conv/FC) │ │ (ReLU) │ │ (BN mean/var)│ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
│ ▲ ▲ ▲ │
│ │ │ │ │
│ ┌────┴─────┬───────────┴──────┬──────────┴────┐ │
│ │ │ │ │ │
│ Basic Static Dynamic Preserved │
│ PTQ PTQ PTQ (all PTQ) │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ W-only W + A W + A (dynamic) │
└─────────────────────────────────────────────────────────────────┘
Component Quantization Details:
Weights (Conv/Linear): Always quantized in all PTQ methods
Biases: Always quantized in all PTQ methods
Activations (ReLU/GELU): Quantized in Static/Dynamic/Mixed PTQ
BatchNorm Stats: Never quantized (preserved as FP32)
LayerNorm: Quantized in Static/Dynamic PTQ
API Reference¶
Basic PTQ: post_quantization¶
Weight-only quantization (fastest, simplest).
pychop.ptq.post_quantization(
model,
chop,
eval_mode=True,
verbose=False
)
Parameters:
Parameter |
Description |
|---|---|
|
PyTorch: |
|
Quantizer instance ( |
|
Set model to eval mode (PyTorch only) |
|
Print quantization details |
Returns:
PyTorch: Quantized
nn.ModuleJAX: Quantized params dict
Example:
import pychop
from pychop import Chopi
from pychop.ptq import post_quantization
# PyTorch
pychop.backend('torch')
chop = Chopi(bits=8, symmetric=True)
model_q = post_quantization(model, chop, verbose=True)
# JAX
pychop.backend('jax')
from pychop.jx.layers import ChopiSTE
chop = ChopiSTE(bits=8, symmetric=True)
quantized_params = post_quantization(variables, chop, verbose=True)
Static PTQ: static_post_quantization¶
Weights + Activations quantization with calibration (best accuracy).
pychop.ptq.static_post_quantization(
model,
chop,
calibration_data,
calibration_method='minmax',
percentile=99.99,
fuse_bn=True, # PyTorch only
eval_mode=True, # PyTorch only
verbose=False,
model_apply_fn=None # JAX only
)
Parameters:
Parameter |
Description |
|---|---|
|
Model to quantize |
|
Quantizer instance |
|
Iterable of input batches (50-1000 batches recommended) |
|
|
|
Percentile for |
|
Fuse Conv+BN layers (PyTorch only, improves accuracy ~1-2%) |
|
Set model to eval mode (PyTorch only) |
|
Print quantization details |
|
Model’s apply function (JAX only, required for activation stats) |
Returns:
PyTorch: Quantized
nn.Modulewith static activation hooksJAX: Dict with
params,batch_stats,quant_config
Calibration Methods:
Method |
Description |
When to Use |
|---|---|---|
|
Simple min/max clipping |
Quick tests, simple models |
|
Percentile-based clipping (e.g., 99.99%) |
Data with outliers |
|
TensorRT-style KL-divergence optimization |
Production, best accuracy |
|
MSE-based threshold search |
Balance between accuracy and speed |
Example (PyTorch):
import pychop
from pychop import Chopi
from pychop.ptq import static_post_quantization
pychop.backend('torch')
chop = Chopi(bits=8, symmetric=True)
# Prepare calibration data
calibration_data = [
torch.randn(4, 3, 224, 224) for _ in range(100)
]
# Option 1: MinMax calibration (fastest)
model_q = static_post_quantization(
model, chop,
calibration_data=calibration_data,
calibration_method='minmax',
fuse_bn=True,
verbose=True
)
# Option 2: Percentile calibration (better)
model_q = static_post_quantization(
model, chop,
calibration_data=calibration_data,
calibration_method='percentile',
percentile=99.9, # Clip 0.1% outliers
verbose=True
)
# Option 3: KL-Divergence calibration (best)
model_q = static_post_quantization(
model, chop,
calibration_data=calibration_data,
calibration_method='kl_divergence',
fuse_bn=True,
verbose=True
)
# Use quantized model
output = model_q(input)
Example (JAX):
import pychop
import jax
import jax.numpy as jnp
from pychop.jx.layers import ChopiSTE
from pychop.ptq import static_post_quantization
pychop.backend('jax')
chop = ChopiSTE(bits=8, symmetric=True)
# Prepare calibration data
calibration_data = [
jax.random.normal(jax.random.PRNGKey(i), (4, 224, 224, 3))
for i in range(100)
]
# Define apply function
def apply_fn(params, x):
return model.apply(params, x, train=False)
# Static PTQ with KL-divergence
result = static_post_quantization(
variables, chop,
calibration_data=calibration_data,
calibration_method='kl_divergence',
model_apply_fn=apply_fn,
verbose=True
)
# Use quantized model
output = model.apply(result, input, train=False)
Dynamic PTQ: dynamic_post_quantization¶
Weights + Activations quantization without calibration (no calibration needed).
pychop.ptq.dynamic_post_quantization(
model,
chop,
eval_mode=True,
verbose=False
)
Parameters:
Parameter |
Description |
|---|---|
|
Model to quantize |
|
Quantizer instance |
|
Set model to eval mode (PyTorch only) |
|
Print quantization details |
Returns:
PyTorch: Quantized
nn.Modulewith dynamic activation hooksJAX: Dict with
params,batch_stats,quant_config
Key Differences from Static PTQ:
Aspect |
Static PTQ |
Dynamic PTQ |
|---|---|---|
Calibration |
Required (50-1000 batches) |
Not needed |
Inference Speed |
★★★★★ Fast |
★★★★☆ ~5-10% slower |
Accuracy |
★★★★★ Best |
★★★★☆ 0.5-1% lower |
Use Case |
Vision models (fixed input) |
NLP models (variable input) |
Example:
import pychop
from pychop import Chopi
from pychop.ptq import dynamic_post_quantization
pychop.backend('torch')
chop = Chopi(bits=8, symmetric=True)
# Dynamic PTQ (no calibration needed)
model_q = dynamic_post_quantization(model, chop, verbose=True)
# Activations are quantized dynamically per batch
output = model_q(input)
Mixed-Precision PTQ: mixed_post_quantization¶
Separate quantizers for weights and activations (W8A16, W4A8, etc.).
pychop.ptq.mixed_post_quantization(
model,
weight_chop,
activation_chop,
calibration_data=None,
calibration_method='minmax',
percentile=99.99,
dynamic=True, # PyTorch only
eval_mode=True, # PyTorch only
verbose=False,
model_apply_fn=None # JAX only
)
Parameters:
Parameter |
Description |
|---|---|
|
Model to quantize |
|
Quantizer for weights (e.g., |
|
Quantizer for activations (e.g., |
|
Calibration data (optional for dynamic mode) |
|
Calibration algorithm (if |
|
Percentile for |
|
Use dynamic activation quantization (PyTorch only) |
|
Set model to eval mode (PyTorch only) |
|
Print quantization details |
|
Model’s apply function (JAX only) |
Popular Mixed-Precision Configurations:
Config |
Weight Quantizer |
Activation Quantizer |
Use Case |
|---|---|---|---|
W8A16 |
|
|
LLM quantization (minimal accuracy loss) |
W4A8 |
|
|
Extreme compression (75% size reduction) |
W2A8 |
|
|
Experimental (87.5% size reduction) |
W8A8 |
|
|
Standard INT8 quantization |
Example (W8A16 - LLM Quantization):
import pychop
from pychop import Chopi, Chop
from pychop.ptq import mixed_post_quantization
pychop.backend('torch')
# W8A16 configuration
weight_chop = Chopi(bits=8, symmetric=True) # 8-bit weights
activation_chop = Chop(exp_bits=5, sig_bits=10) # FP16 activations
# Option 1: Dynamic (no calibration)
model_q = mixed_post_quantization(
model, weight_chop, activation_chop,
dynamic=True,
verbose=True
)
# Option 2: Static (with calibration)
calibration_data = [torch.randn(4, 3, 224, 224) for _ in range(50)]
model_q = mixed_post_quantization(
model, weight_chop, activation_chop,
calibration_data=calibration_data,
calibration_method='percentile',
dynamic=False,
verbose=True
)
Example (W4A8 - Extreme Compression):
# W4A8 configuration
weight_chop = Chopi(bits=4, symmetric=True) # 4-bit weights (75% size reduction)
activation_chop = Chopi(bits=8, symmetric=True) # 8-bit activations
model_q = mixed_post_quantization(
model, weight_chop, activation_chop,
dynamic=True,
verbose=True
)
Complete Examples¶
Example 1: ResNet-18 INT8 PTQ (PyTorch)¶
import torch
import torchvision.models as models
import pychop
from pychop import Chopi
from pychop.ptq import static_post_quantization
# Load pre-trained ResNet-18
pychop.backend('torch')
model = models.resnet18(pretrained=True)
model.eval()
# Prepare calibration data (100 batches from ImageNet)
calibration_data = []
for images, _ in train_loader:
calibration_data.append(images)
if len(calibration_data) >= 100:
break
# INT8 quantization with percentile calibration
chop = Chopi(bits=8, symmetric=True)
model_q = static_post_quantization(
model, chop,
calibration_data=calibration_data,
calibration_method='percentile',
percentile=99.9,
fuse_bn=True,
verbose=True
)
# Evaluate accuracy
def evaluate(model, test_loader):
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return 100 * correct / total
fp32_acc = evaluate(model, test_loader)
int8_acc = evaluate(model_q, test_loader)
print(f"FP32 Accuracy: {fp32_acc:.2f}%")
print(f"INT8 Accuracy: {int8_acc:.2f}%")
print(f"Accuracy Drop: {fp32_acc - int8_acc:.2f}%")
# Expected output:
# FP32 Accuracy: 69.76%
# INT8 Accuracy: 69.34%
# Accuracy Drop: 0.42%
Example 2: BERT-Base W8A16 PTQ (PyTorch)¶
import torch
from transformers import BertForSequenceClassification
import pychop
from pychop import Chopi, Chop
from pychop.ptq import mixed_post_quantization
# Load pre-trained BERT
pychop.backend('torch')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
model.eval()
# Prepare calibration data
calibration_data = [
torch.randint(0, 30522, (4, 128)) for _ in range(50) # 50 batches
]
# W8A16 configuration
weight_chop = Chopi(bits=8, symmetric=True) # 8-bit weights
activation_chop = Chop(exp_bits=5, sig_bits=10) # FP16 activations
# Mixed-precision PTQ
model_q = mixed_post_quantization(
model, weight_chop, activation_chop,
calibration_data=calibration_data,
calibration_method='percentile',
percentile=99.99,
dynamic=False,
verbose=True
)
# Model size comparison
import os
torch.save(model.state_dict(), 'bert_fp32.pth')
torch.save(model_q.state_dict(), 'bert_w8a16.pth')
fp32_size = os.path.getsize('bert_fp32.pth') / (1024**2)
w8a16_size = os.path.getsize('bert_w8a16.pth') / (1024**2)
print(f"FP32 Model Size: {fp32_size:.2f} MB")
print(f"W8A16 Model Size: {w8a16_size:.2f} MB")
print(f"Size Reduction: {(1 - w8a16_size/fp32_size)*100:.1f}%")
# Expected output:
# FP32 Model Size: 438.00 MB
# W8A16 Model Size: 219.00 MB
# Size Reduction: 50.0%
Example 3: Vision Transformer (ViT) PTQ (JAX)¶
import jax
import jax.numpy as jnp
from flax import linen as nn
import pychop
from pychop.jx.layers import ChopiSTE
from pychop.ptq import static_post_quantization
# Define a simple ViT model
class SimpleViT(nn.Module):
num_classes: int = 1000
@nn.compact
def __call__(self, x, train=False):
# Patch embedding
x = nn.Conv(features=768, kernel_size=(16, 16), strides=16)(x)
# Transformer blocks (simplified)
for _ in range(12):
# Self-attention
attn = nn.MultiHeadDotProductAttention(num_heads=12)(x, x)
x = x + attn
x = nn.LayerNorm()(x)
# MLP
mlp = nn.Dense(features=3072)(x)
mlp = nn.gelu(mlp)
mlp = nn.Dense(features=768)(mlp)
x = x + mlp
x = nn.LayerNorm()(x)
# Classification head
x = jnp.mean(x, axis=(1, 2))
x = nn.Dense(features=self.num_classes)(x)
return x
# Initialize model
pychop.backend('jax')
model = SimpleViT()
rng = jax.random.PRNGKey(0)
variables = model.init(rng, jnp.ones((1, 224, 224, 3)), train=False)
# Prepare calibration data
calibration_data = [
jax.random.normal(jax.random.PRNGKey(i), (4, 224, 224, 3))
for i in range(100)
]
# Define apply function
def apply_fn(params, x):
return model.apply(params, x, train=False)
# INT8 quantization with KL-divergence
chop = ChopiSTE(bits=8, symmetric=True)
result = static_post_quantization(
variables, chop,
calibration_data=calibration_data,
calibration_method='kl_divergence',
model_apply_fn=apply_fn,
verbose=True
)
# Use quantized model
test_input = jax.random.normal(rng, (1, 224, 224, 3))
output = model.apply(result, test_input, train=False)
print(f"Output shape: {output.shape}")
print(f"Quantization config: {result['quant_config']}")
Example 4: Comparing Calibration Methods¶
import torch
import torchvision.models as models
import pychop
from pychop import Chopi
from pychop.ptq import static_post_quantization
# Load model
pychop.backend('torch')
model = models.mobilenet_v2(pretrained=True)
model.eval()
# Prepare calibration data
calibration_data = [
torch.randn(4, 3, 224, 224) for _ in range(100)
]
chop = Chopi(bits=8, symmetric=True)
# Test all calibration methods
methods = ['minmax', 'percentile', 'kl_divergence', 'mse']
results = {}
for method in methods:
print(f"\n{'='*60}")
print(f"Testing {method.upper()} calibration")
print('='*60)
model_q = static_post_quantization(
model, chop,
calibration_data=calibration_data,
calibration_method=method,
percentile=99.9 if method == 'percentile' else 99.99,
fuse_bn=True,
verbose=True
)
# Evaluate
acc = evaluate(model_q, test_loader)
results[method] = acc
print(f"{method} Accuracy: {acc:.2f}%")
# Summary
print(f"\n{'='*60}")
print("Summary: Calibration Method Comparison")
print('='*60)
for method, acc in sorted(results.items(), key=lambda x: x[1], reverse=True):
print(f"{method:20s}: {acc:.2f}%")
# Expected output:
# ============================================================
# Summary: Calibration Method Comparison
# ============================================================
# kl_divergence : 71.85%
# mse : 71.72%
# percentile : 71.58%
# minmax : 71.34%
Best Practices¶
1. Choosing PTQ Method¶
Decision Tree:
Need calibration data?
├─ No → Use Dynamic PTQ
└─ Yes →
└─ High accuracy required?
├─ Yes → Use Static PTQ (KL-divergence or MSE)
└─ No → Use Static PTQ (MinMax or Percentile)
Need mixed precision?
└─ Use Mixed PTQ (W8A16 for LLMs, W4A8 for extreme compression)
2. Calibration Data Guidelines¶
Size:
Vision models: 50-200 batches (200-800 images)
NLP models: 100-500 batches
Small models (<10M params): 50 batches
Large models (>100M params): 200-1000 batches
Diversity:
# Good: Diverse calibration data
calibration_data = sample_diverse_batches(train_loader, n=100)
# Bad: Only one class
calibration_data = [cat_images for _ in range(100)] # Only cats!
Preprocessing:
# Apply same preprocessing as training
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
3. Conv+BN Fusion (PyTorch)¶
Always enable for better accuracy:
# Good
model_q = static_post_quantization(
model, chop,
calibration_data=data,
fuse_bn=True, # ✅ Improves accuracy ~1-2%
verbose=True
)
# Bad
model_q = static_post_quantization(
model, chop,
calibration_data=data,
fuse_bn=False, # ❌ Lower accuracy
verbose=True
)
4. Percentile Selection¶
Guidelines:
99.99%: Default, works for most cases
99.9%: More aggressive clipping, better for outlier-heavy data
99.0%: Very aggressive, use with caution
# Data with many outliers (e.g., sensor data)
model_q = static_post_quantization(
model, chop,
calibration_data=data,
calibration_method='percentile',
percentile=99.9, # Clip 0.1% outliers
verbose=True
)
5. JAX Backend Tips¶
Always provide ``model_apply_fn`` for activation quantization:
# Good
def apply_fn(params, x):
return model.apply(params, x, train=False)
result = static_post_quantization(
variables, chop,
calibration_data=data,
model_apply_fn=apply_fn, # ✅ Required for activation stats
verbose=True
)
# Bad
result = static_post_quantization(
variables, chop,
calibration_data=data,
# ❌ No model_apply_fn = no activation quantization
verbose=True
)
Troubleshooting¶
Common Issues¶
Issue 1: Large Accuracy Drop (>5%)
# Solution 1: Use better calibration method
model_q = static_post_quantization(
model, chop,
calibration_data=data,
calibration_method='kl_divergence', # Try KL instead of minmax
fuse_bn=True,
verbose=True
)
# Solution 2: Increase calibration data
calibration_data = [batch for batch in train_loader[:200]] # More data
# Solution 3: Use mixed-precision
weight_chop = Chopi(bits=8)
activation_chop = Chop(exp_bits=5, sig_bits=10) # FP16 activations
model_q = mixed_post_quantization(
model, weight_chop, activation_chop,
calibration_data=data,
verbose=True
)
Issue 2: Slow Calibration
# Solution 1: Use faster calibration method
model_q = static_post_quantization(
model, chop,
calibration_data=data,
calibration_method='percentile', # Faster than KL/MSE
verbose=True
)
# Solution 2: Reduce calibration data
calibration_data = calibration_data[:50] # Use fewer batches
Issue 3: Out of Memory (OOM)
# Solution 1: Reduce batch size in calibration data
calibration_data = [
torch.randn(2, 3, 224, 224) # Smaller batch size
for _ in range(100)
]
# Solution 2: Use dynamic PTQ (no calibration)
model_q = dynamic_post_quantization(model, chop, verbose=True)
Issue 4: JAX “model_apply_fn required” Error
# Solution: Always provide apply function
def apply_fn(params, x):
return model.apply(params, x, train=False)
result = static_post_quantization(
variables, chop,
calibration_data=data,
model_apply_fn=apply_fn, # ✅ Add this
verbose=True
)