Quantized optimizers module¶
This module provides custom PyTorch optimizers with quantized momentum and accumulator states, designed for low-precision training simulations. These optimizers extend torch.optim.Optimizer and utilize the Chop or Chopf for quantization.
Classes¶
- class QuantizedSGD(params, lr=0.01, momentum=0.9, weight_decay=0, exp_bits=8, sig_bits=7, rmode=1)¶
SGD optimizer with quantized momentum.
- Parameters:
params (iterable) – Iterable of parameters to optimize or dicts defining parameter groups.
lr (float) – Learning rate (default: 0.01).
momentum (float) – Momentum factor (default: 0.9).
weight_decay (float) – Weight decay (L2 penalty) (default: 0).
exp_bits (int) – Number of exponent bits for quantization (default: 8).
sig_bits (int) – Number of mantissa bits for quantization (default: 7).
rmode (int) – Rounding mode for quantization (e.g., 1, “nearest”) (default: 1).
Quantizes the momentum buffer (if momentum > 0) and the parameter update.
Example:
optimizer = QuantizedSGD(model.parameters(), lr=0.01, momentum=0.9, rmode=1) optimizer.zero_grad() loss.backward() optimizer.step()
- class QuantizedRMSprop(params, lr=0.01, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, exp_bits=8, sig_bits=7, rmode=1)¶
RMSprop optimizer with quantized accumulator and optional momentum.
- Parameters:
params (iterable) – Iterable of parameters to optimize or dicts defining parameter groups.
lr (float) – Learning rate (default: 0.01).
alpha (float) – Smoothing constant for accumulator (default: 0.99).
eps (float) – Term added to denominator for numerical stability (default: 1e-8).
weight_decay (float) – Weight decay (L2 penalty) (default: 0).
momentum (float) – Momentum factor (default: 0).
exp_bits (int) – Number of exponent bits for quantization (default: 8).
sig_bits (int) – Number of mantissa bits for quantization (default: 7).
rmode (int) – Rounding mode for quantization (default: 1).
Quantizes the square average accumulator and momentum buffer (if used), as well as the final update.
Example:
optimizer = QuantizedRMSprop(model.parameters(), lr=0.01, momentum=0.9, rmode=5) optimizer.zero_grad() loss.backward() optimizer.step()
- class QuantizedAdagrad(params, lr=0.01, lr_decay=0, weight_decay=0, eps=1e-10, exp_bits=8, sig_bits=7, rmode=1)¶
Adagrad optimizer with quantized accumulator.
- Parameters:
params (iterable) – Iterable of parameters to optimize or dicts defining parameter groups.
lr (float) – Learning rate (default: 0.01).
lr_decay (float) – Learning rate decay (default: 0).
weight_decay (float) – Weight decay (L2 penalty) (default: 0).
eps (float) – Term added to denominator for numerical stability (default: 1e-10).
exp_bits (int) – Number of exponent bits for quantization (default: 8).
sig_bits (int) – Number of mantissa bits for quantization (default: 7).
rmode (int) – Rounding mode for quantization (default: 1).
Quantizes the sum of squared gradients (accumulator) and the parameter update.
Example:
optimizer = QuantizedAdagrad(model.parameters(), lr=0.01, rmode=4) optimizer.zero_grad() loss.backward() optimizer.step()
- class QuantizedAdam(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, exp_bits=8, sig_bits=7, rmode=1)¶
Adam optimizer with quantized momentum and accumulator.
- Parameters:
params (iterable) – Iterable of parameters to optimize or dicts defining parameter groups.
lr (float) – Learning rate (default: 1e-3).
betas (tuple[float, float]) – Coefficients for computing running averages of gradient and its square (default: (0.9, 0.999)).
eps (float) – Term added to denominator for numerical stability (default: 1e-8).
weight_decay (float) – Weight decay (L2 penalty) (default: 0).
exp_bits (int) – Number of exponent bits for quantization (default: 8).
sig_bits (int) – Number of mantissa bits for quantization (default: 7).
rmode (int) – Rounding mode for quantization (default: 1).
Quantizes the first moment (momentum), second moment (accumulator), and the parameter update.
Example:
optimizer = QuantizedAdam(model.parameters(), lr=0.001, rmode=6) optimizer.zero_grad() loss.backward() optimizer.step()
Notes¶
All optimizers rely on the Chop/Chopf for quantization, which must be imported from its respective module.
These optimizers are designed for low-precision training and may exhibit different convergence behavior compared to their full-precision counterparts.