"""
Microscaling (MX) Formats - Backend Agnostic Entry Point
OCP Microscaling format with automatic backend detection.
Supports NumPy, JAX, and PyTorch backends.
MX Format Structure:
- Block of N elements (typically 32)
- One shared scale factor (exponent) per block
- Each element has its own exponent and mantissa
- Significantly better dynamic range than BFP
Reference:
OCP Microscaling Formats (MX) v1.0 Specification
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
Usage:
>>> import pychop
>>> pychop.backend('auto') # Auto-detect
>>>
>>> # NumPy
>>> import numpy as np
>>> X = np.random.randn(1024, 768)
>>> X_q = mx_quantize(X, format='mxfp8_e4m3')
>>>
>>> # PyTorch (with STE)
>>> import torch
>>> X = torch.randn(128, 768, requires_grad=True)
>>> X_q = mx_quantize(X, format='mxfp8_e4m3')
>>> loss = X_q.sum()
>>> loss.backward() # Automatic STE!
>>>
>>> # JAX
>>> import jax.numpy as jnp
>>> X = jnp.array(np.random.randn(512, 512))
>>> X_q = mx_quantize(X, format='mxfp8_e4m3')
Author: Xinye Chen
"""
import os
from typing import Union, Tuple, Optional, Any, Dict
from dataclasses import dataclass
# ============================================================================
# Backend Detection (inline to avoid import issues)
# ============================================================================
def _detect_array_type(x: Any) -> str:
"""Detect backend from input array type."""
module = type(x).__module__
if "torch" in module:
return "torch"
if "jax" in module:
return "jax"
if "tensorflow" in module:
return "tensorflow"
return "numpy"
def _get_backend_env() -> str:
"""Get backend from environment variable."""
return os.environ.get('chop_backend', 'auto')
# ============================================================================
# MX Format Specification (Backend-Independent)
# ============================================================================
[docs]
@dataclass
class MXSpec:
"""
Microscaling format specification.
MX format uses:
- Shared scale (exponent) for block of elements
- Individual exponent + mantissa for each element
Attributes
----------
name : str
Format name (e.g., 'MXFP8_E4M3')
exp_bits : int
Element exponent bits
sig_bits : int
Element significand bits (excluding implicit 1)
block_size : int
Elements per block
scale_exp_bits : int
Scale factor exponent bits
scale_sig_bits : int
Scale factor significand bits
"""
name: str
exp_bits: int
sig_bits: int
block_size: int = 32
scale_exp_bits: int = 8
scale_sig_bits: int = 0 # Scale is typically just exponent
@property
def element_bits(self) -> int:
"""Total bits per element (1 sign + exp + sig)."""
return 1 + self.exp_bits + self.sig_bits
@property
def total_bits_per_block(self) -> int:
"""Total bits for entire block (elements + scale)."""
element_bits = self.element_bits * self.block_size
scale_bits = self.scale_exp_bits + self.scale_sig_bits
return element_bits + scale_bits
@property
def compression_vs_fp32(self) -> float:
"""Compression ratio vs FP32."""
fp32_bits = 32 * self.block_size
return fp32_bits / self.total_bits_per_block
@property
def compression_vs_fp16(self) -> float:
"""Compression ratio vs FP16."""
fp16_bits = 16 * self.block_size
return fp16_bits / self.total_bits_per_block
def __repr__(self):
return (f"MXSpec(name='{self.name}', E{self.exp_bits}M{self.sig_bits}, "
f"block_size={self.block_size})")
# Predefined MX formats (OCP standard)
MX_FORMATS = {
# MXFP8 formats
'mxfp8_e5m2': MXSpec('MXFP8_E5M2', exp_bits=5, sig_bits=2, block_size=32),
'mxfp8_e4m3': MXSpec('MXFP8_E4M3', exp_bits=4, sig_bits=3, block_size=32),
# MXFP6 formats
'mxfp6_e3m2': MXSpec('MXFP6_E3M2', exp_bits=3, sig_bits=2, block_size=32),
'mxfp6_e2m3': MXSpec('MXFP6_E2M3', exp_bits=2, sig_bits=3, block_size=32),
# MXFP4 format
'mxfp4_e2m1': MXSpec('MXFP4_E2M1', exp_bits=2, sig_bits=1, block_size=32),
# MXINT8 (integer format with MX scaling)
'mxint8': MXSpec('MXINT8', exp_bits=0, sig_bits=7, block_size=32),
}
[docs]
def create_mx_spec(
exp_bits: int,
sig_bits: int,
block_size: int = 32,
scale_exp_bits: int = 8,
name: Optional[str] = None
) -> MXSpec:
"""Create custom MX format specification."""
if name is None:
total_bits = 1 + exp_bits + sig_bits
name = f"CUSTOM_MX{total_bits}_E{exp_bits}M{sig_bits}"
return MXSpec(
name=name,
exp_bits=exp_bits,
sig_bits=sig_bits,
block_size=block_size,
scale_exp_bits=scale_exp_bits
)
# ============================================================================
# Backend Detection and Routing
# ============================================================================
def _resolve_backend(X: Any = None) -> str:
"""Resolve which backend to use."""
env_backend = _get_backend_env()
if env_backend == 'auto':
if X is not None:
return _detect_array_type(X)
else:
return 'numpy'
if env_backend not in {'numpy', 'jax', 'torch', 'tensorflow'}:
raise ValueError(
f"Invalid backend: {env_backend}. "
"Must be 'numpy', 'jax', 'torch', 'tensorflow', or 'auto'."
)
return env_backend
def _get_backend_module(backend: str):
"""Get backend-specific MX implementation."""
if backend == 'torch':
try:
from .tch import mx_formats as backend_module
except ImportError:
raise ImportError(
"PyTorch backend not available. "
"Install with: pip install torch"
)
elif backend == 'jax':
try:
from .jx import mx_formats as backend_module
except ImportError:
raise ImportError(
"JAX backend not available. "
"Install with: pip install jax jaxlib flax"
)
elif backend == 'numpy':
from .np import mx_formats as backend_module
elif backend == 'tensorflow':
try:
from .tf import mx_formats as backend_module
except ImportError:
raise ImportError(
"TensorFlow backend not available. "
"Install with: pip install tensorflow"
)
else:
raise ValueError(f"Unsupported backend: {backend}")
return backend_module
# ============================================================================
# User-Facing Functions
# ============================================================================
[docs]
def mx_quantize(
data: Any,
format: Union[str, MXSpec, Tuple[int, int]] = 'mxfp8_e4m3',
block_size: int = 32,
scale_exp_bits: Optional[int] = None,
scale_sig_bits: Optional[int] = None,
backend: Optional[str] = None
) -> Any:
"""
Quantize array to MX format.
Examples
--------
>>> import numpy as np
>>> X = np.random.randn(1024, 768)
>>> X_q = mx_quantize(X, format='mxfp8_e4m3')
"""
# Resolve backend
if backend is None:
backend = _resolve_backend(data)
# Get backend module
backend_module = _get_backend_module(backend)
# Call backend-specific quantization
return backend_module.mx_quantize(
data,
format=format,
block_size=block_size,
scale_exp_bits=scale_exp_bits,
scale_sig_bits=scale_sig_bits
)
[docs]
class MXTensor:
"""Backend-agnostic MX tensor wrapper."""
def __init__(
self,
data: Any,
format: Union[str, MXSpec, Tuple[int, int]] = 'mxfp8_e4m3',
block_size: int = 32,
scale_exp_bits: Optional[int] = None,
scale_sig_bits: Optional[int] = None,
backend: Optional[str] = None
):
# Resolve backend
if backend is None:
self.backend = _resolve_backend(data)
else:
self.backend = backend
# Get backend module
backend_module = _get_backend_module(self.backend)
# Create backend-specific tensor
self._impl = backend_module.MXTensor_(
data,
format=format,
block_size=block_size,
scale_exp_bits=scale_exp_bits,
scale_sig_bits=scale_sig_bits
)
[docs]
def dequantize(self) -> Any:
"""Dequantize to original data type."""
return self._impl.dequantize()
[docs]
def statistics(self) -> dict:
"""Get quantization statistics."""
return self._impl.statistics()
def __repr__(self):
return f"MXTensor(backend={self.backend}, impl={self._impl})"
__all__ = [
'MXSpec',
'MXTensor',
'MX_FORMATS',
'create_mx_spec',
'mx_quantize',
'compare_mx_formats',
'print_mx_format_table',
]