fp8

package
v1.25.3 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 27, 2026 License: Apache-2.0 Imports: 7 Imported by: 0

Documentation

Overview

Package fp8 implements FP8 mixed-precision training support.

Stability: alpha

Package fp8 provides FP8 mixed-precision training layers.

Package fp8 provides FP8 mixed-precision training layers.

Index

Constants

View Source
const NVFP4BlockSize = 32

NVFP4BlockSize is the number of elements per quantization block. Each block shares a single FP32 scale factor (absmax scaling).

Variables

This section is empty.

Functions

func DequantizeBlockNVFP4

func DequantizeBlockNVFP4(q []NVFloat4, scale float32) []float32

DequantizeBlockNVFP4 dequantizes a block of NVFP4 values back to float32 using the provided scale factor from QuantizeBlockNVFP4.

func DequantizeNVFP4

func DequantizeNVFP4(n NVFloat4) float32

DequantizeNVFP4 converts an NVFP4 (E2M1) value back to float32.

Types

type FP8Linear

type FP8Linear[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

FP8Linear implements a linear layer that uses FP8 quantized weights for the forward pass and maintains full-precision master weights for gradient updates.

Forward: quantize input and weight to FP8 (via per-tensor absmax scaling), compute the matmul (GPU engine dispatches to FP8 GEMM when both operands carry FP8E4M3Storage), output in full precision.

Backward: use full-precision master weights to compute standard gradients for both input and weight. After the optimizer step, call SyncFP8Weights to refresh the FP8 snapshot.

func NewFP8Linear

func NewFP8Linear[T tensor.Numeric](
	name string,
	engine compute.Engine[T],
	inFeatures, outFeatures int,
	initData []T,
) (*FP8Linear[T], error)

NewFP8Linear creates an FP8 linear layer with the given dimensions. initData provides the initial weight values in full precision (row-major, shape [outFeatures, inFeatures]). If nil, weights are zero-initialized.

func (*FP8Linear[T]) Attributes

func (l *FP8Linear[T]) Attributes() map[string]interface{}

Attributes returns layer attributes.

func (*FP8Linear[T]) Backward

func (l *FP8Linear[T]) Backward(ctx context.Context, mode types.BackwardMode, outputGradient *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward computes full-precision gradients using the master weights. grad_input = outputGradient @ W (shape: [batch, inFeatures]) grad_weight = outputGradient^T @ x (shape: [outFeatures, inFeatures])

func (*FP8Linear[T]) Forward

func (l *FP8Linear[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes y = x @ W^T using FP8 quantized weights. Input x has shape [batch, inFeatures], output has shape [batch, outFeatures].

func (*FP8Linear[T]) Name

func (l *FP8Linear[T]) Name() string

Name returns the layer name.

func (*FP8Linear[T]) OpType

func (l *FP8Linear[T]) OpType() string

OpType returns the operation type.

func (*FP8Linear[T]) OutputShape

func (l *FP8Linear[T]) OutputShape() []int

OutputShape returns the output shape [-1, outFeatures].

func (*FP8Linear[T]) Parameters

func (l *FP8Linear[T]) Parameters() []*graph.Parameter[T]

Parameters returns the trainable master weight parameter.

func (*FP8Linear[T]) SyncFP8Weights

func (l *FP8Linear[T]) SyncFP8Weights() error

SyncFP8Weights re-quantizes the master weights to FP8 after an optimizer step. Call this after each optimizer.Step().

type LossScaler

type LossScaler struct {
	Scale        float64
	GrowInterval int // steps between scale doublings (default: 2000)
	// contains filtered or unexported fields
}

LossScaler implements dynamic loss scaling for FP8 mixed-precision training. It scales the loss before backpropagation to prevent gradient underflow in low-precision formats, and dynamically adjusts the scale factor based on whether overflow (inf/NaN) is detected in the resulting gradients.

func NewLossScaler

func NewLossScaler(initialScale float64) *LossScaler

NewLossScaler creates a LossScaler with the given initial scale factor. GrowInterval defaults to 2000.

func (*LossScaler) CheckGradients

func (ls *LossScaler) CheckGradients(grads [][]float32) bool

CheckGradients inspects all gradient values for inf or NaN. If any are found, it halves the scale (with a floor of 1.0) and returns false. Returns true if all gradients are finite.

func (*LossScaler) ScaleLoss

func (ls *LossScaler) ScaleLoss(loss float64) float64

ScaleLoss returns loss multiplied by the current scale factor.

func (*LossScaler) UnscaleGradients

func (ls *LossScaler) UnscaleGradients(grads [][]float32)

UnscaleGradients divides all gradient values by the current scale factor, reversing the effect of ScaleLoss on the gradient magnitudes.

func (*LossScaler) Update

func (ls *LossScaler) Update(hadOverflow bool)

Update advances the step counter. If hadOverflow is true, the counter resets. After GrowInterval consecutive steps without overflow, the scale is doubled.

type MasterWeightStore

type MasterWeightStore[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

MasterWeightStore maintains FP32 master copies of FP8 model parameters. Optimizer updates the FP32 copies; FP8 copies are updated by casting after each step.

func NewMasterWeightStore

func NewMasterWeightStore[T tensor.Numeric](layers []*FP8Linear[T]) (*MasterWeightStore[T], error)

NewMasterWeightStore creates a store for the given FP8Linear layers. It copies each layer's current master weights into a float32 tensor.

func (*MasterWeightStore[T]) FP32Params

func (s *MasterWeightStore[T]) FP32Params() []*tensor.TensorNumeric[float32]

FP32Params returns the FP32 master copy of all parameters. These are the parameters that the optimizer should update.

func (*MasterWeightStore[T]) MemoryBytes

func (s *MasterWeightStore[T]) MemoryBytes() int64

MemoryBytes returns total bytes used by FP32 master weight copies. Each float32 parameter uses 4 bytes.

func (*MasterWeightStore[T]) SyncToFP8

func (s *MasterWeightStore[T]) SyncToFP8() error

SyncToFP8 casts updated FP32 master weights back to FP8 in each FP8Linear. Call after each optimizer step.

type NVFloat4

type NVFloat4 uint8

NVFloat4 represents a 4-bit floating-point number in NVIDIA's NVFP4 (E2M1) format.

Bit layout (4 bits packed into a uint8):

  • 1 bit : Sign (0 = positive, 1 = negative)
  • 2 bits : Exponent (biased by 1, range [0, 3])
  • 1 bit : Mantissa (1 explicit bit)

Representable values (positive): 0, 0.5, 1, 1.5, 2, 3, 4, 6 Only the lower 4 bits of the uint8 are used.

func QuantizeBlockNVFP4

func QuantizeBlockNVFP4(data []float32) ([]NVFloat4, float32)

QuantizeBlockNVFP4 quantizes a block of float32 values to NVFP4 with a per-block absmax scale factor. Returns the quantized values and the scale used for dequantization.

func QuantizeToNVFP4

func QuantizeToNVFP4(f float32) NVFloat4

QuantizeToNVFP4 converts a float32 value to the nearest NVFP4 (E2M1) representation using round-to-nearest-even. Values exceeding the representable range are clamped (saturated).

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL