Documentation
¶
Overview ¶
Package fp8 implements FP8 mixed-precision training support.
Stability: alpha
Package fp8 provides FP8 mixed-precision training layers.
Package fp8 provides FP8 mixed-precision training layers.
Index ¶
- Constants
- func DequantizeBlockNVFP4(q []NVFloat4, scale float32) []float32
- func DequantizeNVFP4(n NVFloat4) float32
- type FP8Linear
- func (l *FP8Linear[T]) Attributes() map[string]interface{}
- func (l *FP8Linear[T]) Backward(ctx context.Context, mode types.BackwardMode, ...) ([]*tensor.TensorNumeric[T], error)
- func (l *FP8Linear[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (l *FP8Linear[T]) Name() string
- func (l *FP8Linear[T]) OpType() string
- func (l *FP8Linear[T]) OutputShape() []int
- func (l *FP8Linear[T]) Parameters() []*graph.Parameter[T]
- func (l *FP8Linear[T]) SyncFP8Weights() error
- type LossScaler
- type MasterWeightStore
- type NVFloat4
Constants ¶
const NVFP4BlockSize = 32
NVFP4BlockSize is the number of elements per quantization block. Each block shares a single FP32 scale factor (absmax scaling).
Variables ¶
This section is empty.
Functions ¶
func DequantizeBlockNVFP4 ¶
DequantizeBlockNVFP4 dequantizes a block of NVFP4 values back to float32 using the provided scale factor from QuantizeBlockNVFP4.
func DequantizeNVFP4 ¶
DequantizeNVFP4 converts an NVFP4 (E2M1) value back to float32.
Types ¶
type FP8Linear ¶
FP8Linear implements a linear layer that uses FP8 quantized weights for the forward pass and maintains full-precision master weights for gradient updates.
Forward: quantize input and weight to FP8 (via per-tensor absmax scaling), compute the matmul (GPU engine dispatches to FP8 GEMM when both operands carry FP8E4M3Storage), output in full precision.
Backward: use full-precision master weights to compute standard gradients for both input and weight. After the optimizer step, call SyncFP8Weights to refresh the FP8 snapshot.
func NewFP8Linear ¶
func NewFP8Linear[T tensor.Numeric]( name string, engine compute.Engine[T], inFeatures, outFeatures int, initData []T, ) (*FP8Linear[T], error)
NewFP8Linear creates an FP8 linear layer with the given dimensions. initData provides the initial weight values in full precision (row-major, shape [outFeatures, inFeatures]). If nil, weights are zero-initialized.
func (*FP8Linear[T]) Attributes ¶
Attributes returns layer attributes.
func (*FP8Linear[T]) Backward ¶
func (l *FP8Linear[T]) Backward(ctx context.Context, mode types.BackwardMode, outputGradient *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward computes full-precision gradients using the master weights. grad_input = outputGradient @ W (shape: [batch, inFeatures]) grad_weight = outputGradient^T @ x (shape: [outFeatures, inFeatures])
func (*FP8Linear[T]) Forward ¶
func (l *FP8Linear[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes y = x @ W^T using FP8 quantized weights. Input x has shape [batch, inFeatures], output has shape [batch, outFeatures].
func (*FP8Linear[T]) OutputShape ¶
OutputShape returns the output shape [-1, outFeatures].
func (*FP8Linear[T]) Parameters ¶
Parameters returns the trainable master weight parameter.
func (*FP8Linear[T]) SyncFP8Weights ¶
SyncFP8Weights re-quantizes the master weights to FP8 after an optimizer step. Call this after each optimizer.Step().
type LossScaler ¶
type LossScaler struct {
Scale float64
GrowInterval int // steps between scale doublings (default: 2000)
// contains filtered or unexported fields
}
LossScaler implements dynamic loss scaling for FP8 mixed-precision training. It scales the loss before backpropagation to prevent gradient underflow in low-precision formats, and dynamically adjusts the scale factor based on whether overflow (inf/NaN) is detected in the resulting gradients.
func NewLossScaler ¶
func NewLossScaler(initialScale float64) *LossScaler
NewLossScaler creates a LossScaler with the given initial scale factor. GrowInterval defaults to 2000.
func (*LossScaler) CheckGradients ¶
func (ls *LossScaler) CheckGradients(grads [][]float32) bool
CheckGradients inspects all gradient values for inf or NaN. If any are found, it halves the scale (with a floor of 1.0) and returns false. Returns true if all gradients are finite.
func (*LossScaler) ScaleLoss ¶
func (ls *LossScaler) ScaleLoss(loss float64) float64
ScaleLoss returns loss multiplied by the current scale factor.
func (*LossScaler) UnscaleGradients ¶
func (ls *LossScaler) UnscaleGradients(grads [][]float32)
UnscaleGradients divides all gradient values by the current scale factor, reversing the effect of ScaleLoss on the gradient magnitudes.
func (*LossScaler) Update ¶
func (ls *LossScaler) Update(hadOverflow bool)
Update advances the step counter. If hadOverflow is true, the counter resets. After GrowInterval consecutive steps without overflow, the scale is doubled.
type MasterWeightStore ¶
MasterWeightStore maintains FP32 master copies of FP8 model parameters. Optimizer updates the FP32 copies; FP8 copies are updated by casting after each step.
func NewMasterWeightStore ¶
func NewMasterWeightStore[T tensor.Numeric](layers []*FP8Linear[T]) (*MasterWeightStore[T], error)
NewMasterWeightStore creates a store for the given FP8Linear layers. It copies each layer's current master weights into a float32 tensor.
func (*MasterWeightStore[T]) FP32Params ¶
func (s *MasterWeightStore[T]) FP32Params() []*tensor.TensorNumeric[float32]
FP32Params returns the FP32 master copy of all parameters. These are the parameters that the optimizer should update.
func (*MasterWeightStore[T]) MemoryBytes ¶
func (s *MasterWeightStore[T]) MemoryBytes() int64
MemoryBytes returns total bytes used by FP32 master weight copies. Each float32 parameter uses 4 bytes.
func (*MasterWeightStore[T]) SyncToFP8 ¶
func (s *MasterWeightStore[T]) SyncToFP8() error
SyncToFP8 casts updated FP32 master weights back to FP8 in each FP8Linear. Call after each optimizer step.
type NVFloat4 ¶
type NVFloat4 uint8
NVFloat4 represents a 4-bit floating-point number in NVIDIA's NVFP4 (E2M1) format.
Bit layout (4 bits packed into a uint8):
- 1 bit : Sign (0 = positive, 1 = negative)
- 2 bits : Exponent (biased by 1, range [0, 3])
- 1 bit : Mantissa (1 explicit bit)
Representable values (positive): 0, 0.5, 1, 1.5, 2, 3, 4, 6 Only the lower 4 bits of the uint8 are used.
func QuantizeBlockNVFP4 ¶
QuantizeBlockNVFP4 quantizes a block of float32 values to NVFP4 with a per-block absmax scale factor. Returns the quantized values and the scale used for dequantization.
func QuantizeToNVFP4 ¶
QuantizeToNVFP4 converts a float32 value to the nearest NVFP4 (E2M1) representation using round-to-nearest-even. Values exceeding the representable range are clamped (saturated).