import os
def _get_backend():
"""Get current backend from environment variable."""
backend = os.environ.get("chop_backend", "auto")
return "tensorflow" if backend == "tf" else backend
def _resolve_backend_for_model(model):
backend = _get_backend()
if backend != "auto":
return backend
module = getattr(type(model), "__module__", "").lower()
if "tensorflow" in module or "keras" in module:
return "tensorflow"
if "torch" in module:
return "torch"
if "flax" in module or "jax" in module:
return "jax"
return "torch"
def _import_backend_layers(backend: str):
"""Dynamically import backend-specific layer implementations.
Raises
------
ImportError
If the backend's layers module cannot be imported (e.g., missing dependencies).
"""
if backend == "jax":
try:
from .jx import layers as backend_module
except ImportError as e:
if 'flax' in str(e):
raise ImportError(
"JAX backend requires 'flax' to be installed. "
"Install it with: pip install flax jax jaxlib\n"
"Or switch to PyTorch backend: pychop.backend('torch')"
) from e
raise
elif backend == "torch":
try:
from .tch import layers as backend_module
except ImportError as e:
if 'torch' in str(e):
raise ImportError(
"PyTorch backend requires 'torch' to be installed. "
"Install it with: pip install torch\n"
"Or switch to JAX backend: pychop.backend('jax')"
) from e
raise
elif backend == "tensorflow":
try:
from .tf import layers as backend_module
except ImportError as e:
if 'tensorflow' in str(e).lower():
raise ImportError(
"TensorFlow backend requires 'tensorflow' to be installed. "
"Install it with: pip install tensorflow\n"
"Or switch to PyTorch backend: pychop.backend('torch')"
) from e
raise
else:
# Default to torch
try:
from .tch import layers as backend_module
except ImportError:
raise ImportError(
f"Unknown backend '{backend}' and PyTorch backend is not available. "
f"Valid backends: 'torch', 'jax', 'tensorflow'"
)
return backend_module
# ==================================================================
# Factory Functions with Better Error Messages
# ==================================================================
def _create_layer_factory(layer_name: str):
"""Create a factory function for a specific layer type."""
def factory(*args, **kwargs):
backend = _get_backend()
try:
module = _import_backend_layers(backend)
except ImportError as e:
raise ImportError(
f"Cannot create {layer_name}: {e}"
) from e
layer_class = getattr(module, layer_name, None)
if layer_class is None:
raise AttributeError(
f"{layer_name} is not available in {backend} backend. "
f"Please check the documentation or try a different backend."
)
return layer_class(*args, **kwargs)
factory.__name__ = layer_name
factory.__doc__ = f"Create a {layer_name} for the current backend."
return factory
# ==================================================================
# STE Wrappers (frontend - backend agnostic in concept)
# ==================================================================
[docs]
def ChopSTE(*args, **kwargs):
"""Create a ChopSTE instance for the current backend.
Raises
------
ImportError
If the current backend's dependencies are not installed.
"""
backend = _get_backend()
module = _import_backend_layers(backend)
return module.ChopSTE(*args, **kwargs)
[docs]
def ChopfSTE(*args, **kwargs):
"""Create a ChopfSTE instance for the current backend.
Raises
------
ImportError
If the current backend's dependencies are not installed.
"""
backend = _get_backend()
module = _import_backend_layers(backend)
return module.ChopfSTE(*args, **kwargs)
[docs]
def ChopiSTE(*args, **kwargs):
"""Create a ChopiSTE instance for the current backend.
Raises
------
ImportError
If the current backend's dependencies are not installed.
"""
backend = _get_backend()
module = _import_backend_layers(backend)
return module.ChopiSTE(*args, **kwargs)
# ==================================================================
# Post-Quantization (dispatch to backend-specific implementation)
# ==================================================================
[docs]
def post_quantization(model, chop, eval_mode: bool = True, verbose: bool = False):
"""Post-training quantization (PTQ) wrapper.
Dispatches to backend-specific implementation.
Parameters
----------
model : torch.nn.Module or flax.linen.Module
Neural network model.
chop : Chop, Chopf, or Chopi
Quantizer instance.
eval_mode : bool, default=True
Whether to set model to evaluation mode (PyTorch only).
verbose : bool, default=False
Whether to print quantization details.
Returns
-------
model
Quantized model.
Raises
------
ImportError
If the current backend's dependencies are not installed.
"""
backend = _resolve_backend_for_model(model)
module = _import_backend_layers(backend)
return module.post_quantization(model, chop, eval_mode, verbose)
# ==================================================================
# Static Post-Quantization (dispatch to backend)
# ==================================================================
def static_post_quantization(model, chop, calibration_data,
eval_mode: bool = True, verbose: bool = False):
"""Static post-training quantization with activation calibration.
Quantizes weights/biases AND activations using the same quantizer.
Uses *calibration_data* to collect per-layer activation min/max,
then clamps + quantizes activations during inference.
Dispatches to backend-specific implementation.
Parameters
----------
model : torch.nn.Module or flax.linen.Module
Neural network model.
chop : Chop, Chopf, or Chopi
Quantizer instance (used for both weights and activations).
calibration_data : iterable
Input data for calibration (DataLoader, list of tensors, etc.).
eval_mode : bool, default=True
Set model to eval mode (PyTorch only).
verbose : bool, default=False
Print quantization details.
Returns
-------
model or dict
PyTorch: quantized ``nn.Module`` with static activation hooks.
JAX: dict with ``params``, ``activation_stats``, ``quantized_apply``.
Raises
------
ImportError
If the current backend's dependencies are not installed.
"""
backend = _resolve_backend_for_model(model)
module = _import_backend_layers(backend)
return module.static_post_quantization(
model, chop, calibration_data, eval_mode, verbose
)
# ==================================================================
# Dynamic Post-Quantization (dispatch to backend)
# ==================================================================
def dynamic_post_quantization(model, chop,
eval_mode: bool = True, verbose: bool = False):
"""Dynamic post-training quantization — no calibration needed.
Quantizes weights/biases offline using the same quantizer; activations
are quantized on-the-fly at every inference step.
Dispatches to backend-specific implementation.
Parameters
----------
model : torch.nn.Module or flax.linen.Module
Neural network model.
chop : Chop, Chopf, or Chopi
Quantizer instance (used for both weights and activations).
eval_mode : bool, default=True
Set model to eval mode (PyTorch only).
verbose : bool, default=False
Print quantization details.
Returns
-------
model or dict
PyTorch: quantized ``nn.Module`` with dynamic activation hooks.
JAX: dict with ``params`` and ``dynamic_apply``.
Raises
------
ImportError
If the current backend's dependencies are not installed.
"""
backend = _resolve_backend_for_model(model)
module = _import_backend_layers(backend)
return module.dynamic_post_quantization(model, chop, eval_mode, verbose)
# ==================================================================
# Mixed-Precision Post-Quantization (dispatch to backend)
# ==================================================================
def mixed_post_quantization(model, weight_chop, activation_chop,
calibration_data=None, dynamic: bool = True,
eval_mode: bool = True, verbose: bool = False):
"""Mixed-precision post-training quantization (e.g. W8A8, W4A16).
Uses **separate quantizers** for weights and activations, enabling
fine-grained control over precision allocation such as W8A8, W4A8,
W8A16, etc.
Dispatches to backend-specific implementation.
Parameters
----------
model : torch.nn.Module or flax.linen.Module
Neural network model.
weight_chop : Chop, Chopf, Chopi, or None
Quantizer for weights/biases. ``None`` = keep full-precision.
activation_chop : Chop, Chopf, Chopi, or None
Quantizer for activations. ``None`` = keep full-precision.
calibration_data : iterable or None
Required when ``dynamic=False``. Ignored when ``dynamic=True``.
dynamic : bool, default=True
``True`` = dynamic activation quantization (no calibration).
``False`` = static calibration + clamp + quantize.
eval_mode : bool, default=True
Set model to eval mode (PyTorch only).
verbose : bool, default=False
Print quantization details.
Returns
-------
model or dict
PyTorch: quantized ``nn.Module``.
JAX: dict with ``params`` and ``mixed_apply``.
Raises
------
ValueError
If ``dynamic=False`` and ``calibration_data`` is None.
ImportError
If the current backend's dependencies are not installed.
"""
backend = _resolve_backend_for_model(model)
module = _import_backend_layers(backend)
return module.mixed_post_quantization(
model, weight_chop, activation_chop,
calibration_data=calibration_data,
dynamic=dynamic,
eval_mode=eval_mode,
verbose=verbose,
)
# ==================================================================
# Layer Factory Functions
# ==================================================================
def _create_layer_factory(layer_name: str):
"""Create a factory function for a specific layer type."""
def factory(*args, **kwargs):
backend = _get_backend()
module = _import_backend_layers(backend)
layer_class = getattr(module, layer_name)
return layer_class(*args, **kwargs)
factory.__name__ = layer_name
factory.__doc__ = f"Create a {layer_name} for the current backend."
return factory
# ==================================================================
# Floating-Point Quantized Layers
# ==================================================================
# Convolution layers
QuantizedLinear = _create_layer_factory("QuantizedLinear")
QuantizedConv1d = _create_layer_factory("QuantizedConv1d")
QuantizedConv2d = _create_layer_factory("QuantizedConv2d")
QuantizedConv3d = _create_layer_factory("QuantizedConv3d")
QuantizedConvTranspose1d = _create_layer_factory("QuantizedConvTranspose1d")
QuantizedConvTranspose2d = _create_layer_factory("QuantizedConvTranspose2d")
QuantizedConvTranspose3d = _create_layer_factory("QuantizedConvTranspose3d")
# Recurrent layers
QuantizedRNN = _create_layer_factory("QuantizedRNN")
QuantizedLSTM = _create_layer_factory("QuantizedLSTM")
QuantizedGRU = _create_layer_factory("QuantizedGRU")
# Pooling layers
QuantizedMaxPool1d = _create_layer_factory("QuantizedMaxPool1d")
QuantizedMaxPool2d = _create_layer_factory("QuantizedMaxPool2d")
QuantizedMaxPool3d = _create_layer_factory("QuantizedMaxPool3d")
QuantizedAvgPool1d = _create_layer_factory("QuantizedAvgPool1d")
QuantizedAvgPool2d = _create_layer_factory("QuantizedAvgPool2d")
QuantizedAvgPool3d = _create_layer_factory("QuantizedAvgPool3d")
QuantizedAdaptiveAvgPool2d = _create_layer_factory("QuantizedAdaptiveAvgPool2d")
# Normalization layers
QuantizedBatchNorm1d = _create_layer_factory("QuantizedBatchNorm1d")
QuantizedBatchNorm2d = _create_layer_factory("QuantizedBatchNorm2d")
QuantizedBatchNorm3d = _create_layer_factory("QuantizedBatchNorm3d")
QuantizedLayerNorm = _create_layer_factory("QuantizedLayerNorm")
QuantizedInstanceNorm1d = _create_layer_factory("QuantizedInstanceNorm1d")
QuantizedInstanceNorm2d = _create_layer_factory("QuantizedInstanceNorm2d")
QuantizedInstanceNorm3d = _create_layer_factory("QuantizedInstanceNorm3d")
QuantizedGroupNorm = _create_layer_factory("QuantizedGroupNorm")
# Attention layers
QuantizedMultiheadAttention = _create_layer_factory("QuantizedMultiheadAttention")
QuantizedAttention = _create_layer_factory("QuantizedMultiheadAttention") # Alias
# Activation layers
QuantizedReLU = _create_layer_factory("QuantizedReLU")
QuantizedSigmoid = _create_layer_factory("QuantizedSigmoid")
QuantizedTanh = _create_layer_factory("QuantizedTanh")
QuantizedLeakyReLU = _create_layer_factory("QuantizedLeakyReLU")
QuantizedSoftmax = _create_layer_factory("QuantizedSoftmax")
QuantizedGELU = _create_layer_factory("QuantizedGELU")
QuantizedELU = _create_layer_factory("QuantizedELU")
QuantizedPReLU = _create_layer_factory("QuantizedPReLU")
# Dropout
QuantizedDropout = _create_layer_factory("QuantizedDropout")
# Embedding
QuantizedEmbedding = _create_layer_factory("QuantizedEmbedding")
# Aliases
QuantizedAvgPool = QuantizedAvgPool2d
# ==================================================================
# Integer Quantized Layers
# ==================================================================
# Convolution layers
IQuantizedLinear = _create_layer_factory("IQuantizedLinear")
IQuantizedConv1d = _create_layer_factory("IQuantizedConv1d")
IQuantizedConv2d = _create_layer_factory("IQuantizedConv2d")
IQuantizedConv3d = _create_layer_factory("IQuantizedConv3d")
IQuantizedConvTranspose1d = _create_layer_factory("IQuantizedConvTranspose1d")
IQuantizedConvTranspose2d = _create_layer_factory("IQuantizedConvTranspose2d")
IQuantizedConvTranspose3d = _create_layer_factory("IQuantizedConvTranspose3d")
# Recurrent layers
IQuantizedRNN = _create_layer_factory("IQuantizedRNN")
IQuantizedLSTM = _create_layer_factory("IQuantizedLSTM")
IQuantizedGRU = _create_layer_factory("IQuantizedGRU")
# Pooling layers
IQuantizedMaxPool1d = _create_layer_factory("IQuantizedMaxPool1d")
IQuantizedMaxPool2d = _create_layer_factory("IQuantizedMaxPool2d")
IQuantizedMaxPool3d = _create_layer_factory("IQuantizedMaxPool3d")
IQuantizedAvgPool1d = _create_layer_factory("IQuantizedAvgPool1d")
IQuantizedAvgPool2d = _create_layer_factory("IQuantizedAvgPool2d")
IQuantizedAvgPool3d = _create_layer_factory("IQuantizedAvgPool3d")
IQuantizedAdaptiveAvgPool1d = _create_layer_factory("IQuantizedAdaptiveAvgPool1d")
IQuantizedAdaptiveAvgPool2d = _create_layer_factory("IQuantizedAdaptiveAvgPool2d")
IQuantizedAdaptiveAvgPool3d = _create_layer_factory("IQuantizedAdaptiveAvgPool3d")
# Normalization layers
IQuantizedBatchNorm1d = _create_layer_factory("IQuantizedBatchNorm1d")
IQuantizedBatchNorm2d = _create_layer_factory("IQuantizedBatchNorm2d")
IQuantizedBatchNorm3d = _create_layer_factory("IQuantizedBatchNorm3d")
IQuantizedLayerNorm = _create_layer_factory("IQuantizedLayerNorm")
IQuantizedInstanceNorm1d = _create_layer_factory("IQuantizedInstanceNorm1d")
IQuantizedInstanceNorm2d = _create_layer_factory("IQuantizedInstanceNorm2d")
IQuantizedInstanceNorm3d = _create_layer_factory("IQuantizedInstanceNorm3d")
IQuantizedGroupNorm = _create_layer_factory("IQuantizedGroupNorm")
# Attention layers
IQuantizedMultiheadAttention = _create_layer_factory("IQuantizedMultiheadAttention")
IQuantizedAttention = _create_layer_factory("IQuantizedMultiheadAttention") # Alias
# Activation layers
IQuantizedReLU = _create_layer_factory("IQuantizedReLU")
IQuantizedSigmoid = _create_layer_factory("IQuantizedSigmoid")
IQuantizedTanh = _create_layer_factory("IQuantizedTanh")
IQuantizedLeakyReLU = _create_layer_factory("IQuantizedLeakyReLU")
IQuantizedSoftmax = _create_layer_factory("IQuantizedSoftmax")
IQuantizedGELU = _create_layer_factory("IQuantizedGELU")
IQuantizedELU = _create_layer_factory("IQuantizedELU")
IQuantizedPReLU = _create_layer_factory("IQuantizedPReLU")
IQuantizedSiLU = _create_layer_factory("IQuantizedSiLU")
# Dropout
IQuantizedDropout = _create_layer_factory("IQuantizedDropout")
# Embedding
IQuantizedEmbedding = _create_layer_factory("IQuantizedEmbedding")
# Aliases
IQuantizedAvgPool = IQuantizedAvgPool2d
# ==================================================================
# Export all symbols
# ==================================================================
__all__ = [
# STE wrappers
"ChopSTE", "ChopfSTE", "ChopiSTE",
# Utilities
"post_quantization",
"static_post_quantization",
"dynamic_post_quantization",
"mixed_post_quantization",
# Floating-point quantized layers
"QuantizedLinear", "QuantizedConv1d", "QuantizedConv2d", "QuantizedConv3d",
"QuantizedConvTranspose1d", "QuantizedConvTranspose2d", "QuantizedConvTranspose3d",
"QuantizedRNN", "QuantizedLSTM", "QuantizedGRU",
"QuantizedMaxPool1d", "QuantizedMaxPool2d", "QuantizedMaxPool3d",
"QuantizedAvgPool1d", "QuantizedAvgPool2d", "QuantizedAvgPool3d",
"QuantizedAdaptiveAvgPool2d",
"QuantizedBatchNorm1d", "QuantizedBatchNorm2d", "QuantizedBatchNorm3d",
"QuantizedLayerNorm", "QuantizedInstanceNorm1d", "QuantizedInstanceNorm2d",
"QuantizedInstanceNorm3d", "QuantizedGroupNorm",
"QuantizedMultiheadAttention", "QuantizedAttention",
"QuantizedReLU", "QuantizedSigmoid", "QuantizedTanh", "QuantizedLeakyReLU",
"QuantizedSoftmax", "QuantizedGELU", "QuantizedELU", "QuantizedPReLU",
"QuantizedDropout", "QuantizedEmbedding", "QuantizedAvgPool",
# Integer quantized layers
"IQuantizedLinear", "IQuantizedConv1d", "IQuantizedConv2d", "IQuantizedConv3d",
"IQuantizedConvTranspose1d", "IQuantizedConvTranspose2d", "IQuantizedConvTranspose3d",
"IQuantizedRNN", "IQuantizedLSTM", "IQuantizedGRU",
"IQuantizedMaxPool1d", "IQuantizedMaxPool2d", "IQuantizedMaxPool3d",
"IQuantizedAvgPool1d", "IQuantizedAvgPool2d", "IQuantizedAvgPool3d",
"IQuantizedAdaptiveAvgPool1d", "IQuantizedAdaptiveAvgPool2d", "IQuantizedAdaptiveAvgPool3d",
"IQuantizedBatchNorm1d", "IQuantizedBatchNorm2d", "IQuantizedBatchNorm3d",
"IQuantizedLayerNorm", "IQuantizedInstanceNorm1d", "IQuantizedInstanceNorm2d",
"IQuantizedInstanceNorm3d", "IQuantizedGroupNorm",
"IQuantizedMultiheadAttention", "IQuantizedAttention",
"IQuantizedReLU", "IQuantizedSigmoid", "IQuantizedTanh", "IQuantizedLeakyReLU",
"IQuantizedSoftmax", "IQuantizedGELU", "IQuantizedELU", "IQuantizedPReLU", "IQuantizedSiLU",
"IQuantizedDropout", "IQuantizedEmbedding", "IQuantizedAvgPool",
]