Documentation
¶
Overview ¶
Package tabular provides tabular ML model types. (Stability: alpha)
Index ¶
- func Save(model *Model, path string) error
- func SparsemaxDirect(input *tensor.TensorNumeric[float32]) (*tensor.TensorNumeric[float32], error)
- type Activation
- type Adapter
- type BaseModel
- type Direction
- type Ensemble
- type FTTransformer
- type FTTransformerConfig
- type LoRAConfig
- type MetaLearnerConfig
- type Model
- func Load(path string, engine compute.Engine[float32], ops numeric.Arithmetic[float32]) (*Model, error)
- func MergeAdapter(base *BaseModel, adapter *Adapter, engine compute.Engine[float32]) (*Model, error)
- func NewModel(config ModelConfig, engine compute.Engine[float32], ...) (*Model, error)
- func Train(data [][]float64, labels []int, config TrainConfig, mc ModelConfig, ...) (*Model, error)
- type ModelConfig
- type NormMode
- type PreTrainConfig
- type SAINT
- type SAINTConfig
- type TabNet
- func (t *TabNet) AttentionMasks() []*tensor.TensorNumeric[float32]
- func (t *TabNet) FeatureImportance(ctx context.Context) (*tensor.TensorNumeric[float32], error)
- func (t *TabNet) Forward(ctx context.Context, input *tensor.TensorNumeric[float32]) (*tensor.TensorNumeric[float32], error)
- func (t *TabNet) Predict(features []float64) (Direction, float64, error)
- type TabNetConfig
- type TabResNet
- type TabResNetConfig
- type TrainConfig
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func Save ¶
Save writes a Model to the given path in the ZTAB binary format.
Format:
- 4-byte magic ("ZTAB")
- 4-byte version (uint32 little-endian, currently 1)
- 4-byte config length (uint32 little-endian)
- JSON-encoded ModelConfig
- Weight data: for each hidden layer then the output head, weights tensor data followed by biases tensor data as raw float32 little-endian bytes.
func SparsemaxDirect ¶
func SparsemaxDirect(input *tensor.TensorNumeric[float32]) (*tensor.TensorNumeric[float32], error)
SparsemaxDirect exposes sparsemax for direct testing.
Types ¶
type Activation ¶
type Activation int
Activation selects the activation function used between hidden layers.
const ( // ActivationReLU uses the ReLU activation function. ActivationReLU Activation = iota // ActivationGELU uses the GELU activation function. ActivationGELU )
type Adapter ¶
type Adapter struct {
Layers map[int]loraLayerAdapter // hidden layer index → adapter
Config LoRAConfig
// Model architecture metadata for validation during merge.
InputDim int
HiddenDims []int
}
Adapter holds LoRA adapter weights produced by FineTuneLoRA.
func FineTuneLoRA ¶
func FineTuneLoRA( base *BaseModel, data [][]float64, labels []int, config LoRAConfig, engine compute.Engine[float32], ops numeric.Arithmetic[float32], ) (*Adapter, error)
FineTuneLoRA applies Low-Rank Adaptation to a pre-trained BaseModel. Only the LoRA A and B matrices are trained; base model weights are frozen. This enables fast adaptation on small per-source datasets.
type BaseModel ¶
type BaseModel struct {
Model *Model
}
BaseModel wraps a pre-trained Model whose weights serve as an initialisation point for fine-tuning on a specific data source.
func PreTrain ¶
func PreTrain( allData [][][]float64, allLabels [][]int, config PreTrainConfig, engine compute.Engine[float32], ops numeric.Arithmetic[float32], ) (*BaseModel, error)
PreTrain trains a tabular model on data from multiple sources so the model learns universal feature patterns. allData[source][sample][feature] contains the feature vectors; allLabels[source][sample] contains the corresponding labels. All sources must share the same feature dimensionality and label space.
func (*BaseModel) FineTune ¶
func (bm *BaseModel) FineTune(data [][]float64, labels []int, config TrainConfig, engine compute.Engine[float32], ops numeric.Arithmetic[float32]) (*Model, error)
FineTune creates a new Model initialised from the BaseModel's pre-trained weights and trains it on the given data. The fine-tuning data must have the same feature dimensionality as the pre-training data.
type Ensemble ¶
type Ensemble struct {
// contains filtered or unexported fields
}
Ensemble combines multiple tabular Models and an optional tree ensemble via stacking. A learned meta-learner MLP fuses sub-model softmax outputs and tree predictions into a final Direction prediction.
func NewEnsemble ¶
func NewEnsemble(models []*Model, treePredictions func([]float64) []float64, engine compute.Engine[float32], ops numeric.Arithmetic[float32]) (*Ensemble, error)
NewEnsemble creates an Ensemble from trained sub-models and an optional tree prediction callback. treePredictions may be nil if no tree ensemble is used. The callback receives raw features and returns tree ensemble outputs (e.g., class probabilities), decoupling the ensemble from any specific tree library.
func (*Ensemble) Predict ¶
Predict runs all sub-models and the tree prediction callback on the given features, concatenates their outputs, and feeds the result through the trained meta-learner to produce a final Direction and confidence.
func (*Ensemble) TrainMetaLearner ¶
func (e *Ensemble) TrainMetaLearner(subModelOutputs [][]float64, labels []int, tc TrainConfig, mlc MetaLearnerConfig) error
TrainMetaLearner trains the stacking meta-learner on pre-computed sub-model outputs. subModelOutputs[i] is the concatenated softmax outputs from all sub-models and tree predictions for sample i. labels[i] is in [0, 3).
type FTTransformer ¶
type FTTransformer struct {
// contains filtered or unexported fields
}
FTTransformer implements the Feature Tokenizer + Transformer architecture for tabular data. Each numeric feature is tokenized via a learned embedding, a CLS token is prepended, and the sequence is processed by a stack of transformer encoder layers. The CLS token output feeds a linear head for 3-class classification (Long/Short/Flat).
func NewFTTransformer ¶
func NewFTTransformer(config FTTransformerConfig, engine compute.Engine[float32], ops numeric.Arithmetic[float32]) (*FTTransformer, error)
NewFTTransformer creates a new FTTransformer with the given configuration.
func (*FTTransformer) Forward ¶
func (ft *FTTransformer) Forward(ctx context.Context, input *tensor.TensorNumeric[float32]) (*tensor.TensorNumeric[float32], error)
Forward runs the FTTransformer forward pass on a batch of inputs. Input shape: [batch, NumFeatures]. Output shape: [batch, 3] (logits).
type FTTransformerConfig ¶
type FTTransformerConfig struct {
NumFeatures int // number of numeric input features
DToken int // embedding dimension per feature token
NHeads int // number of attention heads
NLayers int // number of transformer encoder layers
DFFN int // hidden dimension of the feed-forward network
DropoutRate float64 // dropout rate (unused at inference, reserved for training)
}
FTTransformerConfig holds the configuration for an FTTransformer.
type LoRAConfig ¶
type LoRAConfig struct {
Rank int // Low-rank dimension (typically 2-16 for tabular).
Alpha float32 // LoRA scaling factor.
TargetLayers []int // Hidden layer indices to adapt; nil means all hidden layers.
Epochs int
BatchSize int
LearningRate float64
WeightDecay float64
}
LoRAConfig holds configuration for LoRA fine-tuning of a tabular model.
type MetaLearnerConfig ¶
type MetaLearnerConfig struct {
HiddenDims []int
}
MetaLearnerConfig holds configuration for the ensemble's meta-learner MLP.
type Model ¶
type Model struct {
// contains filtered or unexported fields
}
Model is a configurable MLP for tabular prediction built on ztensor.
func Load ¶
func Load(path string, engine compute.Engine[float32], ops numeric.Arithmetic[float32]) (*Model, error)
Load reads a Model from the given path in the ZTAB binary format.
func MergeAdapter ¶
func MergeAdapter(base *BaseModel, adapter *Adapter, engine compute.Engine[float32]) (*Model, error)
MergeAdapter merges LoRA adapter weights into a BaseModel to produce a regular Model with no LoRA overhead during inference. The merged model produces identical predictions to running the base model with the adapter.
func NewModel ¶
func NewModel(config ModelConfig, engine compute.Engine[float32], ops numeric.Arithmetic[float32]) (*Model, error)
NewModel creates a new tabular Model with the given configuration.
func Train ¶
func Train(data [][]float64, labels []int, config TrainConfig, mc ModelConfig, engine compute.Engine[float32], ops numeric.Arithmetic[float32]) (*Model, error)
Train trains a tabular Model on the given data and labels using AdamW and cross-entropy loss. labels[i] must be in [0, 3) corresponding to Long, Short, Flat. It returns a trained Model ready for Predict.
type ModelConfig ¶
type ModelConfig struct {
InputDim int
HiddenDims []int
DropoutRate float64
Activation Activation
}
ModelConfig holds the configuration for a tabular Model.
type NormMode ¶
type NormMode int
NormMode selects the normalization applied after each residual block.
type PreTrainConfig ¶
type PreTrainConfig struct {
Epochs int
BatchSize int
LearningRate float64
WeightDecay float64
HiddenDims []int
DropoutRate float64
Activation Activation
}
PreTrainConfig holds hyperparameters for pre-training a base model on multi-source data.
type SAINT ¶
type SAINT struct {
// contains filtered or unexported fields
}
SAINT implements Self-Attention and Intersample Attention for tabular data.
func NewSAINT ¶
func NewSAINT(config SAINTConfig, engine compute.Engine[float32], ops numeric.Arithmetic[float32]) (*SAINT, error)
NewSAINT creates a new SAINT model with the given configuration.
func TrainSAINT ¶
func TrainSAINT(data [][]float64, labels []int, tc TrainConfig, sc SAINTConfig, engine compute.Engine[float32], ops numeric.Arithmetic[float32]) (*SAINT, error)
TrainSAINT trains a SAINT model on the given data and labels using SGD with manual gradient computation. labels[i] must be in [0, 3) corresponding to Long, Short, Flat. Returns a trained SAINT model.
type SAINTConfig ¶
type SAINTConfig struct {
NumFeatures int
DModel int
NHeads int
NLayers int
InterSampleAttention bool
}
SAINTConfig holds the configuration for a SAINT model.
type TabNet ¶
type TabNet struct {
// contains filtered or unexported fields
}
TabNet implements the TabNet architecture with sequential attention and sparsemax.
func NewTabNet ¶
func NewTabNet(config TabNetConfig, engine compute.Engine[float32], ops numeric.Arithmetic[float32]) (*TabNet, error)
NewTabNet creates a new TabNet model with the given configuration.
func (*TabNet) AttentionMasks ¶
func (t *TabNet) AttentionMasks() []*tensor.TensorNumeric[float32]
AttentionMasks returns the attention masks from the last forward pass. Each mask has shape [batch, inputDim] and represents the feature importance at each step. Returns nil if no forward pass has been run.
func (*TabNet) FeatureImportance ¶
FeatureImportance returns the aggregate feature importance from the last forward pass. The result is the sum of attention masks across all steps, shape [batch, inputDim]. Returns nil, error if no forward pass has been run.
type TabNetConfig ¶
type TabNetConfig struct {
InputDim int
OutputDim int
NSteps int
RelaxationFactor float64
SparsityCoefficient float64
FeatureTransformerDim int // hidden dim for feature transformer blocks
}
TabNetConfig holds the configuration for a TabNet model.
type TabResNet ¶
type TabResNet struct {
// contains filtered or unexported fields
}
TabResNet is an MLP with skip connections between hidden layers. It is a simple but surprisingly strong baseline for tabular data.
func NewTabResNet ¶
func NewTabResNet(config TabResNetConfig, engine compute.Engine[float32], ops numeric.Arithmetic[float32]) (*TabResNet, error)
NewTabResNet creates a new TabResNet model with the given configuration.
func (*TabResNet) Predict ¶
Predict runs inference on the given features and returns a Direction and confidence score.
func (*TabResNet) TabResNetParams ¶
func (m *TabResNet) TabResNetParams() (inputLayer mlpLayer, blocks []resBlock, head mlpLayer)
TabResNetParams returns model parameters as mlpLayer slices for training integration. Returns input layer, block layers (linear + optional shortcut), and head.
type TabResNetConfig ¶
type TabResNetConfig struct {
InputDim int
OutputDim int // number of output classes (default 3: Long/Short/Flat)
HiddenDims []int
DropoutRate float64
Activation Activation
Norm NormMode
}
TabResNetConfig holds the configuration for a TabResNet model.